diff sat/memory/sqla.py @ 3664:9ae6ec74face

memory (sqla): implement `searchPubsubItems`: `searchPubsubItems` is a high level method to handle Full-Text Search queries on Pubsub cache. rel 361
author Goffi <goffi@goffi.org>
date Wed, 08 Sep 2021 17:58:48 +0200
parents 257135d5c5c2
children 72b0e4053ab0
line wrap: on
line diff
--- a/sat/memory/sqla.py	Wed Sep 08 17:58:48 2021 +0200
+++ b/sat/memory/sqla.py	Wed Sep 08 17:58:48 2021 +0200
@@ -22,11 +22,11 @@
 from datetime import datetime
 from asyncio.subprocess import PIPE
 from pathlib import Path
-from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional
+from typing import Union, Dict, List, Tuple, Iterable, Any, Callable, Optional
 from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine
 from sqlalchemy.exc import IntegrityError, NoResultFound
 from sqlalchemy.orm import (
-    sessionmaker, subqueryload, joinedload, contains_eager # , aliased
+    sessionmaker, subqueryload, joinedload, selectinload, contains_eager
 )
 from sqlalchemy.orm.decl_api import DeclarativeMeta
 from sqlalchemy.future import select
@@ -34,6 +34,7 @@
 from sqlalchemy import update, delete, and_, or_, event, func
 from sqlalchemy.sql.functions import coalesce, sum as sum_, now, count
 from sqlalchemy.dialects.sqlite import insert
+from sqlalchemy import text, literal_column, Integer
 from alembic import script as al_script, config as al_config
 from alembic.runtime import migration as al_migration
 from twisted.internet import defer
@@ -72,6 +73,28 @@
 
 log = getLogger(__name__)
 migration_path = Path(migration.__file__).parent
+#: mapping of Libervia search query operators to SQLAlchemy method name
+OP_MAP = {
+    "==": "__eq__",
+    "eq": "__eq__",
+    "!=": "__ne__",
+    "ne": "__ne__",
+    ">": "__gt__",
+    "gt": "__gt__",
+    "<": "__le__",
+    "le": "__le__",
+    "between": "between",
+    "in": "in_",
+    "not_in": "not_in",
+    "overlap": "in_",
+    "ioverlap": "in_",
+    "disjoint": "in_",
+    "idisjoint": "in_",
+    "like": "like",
+    "ilike": "ilike",
+    "not_like": "notlike",
+    "not_ilike": "notilike",
+}
 
 
 @event.listens_for(Engine, "connect")
@@ -1355,3 +1378,212 @@
         if desc:
             result.reverse()
         return result, metadata
