diff libervia/backend/memory/sqla.py @ 4130:02f0adc745c6

core: notifications implementation, first draft: add a new table for notifications, and methods/bridge methods to manipulate them.
author Goffi <goffi@goffi.org>
date Mon, 16 Oct 2023 17:29:31 +0200
parents 74c66c0d93f3
children 6a0066ea5c97
line wrap: on
line diff
--- a/libervia/backend/memory/sqla.py	Wed Oct 18 15:30:07 2023 +0200
+++ b/libervia/backend/memory/sqla.py	Mon Oct 16 17:29:31 2023 +0200
@@ -62,6 +62,10 @@
     History,
     Message,
     NOT_IN_EXTRA,
+    Notification,
+    NotificationPriority,
+    NotificationStatus,
+    NotificationType,
     ParamGen,
     ParamInd,
     PrivateGen,
@@ -74,6 +78,8 @@
     Subject,
     SyncState,
     Thread,
+    get_profile_by_id,
+    profiles,
 )
 from libervia.backend.tools.common import uri
 from libervia.backend.tools.utils import aio, as_future
@@ -117,13 +123,15 @@
     def __init__(self):
         self.initialized = defer.Deferred()
         # we keep cache for the profiles (key: profile name, value: profile id)
-        # profile id to name
-        self.profiles: Dict[str, int] = {}
         # profile id to component entry point
         self.components: Dict[int, str] = {}
 
+    @property
+    def profiles(self):
+        return profiles
+
     def get_profile_by_id(self, profile_id):
-        return self.profiles.get(profile_id)
+        return get_profile_by_id(profile_id)
 
     async def migrate_apply(self, *args: str, log_output: bool = False) -> None:
         """Do a migration command
@@ -139,16 +147,18 @@
         """
         stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,)
         proc = await asyncio.create_subprocess_exec(
-            sys.executable, "-m", "alembic", *args,
-            stdout=stdout, stderr=stderr, cwd=migration_path
+            sys.executable,
+            "-m",
+            "alembic",
+            *args,
+            stdout=stdout,
+            stderr=stderr,
+            cwd=migration_path,
         )
         log_out, log_err = await proc.communicate()
         if proc.returncode != 0:
-            msg = _(
-                "Can't {operation} database (exit code {exit_code})"
-            ).format(
-                operation=args[0],
-                exit_code=proc.returncode
+            msg = _("Can't {operation} database (exit code {exit_code})").format(
+                operation=args[0], exit_code=proc.returncode
             )
             if log_out or log_err:
                 msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}"
@@ -216,9 +226,7 @@
         async with engine.connect() as conn:
             await conn.run_sync(self._sqlite_set_journal_mode_wal)
 
-        self.session = sessionmaker(
-            engine, expire_on_commit=False, class_=AsyncSession
-        )
+        self.session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
 
         async with self.session() as session:
             result = await session.execute(select(Profile))
@@ -239,9 +247,9 @@
         db_cls: DeclarativeMeta,
         db_id_col: Mapped,
         id_value: Any,
-        joined_loads = None
+        joined_loads=None,
     ) -> Optional[DeclarativeMeta]:
-        stmt = select(db_cls).where(db_id_col==id_value)
+        stmt = select(db_cls).where(db_id_col == id_value)
         if client is not None:
             stmt = stmt.filter_by(profile_id=self.profiles[client.profile])
         if joined_loads is not None:
@@ -264,7 +272,7 @@
     async def delete(
         self,
         db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]],
-        session_add: Optional[List[DeclarativeMeta]] = None
+        session_add: Optional[List[DeclarativeMeta]] = None,
     ) -> None:
         """Delete an object from database
 
@@ -289,7 +297,7 @@
     ## Profiles
 
     def get_profiles_list(self) -> List[str]:
-        """"Return list of all registered profiles"""
+        """ "Return list of all registered profiles"""
         return list(self.profiles.keys())
 
     def has_profile(self, profile_name: str) -> bool:
@@ -309,7 +317,9 @@
         try:
             return self.components[self.profiles[profile_name]]
         except KeyError:
-            raise exceptions.NotFound("the requested profile doesn't exists or is not a component")
+            raise exceptions.NotFound(
+                "the requested profile doesn't exists or is not a component"
+            )
 
     @aio
     async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None:
@@ -344,7 +354,7 @@
         del self.profiles[profile.name]
         if profile.id in self.components:
             del self.components[profile.id]
-        log.info(_("Profile {name!r} deleted").format(name = name))
+        log.info(_("Profile {name!r} deleted").format(name=name))
 
     ## Params
 
@@ -376,7 +386,9 @@
             params_ind[(p.category, p.name)] = p.value
 
     @aio
