Mercurial > libervia-backend
changeset 4130:02f0adc745c6
core: notifications implementation, first draft:
add a new table for notifications, and methods/bridge methods to manipulate them.
author | Goffi <goffi@goffi.org> |
---|---|
date | Mon, 16 Oct 2023 17:29:31 +0200 |
parents | 51744ad00a42 |
children | c38c33a44171 |
files | libervia/backend/bridge/bridge_constructor/bridge_template.ini libervia/backend/bridge/dbus_bridge.py libervia/backend/bridge/pb.py libervia/backend/core/main.py libervia/backend/memory/memory.py libervia/backend/memory/migration/versions/2ab01aa1f686_create_table_for_notifications.py libervia/backend/memory/sqla.py libervia/backend/memory/sqla_mapping.py libervia/frontends/bridge/dbus_bridge.py libervia/frontends/bridge/pb.py |
diffstat | 10 files changed, 1081 insertions(+), 363 deletions(-) [+] |
line wrap: on
line diff
--- a/libervia/backend/bridge/bridge_constructor/bridge_template.ini Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/backend/bridge/bridge_constructor/bridge_template.ini Mon Oct 16 17:29:31 2023 +0200 @@ -147,6 +147,31 @@ doc_param_2=value: New value doc_param_3=%(doc_profile)s +[notification_new] +type=signal +category=core +sig_in=sdssssbidss +doc=A new notification has been emitted +doc_param_0=id: unique identifier of the notification +doc_param_1=timestamp: notification creation time +doc_param_2=type: type of the notification +doc_param_3=body_plain: plain text body of the notification +doc_param_4=body_rich: rich text (XHTML) body of the notification. Optional. +doc_param_5=title: optional title of the notification +doc_param_6=requires_action: True if the notification requires user action (e.g. a dialog needs to be answered), False otherwise +doc_param_7=priority: how urgent the notification is, represented as an enumeration value +doc_param_8=expire_at: expiration timestamp for the notification. Optional. +doc_param_9=extra: additional serialized data associated with the notification +doc_param_10=profile: profile associated with the notification. C.PROF_KEY_ALL can be used for global notifications. + +[notification_deleted] +type=signal +category=core +sig_in=ss +doc=A new notification has been emitted +doc_param_0=id: id of the deleted notification +doc_param_1=profile: profile of the deleted application, or C.PROF_KEY_ALL for a global notification + [progress_started] type=signal category=core @@ -1024,3 +1049,49 @@ doc_param_4=profile_key: either profile_key or empty string to use common cache this parameter is used only when dest is empty doc_return=path to the new converted image + +[notification_add] +type=method +category=core +sig_in=ssssbbsdss +sig_out= +doc=Add a new notification +doc_param_0=type_: Notification type +doc_param_1=body_plain: Plain text body of the notification +doc_param_2=body_rich: Rich text body of the notification (optional, can be empty string for default) +doc_param_3=title: Title of the notification (optional, can be empty string for default) +doc_param_4=is_global: True if the notification is for all profiles +doc_param_5=requires_action: Indicates if the notification requires action +doc_param_7=priority: Priority level of the notification (e.g. MEDIUM, HIGH, etc.) +doc_param_8=expire_at: Expiration timestamp for the notification (optional, can be 0 for none) +doc_param_9=extra: Additional details for the notification as a dictionary (optional, can be empty dictionary) +doc_param_10=%(doc_profile_key)s: Profile key (use "@ALL@" for all profiles) + +[notifications_get] +type=method +category=core +sig_in=ss +sig_out=s +doc=Retrieve notifications based on provided filters +doc_param_0=filters: a dictionary with filter criteria for notifications retrieval +doc_param_1=%(doc_profile_key)s or @ALL@ for all profiles +doc_return=list of Notification objects. The exact structure will depend on your Notification class. + +[notification_delete] +type=method +category=core +sig_in=sbs +sig_out= +doc=Delete a notification +doc_param_0=id_: ID of the notification to delete +doc_param_1=is_global: true if the notification is a global one +doc_param_2=profile_key: Profile key (use "@ALL@" for all profiles) + +[notifications_expired_clean] +type=method +category=core +sig_in=ds +sig_out= +doc=Cleans expired notifications and older profile-specific notifications +doc_param_0=limit_timestamp: Timestamp limit for older notifications. If -1.0, only truly expired notifications are removed. +doc_param_1=profile_key: Profile key (use "@NONE@" to indicate no specific profile, otherwise only notification for given profile will be expired, in addition to truly expired notifications).
--- a/libervia/backend/bridge/dbus_bridge.py Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/backend/bridge/dbus_bridge.py Mon Oct 16 17:29:31 2023 +0200 @@ -124,6 +124,10 @@ Method('message_encryption_stop', arguments='ss', returns=''), Method('message_send', arguments='sa{ss}a{ss}sss', returns=''), Method('namespaces_get', arguments='', returns='a{ss}'), + Method('notification_add', arguments='ssssbbsdss', returns=''), + Method('notification_delete', arguments='sbs', returns=''), + Method('notifications_expired_clean', arguments='ds', returns=''), + Method('notifications_get', arguments='ss', returns='s'), Method('param_get_a', arguments='ssss', returns='s'), Method('param_get_a_async', arguments='sssis', returns='s'), Method('param_set', arguments='sssis', returns=''), @@ -164,6 +168,8 @@ Signal('message_encryption_started', 'sss'), Signal('message_encryption_stopped', 'sa{ss}s'), Signal('message_new', 'sdssa{ss}a{ss}sss'), + Signal('notification_deleted', 'ss'), + Signal('notification_new', 'sdssssbidss'), Signal('param_update', 'ssss'), Signal('presence_update', 'ssia{ss}s'), Signal('progress_error', 'sss'), @@ -304,6 +310,18 @@ def dbus_namespaces_get(self, ): return self._callback("namespaces_get", ) + def dbus_notification_add(self, type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra): + return self._callback("notification_add", type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra) + + def dbus_notification_delete(self, id_, is_global, profile_key): + return self._callback("notification_delete", id_, is_global, profile_key) + + def dbus_notifications_expired_clean(self, limit_timestamp, profile_key): + return self._callback("notifications_expired_clean", limit_timestamp, profile_key) + + def dbus_notifications_get(self, filters, profile_key): + return self._callback("notifications_get", filters, profile_key) + def dbus_param_get_a(self, name, category, attribute="value", profile_key="@DEFAULT@"): return self._callback("param_get_a", name, category, attribute, profile_key) @@ -447,6 +465,12 @@ def message_new(self, uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile): self._obj.emitSignal("message_new", uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile) + def notification_deleted(self, id, profile): + self._obj.emitSignal("notification_deleted", id, profile) + + def notification_new(self, id, timestamp, type, body_plain, body_rich, title, requires_action, priority, expire_at, extra, profile): + self._obj.emitSignal("notification_new", id, timestamp, type, body_plain, body_rich, title, requires_action, priority, expire_at, extra, profile) + def param_update(self, name, value, category, profile): self._obj.emitSignal("param_update", name, value, category, profile)
--- a/libervia/backend/bridge/pb.py Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/backend/bridge/pb.py Mon Oct 16 17:29:31 2023 +0200 @@ -193,6 +193,12 @@ def message_new(self, uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile): self.send_signal("message_new", uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile) + def notification_deleted(self, id, profile): + self.send_signal("notification_deleted", id, profile) + + def notification_new(self, id, timestamp, type, body_plain, body_rich, title, requires_action, priority, expire_at, extra, profile): + self.send_signal("notification_new", id, timestamp, type, body_plain, body_rich, title, requires_action, priority, expire_at, extra, profile) + def param_update(self, name, value, category, profile): self.send_signal("param_update", name, value, category, profile)
--- a/libervia/backend/core/main.py Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/backend/core/main.py Mon Oct 16 17:29:31 2023 +0200 @@ -221,6 +221,10 @@ self.bridge.register_method("image_resize", self._image_resize) self.bridge.register_method("image_generate_preview", self._image_generate_preview) self.bridge.register_method("image_convert", self._image_convert) + self.bridge.register_method("notification_add", self.memory._add_notification) + self.bridge.register_method("notifications_get", self.memory._get_notifications) + self.bridge.register_method("notification_delete", self.memory._delete_notification) + self.bridge.register_method("notifications_expired_clean", self.memory._notifications_expired_clean) await self.memory.initialise()
--- a/libervia/backend/memory/memory.py Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/backend/memory/memory.py Mon Oct 16 17:29:31 2023 +0200 @@ -16,29 +16,39 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. -import os.path +from collections import namedtuple import copy -import shortuuid -import mimetypes -import time +from dataclasses import dataclass from functools import partial -from typing import Optional, Tuple, Dict +import mimetypes +import os.path from pathlib import Path +import time +from typing import Dict, Optional, Tuple from uuid import uuid4 -from collections import namedtuple + +import shortuuid +from twisted.internet import defer, error, reactor from twisted.python import failure -from twisted.internet import defer, reactor, error from twisted.words.protocols.jabber import jid + +from libervia.backend.core import exceptions +from libervia.backend.core.constants import Const as C +from libervia.backend.core.core_types import SatXMPPEntity from libervia.backend.core.i18n import _ from libervia.backend.core.log import getLogger -from libervia.backend.core import exceptions -from libervia.backend.core.constants import Const as C -from libervia.backend.memory.sqla import Storage -from libervia.backend.memory.persistent import PersistentDict -from libervia.backend.memory.params import Params -from libervia.backend.memory.disco import Discovery from libervia.backend.memory.crypto import BlockCipher from libervia.backend.memory.crypto import PasswordHasher +from libervia.backend.memory.disco import Discovery +from libervia.backend.memory.params import Params +from libervia.backend.memory.persistent import PersistentDict +from libervia.backend.memory.sqla import ( + Notification, + NotificationPriority, + NotificationStatus, + NotificationType, + Storage, +) from libervia.backend.tools import config as tools_config from libervia.backend.tools.common import data_format from libervia.backend.tools.common import regex @@ -1848,6 +1858,180 @@ *(regex.path_escape(a) for a in args) ) + ## Notifications ## + + + def _add_notification( + self, + type_: str, + body_plain: str, + body_rich: str, + title: str, + is_global: bool, + requires_action: bool, + priority: str, + expire_at: float, + extra_s: str, + profile_key: str + ) -> defer.Deferred: + client = self.host.get_client(profile_key) + + if not client.is_admin: + raise exceptions.PermissionError("Only admins can add a notification") + + try: + notification_type = NotificationType[type_] + notification_priority = NotificationPriority[priority] + except KeyError as e: + raise exceptions.DataError( + f"invalid notification type or priority data: {e}" + ) + + return defer.ensureDeferred( + self.add_notification( + client, + notification_type, + body_plain, + body_rich or None, + title or None, + is_global, + requires_action, + notification_priority, + expire_at or None, + data_format.deserialise(extra_s) + ) + ) + + async def add_notification( + self, + client: SatXMPPEntity, + type_: NotificationType, + body_plain: str, + body_rich: Optional[str] = None, + title: Optional[str] = None, + is_global: bool = False, + requires_action: bool = False, + priority: NotificationPriority = NotificationPriority.MEDIUM, + expire_at: Optional[float] = None, + extra: Optional[dict] = None, + ) -> None: + """Create and broadcast a new notification. + + @param client: client associated with the notification. If None, the notification + will be global (i.e. for all profiles). + @param type_: type of the notification. + @param body_plain: plain text body. + @param body_rich: rich text (XHTML) body. + @param title: optional title. + @param is_global: True if the notification is for all profiles. + @param requires_action: True if the notification requires user action (e.g. a + dialog need to be answered). + @priority: how urgent the notification is + @param expire_at: expiration timestamp for the notification. + @param extra: additional data. + """ + notification = await self.storage.add_notification( + None if is_global else client, type_, body_plain, body_rich, title, + requires_action, priority, expire_at, extra + ) + self.host.bridge.notification_new( + str(notification.id), + notification.timestamp, + type_.value, + body_plain, + body_rich or '', + title or '', + requires_action, + priority.value, + expire_at or 0, + data_format.serialise(extra) if extra else '', + C.PROF_KEY_ALL if is_global else client.profile + ) + + def _get_notifications(self, filters_s: str, profile_key: str) -> defer.Deferred: + """Fetch notifications for bridge with given filters and profile key. + + @param filters_s: serialized filter conditions. Keys can be: + :type_ (str): + Filter by type of the notification. + :status (str): + Filter by status of the notification. + :requires_action (bool): + Filter by notifications that require user action. + :min_priority (str): + Filter by minimum priority value. + @param profile_key: key of the profile to fetch notifications for. + @return: Deferred which fires with a list of serialised notifications. + """ + client = self.host.get_client(profile_key) + + filters = data_format.deserialise(filters_s) + + try: + if 'type' in filters: + filters['type_'] = NotificationType[filters.pop('type')] + if 'status' in filters: + filters['status'] = NotificationStatus[filters['status']] + if 'min_priority' in filters: + filters['min_priority'] = NotificationPriority[filters['min_priority']].value + except KeyError as e: + raise exceptions.DataError(f"invalid filter data: {e}") + + d = defer.ensureDeferred(self.storage.get_notifications(client, **filters)) + d.addCallback( + lambda notifications: data_format.serialise( + [notification.serialise() for notification in notifications] + ) + ) + return d + + def _delete_notification( + self, + id_: str, + is_global: bool, + profile_key: str + ) -> defer.Deferred: + client = self.host.get_client(profile_key) + if is_global and not client.is_admin: + raise exceptions.PermissionError( + "Only admins can delete global notifications" + ) + return defer.ensureDeferred(self.delete_notification(client, id_, is_global)) + + async def delete_notification( + self, + client: SatXMPPEntity, + id_: str, + is_global: bool=False + ) -> None: + """Delete a notification + + the notification must be from the requesting profile. + @param id_: ID of the notification + is_global: if True, a global notification will be removed. + """ + await self.storage.delete_notification(None if is_global else client, id_) + self.host.bridge.notification_deleted( + id_, + C.PROF_KEY_ALL if is_global else client.profile + ) + + def _notifications_expired_clean( + self, limit_timestamp: float, profile_key: str + ) -> defer.Deferred: + if profile_key == C.PROF_KEY_NONE: + client = None + else: + client = self.host.get_client(profile_key) + + return defer.ensureDeferred( + self.storage.clean_expired_notifications( + client, + None if limit_timestamp == -1.0 else limit_timestamp + ) + ) + + ## Misc ## def is_entity_available(self, client, entity_jid):
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libervia/backend/memory/migration/versions/2ab01aa1f686_create_table_for_notifications.py Mon Oct 16 17:29:31 2023 +0200 @@ -0,0 +1,46 @@ +"""create table for notifications + +Revision ID: 2ab01aa1f686 +Revises: 4b002773cf92 +Create Date: 2023-10-16 12:11:43.507295 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '2ab01aa1f686' +down_revision = '4b002773cf92' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table('notifications', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('timestamp', sa.Float(), nullable=False), + sa.Column('expire_at', sa.Float(), nullable=True), + sa.Column('profile_id', sa.Integer(), nullable=True), + sa.Column('type', sa.Enum('chat', 'blog', 'calendar', 'file', 'call', 'service', 'other', name='notificationtype'), nullable=False), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('body_plain', sa.Text(), nullable=False), + sa.Column('body_rich', sa.Text(), nullable=True), + sa.Column('requires_action', sa.Boolean(), nullable=True), + sa.Column('priority', sa.Integer(), nullable=True), + sa.Column('extra_data', sa.JSON(), nullable=True), + sa.Column('status', sa.Enum('new', 'read', name='notificationstatus'), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], name=op.f('fk_notifications_profile_id_profiles'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_notifications')) + ) + with op.batch_alter_table('notifications', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_notifications_profile_id'), ['profile_id'], unique=False) + batch_op.create_index('notifications_profile_id_status', ['profile_id', 'status'], unique=False) + + +def downgrade(): + with op.batch_alter_table('notifications', schema=None) as batch_op: + batch_op.drop_index('notifications_profile_id_status') + batch_op.drop_index(batch_op.f('ix_notifications_profile_id')) + + op.drop_table('notifications')
--- a/libervia/backend/memory/sqla.py Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/backend/memory/sqla.py Mon Oct 16 17:29:31 2023 +0200 @@ -62,6 +62,10 @@ History, Message, NOT_IN_EXTRA, + Notification, + NotificationPriority, + NotificationStatus, + NotificationType, ParamGen, ParamInd, PrivateGen, @@ -74,6 +78,8 @@ Subject, SyncState, Thread, + get_profile_by_id, + profiles, ) from libervia.backend.tools.common import uri from libervia.backend.tools.utils import aio, as_future @@ -117,13 +123,15 @@ def __init__(self): self.initialized = defer.Deferred() # we keep cache for the profiles (key: profile name, value: profile id) - # profile id to name - self.profiles: Dict[str, int] = {} # profile id to component entry point self.components: Dict[int, str] = {} + @property + def profiles(self): + return profiles + def get_profile_by_id(self, profile_id): - return self.profiles.get(profile_id) + return get_profile_by_id(profile_id) async def migrate_apply(self, *args: str, log_output: bool = False) -> None: """Do a migration command @@ -139,16 +147,18 @@ """ stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) proc = await asyncio.create_subprocess_exec( - sys.executable, "-m", "alembic", *args, - stdout=stdout, stderr=stderr, cwd=migration_path + sys.executable, + "-m", + "alembic", + *args, + stdout=stdout, + stderr=stderr, + cwd=migration_path, ) log_out, log_err = await proc.communicate() if proc.returncode != 0: - msg = _( - "Can't {operation} database (exit code {exit_code})" - ).format( - operation=args[0], - exit_code=proc.returncode + msg = _("Can't {operation} database (exit code {exit_code})").format( + operation=args[0], exit_code=proc.returncode ) if log_out or log_err: msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" @@ -216,9 +226,7 @@ async with engine.connect() as conn: await conn.run_sync(self._sqlite_set_journal_mode_wal) - self.session = sessionmaker( - engine, expire_on_commit=False, class_=AsyncSession - ) + self.session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) async with self.session() as session: result = await session.execute(select(Profile)) @@ -239,9 +247,9 @@ db_cls: DeclarativeMeta, db_id_col: Mapped, id_value: Any, - joined_loads = None + joined_loads=None, ) -> Optional[DeclarativeMeta]: - stmt = select(db_cls).where(db_id_col==id_value) + stmt = select(db_cls).where(db_id_col == id_value) if client is not None: stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) if joined_loads is not None: @@ -264,7 +272,7 @@ async def delete( self, db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]], - session_add: Optional[List[DeclarativeMeta]] = None + session_add: Optional[List[DeclarativeMeta]] = None, ) -> None: """Delete an object from database @@ -289,7 +297,7 @@ ## Profiles def get_profiles_list(self) -> List[str]: - """"Return list of all registered profiles""" + """ "Return list of all registered profiles""" return list(self.profiles.keys()) def has_profile(self, profile_name: str) -> bool: @@ -309,7 +317,9 @@ try: return self.components[self.profiles[profile_name]] except KeyError: - raise exceptions.NotFound("the requested profile doesn't exists or is not a component") + raise exceptions.NotFound( + "the requested profile doesn't exists or is not a component" + ) @aio async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None: @@ -344,7 +354,7 @@ del self.profiles[profile.name] if profile.id in self.components: del self.components[profile.id] - log.info(_("Profile {name!r} deleted").format(name = name)) + log.info(_("Profile {name!r} deleted").format(name=name)) ## Params @@ -376,7 +386,9 @@ params_ind[(p.category, p.name)] = p.value @aio - async def get_ind_param(self, category: str, name: str, profile: str) -> Optional[str]: + async def get_ind_param( + self, category: str, name: str, profile: str + ) -> Optional[str]: """Ask database for the value of one specific individual parameter @param category: category of the parameter @@ -385,11 +397,8 @@ """ async with self.session() as session: result = await session.execute( - select(ParamInd.value) - .filter_by( - category=category, - name=name, - profile_id=self.profiles[profile] + select(ParamInd.value).filter_by( + category=category, name=name, profile_id=self.profiles[profile] ) ) return result.scalar_one_or_none() @@ -405,10 +414,7 @@ async with self.session() as session: result = await session.execute( select(ParamInd) - .filter_by( - category=category, - name=name - ) + .filter_by(category=category, name=name) .options(subqueryload(ParamInd.profile)) ) return {param.profile.name: param.value for param in result.scalars()} @@ -422,26 +428,20 @@ @param value: value to set """ async with self.session() as session: - stmt = insert(ParamGen).values( - category=category, - name=name, - value=value - ).on_conflict_do_update( - index_elements=(ParamGen.category, ParamGen.name), - set_={ - ParamGen.value: value - } + stmt = ( + insert(ParamGen) + .values(category=category, name=name, value=value) + .on_conflict_do_update( + index_elements=(ParamGen.category, ParamGen.name), + set_={ParamGen.value: value}, + ) ) await session.execute(stmt) await session.commit() @aio async def set_ind_param( - self, - category:str, - name: str, - value: Optional[str], - profile: str + self, category: str, name: str, value: Optional[str], profile: str ) -> None: """Save the individual parameters in database @@ -451,16 +451,22 @@ @param profile: a profile which *must* exist """ async with self.session() as session: - stmt = insert(ParamInd).values( - category=category, - name=name, - profile_id=self.profiles[profile], - value=value - ).on_conflict_do_update( - index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), - set_={ - ParamInd.value: value - } + stmt = ( + insert(ParamInd) + .values( + category=category, + name=name, + profile_id=self.profiles[profile], + value=value, + ) + .on_conflict_do_update( + index_elements=( + ParamInd.category, + ParamInd.name, + ParamInd.profile_id, + ), + set_={ParamInd.value: value}, + ) ) await session.execute(stmt) await session.commit() @@ -474,13 +480,11 @@ if jid_.resource: if dest: return and_( - History.dest == jid_.userhost(), - History.dest_res == jid_.resource + History.dest == jid_.userhost(), History.dest_res == jid_.resource ) else: return and_( - History.source == jid_.userhost(), - History.source_res == jid_.resource + History.source == jid_.userhost(), History.source_res == jid_.resource ) else: if dest: @@ -497,9 +501,7 @@ between: bool = True, filters: Optional[Dict[str, str]] = None, profile: Optional[str] = None, - ) -> List[Tuple[ - str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] - ]: + ) -> List[Tuple[str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]]: """Retrieve messages in history @param from_jid: source JID (full, or bare for catchall) @@ -523,9 +525,7 @@ stmt = ( select(History) - .filter_by( - profile_id=self.profiles[profile] - ) + .filter_by(profile_id=self.profiles[profile]) .outerjoin(History.messages) .outerjoin(History.subjects) .outerjoin(History.thread) @@ -540,11 +540,10 @@ # order when returning the result. We use DESC here so LIMIT keep the last # messages History.timestamp.desc(), - History.received_timestamp.desc() + History.received_timestamp.desc(), ) ) - if not from_jid and not to_jid: # no jid specified, we want all one2one communications pass @@ -554,10 +553,7 @@ # from or to this jid jid_ = from_jid or to_jid stmt = stmt.where( - or_( - self._jid_filter(jid_), - self._jid_filter(jid_, dest=True) - ) + or_(self._jid_filter(jid_), self._jid_filter(jid_, dest=True)) ) else: # we have 2 jids specified, we check all communications between @@ -571,7 +567,7 @@ and_( self._jid_filter(to_jid), self._jid_filter(from_jid, dest=True), - ) + ), ) ) else: @@ -583,44 +579,44 @@ stmt = stmt.where(self._jid_filter(to_jid, dest=True)) if filters: - if 'timestamp_start' in filters: - stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) - if 'before_uid' in filters: + if "timestamp_start" in filters: + stmt = stmt.where(History.timestamp >= float(filters["timestamp_start"])) + if "before_uid" in filters: # orignially this query was using SQLITE's rowid. This has been changed # to use coalesce(received_timestamp, timestamp) to be SQL engine independant stmt = stmt.where( - coalesce( - History.received_timestamp, - History.timestamp - ) < ( - select(coalesce(History.received_timestamp, History.timestamp)) - .filter_by(uid=filters["before_uid"]) + coalesce(History.received_timestamp, History.timestamp) + < ( + select( + coalesce(History.received_timestamp, History.timestamp) + ).filter_by(uid=filters["before_uid"]) ).scalar_subquery() ) - if 'body' in filters: + if "body" in filters: # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) - if 'search' in filters: + if "search" in filters: search_term = f"%{filters['search']}%" - stmt = stmt.where(or_( - Message.message.like(search_term), - History.source_res.like(search_term) - )) - if 'types' in filters: - types = filters['types'].split() + stmt = stmt.where( + or_( + Message.message.like(search_term), + History.source_res.like(search_term), + ) + ) + if "types" in filters: + types = filters["types"].split() stmt = stmt.where(History.type.in_(types)) - if 'not_types' in filters: - types = filters['not_types'].split() + if "not_types" in filters: + types = filters["not_types"].split() stmt = stmt.where(History.type.not_in(types)) - if 'last_stanza_id' in filters: + if "last_stanza_id" in filters: # this request get the last message with a "stanza_id" that we # have in history. This is mainly used to retrieve messages sent # while we were offline, using MAM (XEP-0313). - if (filters['last_stanza_id'] is not True - or limit != 1): + if filters["last_stanza_id"] is not True or limit != 1: raise ValueError("Unexpected values for last_stanza_id filter") stmt = stmt.where(History.stanza_id.is_not(None)) - if 'origin_id' in filters: + if "origin_id" in filters: stmt = stmt.where(History.origin_id == filters["origin_id"]) if limit is not None: @@ -640,34 +636,40 @@ @param data: message data as build by SatMessageProtocol.onMessage """ extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} - messages = [Message(message=mess, language=lang) - for lang, mess in data["message"].items()] - subjects = [Subject(subject=mess, language=lang) - for lang, mess in data["subject"].items()] + messages = [ + Message(message=mess, language=lang) for lang, mess in data["message"].items() + ] + subjects = [ + Subject(subject=mess, language=lang) for lang, mess in data["subject"].items() + ] if "thread" in data["extra"]: - thread = Thread(thread_id=data["extra"]["thread"], - parent_id=data["extra"].get["thread_parent"]) + thread = Thread( + thread_id=data["extra"]["thread"], + parent_id=data["extra"].get["thread_parent"], + ) else: thread = None try: async with self.session() as session: async with session.begin(): - session.add(History( - uid=data["uid"], - origin_id=data["extra"].get("origin_id"), - stanza_id=data["extra"].get("stanza_id"), - update_uid=data["extra"].get("update_uid"), - profile_id=self.profiles[profile], - source_jid=data["from"], - dest_jid=data["to"], - timestamp=data["timestamp"], - received_timestamp=data.get("received_timestamp"), - type=data["type"], - extra=extra, - messages=messages, - subjects=subjects, - thread=thread, - )) + session.add( + History( + uid=data["uid"], + origin_id=data["extra"].get("origin_id"), + stanza_id=data["extra"].get("stanza_id"), + update_uid=data["extra"].get("update_uid"), + profile_id=self.profiles[profile], + source_jid=data["from"], + dest_jid=data["to"], + timestamp=data["timestamp"], + received_timestamp=data.get("received_timestamp"), + type=data["type"], + extra=extra, + messages=messages, + subjects=subjects, + thread=thread, + ) + ) except IntegrityError as e: if "unique" in str(e.orig).lower(): log.debug( @@ -689,14 +691,13 @@ else: return PrivateIndBin if binary else PrivateInd - @aio async def get_privates( self, - namespace:str, + namespace: str, keys: Optional[Iterable[str]] = None, binary: bool = False, - profile: Optional[str] = None + profile: Optional[str] = None, ) -> Dict[str, Any]: """Get private value(s) from databases @@ -728,10 +729,10 @@ async def set_private_value( self, namespace: str, - key:str, + key: str, value: Any, binary: bool = False, - profile: Optional[str] = None + profile: Optional[str] = None, ) -> None: """Set a private value in database @@ -745,11 +746,7 @@ """ cls = self._get_private_class(binary, profile) - values = { - "namespace": namespace, - "key": key, - "value": value - } + values = {"namespace": namespace, "key": key, "value": value} index_elements = [cls.namespace, cls.key] if profile is not None: @@ -758,11 +755,10 @@ async with self.session() as session: await session.execute( - insert(cls).values(**values).on_conflict_do_update( - index_elements=index_elements, - set_={ - cls.value: value - } + insert(cls) + .values(**values) + .on_conflict_do_update( + index_elements=index_elements, set_={cls.value: value} ) ) await session.commit() @@ -773,7 +769,7 @@ namespace: str, key: str, binary: bool = False, - profile: Optional[str] = None + profile: Optional[str] = None, ) -> None: """Delete private value from database @@ -796,10 +792,7 @@ @aio async def del_private_namespace( - self, - namespace: str, - binary: bool = False, - profile: Optional[str] = None + self, namespace: str, binary: bool = False, profile: Optional[str] = None ) -> None: """Delete all data from a private namespace @@ -825,7 +818,7 @@ self, client: Optional[SatXMPPEntity], file_id: Optional[str] = None, - version: Optional[str] = '', + version: Optional[str] = "", parent: Optional[str] = None, type_: Optional[str] = None, file_hash: Optional[str] = None, @@ -837,7 +830,7 @@ owner: Optional[jid.JID] = None, access: Optional[dict] = None, projection: Optional[List[str]] = None, - unique: bool = False + unique: bool = False, ) -> List[dict]: """Retrieve files with with given filters @@ -857,9 +850,23 @@ """ if projection is None: projection = [ - 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', - 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', - 'created', 'modified', 'owner', 'access', 'extra' + "id", + "version", + "parent", + "type", + "file_hash", + "hash_algo", + "name", + "size", + "namespace", + "media_type", + "media_subtype", + "public_id", + "created", + "modified", + "owner", + "access", + "extra", ] stmt = select(*[getattr(File, f) for f in projection]) @@ -891,7 +898,7 @@ if namespace is not None: stmt = stmt.filter_by(namespace=namespace) if mime_type is not None: - if '/' in mime_type: + if "/" in mime_type: media_type, media_subtype = mime_type.split("/", 1) stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) else: @@ -901,8 +908,8 @@ if owner is not None: stmt = stmt.filter_by(owner=owner) if access is not None: - raise NotImplementedError('Access check is not implemented yet') - # a JSON comparison is needed here + raise NotImplementedError("Access check is not implemented yet") + # a JSON comparison is needed here async with self.session() as session: result = await session.execute(stmt) @@ -928,7 +935,7 @@ modified: Optional[float] = None, owner: Optional[jid.JID] = None, access: Optional[dict] = None, - extra: Optional[dict] = None + extra: Optional[dict] = None, ) -> None: """Set a file metadata @@ -958,33 +965,35 @@ """ if mime_type is None: media_type = media_subtype = None - elif '/' in mime_type: - media_type, media_subtype = mime_type.split('/', 1) + elif "/" in mime_type: + media_type, media_subtype = mime_type.split("/", 1) else: media_type, media_subtype = mime_type, None async with self.session() as session: async with session.begin(): - session.add(File( - id=file_id, - version=version.strip(), - parent=parent, - type=type_, - file_hash=file_hash, - hash_algo=hash_algo, - name=name, - size=size, - namespace=namespace, - media_type=media_type, - media_subtype=media_subtype, - public_id=public_id, - created=time.time() if created is None else created, - modified=modified, - owner=owner, - access=access, - extra=extra, - profile_id=self.profiles[client.profile] - )) + session.add( + File( + id=file_id, + version=version.strip(), + parent=parent, + type=type_, + file_hash=file_hash, + hash_algo=hash_algo, + name=name, + size=size, + namespace=namespace, + media_type=media_type, + media_subtype=media_subtype, + public_id=public_id, + created=time.time() if created is None else created, + modified=modified, + owner=owner, + access=access, + extra=extra, + profile_id=self.profiles[client.profile], + ) + ) @aio async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int: @@ -993,8 +1002,9 @@ select(sum_(File.size)).filter_by( owner=owner, type=C.FILE_TYPE_FILE, - profile_id=self.profiles[client.profile] - )) + profile_id=self.profiles[client.profile], + ) + ) return result.scalar_one_or_none() or 0 @aio @@ -1011,10 +1021,7 @@ @aio async def file_update( - self, - file_id: str, - column: str, - update_cb: Callable[[dict], None] + self, file_id: str, column: str, update_cb: Callable[[dict], None] ) -> None: """Update a column value using a method to avoid race conditions @@ -1029,16 +1036,16 @@ it get the deserialized data (i.e. a Python object) directly @raise exceptions.NotFound: there is not file with this id """ - if column not in ('access', 'extra'): - raise exceptions.InternalError('bad column name') + if column not in ("access", "extra"): + raise exceptions.InternalError("bad column name") orm_col = getattr(File, column) for i in range(5): async with self.session() as session: try: - value = (await session.execute( - select(orm_col).filter_by(id=file_id) - )).scalar_one() + value = ( + await session.execute(select(orm_col).filter_by(id=file_id)) + ).scalar_one() except NoResultFound: raise exceptions.NotFound old_value = copy.deepcopy(value) @@ -1057,14 +1064,17 @@ break log.warning( - _("table not updated, probably due to race condition, trying again " - "({tries})").format(tries=i+1) + _( + "table not updated, probably due to race condition, trying again " + "({tries})" + ).format(tries=i + 1) ) else: raise exceptions.DatabaseError( - _("Can't update file {file_id} due to race condition") - .format(file_id=file_id) + _("Can't update file {file_id} due to race condition").format( + file_id=file_id + ) ) @aio @@ -1076,7 +1086,7 @@ with_items: bool = False, with_subscriptions: bool = False, create: bool = False, - create_kwargs: Optional[dict] = None + create_kwargs: Optional[dict] = None, ) -> Optional[PubsubNode]: """Retrieve a PubsubNode from DB @@ -1089,22 +1099,15 @@ needs to be created. """ async with self.session() as session: - stmt = ( - select(PubsubNode) - .filter_by( - service=service, - name=name, - profile_id=self.profiles[client.profile], - ) + stmt = select(PubsubNode).filter_by( + service=service, + name=name, + profile_id=self.profiles[client.profile], ) if with_items: - stmt = stmt.options( - joinedload(PubsubNode.items) - ) + stmt = stmt.options(joinedload(PubsubNode.items)) if with_subscriptions: - stmt = stmt.options( - joinedload(PubsubNode.subscriptions) - ) + stmt = stmt.options(joinedload(PubsubNode.subscriptions)) result = await session.execute(stmt) ret = result.unique().scalar_one_or_none() if ret is None and create: @@ -1112,21 +1115,23 @@ if create_kwargs is None: create_kwargs = {} try: - return await as_future(self.set_pubsub_node( - client, service, name, **create_kwargs - )) + return await as_future( + self.set_pubsub_node(client, service, name, **create_kwargs) + ) except IntegrityError as e: if "unique" in str(e.orig).lower(): # the node may already exist, if it has been created just after # get_pubsub_node above log.debug("ignoring UNIQUE constraint error") - cached_node = await as_future(self.get_pubsub_node( - client, - service, - name, - with_items=with_items, - with_subscriptions=with_subscriptions - )) + cached_node = await as_future( + self.get_pubsub_node( + client, + service, + name, + with_items=with_items, + with_subscriptions=with_subscriptions, + ) + ) else: raise e else: @@ -1160,9 +1165,7 @@ @aio async def update_pubsub_node_sync_state( - self, - node: PubsubNode, - state: SyncState + self, node: PubsubNode, state: SyncState ) -> None: async with self.session() as session: async with session.begin(): @@ -1180,7 +1183,7 @@ self, profiles: Optional[List[str]], services: Optional[List[jid.JID]], - names: Optional[List[str]] + names: Optional[List[str]], ) -> None: """Delete items cached for a node @@ -1194,9 +1197,7 @@ stmt = delete(PubsubNode) if profiles is not None: stmt = stmt.where( - PubsubNode.profile.in_( - [self.profiles[p] for p in profiles] - ) + PubsubNode.profile.in_([self.profiles[p] for p in profiles]) ) if services is not None: stmt = stmt.where(PubsubNode.service.in_(services)) @@ -1223,27 +1224,29 @@ async with session.begin(): for idx, item in enumerate(items): parsed = parsed_items[idx] if parsed_items else None - stmt = insert(PubsubItem).values( - node_id = node.id, - name = item["id"], - data = item, - parsed = parsed, - ).on_conflict_do_update( - index_elements=(PubsubItem.node_id, PubsubItem.name), - set_={ - PubsubItem.data: item, - PubsubItem.parsed: parsed, - PubsubItem.updated: now() - } + stmt = ( + insert(PubsubItem) + .values( + node_id=node.id, + name=item["id"], + data=item, + parsed=parsed, + ) + .on_conflict_do_update( + index_elements=(PubsubItem.node_id, PubsubItem.name), + set_={ + PubsubItem.data: item, + PubsubItem.parsed: parsed, + PubsubItem.updated: now(), + }, + ) ) await session.execute(stmt) await session.commit() @aio async def delete_pubsub_items( - self, - node: PubsubNode, - items_names: Optional[List[str]] = None + self, node: PubsubNode, items_names: Optional[List[str]] = None ) -> None: """Delete items cached for a node @@ -1296,10 +1299,8 @@ if values is None: continue sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) - stmt = ( - stmt - .where(PubsubItem.node_id.in_(sub_q)) - .execution_options(synchronize_session=False) + stmt = stmt.where(PubsubItem.node_id.in_(sub_q)).execution_options( + synchronize_session=False ) if created_before is not None: @@ -1367,11 +1368,7 @@ use_rsm = True from_index = 0 - stmt = ( - select(PubsubItem) - .filter_by(node_id=node.id) - .limit(max_items) - ) + stmt = select(PubsubItem).filter_by(node_id=node.id).limit(max_items) if item_ids is not None: stmt = stmt.where(PubsubItem.name.in_(item_ids)) @@ -1402,7 +1399,7 @@ PubsubItem.id, PubsubItem.name, # row_number starts from 1, but RSM index must start from 0 - (func.row_number().over(order_by=order)-1).label("item_index") + (func.row_number().over(order_by=order) - 1).label("item_index"), ).filter_by(node_id=node.id) row_num_cte = row_num_q.cte() @@ -1412,8 +1409,9 @@ # we need to use row number item_name = before or after row_num_limit_q = ( - select(row_num_cte.c.item_index) - .where(row_num_cte.c.name==item_name) + select(row_num_cte.c.item_index).where( + row_num_cte.c.name == item_name + ) ).scalar_subquery() stmt = ( @@ -1422,22 +1420,16 @@ .limit(max_items) ) if before: - stmt = ( - stmt - .where(row_num_cte.c.item_index<row_num_limit_q) - .order_by(row_num_cte.c.item_index.desc()) - ) + stmt = stmt.where( + row_num_cte.c.item_index < row_num_limit_q + ).order_by(row_num_cte.c.item_index.desc()) elif after: - stmt = ( - stmt - .where(row_num_cte.c.item_index>row_num_limit_q) - .order_by(row_num_cte.c.item_index.asc()) - ) + stmt = stmt.where( + row_num_cte.c.item_index > row_num_limit_q + ).order_by(row_num_cte.c.item_index.asc()) else: - stmt = ( - stmt - .where(row_num_cte.c.item_index>=from_index) - .order_by(row_num_cte.c.item_index.asc()) + stmt = stmt.where(row_num_cte.c.item_index >= from_index).order_by( + row_num_cte.c.item_index.asc() ) # from_index is used @@ -1468,12 +1460,14 @@ last = result[-1][1].name metadata["rsm"] = { - k: v for k, v in { + k: v + for k, v in { "index": index, "count": rows_count, "first": first, "last": last, - }.items() if v is not None + }.items() + if v is not None } metadata["complete"] = (index or 0) + len(result) == rows_count @@ -1487,10 +1481,7 @@ result.reverse() return result, metadata - def _get_sqlite_path( - self, - path: List[Union[str, int]] - ) -> str: + def _get_sqlite_path(self, path: List[Union[str, int]]) -> str: """generate path suitable to query JSON element with SQLite""" return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}" @@ -1557,19 +1548,20 @@ # Full-Text Search fts = query.get("fts") if fts: - fts_select = text( - "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)" - ).bindparams(fts_query=fts).columns(rowid=Integer).subquery() - stmt = ( - stmt - .select_from(fts_select) - .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id) + fts_select = ( + text("SELECT rowid, rank FROM pubsub_items_fts(:fts_query)") + .bindparams(fts_query=fts) + .columns(rowid=Integer) + .subquery() + ) + stmt = stmt.select_from(fts_select).outerjoin( + PubsubItem, fts_select.c.rowid == PubsubItem.id ) # node related filters profiles = query.get("profiles") - if (profiles - or any(query.get(k) for k in ("nodes", "services", "types", "subtypes")) + if profiles or any( + query.get(k) for k in ("nodes", "services", "types", "subtypes") ): stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node)) if profiles: @@ -1585,7 +1577,7 @@ ("nodes", "name"), ("services", "service"), ("types", "type_"), - ("subtypes", "subtype") + ("subtypes", "subtype"), ): value = query.get(key) if not value: @@ -1596,10 +1588,7 @@ value.remove(None) condition = getattr(PubsubNode, attr).is_(None) if value: - condition = or_( - getattr(PubsubNode, attr).in_(value), - condition - ) + condition = or_(getattr(PubsubNode, attr).in_(value), condition) else: condition = getattr(PubsubNode, attr).in_(value) stmt = stmt.where(condition) @@ -1644,10 +1633,12 @@ try: left, right = value except (ValueError, TypeError): - raise ValueError(_( - "invalid value for \"between\" filter, you must use a 2 items " - "array: {value!r}" - ).format(value=value)) + raise ValueError( + _( + 'invalid value for "between" filter, you must use a 2 items ' + "array: {value!r}" + ).format(value=value) + ) col = func.json_extract(PubsubItem.parsed, sqlite_path) stmt = stmt.where(col.between(left, right)) else: @@ -1662,10 +1653,12 @@ for order_data in order_by: order, path = order_data.get("order"), order_data.get("path") if order and path: - raise ValueError(_( - '"order" and "path" can\'t be used at the same time in ' - '"order-by" data' - )) + raise ValueError( + _( + '"order" and "path" can\'t be used at the same time in ' + '"order-by" data' + ) + ) if order: if order == "creation": col = PubsubItem.id @@ -1702,3 +1695,172 @@ result = await session.execute(stmt) return result.scalars().all() + + # Notifications + + @aio + async def add_notification( + self, + client: Optional[SatXMPPEntity], + type_: NotificationType, + body_plain: str, + body_rich: Optional[str] = None, + title: Optional[str] = None, + requires_action: bool = False, + priority: NotificationPriority = NotificationPriority.MEDIUM, + expire_at: Optional[float] = None, + extra: Optional[dict] = None, + ) -> Notification: + """Add a new notification to the DB. + + @param client: client associated with the notification. If None, the notification + will be global. + @param type_: type of the notification. + @param body_plain: plain text body. + @param body_rich: rich text (XHTML) body. + @param title: optional title. + @param requires_action: True if the notification requires user action (e.g. a + dialog need to be answered). + @priority: how urgent the notification is + @param expire_at: expiration timestamp for the notification. + @param extra: additional data. + @return: created Notification + """ + profile_id = self.profiles[client.profile] if client else None + notification = Notification( + profile_id=profile_id, + type=type_, + body_plain=body_plain, + body_rich=body_rich, + requires_action=requires_action, + priority=priority, + expire_at=expire_at, + title=title, + extra_data=extra, + status=NotificationStatus.new, + ) + async with self.session() as session: + async with session.begin(): + session.add(notification) + return notification + + @aio + async def update_notification( + self, client: SatXMPPEntity, notification_id: int, **kwargs + ) -> None: + """Update an existing notification. + + @param client: client associated with the notification. + @param notification_id: ID of the notification to update. + """ + profile_id = self.profiles[client.profile] + async with self.session() as session: + await session.execute( + update(Notification) + .where( + and_( + Notification.profile_id == profile_id, + Notification.id == notification_id, + ) + ) + .values(**kwargs) + ) + await session.commit() + + @aio + async def get_notifications( + self, + client: SatXMPPEntity, + type_: Optional[NotificationType] = None, + status: Optional[NotificationStatus] = None, + requires_action: Optional[bool] = None, + min_priority: Optional[int] = None + ) -> List[Notification]: + """Retrieve all notifications for a given profile with optional filters. + + @param client: client associated with the notifications. + @param type_: filter by type of the notification. + @param status: filter by status of the notification. + @param requires_action: filter by notifications that require user action. + @param min_priority: filter by minimum priority value. + @return: list of matching Notification instances. + """ + profile_id = self.profiles[client.profile] + filters = [or_(Notification.profile_id == profile_id, Notification.profile_id.is_(None))] + + if type_: + filters.append(Notification.type == type_) + if status: + filters.append(Notification.status == status) + if requires_action is not None: + filters.append(Notification.requires_action == requires_action) + if min_priority: + filters.append(Notification.priority >= min_priority) + + async with self.session() as session: + result = await session.execute( + select(Notification) + .where(and_(*filters)) + .order_by(Notification.id) + ) + return result.scalars().all() + + @aio + async def delete_notification( + self, client: Optional[SatXMPPEntity], notification_id: str + ) -> None: + """Delete a notification by its profile and id. + + @param client: client associated with the notification. If None, profile_id will be NULL. + @param notification_id: ID of the notification to delete. + """ + profile_id = self.profiles[client.profile] if client else None + async with self.session() as session: + await session.execute( + delete(Notification).where( + and_( + Notification.profile_id == profile_id, + Notification.id == int(notification_id), + ) + ) + ) + await session.commit() + + @aio + async def clean_expired_notifications( + self, client: Optional[SatXMPPEntity], limit_timestamp: Optional[float] = None + ) -> None: + """Cleans expired notifications and older profile-specific notifications. + + - Removes all notifications where the expiration timestamp has passed, + irrespective of their profile. + - If a limit_timestamp is provided, removes older notifications with a profile set + (i.e., not global notifications) that do not require user action. If client is + provided, only remove notification for this profile. + + @param client: if provided, only expire notification for this client (in addition + to truly expired notifications for everybody). + @param limit_timestamp: Timestamp limit for older notifications. If None, only + truly expired notifications are removed. + """ + + # Delete truly expired notifications + expired_condition = Notification.expire_at < time.time() + + # Delete older profile-specific notifications (created before the limit_timestamp) + if client is None: + profile_condition = Notification.profile_id.isnot(None) + else: + profile_condition = Notification.profile_id == self.profiles[client.profile] + older_condition = and_( + profile_condition, + Notification.timestamp < limit_timestamp if limit_timestamp else False, + Notification.requires_action == False, + ) + + # Combine the conditions + conditions = or_(expired_condition, older_condition) + + async with self.session() as session: + await session.execute(delete(Notification).where(conditions)) + await session.commit()
--- a/libervia/backend/memory/sqla_mapping.py Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/backend/memory/sqla_mapping.py Mon Oct 16 17:29:31 2023 +0200 @@ -16,12 +16,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. -from typing import Dict, Any from datetime import datetime import enum import json import pickle import time +from typing import Any, Dict from sqlalchemy import ( Boolean, @@ -45,21 +45,52 @@ from twisted.words.protocols.jabber import jid from wokkel import generic +from libervia.backend.core.constants import Const as C + Base = declarative_base( metadata=MetaData( naming_convention={ - "ix": 'ix_%(column_0_label)s', + "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(table_name)s" + "pk": "pk_%(table_name)s", } ) ) # keys which are in message data extra but not stored in extra field this is # because those values are stored in separate fields -NOT_IN_EXTRA = ('origin_id', 'stanza_id', 'received_timestamp', 'update_uid') +NOT_IN_EXTRA = ("origin_id", "stanza_id", "received_timestamp", "update_uid") + + +class Profiles(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.id_to_profile = {v: k for k, v in self.items()} + + def __setitem__(self, key, value): + super().__setitem__(key, value) + self.id_to_profile[value] = key + + def __delitem__(self, key): + del self.id_to_profile[self[key]] + super().__delitem__(key) + + def update(self, *args, **kwargs): + super().update(*args, **kwargs) + self.id_to_profile = {v: k for k, v in self.items()} + + def clear(self): + super().clear() + self.id_to_profile.clear() + + +profiles = Profiles() + + +def get_profile_by_id( profile_id): + return profiles.id_to_profile.get(profile_id) class SyncState(enum.Enum): @@ -78,11 +109,34 @@ PENDING = 2 +class NotificationType(enum.Enum): + chat = "chat" + blog = "blog" + calendar = "calendar" + file = "file" + call = "call" + service = "service" + other = "other" + + +class NotificationStatus(enum.Enum): + new = "new" + read = "read" + + +class NotificationPriority(enum.IntEnum): + LOW = 10 + MEDIUM = 20 + HIGH = 30 + URGENT = 40 + + class LegacyPickle(TypeDecorator): """Handle troubles with data pickled by former version of SàT This type is temporary until we do migration to a proper data type """ + # Blob is used on SQLite but gives errors when used here, while Text works fine impl = Text cache_ok = True @@ -109,12 +163,13 @@ # JSON return pickle.loads( value.replace(b"sat.plugins", b"libervia.backend.plugins"), - encoding="utf-8" + encoding="utf-8", ) class Json(TypeDecorator): """Handle JSON field in DB independant way""" + # Blob is used on SQLite but gives errors when used here, while Text works fine impl = Text cache_ok = True @@ -156,6 +211,7 @@ class JID(TypeDecorator): """Store twisted JID in text fields""" + impl = Text cache_ok = True @@ -193,9 +249,7 @@ __tablename__ = "components" profile_id = Column( - ForeignKey("profiles.id", ondelete="CASCADE"), - nullable=True, - primary_key=True + ForeignKey("profiles.id", ondelete="CASCADE"), nullable=True, primary_key=True ) entry_point = Column(Text, nullable=False) profile = relationship("Profile") @@ -209,7 +263,7 @@ Index("history__profile_id_timestamp", "profile_id", "timestamp"), Index( "history__profile_id_received_timestamp", "profile_id", "received_timestamp" - ) + ), ) uid = Column(Text, primary_key=True) @@ -295,14 +349,14 @@ if self.thread.parent_id is not None: extra["thread_parent"] = self.thread.parent_id - return { - "from": f"{self.source}/{self.source_res}" if self.source_res - else self.source, + "from": f"{self.source}/{self.source_res}" + if self.source_res + else self.source, "to": f"{self.dest}/{self.dest_res}" if self.dest_res else self.dest, "uid": self.uid, - "message": {m.language or '': m.message for m in self.messages}, - "subject": {m.language or '': m.subject for m in self.subjects}, + "message": {m.language or "": m.message for m in self.messages}, + "subject": {m.language or "": m.subject for m in self.subjects}, "type": self.type, "extra": extra, "timestamp": self.timestamp, @@ -311,8 +365,14 @@ def as_tuple(self): d = self.serialise() return ( - d['uid'], d['timestamp'], d['from'], d['to'], d['message'], d['subject'], - d['type'], d['extra'] + d["uid"], + d["timestamp"], + d["from"], + d["to"], + d["message"], + d["subject"], + d["type"], + d["extra"], ) @staticmethod @@ -336,9 +396,7 @@ class Message(Base): __tablename__ = "message" - __table_args__ = ( - Index("message__history_uid", "history_uid"), - ) + __table_args__ = (Index("message__history_uid", "history_uid"),) id = Column( Integer, @@ -358,16 +416,14 @@ def __repr__(self): lang_str = f"[{self.language}]" if self.language else "" - msg = f"{self.message[:20]}…" if len(self.message)>20 else self.message + msg = f"{self.message[:20]}…" if len(self.message) > 20 else self.message content = f"{lang_str}{msg}" return f"Message<{content}>" class Subject(Base): __tablename__ = "subject" - __table_args__ = ( - Index("subject__history_uid", "history_uid"), - ) + __table_args__ = (Index("subject__history_uid", "history_uid"),) id = Column( Integer, @@ -387,16 +443,14 @@ def __repr__(self): lang_str = f"[{self.language}]" if self.language else "" - msg = f"{self.subject[:20]}…" if len(self.subject)>20 else self.subject + msg = f"{self.subject[:20]}…" if len(self.subject) > 20 else self.subject content = f"{lang_str}{msg}" return f"Subject<{content}>" class Thread(Base): __tablename__ = "thread" - __table_args__ = ( - Index("thread__history_uid", "history_uid"), - ) + __table_args__ = (Index("thread__history_uid", "history_uid"),) id = Column( Integer, @@ -412,6 +466,51 @@ return f"Thread<{self.thread_id} [parent: {self.parent_id}]>" +class Notification(Base): + __tablename__ = "notifications" + __table_args__ = (Index("notifications_profile_id_status", "profile_id", "status"),) + + id = Column(Integer, primary_key=True, autoincrement=True) + timestamp = Column(Float, nullable=False, default=time.time) + expire_at = Column(Float, nullable=True) + + profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), index=True, nullable=True) + profile = relationship("Profile") + + type = Column(Enum(NotificationType), nullable=False) + + title = Column(Text, nullable=True) + body_plain = Column(Text, nullable=False) + body_rich = Column(Text, nullable=True) + + requires_action = Column(Boolean, default=False) + priority = Column(Integer, default=NotificationPriority.MEDIUM.value) + + extra_data = Column(JSON) + status = Column(Enum(NotificationStatus), default=NotificationStatus.new) + + def serialise(self) -> dict[str, str | float | bool | int | dict]: + """ + Serialises the Notification instance to a dictionary. + """ + result = {} + for column in self.__table__.columns: + value = getattr(self, column.name) + if value is not None: + if column.name in ("type", "status"): + result[column.name] = value.name + elif column.name == "id": + result[column.name] = str(value) + elif column.name == "profile_id": + if value is None: + result["profile"] = C.PROF_KEY_ALL + else: + result["profile"] = get_profile_by_id(value) + else: + result[column.name] = value + return result + + class ParamGen(Base): __tablename__ = "param_gen" @@ -425,9 +524,7 @@ category = Column(Text, primary_key=True) name = Column(Text, primary_key=True) - profile_id = Column( - ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True - ) + profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True) value = Column(Text) profile = relationship("Profile", back_populates="params") @@ -446,9 +543,7 @@ namespace = Column(Text, primary_key=True) key = Column(Text, primary_key=True) - profile_id = Column( - ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True - ) + profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True) value = Column(Text) profile = relationship("Profile", back_populates="private_data") @@ -467,9 +562,7 @@ namespace = Column(Text, primary_key=True) key = Column(Text, primary_key=True) - profile_id = Column( - ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True - ) + profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True) value = Column(LegacyPickle) profile = relationship("Profile", back_populates="private_bin_data") @@ -484,8 +577,8 @@ "profile_id", "owner", "media_type", - "media_subtype" - ) + "media_subtype", + ), ) id = Column(Text, primary_key=True) @@ -493,11 +586,7 @@ version = Column(Text, primary_key=True) parent = Column(Text, nullable=False) type = Column( - Enum( - "file", "directory", - name="file_type", - create_constraint=True - ), + Enum("file", "directory", name="file_type", create_constraint=True), nullable=False, server_default="file", ) @@ -520,14 +609,10 @@ class PubsubNode(Base): __tablename__ = "pubsub_nodes" - __table_args__ = ( - UniqueConstraint("profile_id", "service", "name"), - ) + __table_args__ = (UniqueConstraint("profile_id", "service", "name"),) id = Column(Integer, primary_key=True) - profile_id = Column( - ForeignKey("profiles.id", ondelete="CASCADE") - ) + profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE")) service = Column(JID) name = Column(Text, nullable=False) subscribed = Column( @@ -540,19 +625,11 @@ name="sync_state", create_constraint=True, ), - nullable=True - ) - sync_state_updated = Column( - Float, - nullable=False, - default=time.time() + nullable=True, ) - type_ = Column( - Text, name="type", nullable=True - ) - subtype = Column( - Text, nullable=True - ) + sync_state_updated = Column(Float, nullable=False, default=time.time()) + type_ = Column(Text, name="type", nullable=True) + subtype = Column(Text, nullable=True) extra = Column(JSON) items = relationship("PubsubItem", back_populates="node", passive_deletes=True) @@ -567,10 +644,9 @@ Used by components managing a pubsub service """ + __tablename__ = "pubsub_subs" - __table_args__ = ( - UniqueConstraint("node_id", "subscriber"), - ) + __table_args__ = (UniqueConstraint("node_id", "subscriber"),) id = Column(Integer, primary_key=True) node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False) @@ -581,7 +657,7 @@ name="state", create_constraint=True, ), - nullable=True + nullable=True, ) node = relationship("PubsubNode", back_populates="subscriptions") @@ -589,9 +665,7 @@ class PubsubItem(Base): __tablename__ = "pubsub_items" - __table_args__ = ( - UniqueConstraint("node_id", "name"), - ) + __table_args__ = (UniqueConstraint("node_id", "name"),) id = Column(Integer, primary_key=True) node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False) name = Column(Text, nullable=False) @@ -607,6 +681,7 @@ # create + @event.listens_for(PubsubItem.__table__, "after_create") def fts_create(target, connection, **kw): """Full-Text Search table creation""" @@ -626,13 +701,15 @@ " INSERT INTO pubsub_items_fts(pubsub_items_fts, rowid, data) VALUES" "('delete', old.id, old.data);" " INSERT INTO pubsub_items_fts(rowid, data) VALUES(new.id, new.data);" - "END" + "END", ] for q in queries: connection.execute(DDL(q)) + # drop + @event.listens_for(PubsubItem.__table__, "before_drop") def fts_drop(target, connection, **kw): "Full-Text Search table drop" ""
--- a/libervia/frontends/bridge/dbus_bridge.py Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/frontends/bridge/dbus_bridge.py Mon Oct 16 17:29:31 2023 +0200 @@ -562,6 +562,62 @@ kwargs['error_handler'] = error_handler return self.db_core_iface.namespaces_get(**kwargs) + def notification_add(self, type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra, callback=None, errback=None): + if callback is None: + error_handler = None + else: + if errback is None: + errback = log.error + error_handler = lambda err:errback(dbus_to_bridge_exception(err)) + kwargs={} + if callback is not None: + kwargs['timeout'] = const_TIMEOUT + kwargs['reply_handler'] = callback + kwargs['error_handler'] = error_handler + return self.db_core_iface.notification_add(type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra, **kwargs) + + def notification_delete(self, id_, is_global, profile_key, callback=None, errback=None): + if callback is None: + error_handler = None + else: + if errback is None: + errback = log.error + error_handler = lambda err:errback(dbus_to_bridge_exception(err)) + kwargs={} + if callback is not None: + kwargs['timeout'] = const_TIMEOUT + kwargs['reply_handler'] = callback + kwargs['error_handler'] = error_handler + return self.db_core_iface.notification_delete(id_, is_global, profile_key, **kwargs) + + def notifications_expired_clean(self, limit_timestamp, profile_key, callback=None, errback=None): + if callback is None: + error_handler = None + else: + if errback is None: + errback = log.error + error_handler = lambda err:errback(dbus_to_bridge_exception(err)) + kwargs={} + if callback is not None: + kwargs['timeout'] = const_TIMEOUT + kwargs['reply_handler'] = callback + kwargs['error_handler'] = error_handler + return self.db_core_iface.notifications_expired_clean(limit_timestamp, profile_key, **kwargs) + + def notifications_get(self, filters, profile_key, callback=None, errback=None): + if callback is None: + error_handler = None + else: + if errback is None: + errback = log.error + error_handler = lambda err:errback(dbus_to_bridge_exception(err)) + kwargs={} + if callback is not None: + kwargs['timeout'] = const_TIMEOUT + kwargs['reply_handler'] = callback + kwargs['error_handler'] = error_handler + return str(self.db_core_iface.notifications_get(filters, profile_key, **kwargs)) + def param_get_a(self, name, category, attribute="value", profile_key="@DEFAULT@", callback=None, errback=None): if callback is None: error_handler = None @@ -1271,6 +1327,38 @@ self.db_core_iface.namespaces_get(timeout=const_TIMEOUT, reply_handler=reply_handler, error_handler=error_handler) return fut + def notification_add(self, type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra): + loop = asyncio.get_running_loop() + fut = loop.create_future() + reply_handler = lambda ret=None: loop.call_soon_threadsafe(fut.set_result, ret) + error_handler = lambda err: loop.call_soon_threadsafe(fut.set_exception, dbus_to_bridge_exception(err)) + self.db_core_iface.notification_add(type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra, timeout=const_TIMEOUT, reply_handler=reply_handler, error_handler=error_handler) + return fut + + def notification_delete(self, id_, is_global, profile_key): + loop = asyncio.get_running_loop() + fut = loop.create_future() + reply_handler = lambda ret=None: loop.call_soon_threadsafe(fut.set_result, ret) + error_handler = lambda err: loop.call_soon_threadsafe(fut.set_exception, dbus_to_bridge_exception(err)) + self.db_core_iface.notification_delete(id_, is_global, profile_key, timeout=const_TIMEOUT, reply_handler=reply_handler, error_handler=error_handler) + return fut + + def notifications_expired_clean(self, limit_timestamp, profile_key): + loop = asyncio.get_running_loop() + fut = loop.create_future() + reply_handler = lambda ret=None: loop.call_soon_threadsafe(fut.set_result, ret) + error_handler = lambda err: loop.call_soon_threadsafe(fut.set_exception, dbus_to_bridge_exception(err)) + self.db_core_iface.notifications_expired_clean(limit_timestamp, profile_key, timeout=const_TIMEOUT, reply_handler=reply_handler, error_handler=error_handler) + return fut + + def notifications_get(self, filters, profile_key): + loop = asyncio.get_running_loop() + fut = loop.create_future() + reply_handler = lambda ret=None: loop.call_soon_threadsafe(fut.set_result, ret) + error_handler = lambda err: loop.call_soon_threadsafe(fut.set_exception, dbus_to_bridge_exception(err)) + self.db_core_iface.notifications_get(filters, profile_key, timeout=const_TIMEOUT, reply_handler=reply_handler, error_handler=error_handler) + return fut + def param_get_a(self, name, category, attribute="value", profile_key="@DEFAULT@"): loop = asyncio.get_running_loop() fut = loop.create_future()
--- a/libervia/frontends/bridge/pb.py Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/frontends/bridge/pb.py Mon Oct 16 17:29:31 2023 +0200 @@ -489,6 +489,42 @@ else: d.addErrback(self._errback, ori_errback=errback) + def notification_add(self, type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra, callback=None, errback=None): + d = self.root.callRemote("notification_add", type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra) + if callback is not None: + d.addCallback(lambda __: callback()) + if errback is None: + d.addErrback(self._generic_errback) + else: + d.addErrback(self._errback, ori_errback=errback) + + def notification_delete(self, id_, is_global, profile_key, callback=None, errback=None): + d = self.root.callRemote("notification_delete", id_, is_global, profile_key) + if callback is not None: + d.addCallback(lambda __: callback()) + if errback is None: + d.addErrback(self._generic_errback) + else: + d.addErrback(self._errback, ori_errback=errback) + + def notifications_expired_clean(self, limit_timestamp, profile_key, callback=None, errback=None): + d = self.root.callRemote("notifications_expired_clean", limit_timestamp, profile_key) + if callback is not None: + d.addCallback(lambda __: callback()) + if errback is None: + d.addErrback(self._generic_errback) + else: + d.addErrback(self._errback, ori_errback=errback) + + def notifications_get(self, filters, profile_key, callback=None, errback=None): + d = self.root.callRemote("notifications_get", filters, profile_key) + if callback is not None: + d.addCallback(callback) + if errback is None: + d.addErrback(self._generic_errback) + else: + d.addErrback(self._errback, ori_errback=errback) + def param_get_a(self, name, category, attribute="value", profile_key="@DEFAULT@", callback=None, errback=None): d = self.root.callRemote("param_get_a", name, category, attribute, profile_key) if callback is not None: @@ -969,6 +1005,26 @@ d.addErrback(self._errback) return d.asFuture(asyncio.get_event_loop()) + def notification_add(self, type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra): + d = self.root.callRemote("notification_add", type_, body_plain, body_rich, title, is_global, requires_action, arg_6, priority, expire_at, extra) + d.addErrback(self._errback) + return d.asFuture(asyncio.get_event_loop()) + + def notification_delete(self, id_, is_global, profile_key): + d = self.root.callRemote("notification_delete", id_, is_global, profile_key) + d.addErrback(self._errback) + return d.asFuture(asyncio.get_event_loop()) + + def notifications_expired_clean(self, limit_timestamp, profile_key): + d = self.root.callRemote("notifications_expired_clean", limit_timestamp, profile_key) + d.addErrback(self._errback) + return d.asFuture(asyncio.get_event_loop()) + + def notifications_get(self, filters, profile_key): + d = self.root.callRemote("notifications_get", filters, profile_key) + d.addErrback(self._errback) + return d.asFuture(asyncio.get_event_loop()) + def param_get_a(self, name, category, attribute="value", profile_key="@DEFAULT@"): d = self.root.callRemote("param_get_a", name, category, attribute, profile_key) d.addErrback(self._errback)