diff sat/memory/sqla.py @ 3595:7510648e8e3a

core (memory/sqla): methods to manipulate pubsub tables
author Goffi <goffi@goffi.org>
date Thu, 29 Jul 2021 22:51:01 +0200
parents 16ade4ad63f3
children 2d97c695af05
line wrap: on
line diff
--- a/sat/memory/sqla.py	Thu Jul 29 22:50:57 2021 +0200
+++ b/sat/memory/sqla.py	Thu Jul 29 22:51:01 2021 +0200
@@ -19,31 +19,38 @@
 import sys
 import time
 import asyncio
+from datetime import datetime
 from asyncio.subprocess import PIPE
 from pathlib import Path
 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional
 from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine
 from sqlalchemy.exc import IntegrityError, NoResultFound
-from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager
+from sqlalchemy.orm import (
+    sessionmaker, subqueryload, joinedload, contains_eager # , aliased
+)
+from sqlalchemy.orm.decl_api import DeclarativeMeta
 from sqlalchemy.future import select
 from sqlalchemy.engine import Engine, Connection
-from sqlalchemy import update, delete, and_, or_, event
-from sqlalchemy.sql.functions import coalesce, sum as sum_
+from sqlalchemy import update, delete, and_, or_, event, func
+from sqlalchemy.sql.functions import coalesce, sum as sum_, now, count
 from sqlalchemy.dialects.sqlite import insert
 from alembic import script as al_script, config as al_config
 from alembic.runtime import migration as al_migration
 from twisted.internet import defer
 from twisted.words.protocols.jabber import jid
+from twisted.words.xish import domish
 from sat.core.i18n import _
 from sat.core import exceptions
 from sat.core.log import getLogger
 from sat.core.constants import Const as C
 from sat.core.core_types import SatXMPPEntity
 from sat.tools.utils import aio
+from sat.tools.common import uri
 from sat.memory import migration
 from sat.memory import sqla_config
 from sat.memory.sqla_mapping import (
     NOT_IN_EXTRA,
+    SyncState,
     Base,
     Profile,
     Component,
@@ -57,7 +64,9 @@
     PrivateInd,
     PrivateGenBin,
     PrivateIndBin,
-    File
+    File,
+    PubsubNode,
+    PubsubItem,
 )
 
 
@@ -153,7 +162,7 @@
         db_config = sqla_config.getDbConfig()
         engine = create_async_engine(
             db_config["url"],
-            future=True
+            future=True,
         )
 
         new_base = not db_config["path"].exists()
@@ -952,3 +961,379 @@
                 _("Can't update file {file_id} due to race condition")
                 .format(file_id=file_id)
             )