-    async def get_ind_param(self, category: str, name: str, profile: str) -> Optional[str]:
+    async def get_ind_param(
+        self, category: str, name: str, profile: str
+    ) -> Optional[str]:
         """Ask database for the value of one specific individual parameter
 
         @param category: category of the parameter
@@ -385,11 +397,8 @@
         """
         async with self.session() as session:
             result = await session.execute(
-                select(ParamInd.value)
-                .filter_by(
-                    category=category,
-                    name=name,
-                    profile_id=self.profiles[profile]
+                select(ParamInd.value).filter_by(
+                    category=category, name=name, profile_id=self.profiles[profile]
                 )
             )
         return result.scalar_one_or_none()
@@ -405,10 +414,7 @@
         async with self.session() as session:
             result = await session.execute(
                 select(ParamInd)
-                .filter_by(
-                    category=category,
-                    name=name
-                )
+                .filter_by(category=category, name=name)
                 .options(subqueryload(ParamInd.profile))
             )
         return {param.profile.name: param.value for param in result.scalars()}
@@ -422,26 +428,20 @@
         @param value: value to set
         """
         async with self.session() as session:
-            stmt = insert(ParamGen).values(
-                category=category,
-                name=name,
-                value=value
-            ).on_conflict_do_update(
-                index_elements=(ParamGen.category, ParamGen.name),
-                set_={
-                    ParamGen.value: value
-                }
+            stmt = (
+                insert(ParamGen)
+                .values(category=category, name=name, value=value)
+                .on_conflict_do_update(
+                    index_elements=(ParamGen.category, ParamGen.name),
+                    set_={ParamGen.value: value},
+                )
             )
             await session.execute(stmt)
             await session.commit()
 
     @aio
     async def set_ind_param(
-        self,
-        category:str,
-        name: str,
-        value: Optional[str],
-        profile: str
+        self, category: str, name: str, value: Optional[str], profile: str
     ) -> None:
         """Save the individual parameters in database
 
@@ -451,16 +451,22 @@
         @param profile: a profile which *must* exist
         """
         async with self.session() as session:
-            stmt = insert(ParamInd).values(
-                category=category,
-                name=name,
-                profile_id=self.profiles[profile],
-                value=value
-            ).on_conflict_do_update(
-                index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id),
-                set_={
-                    ParamInd.value: value
-                }
+            stmt = (
+                insert(ParamInd)
+                .values(
+                    category=category,
+                    name=name,
+                    profile_id=self.profiles[profile],
+                    value=value,
+                )
+                .on_conflict_do_update(
+                    index_elements=(
+                        ParamInd.category,
+                        ParamInd.name,
+                        ParamInd.profile_id,
+                    ),
+                    set_={ParamInd.value: value},
+                )
             )
             await session.execute(stmt)
             await session.commit()
@@ -474,13 +480,11 @@
         if jid_.resource:
             if dest:
                 return and_(
-                    History.dest == jid_.userhost(),
-                    History.dest_res == jid_.resource
+                    History.dest == jid_.userhost(), History.dest_res == jid_.resource
                 )
             else:
                 return and_(
-                    History.source == jid_.userhost(),
-                    History.source_res == jid_.resource
+                    History.source == jid_.userhost(), History.source_res == jid_.resource
                 )
         else:
             if dest:
@@ -497,9 +501,7 @@
         between: bool = True,
         filters: Optional[Dict[str, str]] = None,
         profile: Optional[str] = None,
-    ) -> List[Tuple[
-        str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]
-    ]:
+    ) -> List[Tuple[str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]]:
         """Retrieve messages in history
 
         @param from_jid: source JID (full, or bare for catchall)
@@ -523,9 +525,7 @@
 
         stmt = (
             select(History)
-            .filter_by(
-                profile_id=self.profiles[profile]
-            )
+            .filter_by(profile_id=self.profiles[profile])
             .outerjoin(History.messages)
             .outerjoin(History.subjects)
             .outerjoin(History.thread)
@@ -540,11 +540,10 @@
                 # order when returning the result. We use DESC here so LIMIT keep the last
                 # messages
                 History.timestamp.desc(),
-                History.received_timestamp.desc()
+                History.received_timestamp.desc(),
             )
         )
 
-
         if not from_jid and not to_jid:
             # no jid specified, we want all one2one communications
             pass
@@ -554,10 +553,7 @@
                 # from or to this jid
                 jid_ = from_jid or to_jid
                 stmt = stmt.where(
-                    or_(
-                        self._jid_filter(jid_),
-                        self._jid_filter(jid_, dest=True)
-                    )
+                    or_(self._jid_filter(jid_), self._jid_filter(jid_, dest=True))
                 )
             else:
                 # we have 2 jids specified, we check all communications between
@@ -571,7 +567,7 @@
                         and_(
                             self._jid_filter(to_jid),
                             self._jid_filter(from_jid, dest=True),
-                        )
+                        ),
                     )
                 )
         else:
@@ -583,44 +579,44 @@
                 stmt = stmt.where(self._jid_filter(to_jid, dest=True))
 
         if filters:
