Mercurial > libervia-backend
diff libervia/backend/memory/sqla.py @ 4130:02f0adc745c6
core: notifications implementation, first draft:
add a new table for notifications, and methods/bridge methods to manipulate them.
author | Goffi <goffi@goffi.org> |
---|---|
date | Mon, 16 Oct 2023 17:29:31 +0200 |
parents | 74c66c0d93f3 |
children | 6a0066ea5c97 |
line wrap: on
line diff
--- a/libervia/backend/memory/sqla.py Wed Oct 18 15:30:07 2023 +0200 +++ b/libervia/backend/memory/sqla.py Mon Oct 16 17:29:31 2023 +0200 @@ -62,6 +62,10 @@ History, Message, NOT_IN_EXTRA, + Notification, + NotificationPriority, + NotificationStatus, + NotificationType, ParamGen, ParamInd, PrivateGen, @@ -74,6 +78,8 @@ Subject, SyncState, Thread, + get_profile_by_id, + profiles, ) from libervia.backend.tools.common import uri from libervia.backend.tools.utils import aio, as_future @@ -117,13 +123,15 @@ def __init__(self): self.initialized = defer.Deferred() # we keep cache for the profiles (key: profile name, value: profile id) - # profile id to name - self.profiles: Dict[str, int] = {} # profile id to component entry point self.components: Dict[int, str] = {} + @property + def profiles(self): + return profiles + def get_profile_by_id(self, profile_id): - return self.profiles.get(profile_id) + return get_profile_by_id(profile_id) async def migrate_apply(self, *args: str, log_output: bool = False) -> None: """Do a migration command @@ -139,16 +147,18 @@ """ stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) proc = await asyncio.create_subprocess_exec( - sys.executable, "-m", "alembic", *args, - stdout=stdout, stderr=stderr, cwd=migration_path + sys.executable, + "-m", + "alembic", + *args, + stdout=stdout, + stderr=stderr, + cwd=migration_path, ) log_out, log_err = await proc.communicate() if proc.returncode != 0: - msg = _( - "Can't {operation} database (exit code {exit_code})" - ).format( - operation=args[0], - exit_code=proc.returncode + msg = _("Can't {operation} database (exit code {exit_code})").format( + operation=args[0], exit_code=proc.returncode ) if log_out or log_err: msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" @@ -216,9 +226,7 @@ async with engine.connect() as conn: await conn.run_sync(self._sqlite_set_journal_mode_wal) - self.session = sessionmaker( - engine, expire_on_commit=False, class_=AsyncSession - ) + self.session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) async with self.session() as session: result = await session.execute(select(Profile)) @@ -239,9 +247,9 @@ db_cls: DeclarativeMeta, db_id_col: Mapped, id_value: Any, - joined_loads = None + joined_loads=None, ) -> Optional[DeclarativeMeta]: - stmt = select(db_cls).where(db_id_col==id_value) + stmt = select(db_cls).where(db_id_col == id_value) if client is not None: stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) if joined_loads is not None: @@ -264,7 +272,7 @@ async def delete( self, db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]], - session_add: Optional[List[DeclarativeMeta]] = None + session_add: Optional[List[DeclarativeMeta]] = None, ) -> None: """Delete an object from database @@ -289,7 +297,7 @@ ## Profiles def get_profiles_list(self) -> List[str]: - """"Return list of all registered profiles""" + """ "Return list of all registered profiles""" return list(self.profiles.keys()) def has_profile(self, profile_name: str) -> bool: @@ -309,7 +317,9 @@ try: return self.components[self.profiles[profile_name]] except KeyError: - raise exceptions.NotFound("the requested profile doesn't exists or is not a component") + raise exceptions.NotFound( + "the requested profile doesn't exists or is not a component" + ) @aio async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None: @@ -344,7 +354,7 @@ del self.profiles[profile.name] if profile.id in self.components: del self.components[profile.id] - log.info(_("Profile {name!r} deleted").format(name = name)) + log.info(_("Profile {name!r} deleted").format(name=name)) ## Params @@ -376,7 +386,9 @@ params_ind[(p.category, p.name)] = p.value @aio - async def get_ind_param(self, category: str, name: str, profile: str) -> Optional[str]: + async def get_ind_param( + self, category: str, name: str, profile: str + ) -> Optional[str]: """Ask database for the value of one specific individual parameter @param category: category of the parameter @@ -385,11 +397,8 @@ """ async with self.session() as session: result = await session.execute( - select(ParamInd.value) - .filter_by( - category=category, - name=name, - profile_id=self.profiles[profile] + select(ParamInd.value).filter_by( + category=category, name=name, profile_id=self.profiles[profile] ) ) return result.scalar_one_or_none() @@ -405,10 +414,7 @@ async with self.session() as session: result = await session.execute( select(ParamInd) - .filter_by( - category=category, - name=name - ) + .filter_by(category=category, name=name) .options(subqueryload(ParamInd.profile)) ) return {param.profile.name: param.value for param in result.scalars()} @@ -422,26 +428,20 @@ @param value: value to set """ async with self.session() as session: - stmt = insert(ParamGen).values( - category=category, - name=name, - value=value - ).on_conflict_do_update( - index_elements=(ParamGen.category, ParamGen.name), - set_={ - ParamGen.value: value - } + stmt = ( + insert(ParamGen) + .values(category=category, name=name, value=value) + .on_conflict_do_update( + index_elements=(ParamGen.category, ParamGen.name), + set_={ParamGen.value: value}, + ) ) await session.execute(stmt) await session.commit() @aio async def set_ind_param( - self, - category:str, - name: str, - value: Optional[str], - profile: str + self, category: str, name: str, value: Optional[str], profile: str ) -> None: """Save the individual parameters in database @@ -451,16 +451,22 @@ @param profile: a profile which *must* exist """ async with self.session() as session: - stmt = insert(ParamInd).values( - category=category, - name=name, - profile_id=self.profiles[profile], - value=value - ).on_conflict_do_update( - index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), - set_={ - ParamInd.value: value - } + stmt = ( + insert(ParamInd) + .values( + category=category, + name=name, + profile_id=self.profiles[profile], + value=value, + ) + .on_conflict_do_update( + index_elements=( + ParamInd.category, + ParamInd.name, + ParamInd.profile_id, + ), + set_={ParamInd.value: value}, + ) ) await session.execute(stmt) await session.commit() @@ -474,13 +480,11 @@ if jid_.resource: if dest: return and_( - History.dest == jid_.userhost(), - History.dest_res == jid_.resource + History.dest == jid_.userhost(), History.dest_res == jid_.resource ) else: return and_( - History.source == jid_.userhost(), - History.source_res == jid_.resource + History.source == jid_.userhost(), History.source_res == jid_.resource ) else: if dest: @@ -497,9 +501,7 @@ between: bool = True, filters: Optional[Dict[str, str]] = None, profile: Optional[str] = None, - ) -> List[Tuple[ - str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] - ]: + ) -> List[Tuple[str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]]: """Retrieve messages in history @param from_jid: source JID (full, or bare for catchall) @@ -523,9 +525,7 @@ stmt = ( select(History) - .filter_by( - profile_id=self.profiles[profile] - ) + .filter_by(profile_id=self.profiles[profile]) .outerjoin(History.messages) .outerjoin(History.subjects) .outerjoin(History.thread) @@ -540,11 +540,10 @@ # order when returning the result. We use DESC here so LIMIT keep the last # messages History.timestamp.desc(), - History.received_timestamp.desc() + History.received_timestamp.desc(), ) ) - if not from_jid and not to_jid: # no jid specified, we want all one2one communications pass @@ -554,10 +553,7 @@ # from or to this jid jid_ = from_jid or to_jid stmt = stmt.where( - or_( - self._jid_filter(jid_), - self._jid_filter(jid_, dest=True) - ) + or_(self._jid_filter(jid_), self._jid_filter(jid_, dest=True)) ) else: # we have 2 jids specified, we check all communications between @@ -571,7 +567,7 @@ and_( self._jid_filter(to_jid), self._jid_filter(from_jid, dest=True), - ) + ), ) ) else: @@ -583,44 +579,44 @@ stmt = stmt.where(self._jid_filter(to_jid, dest=True)) if filters: - if 'timestamp_start' in filters: - stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) - if 'before_uid' in filters: + if "timestamp_start" in filters: + stmt = stmt.where(History.timestamp >= float(filters["timestamp_start"])) + if "before_uid" in filters: # orignially this query was using SQLITE's rowid. This has been changed # to use coalesce(received_timestamp, timestamp) to be SQL engine independant stmt = stmt.where( - coalesce( - History.received_timestamp, - History.timestamp - ) < ( - select(coalesce(History.received_timestamp, History.timestamp)) - .filter_by(uid=filters["before_uid"]) + coalesce(History.received_timestamp, History.timestamp) + < ( + select( + coalesce(History.received_timestamp, History.timestamp) + ).filter_by(uid=filters["before_uid"]) ).scalar_subquery() ) - if 'body' in filters: + if "body" in filters: # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) - if 'search' in filters: + if "search" in filters: search_term = f"%{filters['search']}%" - stmt = stmt.where(or_( - Message.message.like(search_term), - History.source_res.like(search_term) - )) - if 'types' in filters: - types = filters['types'].split() + stmt = stmt.where( + or_( + Message.message.like(search_term), + History.source_res.like(search_term), + ) + ) + if "types" in filters: + types = filters["types"].split() stmt = stmt.where(History.type.in_(types)) - if 'not_types' in filters: - types = filters['not_types'].split() + if "not_types" in filters: + types = filters["not_types"].split() stmt = stmt.where(History.type.not_in(types)) - if 'last_stanza_id' in filters: + if "last_stanza_id" in filters: # this request get the last message with a "stanza_id" that we # have in history. This is mainly used to retrieve messages sent # while we were offline, using MAM (XEP-0313). - if (filters['last_stanza_id'] is not True - or limit != 1): + if filters["last_stanza_id"] is not True or limit != 1: raise ValueError("Unexpected values for last_stanza_id filter") stmt = stmt.where(History.stanza_id.is_not(None)) - if 'origin_id' in filters: + if "origin_id" in filters: stmt = stmt.where(History.origin_id == filters["origin_id"]) if limit is not None: @@ -640,34 +636,40 @@ @param data: message data as build by SatMessageProtocol.onMessage """ extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} - messages = [Message(message=mess, language=lang) - for lang, mess in data["message"].items()] - subjects = [Subject(subject=mess, language=lang) - for lang, mess in data["subject"].items()] + messages = [ + Message(message=mess, language=lang) for lang, mess in data["message"].items() + ] + subjects = [ + Subject(subject=mess, language=lang) for lang, mess in data["subject"].items() + ] if "thread" in data["extra"]: - thread = Thread(thread_id=data["extra"]["thread"], - parent_id=data["extra"].get["thread_parent"]) + thread = Thread( + thread_id=data["extra"]["thread"], + parent_id=data["extra"].get["thread_parent"], + ) else: thread = None try: async with self.session() as session: async with session.begin(): - session.add(History( - uid=data["uid"], - origin_id=data["extra"].get("origin_id"), - stanza_id=data["extra"].get("stanza_id"), - update_uid=data["extra"].get("update_uid"), - profile_id=self.profiles[profile], - source_jid=data["from"], - dest_jid=data["to"], - timestamp=data["timestamp"], - received_timestamp=data.get("received_timestamp"), - type=data["type"], - extra=extra, - messages=messages, - subjects=subjects, - thread=thread, - )) + session.add( + History( + uid=data["uid"], + origin_id=data["extra"].get("origin_id"), + stanza_id=data["extra"].get("stanza_id"), + update_uid=data["extra"].get("update_uid"), + profile_id=self.profiles[profile], + source_jid=data["from"], + dest_jid=data["to"], + timestamp=data["timestamp"], + received_timestamp=data.get("received_timestamp"), + type=data["type"], + extra=extra, + messages=messages, + subjects=subjects, + thread=thread, + ) + ) except IntegrityError as e: if "unique" in str(e.orig).lower(): log.debug( @@ -689,14 +691,13 @@ else: return PrivateIndBin if binary else PrivateInd - @aio async def get_privates( self, - namespace:str, + namespace: str, keys: Optional[Iterable[str]] = None, binary: bool = False, - profile: Optional[str] = None + profile: Optional[str] = None, ) -> Dict[str, Any]: """Get private value(s) from databases @@ -728,10 +729,10 @@ async def set_private_value( self, namespace: str, - key:str, + key: str, value: Any, binary: bool = False, - profile: Optional[str] = None + profile: Optional[str] = None, ) -> None: """Set a private value in database @@ -745,11 +746,7 @@ """ cls = self._get_private_class(binary, profile) - values = { - "namespace": namespace, - "key": key, - "value": value - } + values = {"namespace": namespace, "key": key, "value": value} index_elements = [cls.namespace, cls.key] if profile is not None: @@ -758,11 +755,10 @@ async with self.session() as session: await session.execute( - insert(cls).values(**values).on_conflict_do_update( - index_elements=index_elements, - set_={ - cls.value: value - } + insert(cls) + .values(**values) + .on_conflict_do_update( + index_elements=index_elements, set_={cls.value: value} ) ) await session.commit() @@ -773,7 +769,7 @@ namespace: str, key: str, binary: bool = False, - profile: Optional[str] = None + profile: Optional[str] = None, ) -> None: """Delete private value from database @@ -796,10 +792,7 @@ @aio async def del_private_namespace( - self, - namespace: str, - binary: bool = False, - profile: Optional[str] = None + self, namespace: str, binary: bool = False, profile: Optional[str] = None ) -> None: """Delete all data from a private namespace @@ -825,7 +818,7 @@ self, client: Optional[SatXMPPEntity], file_id: Optional[str] = None, - version: Optional[str] = '', + version: Optional[str] = "", parent: Optional[str] = None, type_: Optional[str] = None, file_hash: Optional[str] = None, @@ -837,7 +830,7 @@ owner: Optional[jid.JID] = None, access: Optional[dict] = None, projection: Optional[List[str]] = None, - unique: bool = False + unique: bool = False, ) -> List[dict]: """Retrieve files with with given filters @@ -857,9 +850,23 @@ """ if projection is None: projection = [ - 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', - 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', - 'created', 'modified', 'owner', 'access', 'extra' + "id", + "version", + "parent", + "type", + "file_hash", + "hash_algo", + "name", + "size", + "namespace", + "media_type", + "media_subtype", + "public_id", + "created", + "modified", + "owner", + "access", + "extra", ] stmt = select(*[getattr(File, f) for f in projection]) @@ -891,7 +898,7 @@ if namespace is not None: stmt = stmt.filter_by(namespace=namespace) if mime_type is not None: - if '/' in mime_type: + if "/" in mime_type: media_type, media_subtype = mime_type.split("/", 1) stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) else: @@ -901,8 +908,8 @@ if owner is not None: stmt = stmt.filter_by(owner=owner) if access is not None: - raise NotImplementedError('Access check is not implemented yet') - # a JSON comparison is needed here + raise NotImplementedError("Access check is not implemented yet") + # a JSON comparison is needed here async with self.session() as session: result = await session.execute(stmt) @@ -928,7 +935,7 @@ modified: Optional[float] = None, owner: Optional[jid.JID] = None, access: Optional[dict] = None, - extra: Optional[dict] = None + extra: Optional[dict] = None, ) -> None: """Set a file metadata @@ -958,33 +965,35 @@ """ if mime_type is None: media_type = media_subtype = None - elif '/' in mime_type: - media_type, media_subtype = mime_type.split('/', 1) + elif "/" in mime_type: + media_type, media_subtype = mime_type.split("/", 1) else: media_type, media_subtype = mime_type, None async with self.session() as session: async with session.begin(): - session.add(File( - id=file_id, - version=version.strip(), - parent=parent, - type=type_, - file_hash=file_hash, - hash_algo=hash_algo, - name=name, - size=size, - namespace=namespace, - media_type=media_type, - media_subtype=media_subtype, - public_id=public_id, - created=time.time() if created is None else created, - modified=modified, - owner=owner, - access=access, - extra=extra, - profile_id=self.profiles[client.profile] - )) + session.add( + File( + id=file_id, + version=version.strip(), + parent=parent, + type=type_, + file_hash=file_hash, + hash_algo=hash_algo, + name=name, + size=size, + namespace=namespace, + media_type=media_type, + media_subtype=media_subtype, + public_id=public_id, + created=time.time() if created is None else created, + modified=modified, + owner=owner, + access=access, + extra=extra, + profile_id=self.profiles[client.profile], + ) + ) @aio async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int: @@ -993,8 +1002,9 @@ select(sum_(File.size)).filter_by( owner=owner, type=C.FILE_TYPE_FILE, - profile_id=self.profiles[client.profile] - )) + profile_id=self.profiles[client.profile], + ) + ) return result.scalar_one_or_none() or 0 @aio @@ -1011,10 +1021,7 @@ @aio async def file_update( - self, - file_id: str, - column: str, - update_cb: Callable[[dict], None] + self, file_id: str, column: str, update_cb: Callable[[dict], None] ) -> None: """Update a column value using a method to avoid race conditions @@ -1029,16 +1036,16 @@ it get the deserialized data (i.e. a Python object) directly @raise exceptions.NotFound: there is not file with this id """ - if column not in ('access', 'extra'): - raise exceptions.InternalError('bad column name') + if column not in ("access", "extra"): + raise exceptions.InternalError("bad column name") orm_col = getattr(File, column) for i in range(5): async with self.session() as session: try: - value = (await session.execute( - select(orm_col).filter_by(id=file_id) - )).scalar_one() + value = ( + await session.execute(select(orm_col).filter_by(id=file_id)) + ).scalar_one() except NoResultFound: raise exceptions.NotFound old_value = copy.deepcopy(value) @@ -1057,14 +1064,17 @@ break log.warning( - _("table not updated, probably due to race condition, trying again " - "({tries})").format(tries=i+1) + _( + "table not updated, probably due to race condition, trying again " + "({tries})" + ).format(tries=i + 1) ) else: raise exceptions.DatabaseError( - _("Can't update file {file_id} due to race condition") - .format(file_id=file_id) + _("Can't update file {file_id} due to race condition").format( + file_id=file_id + ) ) @aio @@ -1076,7 +1086,7 @@ with_items: bool = False, with_subscriptions: bool = False, create: bool = False, - create_kwargs: Optional[dict] = None + create_kwargs: Optional[dict] = None, ) -> Optional[PubsubNode]: """Retrieve a PubsubNode from DB @@ -1089,22 +1099,15 @@ needs to be created. """ async with self.session() as session: - stmt = ( - select(PubsubNode) - .filter_by( - service=service, - name=name, - profile_id=self.profiles[client.profile], - ) + stmt = select(PubsubNode).filter_by( + service=service, + name=name, + profile_id=self.profiles[client.profile], ) if with_items: - stmt = stmt.options( - joinedload(PubsubNode.items) - ) + stmt = stmt.options(joinedload(PubsubNode.items)) if with_subscriptions: - stmt = stmt.options( - joinedload(PubsubNode.subscriptions) - ) + stmt = stmt.options(joinedload(PubsubNode.subscriptions)) result = await session.execute(stmt) ret = result.unique().scalar_one_or_none() if ret is None and create: @@ -1112,21 +1115,23 @@ if create_kwargs is None: create_kwargs = {} try: - return await as_future(self.set_pubsub_node( - client, service, name, **create_kwargs - )) + return await as_future( + self.set_pubsub_node(client, service, name, **create_kwargs) + ) except IntegrityError as e: if "unique" in str(e.orig).lower(): # the node may already exist, if it has been created just after # get_pubsub_node above log.debug("ignoring UNIQUE constraint error") - cached_node = await as_future(self.get_pubsub_node( - client, - service, - name, - with_items=with_items, - with_subscriptions=with_subscriptions - )) + cached_node = await as_future( + self.get_pubsub_node( + client, + service, + name, + with_items=with_items, + with_subscriptions=with_subscriptions, + ) + ) else: raise e else: @@ -1160,9 +1165,7 @@ @aio async def update_pubsub_node_sync_state( - self, - node: PubsubNode, - state: SyncState + self, node: PubsubNode, state: SyncState ) -> None: async with self.session() as session: async with session.begin(): @@ -1180,7 +1183,7 @@ self, profiles: Optional[List[str]], services: Optional[List[jid.JID]], - names: Optional[List[str]] + names: Optional[List[str]], ) -> None: """Delete items cached for a node @@ -1194,9 +1197,7 @@ stmt = delete(PubsubNode) if profiles is not None: stmt = stmt.where( - PubsubNode.profile.in_( - [self.profiles[p] for p in profiles] - ) + PubsubNode.profile.in_([self.profiles[p] for p in profiles]) ) if services is not None: stmt = stmt.where(PubsubNode.service.in_(services)) @@ -1223,27 +1224,29 @@ async with session.begin(): for idx, item in enumerate(items): parsed = parsed_items[idx] if parsed_items else None - stmt = insert(PubsubItem).values( - node_id = node.id, - name = item["id"], - data = item, - parsed = parsed, - ).on_conflict_do_update( - index_elements=(PubsubItem.node_id, PubsubItem.name), - set_={ - PubsubItem.data: item, - PubsubItem.parsed: parsed, - PubsubItem.updated: now() - } + stmt = ( + insert(PubsubItem) + .values( + node_id=node.id, + name=item["id"], + data=item, + parsed=parsed, + ) + .on_conflict_do_update( + index_elements=(PubsubItem.node_id, PubsubItem.name), + set_={ + PubsubItem.data: item, + PubsubItem.parsed: parsed, + PubsubItem.updated: now(), + }, + ) ) await session.execute(stmt) await session.commit() @aio async def delete_pubsub_items( - self, - node: PubsubNode, - items_names: Optional[List[str]] = None + self, node: PubsubNode, items_names: Optional[List[str]] = None ) -> None: """Delete items cached for a node @@ -1296,10 +1299,8 @@ if values is None: continue sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) - stmt = ( - stmt - .where(PubsubItem.node_id.in_(sub_q)) - .execution_options(synchronize_session=False) + stmt = stmt.where(PubsubItem.node_id.in_(sub_q)).execution_options( + synchronize_session=False ) if created_before is not None: @@ -1367,11 +1368,7 @@ use_rsm = True from_index = 0 - stmt = ( - select(PubsubItem) - .filter_by(node_id=node.id) - .limit(max_items) - ) + stmt = select(PubsubItem).filter_by(node_id=node.id).limit(max_items) if item_ids is not None: stmt = stmt.where(PubsubItem.name.in_(item_ids)) @@ -1402,7 +1399,7 @@ PubsubItem.id, PubsubItem.name, # row_number starts from 1, but RSM index must start from 0 - (func.row_number().over(order_by=order)-1).label("item_index") + (func.row_number().over(order_by=order) - 1).label("item_index"), ).filter_by(node_id=node.id) row_num_cte = row_num_q.cte() @@ -1412,8 +1409,9 @@ # we need to use row number item_name = before or after row_num_limit_q = ( - select(row_num_cte.c.item_index) - .where(row_num_cte.c.name==item_name) + select(row_num_cte.c.item_index).where( + row_num_cte.c.name == item_name + ) ).scalar_subquery() stmt = ( @@ -1422,22 +1420,16 @@ .limit(max_items) ) if before: - stmt = ( - stmt - .where(row_num_cte.c.item_index<row_num_limit_q) - .order_by(row_num_cte.c.item_index.desc()) - ) + stmt = stmt.where( + row_num_cte.c.item_index < row_num_limit_q + ).order_by(row_num_cte.c.item_index.desc()) elif after: - stmt = ( - stmt - .where(row_num_cte.c.item_index>row_num_limit_q) - .order_by(row_num_cte.c.item_index.asc()) - ) + stmt = stmt.where( + row_num_cte.c.item_index > row_num_limit_q + ).order_by(row_num_cte.c.item_index.asc()) else: - stmt = ( - stmt - .where(row_num_cte.c.item_index>=from_index) - .order_by(row_num_cte.c.item_index.asc()) + stmt = stmt.where(row_num_cte.c.item_index >= from_index).order_by( + row_num_cte.c.item_index.asc() ) # from_index is used @@ -1468,12 +1460,14 @@ last = result[-1][1].name metadata["rsm"] = { - k: v for k, v in { + k: v + for k, v in { "index": index, "count": rows_count, "first": first, "last": last, - }.items() if v is not None + }.items() + if v is not None } metadata["complete"] = (index or 0) + len(result) == rows_count @@ -1487,10 +1481,7 @@ result.reverse() return result, metadata - def _get_sqlite_path( - self, - path: List[Union[str, int]] - ) -> str: + def _get_sqlite_path(self, path: List[Union[str, int]]) -> str: """generate path suitable to query JSON element with SQLite""" return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}" @@ -1557,19 +1548,20 @@ # Full-Text Search fts = query.get("fts") if fts: - fts_select = text( - "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)" - ).bindparams(fts_query=fts).columns(rowid=Integer).subquery() - stmt = ( - stmt - .select_from(fts_select) - .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id) + fts_select = ( + text("SELECT rowid, rank FROM pubsub_items_fts(:fts_query)") + .bindparams(fts_query=fts) + .columns(rowid=Integer) + .subquery() + ) + stmt = stmt.select_from(fts_select).outerjoin( + PubsubItem, fts_select.c.rowid == PubsubItem.id ) # node related filters profiles = query.get("profiles") - if (profiles - or any(query.get(k) for k in ("nodes", "services", "types", "subtypes")) + if profiles or any( + query.get(k) for k in ("nodes", "services", "types", "subtypes") ): stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node)) if profiles: @@ -1585,7 +1577,7 @@ ("nodes", "name"), ("services", "service"), ("types", "type_"), - ("subtypes", "subtype") + ("subtypes", "subtype"), ): value = query.get(key) if not value: @@ -1596,10 +1588,7 @@ value.remove(None) condition = getattr(PubsubNode, attr).is_(None) if value: - condition = or_( - getattr(PubsubNode, attr).in_(value), - condition - ) + condition = or_(getattr(PubsubNode, attr).in_(value), condition) else: condition = getattr(PubsubNode, attr).in_(value) stmt = stmt.where(condition) @@ -1644,10 +1633,12 @@ try: left, right = value except (ValueError, TypeError): - raise ValueError(_( - "invalid value for \"between\" filter, you must use a 2 items " - "array: {value!r}" - ).format(value=value)) + raise ValueError( + _( + 'invalid value for "between" filter, you must use a 2 items ' + "array: {value!r}" + ).format(value=value) + ) col = func.json_extract(PubsubItem.parsed, sqlite_path) stmt = stmt.where(col.between(left, right)) else: @@ -1662,10 +1653,12 @@ for order_data in order_by: order, path = order_data.get("order"), order_data.get("path") if order and path: - raise ValueError(_( - '"order" and "path" can\'t be used at the same time in ' - '"order-by" data' - )) + raise ValueError( + _( + '"order" and "path" can\'t be used at the same time in ' + '"order-by" data' + ) + ) if order: if order == "creation": col = PubsubItem.id @@ -1702,3 +1695,172 @@ result = await session.execute(stmt) return result.scalars().all() + + # Notifications + + @aio + async def add_notification( + self, + client: Optional[SatXMPPEntity], + type_: NotificationType, + body_plain: str, + body_rich: Optional[str] = None, + title: Optional[str] = None, + requires_action: bool = False, + priority: NotificationPriority = NotificationPriority.MEDIUM, + expire_at: Optional[float] = None, + extra: Optional[dict] = None, + ) -> Notification: + """Add a new notification to the DB. + + @param client: client associated with the notification. If None, the notification + will be global. + @param type_: type of the notification. + @param body_plain: plain text body. + @param body_rich: rich text (XHTML) body. + @param title: optional title. + @param requires_action: True if the notification requires user action (e.g. a + dialog need to be answered). + @priority: how urgent the notification is + @param expire_at: expiration timestamp for the notification. + @param extra: additional data. + @return: created Notification + """ + profile_id = self.profiles[client.profile] if client else None + notification = Notification( + profile_id=profile_id, + type=type_, + body_plain=body_plain, + body_rich=body_rich, + requires_action=requires_action, + priority=priority, + expire_at=expire_at, + title=title, + extra_data=extra, + status=NotificationStatus.new, + ) + async with self.session() as session: + async with session.begin(): + session.add(notification) + return notification + + @aio + async def update_notification( + self, client: SatXMPPEntity, notification_id: int, **kwargs + ) -> None: + """Update an existing notification. + + @param client: client associated with the notification. + @param notification_id: ID of the notification to update. + """ + profile_id = self.profiles[client.profile] + async with self.session() as session: + await session.execute( + update(Notification) + .where( + and_( + Notification.profile_id == profile_id, + Notification.id == notification_id, + ) + ) + .values(**kwargs) + ) + await session.commit() + + @aio + async def get_notifications( + self, + client: SatXMPPEntity, + type_: Optional[NotificationType] = None, + status: Optional[NotificationStatus] = None, + requires_action: Optional[bool] = None, + min_priority: Optional[int] = None + ) -> List[Notification]: + """Retrieve all notifications for a given profile with optional filters. + + @param client: client associated with the notifications. + @param type_: filter by type of the notification. + @param status: filter by status of the notification. + @param requires_action: filter by notifications that require user action. + @param min_priority: filter by minimum priority value. + @return: list of matching Notification instances. + """ + profile_id = self.profiles[client.profile] + filters = [or_(Notification.profile_id == profile_id, Notification.profile_id.is_(None))] + + if type_: + filters.append(Notification.type == type_) + if status: + filters.append(Notification.status == status) + if requires_action is not None: + filters.append(Notification.requires_action == requires_action) + if min_priority: + filters.append(Notification.priority >= min_priority) + + async with self.session() as session: + result = await session.execute( + select(Notification) + .where(and_(*filters)) + .order_by(Notification.id) + ) + return result.scalars().all() + + @aio + async def delete_notification( + self, client: Optional[SatXMPPEntity], notification_id: str + ) -> None: + """Delete a notification by its profile and id. + + @param client: client associated with the notification. If None, profile_id will be NULL. + @param notification_id: ID of the notification to delete. + """ + profile_id = self.profiles[client.profile] if client else None + async with self.session() as session: + await session.execute( + delete(Notification).where( + and_( + Notification.profile_id == profile_id, + Notification.id == int(notification_id), + ) + ) + ) + await session.commit() + + @aio + async def clean_expired_notifications( + self, client: Optional[SatXMPPEntity], limit_timestamp: Optional[float] = None + ) -> None: + """Cleans expired notifications and older profile-specific notifications. + + - Removes all notifications where the expiration timestamp has passed, + irrespective of their profile. + - If a limit_timestamp is provided, removes older notifications with a profile set + (i.e., not global notifications) that do not require user action. If client is + provided, only remove notification for this profile. + + @param client: if provided, only expire notification for this client (in addition + to truly expired notifications for everybody). + @param limit_timestamp: Timestamp limit for older notifications. If None, only + truly expired notifications are removed. + """ + + # Delete truly expired notifications + expired_condition = Notification.expire_at < time.time() + + # Delete older profile-specific notifications (created before the limit_timestamp) + if client is None: + profile_condition = Notification.profile_id.isnot(None) + else: + profile_condition = Notification.profile_id == self.profiles[client.profile] + older_condition = and_( + profile_condition, + Notification.timestamp < limit_timestamp if limit_timestamp else False, + Notification.requires_action == False, + ) + + # Combine the conditions + conditions = or_(expired_condition, older_condition) + + async with self.session() as session: + await session.execute(delete(Notification).where(conditions)) + await session.commit()