+
+    def _getSqlitePath(
+        self,
+        path: List[Union[str, int]]
+    ) -> str:
+        """generate path suitable to query JSON element with SQLite"""
+        return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}"
+
+    @aio
+    async def searchPubsubItems(
+        self,
+        query: dict,
+    ) -> Tuple[List[PubsubItem]]:
+        """Search for pubsub items in cache
+
+        @param query: search terms. Keys can be:
+            :fts (str):
+                Full-Text Search query. Currently SQLite FT5 engine is used, its query
+                syntax can be used, see `FTS5 Query documentation
+                <https://sqlite.org/fts5.html#full_text_query_syntax>`_
+            :profiles (list[str]):
+                filter on nodes linked to those profiles
+            :nodes (list[str]):
+                filter on nodes with those names
+            :services (list[jid.JID]):
+                filter on nodes from those services
+            :types (list[str|None]):
+                filter on nodes with those types. None can be used to filter on nodes with
+                no type set
+            :subtypes (list[str|None]):
+                filter on nodes with those subtypes. None can be used to filter on nodes with
+                no subtype set
+            :parsed (list[dict]):
+                Filter on a parsed data field. The dict must contain 3 keys: ``path``
+                which is a list of str or int giving the path to the field of interest
+                (str for a dict key, int for a list index), ``operator`` with indicate the
+                operator to use to check the condition, and ``value`` which depends of
+                field type and operator.
+
+                See documentation for details on operators (it's currently explained at
+                ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command
+                documentation).
+
+            :order-by (list[dict]):
+                Indicates how to order results. The dict can contain either a ``order``
+                for a well-know order or a ``path`` for a parsed data field path
+                (``order`` and ``path`` can't be used at the same time), an an optional
+                ``direction`` which can be ``asc`` or ``desc``. See documentation for
+                details on well-known orders (it's currently explained at
+                ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command
+                documentation).
+
+            :index (int):
+                starting index of items to return from the query result. It's translated
+                to SQL's OFFSET
+
+            :limit (int):
+                maximum number of items to return. It's translated to SQL's LIMIT.
+
+        @result: found items (the ``node`` attribute will be filled with suitable
+            PubsubNode)
+        """
+        # TODO: FTS and parsed data filters use SQLite specific syntax
+        #   when other DB engines will be used, this will have to be adapted
+        stmt = select(PubsubItem)
+
+        # Full-Text Search
+        fts = query.get("fts")
+        if fts:
+            fts_select = text(
+                "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)"
+            ).bindparams(fts_query=fts).columns(rowid=Integer).subquery()
+            stmt = (
+                stmt
+                .select_from(fts_select)
+                .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id)
+            )
+
+        # node related filters
+        profiles = query.get("profiles")
+        if (profiles
+            or any(query.get(k) for k in ("nodes", "services", "types", "subtypes"))
+        ):
+            stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node))
+            if profiles:
+                try:
+                    stmt = stmt.where(
+                        PubsubNode.profile_id.in_(self.profiles[p] for p in profiles)
+                    )
+                except KeyError as e:
+                    raise exceptions.ProfileUnknownError(
+                        f"This profile doesn't exist: {e.args[0]!r}"
+                    )
+            for key, attr in (
+                ("nodes", "name"),
+                ("services", "service"),
+                ("types", "type_"),
+                ("subtypes", "subtype")
+            ):
+                value = query.get(key)
+                if not value:
+                    continue
+                if key in ("types", "subtypes") and None in value:
+                    # NULL can't be used with SQL's IN, so we have to add a condition with
+                    # IS NULL, and use a OR if there are other values to check
+                    value.remove(None)
+                    condition = getattr(PubsubNode, attr).is_(None)
+                    if value:
+                        condition = or_(
+                            getattr(PubsubNode, attr).in_(value),
+                            condition
+                        )
+                else:
+                    condition = getattr(PubsubNode, attr).in_(value)
+                stmt = stmt.where(condition)
+        else:
+            stmt = stmt.options(selectinload(PubsubItem.node))
+
+        # parsed data filters
+        parsed = query.get("parsed", [])
+        for filter_ in parsed:
+            try:
+                path = filter_["path"]
+                operator = filter_["op"]
+                value = filter_["value"]
+            except KeyError as e:
+                raise ValueError(
+                    f'missing mandatory key {e.args[0]!r} in "parsed" filter'
+                )
+            try:
+                op_attr = OP_MAP[operator]
+            except KeyError:
+                raise ValueError(f"invalid operator: {operator!r}")
+            sqlite_path = self._getSqlitePath(path)
+            if operator in ("overlap", "ioverlap", "disjoint", "idisjoint"):
+                col = literal_column("json_each.value")
+                if operator[0] == "i":
+                    col = func.lower(col)
+                    value = [str(v).lower() for v in value]
+                condition = (
+                    select(1)
+                    .select_from(func.json_each(PubsubItem.parsed, sqlite_path))
+                    .where(col.in_(value))
+                ).scalar_subquery()
+                if operator in ("disjoint", "idisjoint"):
+                    condition = condition.is_(None)
+                stmt = stmt.where(condition)
+            elif operator == "between":
+                try:
+                    left, right = value
+                except (ValueError, TypeError):
+                    raise ValueError(_(
+                        "invalid value for \"between\" filter, you must use a 2 items "
+                        "array: {value!r}"
+                    ).format(value=value))
+                col = func.json_extract(PubsubItem.parsed, sqlite_path)
+                stmt = stmt.where(col.between(left, right))
+            else:
+                # we use func.json_extract instead of generic JSON way because SQLAlchemy
+                # add a JSON_QUOTE to the value, and we want SQL value
+                col = func.json_extract(PubsubItem.parsed, sqlite_path)
+                stmt = stmt.where(getattr(col, op_attr)(value))
+
+        # order
+        order_by = query.get("order-by") or [{"order": "creation"}]
+
+        for order_data in order_by:
+            order, path = order_data.get("order"), order_data.get("path")
+            if order and path:
+                raise ValueError(_(
+                    '"order" and "path" can\'t be used at the same time in '
+                    '"order-by" data'
+                ))
+            if order:
+                if order == "creation":
+                    col = PubsubItem.id
+                elif order == "modification":
+                    col = PubsubItem.updated
+                elif order == "item_id":
+                    col = PubsubItem.name
+                elif order == "rank":
+                    if not fts:
+                        raise ValueError(
+                            "'rank' order can only be used with Full-Text Search (fts)"
+                        )
+                    col = literal_column("rank")
+                else:
+                    raise NotImplementedError(f"Unknown {order!r} order")
+            else:
+                # we have a JSON path
+                # sqlite_path = self._getSqlitePath(path)
+                col = PubsubItem.parsed[path]
+            direction = order_data.get("direction", "ASC").lower()
+            if not direction in ("asc", "desc"):
+                raise ValueError(f"Invalid order-by direction: {direction!r}")
+            stmt = stmt.order_by(getattr(col, direction)())
+
+        # offset, limit
+        index = query.get("index")
+        if index:
+            stmt = stmt.offset(index)
+        limit = query.get("limit")
+        if limit:
+            stmt = stmt.limit(limit)
+
+        async with self.session() as session:
+            result = await session.execute(stmt)
+
+        return result.scalars().all()