-            if 'timestamp_start' in filters:
-                stmt = stmt.where(History.timestamp >= float(filters['timestamp_start']))
-            if 'before_uid' in filters:
+            if "timestamp_start" in filters:
+                stmt = stmt.where(History.timestamp >= float(filters["timestamp_start"]))
+            if "before_uid" in filters:
                 # orignially this query was using SQLITE's rowid. This has been changed
                 # to use coalesce(received_timestamp, timestamp) to be SQL engine independant
                 stmt = stmt.where(
-                    coalesce(
-                        History.received_timestamp,
-                        History.timestamp
-                    ) < (
-                        select(coalesce(History.received_timestamp, History.timestamp))
-                        .filter_by(uid=filters["before_uid"])
+                    coalesce(History.received_timestamp, History.timestamp)
+                    < (
+                        select(
+                            coalesce(History.received_timestamp, History.timestamp)
+                        ).filter_by(uid=filters["before_uid"])
                     ).scalar_subquery()
                 )
-            if 'body' in filters:
+            if "body" in filters:
                 # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html
                 stmt = stmt.where(Message.message.like(f"%{filters['body']}%"))
-            if 'search' in filters:
+            if "search" in filters:
                 search_term = f"%{filters['search']}%"
-                stmt = stmt.where(or_(
-                    Message.message.like(search_term),
-                    History.source_res.like(search_term)
-                ))
-            if 'types' in filters:
-                types = filters['types'].split()
+                stmt = stmt.where(
+                    or_(
+                        Message.message.like(search_term),
+                        History.source_res.like(search_term),
+                    )
+                )
+            if "types" in filters:
+                types = filters["types"].split()
                 stmt = stmt.where(History.type.in_(types))
-            if 'not_types' in filters:
-                types = filters['not_types'].split()
+            if "not_types" in filters:
+                types = filters["not_types"].split()
                 stmt = stmt.where(History.type.not_in(types))
-            if 'last_stanza_id' in filters:
+            if "last_stanza_id" in filters:
                 # this request get the last message with a "stanza_id" that we
                 # have in history. This is mainly used to retrieve messages sent
                 # while we were offline, using MAM (XEP-0313).
-                if (filters['last_stanza_id'] is not True
-                    or limit != 1):
+                if filters["last_stanza_id"] is not True or limit != 1:
                     raise ValueError("Unexpected values for last_stanza_id filter")
                 stmt = stmt.where(History.stanza_id.is_not(None))
-            if 'origin_id' in filters:
+            if "origin_id" in filters:
                 stmt = stmt.where(History.origin_id == filters["origin_id"])
 
         if limit is not None:
@@ -640,34 +636,40 @@
         @param data: message data as build by SatMessageProtocol.onMessage
         """
         extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA}
-        messages = [Message(message=mess, language=lang)
-                    for lang, mess in data["message"].items()]
-        subjects = [Subject(subject=mess, language=lang)
-                    for lang, mess in data["subject"].items()]
+        messages = [
+            Message(message=mess, language=lang) for lang, mess in data["message"].items()
+        ]
+        subjects = [
+            Subject(subject=mess, language=lang) for lang, mess in data["subject"].items()
+        ]
         if "thread" in data["extra"]:
-            thread = Thread(thread_id=data["extra"]["thread"],
-                            parent_id=data["extra"].get["thread_parent"])
+            thread = Thread(
+                thread_id=data["extra"]["thread"],
+                parent_id=data["extra"].get["thread_parent"],
+            )
         else:
             thread = None
         try:
             async with self.session() as session:
                 async with session.begin():
-                    session.add(History(
-                        uid=data["uid"],
-                        origin_id=data["extra"].get("origin_id"),
-                        stanza_id=data["extra"].get("stanza_id"),
-                        update_uid=data["extra"].get("update_uid"),
-                        profile_id=self.profiles[profile],
-                        source_jid=data["from"],
-                        dest_jid=data["to"],
-                        timestamp=data["timestamp"],
-                        received_timestamp=data.get("received_timestamp"),
-                        type=data["type"],
-                        extra=extra,
-                        messages=messages,
-                        subjects=subjects,
-                        thread=thread,
-                    ))
+                    session.add(
+                        History(
+                            uid=data["uid"],
+                            origin_id=data["extra"].get("origin_id"),
+                            stanza_id=data["extra"].get("stanza_id"),
+                            update_uid=data["extra"].get("update_uid"),
+                            profile_id=self.profiles[profile],
+                            source_jid=data["from"],
+                            dest_jid=data["to"],
+                            timestamp=data["timestamp"],
+                            received_timestamp=data.get("received_timestamp"),
+                            type=data["type"],
+                            extra=extra,
+                            messages=messages,
+                            subjects=subjects,
+                            thread=thread,
+                        )
+                    )
         except IntegrityError as e:
             if "unique" in str(e.orig).lower():
                 log.debug(
@@ -689,14 +691,13 @@
         else:
             return PrivateIndBin if binary else PrivateInd
 
-
     @aio
     async def get_privates(
         self,
-        namespace:str,
+        namespace: str,
         keys: Optional[Iterable[str]] = None,
         binary: bool = False,
-        profile: Optional[str] = None
+        profile: Optional[str] = None,
     ) -> Dict[str, Any]:
         """Get private value(s) from databases
 
@@ -728,10 +729,10 @@
     async def set_private_value(
         self,
         namespace: str,
-        key:str,
+        key: str,
         value: Any,
         binary: bool = False,
-        profile: Optional[str] = None
+        profile: Optional[str] = None,
     ) -> None:
         """Set a private value in database
 
@@ -745,11 +746,7 @@
         """
         cls = self._get_private_class(binary, profile)
 
-        values = {
-            "namespace": namespace,
-            "key": key,
-            "value": value
-        }
+        values = {"namespace": namespace, "key": key, "value": value}
         index_elements = [cls.namespace, cls.key]
 
         if profile is not None:
@@ -758,11 +755,10 @@
 
         async with self.session() as session:
             await session.execute(
-                insert(cls).values(**values).on_conflict_do_update(
-                    index_elements=index_elements,
-                    set_={
-                        cls.value: value
-                    }
+                insert(cls)
+                .values(**values)
+                .on_conflict_do_update(
+                    index_elements=index_elements, set_={cls.value: value}
                 )
             )
             await session.commit()
@@ -773,7 +769,7 @@
         namespace: str,
         key: str,
         binary: bool = False,
-        profile: Optional[str] = None
+        profile: Optional[str] = None,
     ) -> None:
         """Delete private value from database
 
@@ -796,10 +792,7 @@
 
     @aio
     async def del_private_namespace(
-        self,
-        namespace: str,
-        binary: bool = False,
-        profile: Optional[str] = None
+        self, namespace: str, binary: bool = False, profile: Optional[str] = None
     ) -> None:
         """Delete all data from a private namespace
 
@@ -825,7 +818,7 @@
         self,
         client: Optional[SatXMPPEntity],
         file_id: Optional[str] = None,
-        version: Optional[str] = '',
+        version: Optional[str] = "",
         parent: Optional[str] = None,
         type_: Optional[str] = None,
         file_hash: Optional[str] = None,
@@ -837,7 +830,7 @@
         owner: Optional[jid.JID] = None,
         access: Optional[dict] = None,
         projection: Optional[List[str]] = None,
-        unique: bool = False
+        unique: bool = False,
     ) -> List[dict]:
         """Retrieve files with with given filters
 
@@ -857,9 +850,23 @@
         """
         if projection is None:
             projection = [
-                'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name',
-                'size', 'namespace', 'media_type', 'media_subtype', 'public_id',
-                'created', 'modified', 'owner', 'access', 'extra'
+                "id",
+                "version",
+                "parent",
+                "type",
+                "file_hash",
+                "hash_algo",
+                "name",
+                "size",
+                "namespace",
+                "media_type",
+                "media_subtype",
+                "public_id",
+                "created",
+                "modified",
+                "owner",
+                "access",
+                "extra",
             ]
 
         stmt = select(*[getattr(File, f) for f in projection])
