# HG changeset patch # User Goffi # Date 1627591861 -7200 # Node ID 7510648e8e3a931d978654576c9619ed807a49c4 # Parent d5116197e403adfbfd7d2b96bbf528ae5f186e53 core (memory/sqla): methods to manipulate pubsub tables diff -r d5116197e403 -r 7510648e8e3a sat/memory/sqla.py --- a/sat/memory/sqla.py Thu Jul 29 22:50:57 2021 +0200 +++ b/sat/memory/sqla.py Thu Jul 29 22:51:01 2021 +0200 @@ -19,31 +19,38 @@ import sys import time import asyncio +from datetime import datetime from asyncio.subprocess import PIPE from pathlib import Path from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine from sqlalchemy.exc import IntegrityError, NoResultFound -from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager +from sqlalchemy.orm import ( + sessionmaker, subqueryload, joinedload, contains_eager # , aliased +) +from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.future import select from sqlalchemy.engine import Engine, Connection -from sqlalchemy import update, delete, and_, or_, event -from sqlalchemy.sql.functions import coalesce, sum as sum_ +from sqlalchemy import update, delete, and_, or_, event, func +from sqlalchemy.sql.functions import coalesce, sum as sum_, now, count from sqlalchemy.dialects.sqlite import insert from alembic import script as al_script, config as al_config from alembic.runtime import migration as al_migration from twisted.internet import defer from twisted.words.protocols.jabber import jid +from twisted.words.xish import domish from sat.core.i18n import _ from sat.core import exceptions from sat.core.log import getLogger from sat.core.constants import Const as C from sat.core.core_types import SatXMPPEntity from sat.tools.utils import aio +from sat.tools.common import uri from sat.memory import migration from sat.memory import sqla_config from sat.memory.sqla_mapping import ( NOT_IN_EXTRA, + SyncState, Base, Profile, Component, @@ -57,7 +64,9 @@ PrivateInd, PrivateGenBin, PrivateIndBin, - File + File, + PubsubNode, + PubsubItem, ) @@ -153,7 +162,7 @@ db_config = sqla_config.getDbConfig() engine = create_async_engine( db_config["url"], - future=True + future=True, ) new_base = not db_config["path"].exists() @@ -952,3 +961,379 @@ _("Can't update file {file_id} due to race condition") .format(file_id=file_id) ) + + @aio + async def getPubsubNode( + self, + client: SatXMPPEntity, + service: jid.JID, + name: str, + with_items: bool = False, + ) -> Optional[PubsubNode]: + """ + """ + async with self.session() as session: + stmt = ( + select(PubsubNode) + .filter_by( + service=service, + name=name, + profile_id=self.profiles[client.profile], + ) + ) + if with_items: + stmt = stmt.options( + joinedload(PubsubNode.items) + ) + result = await session.execute(stmt) + return result.unique().scalar_one_or_none() + + @aio + async def setPubsubNode( + self, + client: SatXMPPEntity, + service: jid.JID, + name: str, + analyser: Optional[str] = None, + type_: Optional[str] = None, + subtype: Optional[str] = None, + ) -> PubsubNode: + node = PubsubNode( + profile_id=self.profiles[client.profile], + service=service, + name=name, + subscribed=False, + analyser=analyser, + type_=type_, + subtype=subtype, + ) + async with self.session() as session: + async with session.begin(): + session.add(node) + return node + + @aio + async def updatePubsubNodeSyncState( + self, + node: PubsubNode, + state: SyncState + ) -> None: + async with self.session() as session: + async with session.begin(): + await session.execute( + update(PubsubNode) + .filter_by(id=node.id) + .values( + sync_state=state, + sync_state_updated=time.time(), + ) + ) + + @aio + async def deletePubsubNode( + self, + profiles: Optional[List[str]], + services: Optional[List[jid.JID]], + names: Optional[List[str]] + ) -> None: + """Delete items cached for a node + + @param profiles: profile names from which nodes must be deleted. + None to remove nodes from ALL profiles + @param services: JIDs of pubsub services from which nodes must be deleted. + None to remove nodes from ALL services + @param names: names of nodes which must be deleted. + None to remove ALL nodes whatever is their names + """ + stmt = delete(PubsubNode) + if profiles is not None: + stmt = stmt.where( + PubsubNode.profile.in_( + [self.profiles[p] for p in profiles] + ) + ) + if services is not None: + stmt = stmt.where(PubsubNode.service.in_(services)) + if names is not None: + stmt = stmt.where(PubsubNode.name.in_(names)) + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def cachePubsubItems( + self, + client: SatXMPPEntity, + node: PubsubNode, + items: List[domish.Element], + parsed_items: Optional[List[dict]] = None, + ) -> None: + """Add items to database, using an upsert taking care of "updated" field""" + if parsed_items is not None and len(items) != len(parsed_items): + raise exceptions.InternalError( + "parsed_items must have the same lenght as items" + ) + async with self.session() as session: + async with session.begin(): + for idx, item in enumerate(items): + parsed = parsed_items[idx] if parsed_items else None + stmt = insert(PubsubItem).values( + node_id = node.id, + name = item["id"], + data = item, + parsed = parsed, + ).on_conflict_do_update( + index_elements=(PubsubItem.node_id, PubsubItem.name), + set_={ + PubsubItem.data: item, + PubsubItem.parsed: parsed, + PubsubItem.updated: now() + } + ) + await session.execute(stmt) + await session.commit() + + @aio + async def deletePubsubItems( + self, + node: PubsubNode, + items_names: Optional[List[str]] = None + ) -> None: + """Delete items cached for a node + + @param node: node from which items must be deleted + @param items_names: names of items to delete + if None, ALL items will be deleted + """ + stmt = delete(PubsubItem) + if node is not None: + if isinstance(node, list): + stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node])) + else: + stmt = stmt.filter_by(node_id=node.id) + if items_names is not None: + stmt = stmt.where(PubsubItem.name.in_(items_names)) + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def purgePubsubItems( + self, + services: Optional[List[jid.JID]] = None, + names: Optional[List[str]] = None, + types: Optional[List[str]] = None, + subtypes: Optional[List[str]] = None, + profiles: Optional[List[str]] = None, + created_before: Optional[datetime] = None, + updated_before: Optional[datetime] = None, + ) -> None: + """Delete items cached for a node + + @param node: node from which items must be deleted + @param items_names: names of items to delete + if None, ALL items will be deleted + """ + stmt = delete(PubsubItem) + node_fields = { + "service": services, + "name": names, + "type_": types, + "subtype": subtypes, + } + if any(x is not None for x in node_fields.values()): + sub_q = select(PubsubNode.id) + for col, values in node_fields.items(): + if values is None: + continue + sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) + stmt = ( + stmt + .where(PubsubItem.node_id.in_(sub_q)) + .execution_options(synchronize_session=False) + ) + + if profiles is not None: + stmt = stmt.where( + PubsubItem.profile_id.in_([self.profiles[p] for p in profiles]) + ) + + if created_before is not None: + stmt = stmt.where(PubsubItem.created < created_before) + + if updated_before is not None: + stmt = stmt.where(PubsubItem.updated < updated_before) + + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def getItems( + self, + node: PubsubNode, + max_items: Optional[int] = None, + before: Optional[str] = None, + after: Optional[str] = None, + from_index: Optional[int] = None, + order_by: Optional[List[str]] = None, + desc: bool = True, + force_rsm: bool = False, + ) -> Tuple[List[PubsubItem], dict]: + """Get Pubsub Items from cache + + @param node: retrieve items from this node (must be synchronised) + @param max_items: maximum number of items to retrieve + @param before: get items which are before the item with this name in given order + empty string is not managed here, use desc order to reproduce RSM + behaviour. + @param after: get items which are after the item with this name in given order + @param from_index: get items with item index (as defined in RSM spec) + starting from this number + @param order_by: sorting order of items (one of C.ORDER_BY_*) + @param desc: direction or ordering + @param force_rsm: if True, force the use of RSM worklow. + RSM workflow is automatically used if any of before, after or + from_index is used, but if only RSM max_items is used, it won't be + used by default. This parameter let's use RSM workflow in this + case. Note that in addition to RSM metadata, the result will not be + the same (max_items without RSM will returns most recent items, + i.e. last items in modification order, while max_items with RSM + will return the oldest ones (i.e. first items in modification + order). + to be used when max_items is used from RSM + """ + + metadata = { + "service": node.service, + "node": node.name, + "uri": uri.buildXMPPUri( + "pubsub", + path=node.service.full(), + node=node.name, + ), + } + if max_items is None: + max_items = 20 + + use_rsm = any((before, after, from_index is not None)) + if force_rsm and not use_rsm: + # + use_rsm = True + from_index = 0 + + stmt = ( + select(PubsubItem) + .filter_by(node_id=node.id) + .limit(max_items) + ) + + if not order_by: + order_by = [C.ORDER_BY_MODIFICATION] + + order = [] + for order_type in order_by: + if order_type == C.ORDER_BY_MODIFICATION: + if desc: + order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc())) + else: + order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc())) + elif order_type == C.ORDER_BY_CREATION: + if desc: + order.append(PubsubItem.id.desc()) + else: + order.append(PubsubItem.id.asc()) + else: + raise exceptions.InternalError(f"Unknown order type {order_type!r}") + + stmt = stmt.order_by(*order) + + if use_rsm: + # CTE to have result row numbers + row_num_q = select( + PubsubItem.id, + PubsubItem.name, + # row_number starts from 1, but RSM index must start from 0 + (func.row_number().over(order_by=order)-1).label("item_index") + ).filter_by(node_id=node.id) + + row_num_cte = row_num_q.cte() + + if max_items > 0: + # as we can't simply use PubsubItem.id when we order by modification, + # we need to use row number + item_name = before or after + row_num_limit_q = ( + select(row_num_cte.c.item_index) + .where(row_num_cte.c.name==item_name) + ).scalar_subquery() + + stmt = ( + select(row_num_cte.c.item_index, PubsubItem) + .join(row_num_cte, PubsubItem.id == row_num_cte.c.id) + .limit(max_items) + ) + if before: + stmt = ( + stmt + .where(row_num_cte.c.item_indexrow_num_limit_q) + .order_by(row_num_cte.c.item_index.asc()) + ) + else: + stmt = ( + stmt + .where(row_num_cte.c.item_index>=from_index) + .order_by(row_num_cte.c.item_index.asc()) + ) + # from_index is used + + async with self.session() as session: + if max_items == 0: + items = result = [] + else: + result = await session.execute(stmt) + result = result.all() + if before: + result.reverse() + items = [row[-1] for row in result] + rows_count = ( + await session.execute(row_num_q.with_only_columns(count())) + ).scalar_one() + + try: + index = result[0][0] + except IndexError: + index = None + + try: + first = result[0][1].name + except IndexError: + first = None + last = None + else: + last = result[-1][1].name + + + metadata["rsm"] = { + "index": index, + "count": rows_count, + "first": first, + "last": last, + } + metadata["complete"] = index + len(result) == rows_count + + return items, metadata + + async with self.session() as session: + result = await session.execute(stmt) + + result = result.scalars().all() + if desc: + result.reverse() + return result, metadata