+
+    @aio
+    async def getPubsubNode(
+        self,
+        client: SatXMPPEntity,
+        service: jid.JID,
+        name: str,
+        with_items: bool = False,
+    ) -> Optional[PubsubNode]:
+        """
+        """
+        async with self.session() as session:
+            stmt = (
+                select(PubsubNode)
+                .filter_by(
+                    service=service,
+                    name=name,
+                    profile_id=self.profiles[client.profile],
+                )
+            )
+            if with_items:
+                stmt = stmt.options(
+                    joinedload(PubsubNode.items)
+                )
+            result = await session.execute(stmt)
+        return result.unique().scalar_one_or_none()
+
+    @aio
+    async def setPubsubNode(
+        self,
+        client: SatXMPPEntity,
+        service: jid.JID,
+        name: str,
+        analyser: Optional[str] = None,
+        type_: Optional[str] = None,
+        subtype: Optional[str] = None,
+    ) -> PubsubNode:
+        node = PubsubNode(
+            profile_id=self.profiles[client.profile],
+            service=service,
+            name=name,
+            subscribed=False,
+            analyser=analyser,
+            type_=type_,
+            subtype=subtype,
+        )
+        async with self.session() as session:
+            async with session.begin():
+                session.add(node)
+        return node
+
+    @aio
+    async def updatePubsubNodeSyncState(
+        self,
+        node: PubsubNode,
+        state: SyncState
+    ) -> None:
+        async with self.session() as session:
+            async with session.begin():
+                await session.execute(
+                    update(PubsubNode)
+                    .filter_by(id=node.id)
+                    .values(
+                        sync_state=state,
+                        sync_state_updated=time.time(),
+                    )
+                )
+
+    @aio
+    async def deletePubsubNode(
+        self,
+        profiles: Optional[List[str]],
+        services: Optional[List[jid.JID]],
+        names: Optional[List[str]]
+    ) -> None:
+        """Delete items cached for a node
+
+        @param profiles: profile names from which nodes must be deleted.
+            None to remove nodes from ALL profiles
+        @param services: JIDs of pubsub services from which nodes must be deleted.
+            None to remove nodes from ALL services
+        @param names: names of nodes which must be deleted.
+            None to remove ALL nodes whatever is their names
+        """
+        stmt = delete(PubsubNode)
+        if profiles is not None:
+            stmt = stmt.where(
+                PubsubNode.profile.in_(
+                    [self.profiles[p] for p in profiles]
+                )
+            )
+        if services is not None:
+            stmt = stmt.where(PubsubNode.service.in_(services))
+        if names is not None:
+            stmt = stmt.where(PubsubNode.name.in_(names))
+        async with self.session() as session:
+            await session.execute(stmt)
+            await session.commit()
+
+    @aio
+    async def cachePubsubItems(
+        self,
+        client: SatXMPPEntity,
+        node: PubsubNode,
+        items: List[domish.Element],
+        parsed_items: Optional[List[dict]] = None,
+    ) -> None:
+        """Add items to database, using an upsert taking care of "updated" field"""
+        if parsed_items is not None and len(items) != len(parsed_items):
+            raise exceptions.InternalError(
+                "parsed_items must have the same lenght as items"
+            )
+        async with self.session() as session:
+            async with session.begin():
+                for idx, item in enumerate(items):
+                    parsed = parsed_items[idx] if parsed_items else None
+                    stmt = insert(PubsubItem).values(
+                        node_id = node.id,
+                        name = item["id"],
+                        data = item,
+                        parsed = parsed,
+                    ).on_conflict_do_update(
+                        index_elements=(PubsubItem.node_id, PubsubItem.name),
+                        set_={
+                            PubsubItem.data: item,
+                            PubsubItem.parsed: parsed,
+                            PubsubItem.updated: now()
+                        }
+                    )
+                    await session.execute(stmt)
+                await session.commit()
+
+    @aio
+    async def deletePubsubItems(
+        self,
+        node: PubsubNode,
+        items_names: Optional[List[str]] = None
+    ) -> None:
+        """Delete items cached for a node
+
+        @param node: node from which items must be deleted
+        @param items_names: names of items to delete
+            if None, ALL items will be deleted
+        """
+        stmt = delete(PubsubItem)
+        if node is not None:
+            if isinstance(node, list):
+                stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node]))
+            else:
+                stmt = stmt.filter_by(node_id=node.id)
+        if items_names is not None:
+            stmt = stmt.where(PubsubItem.name.in_(items_names))
+        async with self.session() as session:
+            await session.execute(stmt)
+            await session.commit()
+
+    @aio
+    async def purgePubsubItems(
+        self,
+        services: Optional[List[jid.JID]] = None,
+        names: Optional[List[str]] = None,
+        types: Optional[List[str]] = None,
+        subtypes: Optional[List[str]] = None,
+        profiles: Optional[List[str]] = None,
+        created_before: Optional[datetime] = None,
+        updated_before: Optional[datetime] = None,
+    ) -> None:
+        """Delete items cached for a node
+
+        @param node: node from which items must be deleted
+        @param items_names: names of items to delete
+            if None, ALL items will be deleted
+        """
+        stmt = delete(PubsubItem)
+        node_fields = {
+            "service": services,
+            "name": names,
+            "type_": types,
+            "subtype": subtypes,
+        }
+        if any(x is not None for x in node_fields.values()):
+            sub_q = select(PubsubNode.id)
+            for col, values in node_fields.items():
+                if values is None:
+                    continue
+                sub_q = sub_q.where(getattr(PubsubNode, col).in_(values))
+            stmt = (
+                stmt
+                .where(PubsubItem.node_id.in_(sub_q))
+                .execution_options(synchronize_session=False)
+            )
+
+        if profiles is not None:
+            stmt = stmt.where(
+                PubsubItem.profile_id.in_([self.profiles[p] for p in profiles])
+            )
+
+        if created_before is not None:
+            stmt = stmt.where(PubsubItem.created < created_before)
+
+        if updated_before is not None:
+            stmt = stmt.where(PubsubItem.updated < updated_before)
+
+        async with self.session() as session:
+            await session.execute(stmt)
+            await session.commit()
+
+    @aio
+    async def getItems(
+        self,
+        node: PubsubNode,
+        max_items: Optional[int] = None,
+        before: Optional[str] = None,
+        after: Optional[str] = None,
+        from_index: Optional[int] = None,
+        order_by: Optional[List[str]] = None,
+        desc: bool = True,
+        force_rsm: bool = False,
+    ) -> Tuple[List[PubsubItem], dict]:
+        """Get Pubsub Items from cache
+
+        @param node: retrieve items from this node (must be synchronised)
+        @param max_items: maximum number of items to retrieve
+        @param before: get items which are before the item with this name in given order
+            empty string is not managed here, use desc order to reproduce RSM
+            behaviour.
+        @param after: get items which are after the item with this name in given order
+        @param from_index: get items with item index (as defined in RSM spec)
+            starting from this number
+        @param order_by: sorting order of items (one of C.ORDER_BY_*)
+        @param desc: direction or ordering
+        @param force_rsm: if True, force the use of RSM worklow.
+            RSM workflow is automatically used if any of before, after or
+            from_index is used, but if only RSM max_items is used, it won't be
+            used by default. This parameter let's use RSM workflow in this
+            case. Note that in addition to RSM metadata, the result will not be
+            the same (max_items without RSM will returns most recent items,
+            i.e. last items in modification order, while max_items with RSM
+            will return the oldest ones (i.e. first items in modification
+            order).
+            to be used when max_items is used from RSM
+        """
+
+        metadata = {
+            "service": node.service,
+            "node": node.name,
+            "uri": uri.buildXMPPUri(
+                "pubsub",
+                path=node.service.full(),
+                node=node.name,
+            ),
+        }
+        if max_items is None:
+            max_items = 20
+
+        use_rsm = any((before, after, from_index is not None))
+        if force_rsm and not use_rsm:
+            #
+            use_rsm = True
+            from_index = 0
+
+        stmt = (
+            select(PubsubItem)
+            .filter_by(node_id=node.id)
+            .limit(max_items)
+        )
+
+        if not order_by:
+            order_by = [C.ORDER_BY_MODIFICATION]
+
+        order = []
+        for order_type in order_by:
+            if order_type == C.ORDER_BY_MODIFICATION:
+                if desc:
+                    order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc()))
+                else:
+                    order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc()))
+            elif order_type == C.ORDER_BY_CREATION:
+                if desc:
+                    order.append(PubsubItem.id.desc())
+                else:
+                    order.append(PubsubItem.id.asc())
+            else:
+                raise exceptions.InternalError(f"Unknown order type {order_type!r}")
+
+        stmt = stmt.order_by(*order)
+
+        if use_rsm:
+            # CTE to have result row numbers
+            row_num_q = select(
+                PubsubItem.id,
+                PubsubItem.name,
+                # row_number starts from 1, but RSM index must start from 0
+                (func.row_number().over(order_by=order)-1).label("item_index")
+            ).filter_by(node_id=node.id)
+
+            row_num_cte = row_num_q.cte()
+
+            if max_items > 0:
+                # as we can't simply use PubsubItem.id when we order by modification,
+                # we need to use row number
+                item_name = before or after
+                row_num_limit_q = (
+                    select(row_num_cte.c.item_index)
+                    .where(row_num_cte.c.name==item_name)
+                ).scalar_subquery()
+
+                stmt = (
+                    select(row_num_cte.c.item_index, PubsubItem)
+                    .join(row_num_cte, PubsubItem.id == row_num_cte.c.id)
+                    .limit(max_items)
+                )
+                if before:
+                    stmt = (
+                        stmt
+                        .where(row_num_cte.c.item_index<row_num_limit_q)
+                        .order_by(row_num_cte.c.item_index.desc())
+                    )
+                elif after:
+                    stmt = (
+                        stmt
+                        .where(row_num_cte.c.item_index>row_num_limit_q)
+                        .order_by(row_num_cte.c.item_index.asc())
+                    )
+                else:
+                    stmt = (
+                        stmt
+                        .where(row_num_cte.c.item_index>=from_index)
+                        .order_by(row_num_cte.c.item_index.asc())
+                    )
+                    # from_index is used
+
+            async with self.session() as session:
+                if max_items == 0:
+                    items = result = []
+                else:
+                    result = await session.execute(stmt)
+                    result = result.all()
+                    if before:
+                        result.reverse()
+                    items = [row[-1] for row in result]
+                rows_count = (
+                    await session.execute(row_num_q.with_only_columns(count()))
+                ).scalar_one()
+
+            try:
+                index = result[0][0]
+            except IndexError:
+                index = None
+
+            try:
+                first = result[0][1].name
+            except IndexError:
+                first = None
+                last = None
+            else:
+                last = result[-1][1].name
+
+
+            metadata["rsm"] = {
+                "index": index,
+                "count": rows_count,
+                "first": first,
+                "last": last,
+            }
+            metadata["complete"] = index + len(result) == rows_count
+
+            return items, metadata
+
+        async with self.session() as session:
+            result = await session.execute(stmt)
+
+        result = result.scalars().all()
+        if desc:
+            result.reverse()
+        return result, metadata