@@ -891,7 +898,7 @@
         if namespace is not None:
             stmt = stmt.filter_by(namespace=namespace)
         if mime_type is not None:
-            if '/' in mime_type:
+            if "/" in mime_type:
                 media_type, media_subtype = mime_type.split("/", 1)
                 stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype)
             else:
@@ -901,8 +908,8 @@
         if owner is not None:
             stmt = stmt.filter_by(owner=owner)
         if access is not None:
-            raise NotImplementedError('Access check is not implemented yet')
-            # a JSON comparison is needed here
+            raise NotImplementedError("Access check is not implemented yet")
+            # a JSON comparison is needed here
 
         async with self.session() as session:
             result = await session.execute(stmt)
@@ -928,7 +935,7 @@
         modified: Optional[float] = None,
         owner: Optional[jid.JID] = None,
         access: Optional[dict] = None,
-        extra: Optional[dict] = None
+        extra: Optional[dict] = None,
     ) -> None:
         """Set a file metadata
 
@@ -958,33 +965,35 @@
         """
         if mime_type is None:
             media_type = media_subtype = None
-        elif '/' in mime_type:
-            media_type, media_subtype = mime_type.split('/', 1)
+        elif "/" in mime_type:
+            media_type, media_subtype = mime_type.split("/", 1)
         else:
             media_type, media_subtype = mime_type, None
 
         async with self.session() as session:
             async with session.begin():
-                session.add(File(
-                    id=file_id,
-                    version=version.strip(),
-                    parent=parent,
-                    type=type_,
-                    file_hash=file_hash,
-                    hash_algo=hash_algo,
-                    name=name,
-                    size=size,
-                    namespace=namespace,
-                    media_type=media_type,
-                    media_subtype=media_subtype,
-                    public_id=public_id,
-                    created=time.time() if created is None else created,
-                    modified=modified,
-                    owner=owner,
-                    access=access,
-                    extra=extra,
-                    profile_id=self.profiles[client.profile]
-                ))
+                session.add(
+                    File(
+                        id=file_id,
+                        version=version.strip(),
+                        parent=parent,
+                        type=type_,
+                        file_hash=file_hash,
+                        hash_algo=hash_algo,
+                        name=name,
+                        size=size,
+                        namespace=namespace,
+                        media_type=media_type,
+                        media_subtype=media_subtype,
+                        public_id=public_id,
+                        created=time.time() if created is None else created,
+                        modified=modified,
+                        owner=owner,
+                        access=access,
+                        extra=extra,
+                        profile_id=self.profiles[client.profile],
+                    )
+                )
 
     @aio
     async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int:
@@ -993,8 +1002,9 @@
                 select(sum_(File.size)).filter_by(
                     owner=owner,
                     type=C.FILE_TYPE_FILE,
-                    profile_id=self.profiles[client.profile]
-                ))
+                    profile_id=self.profiles[client.profile],
+                )
+            )
         return result.scalar_one_or_none() or 0
 
     @aio
@@ -1011,10 +1021,7 @@
 
     @aio
     async def file_update(
-        self,
-        file_id: str,
-        column: str,
-        update_cb: Callable[[dict], None]
+        self, file_id: str, column: str, update_cb: Callable[[dict], None]
     ) -> None:
         """Update a column value using a method to avoid race conditions
 
@@ -1029,16 +1036,16 @@
             it get the deserialized data (i.e. a Python object) directly
         @raise exceptions.NotFound: there is not file with this id
         """
-        if column not in ('access', 'extra'):
-            raise exceptions.InternalError('bad column name')
+        if column not in ("access", "extra"):
+            raise exceptions.InternalError("bad column name")
         orm_col = getattr(File, column)
 
         for i in range(5):
             async with self.session() as session:
                 try:
-                    value = (await session.execute(
-                        select(orm_col).filter_by(id=file_id)
-                    )).scalar_one()
+                    value = (
+                        await session.execute(select(orm_col).filter_by(id=file_id))
+                    ).scalar_one()
                 except NoResultFound:
                     raise exceptions.NotFound
                 old_value = copy.deepcopy(value)
@@ -1057,14 +1064,17 @@
                 break
 
             log.warning(
-                _("table not updated, probably due to race condition, trying again "
-                  "({tries})").format(tries=i+1)
+                _(
+                    "table not updated, probably due to race condition, trying again "
+                    "({tries})"
+                ).format(tries=i + 1)
             )
 
         else:
             raise exceptions.DatabaseError(
-                _("Can't update file {file_id} due to race condition")
-                .format(file_id=file_id)
+                _("Can't update file {file_id} due to race condition").format(
+                    file_id=file_id
+                )
             )
 
     @aio
@@ -1076,7 +1086,7 @@
         with_items: bool = False,
         with_subscriptions: bool = False,
         create: bool = False,
-        create_kwargs: Optional[dict] = None
+        create_kwargs: Optional[dict] = None,
     ) -> Optional[PubsubNode]:
         """Retrieve a PubsubNode from DB
 
@@ -1089,22 +1099,15 @@
             needs to be created.
         """
         async with self.session() as session:
-            stmt = (
-                select(PubsubNode)
-                .filter_by(
-                    service=service,
-                    name=name,
-                    profile_id=self.profiles[client.profile],
-                )
+            stmt = select(PubsubNode).filter_by(
+                service=service,
+                name=name,
+                profile_id=self.profiles[client.profile],
             )
             if with_items:
-                stmt = stmt.options(
-                    joinedload(PubsubNode.items)
-                )
+                stmt = stmt.options(joinedload(PubsubNode.items))
             if with_subscriptions:
-                stmt = stmt.options(
-                    joinedload(PubsubNode.subscriptions)
-                )
+                stmt = stmt.options(joinedload(PubsubNode.subscriptions))
             result = await session.execute(stmt)
         ret = result.unique().scalar_one_or_none()
         if ret is None and create:
@@ -1112,21 +1115,23 @@
             if create_kwargs is None:
                 create_kwargs = {}
             try:
-                return await as_future(self.set_pubsub_node(
-                    client, service, name, **create_kwargs
-                ))
+                return await as_future(
+                    self.set_pubsub_node(client, service, name, **create_kwargs)
+                )
             except IntegrityError as e:
                 if "unique" in str(e.orig).lower():
                     # the node may already exist, if it has been created just after
                     # get_pubsub_node above
                     log.debug("ignoring UNIQUE constraint error")
-                    cached_node = await as_future(self.get_pubsub_node(
-                        client,
-                        service,
-                        name,
-                        with_items=with_items,
-                        with_subscriptions=with_subscriptions
-                    ))
+                    cached_node = await as_future(
+                        self.get_pubsub_node(
+                            client,
+                            service,
+                            name,
+                            with_items=with_items,
+                            with_subscriptions=with_subscriptions,
+                        )
+                    )
                 else:
                     raise e
         else:
@@ -1160,9 +1165,7 @@
 
     @aio
     async def update_pubsub_node_sync_state(
-        self,
-        node: PubsubNode,
-        state: SyncState
+        self, node: PubsubNode, state: SyncState
     ) -> None:
         async with self.session() as session:
             async with session.begin():
@@ -1180,7 +1183,7 @@
         self,
         profiles: Optional[List[str]],
         services: Optional[List[jid.JID]],
-        names: Optional[List[str]]
+        names: Optional[List[str]],
     ) -> None:
         """Delete items cached for a node
 
@@ -1194,9 +1197,7 @@
         stmt = delete(PubsubNode)
         if profiles is not None:
             stmt = stmt.where(
-                PubsubNode.profile.in_(
-                    [self.profiles[p] for p in profiles]
-                )
+                PubsubNode.profile.in_([self.profiles[p] for p in profiles])
             )
         if services is not None:
             stmt = stmt.where(PubsubNode.service.in_(services))
@@ -1223,27 +1224,29 @@
             async with session.begin():
                 for idx, item in enumerate(items):
                     parsed = parsed_items[idx] if parsed_items else None
-                    stmt = insert(PubsubItem).values(
-                        node_id = node.id,
-                        name = item["id"],
-                        data = item,
-                        parsed = parsed,
-                    ).on_conflict_do_update(
-                        index_elements=(PubsubItem.node_id, PubsubItem.name),
-                        set_={
-                            PubsubItem.data: item,
-                            PubsubItem.parsed: parsed,
-                            PubsubItem.updated: now()
-                        }
+                    stmt = (
+                        insert(PubsubItem)
+                        .values(
+                            node_id=node.id,
+                            name=item["id"],
+                            data=item,
+                            parsed=parsed,
+                        )
+                        .on_conflict_do_update(
+                            index_elements=(PubsubItem.node_id, PubsubItem.name),
+                            set_={
+                                PubsubItem.data: item,
+                                PubsubItem.parsed: parsed,
+                                PubsubItem.updated: now(),
+                            },
+                        )
                     )
                     await session.execute(stmt)
                 await session.commit()
 
     @aio
     async def delete_pubsub_items(
-        self,
-        node: PubsubNode,
-        items_names: Optional[List[str]] = None
+        self, node: PubsubNode, items_names: Optional[List[str]] = None
     ) -> None:
         """Delete items cached for a node
 
@@ -1296,10 +1299,8 @@
                 if values is None:
                     continue
                 sub_q = sub_q.where(getattr(PubsubNode, col).in_(values))
-            stmt = (
-                stmt
-                .where(PubsubItem.node_id.in_(sub_q))
-                .execution_options(synchronize_session=False)
+            stmt = stmt.where(PubsubItem.node_id.in_(sub_q)).execution_options(
+                synchronize_session=False
             )
 
         if created_before is not None:
@@ -1367,11 +1368,7 @@
             use_rsm = True
             from_index = 0
 
-        stmt = (
-            select(PubsubItem)
-            .filter_by(node_id=node.id)
-            .limit(max_items)
-        )
+        stmt = select(PubsubItem).filter_by(node_id=node.id).limit(max_items)
 
         if item_ids is not None:
             stmt = stmt.where(PubsubItem.name.in_(item_ids))
@@ -1402,7 +1399,7 @@
                 PubsubItem.id,
                 PubsubItem.name,
                 # row_number starts from 1, but RSM index must start from 0
-                (func.row_number().over(order_by=order)-1).label("item_index")
+                (func.row_number().over(order_by=order) - 1).label("item_index"),
             ).filter_by(node_id=node.id)
 
             row_num_cte = row_num_q.cte()
@@ -1412,8 +1409,9 @@
                 # we need to use row number
                 item_name = before or after
                 row_num_limit_q = (
-                    select(row_num_cte.c.item_index)
-                    .where(row_num_cte.c.name==item_name)
+                    select(row_num_cte.c.item_index).where(
+                        row_num_cte.c.name == item_name
+                    )
                 ).scalar_subquery()
 
                 stmt = (
@@ -1422,22 +1420,16 @@
                     .limit(max_items)
                 )
                 if before:
-                    stmt = (
-                        stmt
-                        .where(row_num_cte.c.item_index<row_num_limit_q)
-                        .order_by(row_num_cte.c.item_index.desc())
-                    )
+                    stmt = stmt.where(
+                        row_num_cte.c.item_index < row_num_limit_q
+                    ).order_by(row_num_cte.c.item_index.desc())
                 elif after:
-                    stmt = (
-                        stmt
-                        .where(row_num_cte.c.item_index>row_num_limit_q)
-                        .order_by(row_num_cte.c.item_index.asc())
-                    )
+                    stmt = stmt.where(
+                        row_num_cte.c.item_index > row_num_limit_q
+                    ).order_by(row_num_cte.c.item_index.asc())
                 else:
-                    stmt = (
-                        stmt
-                        .where(row_num_cte.c.item_index>=from_index)
-                        .order_by(row_num_cte.c.item_index.asc())
+                    stmt = stmt.where(row_num_cte.c.item_index >= from_index).order_by(
+                        row_num_cte.c.item_index.asc()
                     )
                     # from_index is used
 
@@ -1468,12 +1460,14 @@
                 last = result[-1][1].name
 
             metadata["rsm"] = {
-                k: v for k, v in {
+                k: v
+                for k, v in {
                     "index": index,
                     "count": rows_count,
                     "first": first,
                     "last": last,
-                }.items() if v is not None
+                }.items()
+                if v is not None
             }
             metadata["complete"] = (index or 0) + len(result) == rows_count
 
@@ -1487,10 +1481,7 @@
             result.reverse()
         return result, metadata
 
-    def _get_sqlite_path(
-        self,
-        path: List[Union[str, int]]
-    ) -> str:
+    def _get_sqlite_path(self, path: List[Union[str, int]]) -> str:
         """generate path suitable to query JSON element with SQLite"""
         return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}"
 
@@ -1557,19 +1548,20 @@
         # Full-Text Search
         fts = query.get("fts")
         if fts:
-            fts_select = text(
-                "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)"
-            ).bindparams(fts_query=fts).columns(rowid=Integer).subquery()
-            stmt = (
-                stmt
-                .select_from(fts_select)
-                .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id)
+            fts_select = (
+                text("SELECT rowid, rank FROM pubsub_items_fts(:fts_query)")
+                .bindparams(fts_query=fts)
+                .columns(rowid=Integer)
+                .subquery()
+            )
+            stmt = stmt.select_from(fts_select).outerjoin(
+                PubsubItem, fts_select.c.rowid == PubsubItem.id
             )
 
         # node related filters
         profiles = query.get("profiles")
-        if (profiles
-            or any(query.get(k) for k in ("nodes", "services", "types", "subtypes"))
+        if profiles or any(
+            query.get(k) for k in ("nodes", "services", "types", "subtypes")
         ):
             stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node))
             if profiles:
@@ -1585,7 +1577,7 @@
                 ("nodes", "name"),
                 ("services", "service"),
                 ("types", "type_"),
-                ("subtypes", "subtype")
+                ("subtypes", "subtype"),
             ):
                 value = query.get(key)
                 if not value:
@@ -1596,10 +1588,7 @@
                     value.remove(None)
                     condition = getattr(PubsubNode, attr).is_(None)
                     if value:
-                        condition = or_(
-                            getattr(PubsubNode, attr).in_(value),
-                            condition
-                        )
+                        condition = or_(getattr(PubsubNode, attr).in_(value), condition)
                 else:
                     condition = getattr(PubsubNode, attr).in_(value)
                 stmt = stmt.where(condition)
@@ -1644,10 +1633,12 @@
                 try:
                     left, right = value
                 except (ValueError, TypeError):
-                    raise ValueError(_(
-                        "invalid value for \"between\" filter, you must use a 2 items "
-                        "array: {value!r}"
-                    ).format(value=value))
+                    raise ValueError(
+                        _(
+                            'invalid value for "between" filter, you must use a 2 items '
+                            "array: {value!r}"
+                        ).format(value=value)
+                    )
                 col = func.json_extract(PubsubItem.parsed, sqlite_path)
                 stmt = stmt.where(col.between(left, right))
             else:
@@ -1662,10 +1653,12 @@
         for order_data in order_by:
             order, path = order_data.get("order"), order_data.get("path")
             if order and path:
-                raise ValueError(_(
-                    '"order" and "path" can\'t be used at the same time in '
-                    '"order-by" data'
-                ))
+                raise ValueError(
+                    _(
+                        '"order" and "path" can\'t be used at the same time in '
+                        '"order-by" data'
+                    )
+                )
             if order:
                 if order == "creation":
                     col = PubsubItem.id
@@ -1702,3 +1695,172 @@
             result = await session.execute(stmt)
 
         return result.scalars().all()
+
+    # Notifications
+
+    @aio
+    async def add_notification(
+        self,
+        client: Optional[SatXMPPEntity],
+        type_: NotificationType,
+        body_plain: str,
+        body_rich: Optional[str] = None,
+        title: Optional[str] = None,
+        requires_action: bool = False,
+        priority: NotificationPriority = NotificationPriority.MEDIUM,
+        expire_at: Optional[float] = None,
+        extra: Optional[dict] = None,
+    ) -> Notification:
+        """Add a new notification to the DB.
+
+        @param client: client associated with the notification. If None, the notification
+            will be global.
+        @param type_: type of the notification.
+        @param body_plain: plain text body.
+        @param body_rich: rich text (XHTML) body.
+        @param title: optional title.
+        @param requires_action: True if the notification requires user action (e.g. a
+            dialog need to be answered).
+        @priority: how urgent the notification is
+        @param expire_at: expiration timestamp for the notification.
+        @param extra: additional data.
+        @return: created Notification
+        """
+        profile_id = self.profiles[client.profile] if client else None
+        notification = Notification(
+            profile_id=profile_id,
+            type=type_,
+            body_plain=body_plain,
+            body_rich=body_rich,
+            requires_action=requires_action,
+            priority=priority,
+            expire_at=expire_at,
+            title=title,
+            extra_data=extra,
+            status=NotificationStatus.new,
+        )
+        async with self.session() as session:
+            async with session.begin():
+                session.add(notification)
+        return notification
+
+    @aio
+    async def update_notification(
+        self, client: SatXMPPEntity, notification_id: int, **kwargs
+    ) -> None:
+        """Update an existing notification.
+
+        @param client: client associated with the notification.
+        @param notification_id: ID of the notification to update.
+        """
+        profile_id = self.profiles[client.profile]
+        async with self.session() as session:
+            await session.execute(
+                update(Notification)
+                .where(
+                    and_(
+                        Notification.profile_id == profile_id,
+                        Notification.id == notification_id,
+                    )
+                )
+                .values(**kwargs)
+            )
+            await session.commit()
+
+    @aio
+    async def get_notifications(
+        self,
+        client: SatXMPPEntity,
+        type_: Optional[NotificationType] = None,
+        status: Optional[NotificationStatus] = None,
+        requires_action: Optional[bool] = None,
+        min_priority: Optional[int] = None
+    ) -> List[Notification]:
+        """Retrieve all notifications for a given profile with optional filters.
+
+        @param client: client associated with the notifications.
+        @param type_: filter by type of the notification.
+        @param status: filter by status of the notification.
+        @param requires_action: filter by notifications that require user action.
+        @param min_priority: filter by minimum priority value.
+        @return: list of matching Notification instances.
+        """
+        profile_id = self.profiles[client.profile]
+        filters = [or_(Notification.profile_id == profile_id, Notification.profile_id.is_(None))]
+
+        if type_:
+            filters.append(Notification.type == type_)
+        if status:
+            filters.append(Notification.status == status)
+        if requires_action is not None:
+            filters.append(Notification.requires_action == requires_action)
+        if min_priority:
+            filters.append(Notification.priority >= min_priority)
+
+        async with self.session() as session:
+            result = await session.execute(
+                select(Notification)
+                .where(and_(*filters))
+                .order_by(Notification.id)
+            )
+            return result.scalars().all()
+
+    @aio
+    async def delete_notification(
+        self, client: Optional[SatXMPPEntity], notification_id: str
+    ) -> None:
+        """Delete a notification by its profile and id.
+
+        @param client: client associated with the notification. If None, profile_id will be NULL.
+        @param notification_id: ID of the notification to delete.
+        """
+        profile_id = self.profiles[client.profile] if client else None
+        async with self.session() as session:
+            await session.execute(
+                delete(Notification).where(
+                    and_(
+                        Notification.profile_id == profile_id,
+                        Notification.id == int(notification_id),
+                    )
+                )
+            )
+            await session.commit()
+
+    @aio
+    async def clean_expired_notifications(
+        self, client: Optional[SatXMPPEntity], limit_timestamp: Optional[float] = None
+    ) -> None:
+        """Cleans expired notifications and older profile-specific notifications.
+
+        - Removes all notifications where the expiration timestamp has passed,
+          irrespective of their profile.
+        - If a limit_timestamp is provided, removes older notifications with a profile set
+          (i.e., not global notifications) that do not require user action. If client is
+          provided, only remove notification for this profile.
+
+        @param client: if provided, only expire notification for this client (in addition
+            to truly expired notifications for everybody).
+        @param limit_timestamp: Timestamp limit for older notifications. If None, only
+            truly expired notifications are removed.
+        """
+
+        # Delete truly expired notifications
+        expired_condition = Notification.expire_at < time.time()
+
+        # Delete older profile-specific notifications (created before the limit_timestamp)
+        if client is None:
+            profile_condition = Notification.profile_id.isnot(None)
+        else:
+            profile_condition = Notification.profile_id == self.profiles[client.profile]
+        older_condition = and_(
+            profile_condition,
+            Notification.timestamp < limit_timestamp if limit_timestamp else False,
+            Notification.requires_action == False,
+        )
+
+        # Combine the conditions
+        conditions = or_(expired_condition, older_condition)
+
+        async with self.session() as session:
+            await session.execute(delete(Notification).where(conditions))
+            await session.commit()