view sat/memory/sqla.py @ 3861:37a1193d90db

plugin XEP-0277: add a trigger to `send` method: rel 370
author Goffi <goffi@goffi.org>
date Wed, 20 Jul 2022 17:19:53 +0200
parents 1a10b8b4f169
children 100dd30244c6
line wrap: on
line source

#!/usr/bin/env python3

# Libervia: an XMPP client
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import sys
import time
import asyncio
import copy
from datetime import datetime
from asyncio.subprocess import PIPE
from pathlib import Path
from typing import Union, Dict, List, Tuple, Iterable, Any, Callable, Optional
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.orm import (
    sessionmaker, subqueryload, joinedload, selectinload, contains_eager
)
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.attributes import Mapped
from sqlalchemy.future import select
from sqlalchemy.engine import Engine, Connection
from sqlalchemy import update, delete, and_, or_, event, func
from sqlalchemy.sql.functions import coalesce, sum as sum_, now, count
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy import text, literal_column, Integer
from alembic import script as al_script, config as al_config
from alembic.runtime import migration as al_migration
from twisted.internet import defer
from twisted.words.protocols.jabber import jid
from twisted.words.xish import domish
from sat.core.i18n import _
from sat.core import exceptions
from sat.core.log import getLogger
from sat.core.constants import Const as C
from sat.core.core_types import SatXMPPEntity
from sat.tools.utils import aio
from sat.tools.common import uri
from sat.memory import migration
from sat.memory import sqla_config
from sat.memory.sqla_mapping import (
    NOT_IN_EXTRA,
    SyncState,
    Base,
    Profile,
    Component,
    History,
    Message,
    Subject,
    Thread,
    ParamGen,
    ParamInd,
    PrivateGen,
    PrivateInd,
    PrivateGenBin,
    PrivateIndBin,
    File,
    PubsubNode,
    PubsubItem,
)


log = getLogger(__name__)
migration_path = Path(migration.__file__).parent
#: mapping of Libervia search query operators to SQLAlchemy method name
OP_MAP = {
    "==": "__eq__",
    "eq": "__eq__",
    "!=": "__ne__",
    "ne": "__ne__",
    ">": "__gt__",
    "gt": "__gt__",
    "<": "__le__",
    "le": "__le__",
    "between": "between",
    "in": "in_",
    "not_in": "not_in",
    "overlap": "in_",
    "ioverlap": "in_",
    "disjoint": "in_",
    "idisjoint": "in_",
    "like": "like",
    "ilike": "ilike",
    "not_like": "notlike",
    "not_ilike": "notilike",
}


@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
    cursor = dbapi_connection.cursor()
    cursor.execute("PRAGMA foreign_keys=ON")
    cursor.close()


class Storage:

    def __init__(self):
        self.initialized = defer.Deferred()
        # we keep cache for the profiles (key: profile name, value: profile id)
        # profile id to name
        self.profiles: Dict[str, int] = {}
        # profile id to component entry point
        self.components: Dict[int, str] = {}

    def getProfileById(self, profile_id):
        return self.profiles.get(profile_id)

    async def migrateApply(self, *args: str, log_output: bool = False) -> None:
        """Do a migration command

        Commands are applied by running Alembic in a subprocess.
        Arguments are alembic executables commands

        @param log_output: manage stdout and stderr:
            - if False, stdout and stderr are buffered, and logged only in case of error
            - if True, stdout and stderr will be logged during the command execution
        @raise exceptions.DatabaseError: something went wrong while running the
            process
        """
        stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,)
        proc = await asyncio.create_subprocess_exec(
            sys.executable, "-m", "alembic", *args,
            stdout=stdout, stderr=stderr, cwd=migration_path
        )
        log_out, log_err = await proc.communicate()
        if proc.returncode != 0:
            msg = _(
                "Can't {operation} database (exit code {exit_code})"
            ).format(
                operation=args[0],
                exit_code=proc.returncode
            )
            if log_out or log_err:
                msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}"
            log.error(msg)

            raise exceptions.DatabaseError(msg)

    async def createDB(self, engine: AsyncEngine, db_config: dict) -> None:
        """Create a new database

        The database is generated from SQLAlchemy model, then stamped by Alembic
        """
        # the dir may not exist if it's not the XDG recommended one
        db_config["path"].parent.mkdir(0o700, True, True)
        async with engine.begin() as conn:
            await conn.run_sync(Base.metadata.create_all)

        log.debug("stamping the database")
        await self.migrateApply("stamp", "head")
        log.debug("stamping done")

    def _checkDBIsUpToDate(self, conn: Connection) -> bool:
        al_ini_path = migration_path / "alembic.ini"
        al_cfg = al_config.Config(al_ini_path)
        directory = al_script.ScriptDirectory.from_config(al_cfg)
        context = al_migration.MigrationContext.configure(conn)
        return set(context.get_current_heads()) == set(directory.get_heads())

    async def checkAndUpdateDB(self, engine: AsyncEngine, db_config: dict) -> None:
        """Check that database is up-to-date, and update if necessary"""
        async with engine.connect() as conn:
            up_to_date = await conn.run_sync(self._checkDBIsUpToDate)
        if up_to_date:
            log.debug("Database is up-to-date")
        else:
            log.info("Database needs to be updated")
            log.info("updating…")
            await self.migrateApply("upgrade", "head", log_output=True)
            log.info("Database is now up-to-date")

    @aio
    async def initialise(self) -> None:
        log.info(_("Connecting database"))

        db_config = sqla_config.getDbConfig()
        engine = create_async_engine(
            db_config["url"],
            future=True,
        )

        new_base = not db_config["path"].exists()
        if new_base:
            log.info(_("The database is new, creating the tables"))
            await self.createDB(engine, db_config)
        else:
            await self.checkAndUpdateDB(engine, db_config)

        self.session = sessionmaker(
            engine, expire_on_commit=False, class_=AsyncSession
        )

        async with self.session() as session:
            result = await session.execute(select(Profile))
            for p in result.scalars():
                self.profiles[p.name] = p.id
            result = await session.execute(select(Component))
            for c in result.scalars():
                self.components[c.profile_id] = c.entry_point

        self.initialized.callback(None)

    ## Generic

    @aio
    async def get(
        self,
        client: SatXMPPEntity,
        db_cls: DeclarativeMeta,
        db_id_col: Mapped,
        id_value: Any,
        joined_loads = None
    ) -> Optional[DeclarativeMeta]:
        stmt = select(db_cls).where(db_id_col==id_value)
        if client is not None:
            stmt = stmt.filter_by(profile_id=self.profiles[client.profile])
        if joined_loads is not None:
            for joined_load in joined_loads:
                stmt = stmt.options(joinedload(joined_load))
        async with self.session() as session:
            result = await session.execute(stmt)
        if joined_loads is not None:
            result = result.unique()
        return result.scalar_one_or_none()

    @aio
    async def add(self, db_obj: DeclarativeMeta) -> None:
        """Add an object to database"""
        async with self.session() as session:
            async with session.begin():
                session.add(db_obj)

    @aio
    async def delete(
        self,
        db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]],
        session_add: Optional[List[DeclarativeMeta]] = None
    ) -> None:
        """Delete an object from database

        @param db_obj: object to delete or list of objects to delete
        @param session_add: other objects to add to session.
            This is useful when parents of deleted objects needs to be updated too, or if
            other objects needs to be updated in the same transaction.
        """
        if not db_obj:
            return
        if not isinstance(db_obj, list):
            db_obj = [db_obj]
        async with self.session() as session:
            async with session.begin():
                if session_add is not None:
                    for obj in session_add:
                        session.add(obj)
                for obj in db_obj:
                    await session.delete(obj)
                await session.commit()

    ## Profiles

    def getProfilesList(self) -> List[str]:
        """"Return list of all registered profiles"""
        return list(self.profiles.keys())

    def hasProfile(self, profile_name: str) -> bool:
        """return True if profile_name exists

        @param profile_name: name of the profile to check
        """
        return profile_name in self.profiles

    def profileIsComponent(self, profile_name: str) -> bool:
        try:
            return self.profiles[profile_name] in self.components
        except KeyError:
            raise exceptions.NotFound("the requested profile doesn't exists")

    def getEntryPoint(self, profile_name: str) -> str:
        try:
            return self.components[self.profiles[profile_name]]
        except KeyError:
            raise exceptions.NotFound("the requested profile doesn't exists or is not a component")

    @aio
    async def createProfile(self, name: str, component_ep: Optional[str] = None) -> None:
        """Create a new profile

        @param name: name of the profile
        @param component: if not None, must point to a component entry point
        """
        async with self.session() as session:
            profile = Profile(name=name)
            async with session.begin():
                session.add(profile)
            self.profiles[profile.name] = profile.id
            if component_ep is not None:
                async with session.begin():
                    component = Component(profile=profile, entry_point=component_ep)
                    session.add(component)
                self.components[profile.id] = component_ep
        return profile

    @aio
    async def deleteProfile(self, name: str) -> None:
        """Delete profile

        @param name: name of the profile
        """
        async with self.session() as session:
            result = await session.execute(select(Profile).where(Profile.name == name))
            profile = result.scalar()
            await session.delete(profile)
            await session.commit()
        del self.profiles[profile.name]
        if profile.id in self.components:
            del self.components[profile.id]
        log.info(_("Profile {name!r} deleted").format(name = name))

    ## Params

    @aio
    async def loadGenParams(self, params_gen: dict) -> None:
        """Load general parameters

        @param params_gen: dictionary to fill
        """
        log.debug(_("loading general parameters from database"))
        async with self.session() as session:
            result = await session.execute(select(ParamGen))
        for p in result.scalars():
            params_gen[(p.category, p.name)] = p.value

    @aio
    async def loadIndParams(self, params_ind: dict, profile: str) -> None:
        """Load individual parameters

        @param params_ind: dictionary to fill
        @param profile: a profile which *must* exist
        """
        log.debug(_("loading individual parameters from database"))
        async with self.session() as session:
            result = await session.execute(
                select(ParamInd).where(ParamInd.profile_id == self.profiles[profile])
            )
        for p in result.scalars():
            params_ind[(p.category, p.name)] = p.value

    @aio
    async def getIndParam(self, category: str, name: str, profile: str) -> Optional[str]:
        """Ask database for the value of one specific individual parameter

        @param category: category of the parameter
        @param name: name of the parameter
        @param profile: %(doc_profile)s
        """
        async with self.session() as session:
            result = await session.execute(
                select(ParamInd.value)
                .filter_by(
                    category=category,
                    name=name,
                    profile_id=self.profiles[profile]
                )
            )
        return result.scalar_one_or_none()

    @aio
    async def getIndParamValues(self, category: str, name: str) -> Dict[str, str]:
        """Ask database for the individual values of a parameter for all profiles

        @param category: category of the parameter
        @param name: name of the parameter
        @return dict: profile => value map
        """
        async with self.session() as session:
            result = await session.execute(
                select(ParamInd)
                .filter_by(
                    category=category,
                    name=name
                )
                .options(subqueryload(ParamInd.profile))
            )
        return {param.profile.name: param.value for param in result.scalars()}

    @aio
    async def setGenParam(self, category: str, name: str, value: Optional[str]) -> None:
        """Save the general parameters in database

        @param category: category of the parameter
        @param name: name of the parameter
        @param value: value to set
        """
        async with self.session() as session:
            stmt = insert(ParamGen).values(
                category=category,
                name=name,
                value=value
            ).on_conflict_do_update(
                index_elements=(ParamGen.category, ParamGen.name),
                set_={
                    ParamGen.value: value
                }
            )
            await session.execute(stmt)
            await session.commit()

    @aio
    async def setIndParam(
        self,
        category:str,
        name: str,
        value: Optional[str],
        profile: str
    ) -> None:
        """Save the individual parameters in database

        @param category: category of the parameter
        @param name: name of the parameter
        @param value: value to set
        @param profile: a profile which *must* exist
        """
        async with self.session() as session:
            stmt = insert(ParamInd).values(
                category=category,
                name=name,
                profile_id=self.profiles[profile],
                value=value
            ).on_conflict_do_update(
                index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id),
                set_={
                    ParamInd.value: value
                }
            )
            await session.execute(stmt)
            await session.commit()

    def _jid_filter(self, jid_: jid.JID, dest: bool = False):
        """Generate condition to filter on a JID, using relevant columns

        @param dest: True if it's the destinee JID, otherwise it's the source one
        @param jid_: JID to filter by
        """
        if jid_.resource:
            if dest:
                return and_(
                    History.dest == jid_.userhost(),
                    History.dest_res == jid_.resource
                )
            else:
                return and_(
                    History.source == jid_.userhost(),
                    History.source_res == jid_.resource
                )
        else:
            if dest:
                return History.dest == jid_.userhost()
            else:
                return History.source == jid_.userhost()

    @aio
    async def historyGet(
        self,
        from_jid: Optional[jid.JID],
        to_jid: Optional[jid.JID],
        limit: Optional[int] = None,
        between: bool = True,
        filters: Optional[Dict[str, str]] = None,
        profile: Optional[str] = None,
    ) -> List[Tuple[
        str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]
    ]:
        """Retrieve messages in history

        @param from_jid: source JID (full, or bare for catchall)
        @param to_jid: dest JID (full, or bare for catchall)
        @param limit: maximum number of messages to get:
            - 0 for no message (returns the empty list)
            - None for unlimited
        @param between: confound source and dest (ignore the direction)
        @param filters: pattern to filter the history results
        @return: list of messages as in [messageNew], minus the profile which is already
            known.
        """
        # we have to set a default value to profile because it's last argument
        # and thus follow other keyword arguments with default values
        # but None should not be used for it
        assert profile is not None
        if limit == 0:
            return []
        if filters is None:
            filters = {}

        stmt = (
            select(History)
            .filter_by(
                profile_id=self.profiles[profile]
            )
            .outerjoin(History.messages)
            .outerjoin(History.subjects)
            .outerjoin(History.thread)
            .options(
                contains_eager(History.messages),
                contains_eager(History.subjects),
                contains_eager(History.thread),
            )
            .order_by(
                # timestamp may be identical for 2 close messages (specially when delay is
                # used) that's why we order ties by received_timestamp. We'll reverse the
                # order when returning the result. We use DESC here so LIMIT keep the last
                # messages
                History.timestamp.desc(),
                History.received_timestamp.desc()
            )
        )


        if not from_jid and not to_jid:
            # no jid specified, we want all one2one communications
            pass
        elif between:
            if not from_jid or not to_jid:
                # we only have one jid specified, we check all messages
                # from or to this jid
                jid_ = from_jid or to_jid
                stmt = stmt.where(
                    or_(
                        self._jid_filter(jid_),
                        self._jid_filter(jid_, dest=True)
                    )
                )
            else:
                # we have 2 jids specified, we check all communications between
                # those 2 jids
                stmt = stmt.where(
                    or_(
                        and_(
                            self._jid_filter(from_jid),
                            self._jid_filter(to_jid, dest=True),
                        ),
                        and_(
                            self._jid_filter(to_jid),
                            self._jid_filter(from_jid, dest=True),
                        )
                    )
                )
        else:
            # we want one communication in specific direction (from somebody or
            # to somebody).
            if from_jid is not None:
                stmt = stmt.where(self._jid_filter(from_jid))
            if to_jid is not None:
                stmt = stmt.where(self._jid_filter(to_jid, dest=True))

        if filters:
            if 'timestamp_start' in filters:
                stmt = stmt.where(History.timestamp >= float(filters['timestamp_start']))
            if 'before_uid' in filters:
                # orignially this query was using SQLITE's rowid. This has been changed
                # to use coalesce(received_timestamp, timestamp) to be SQL engine independant
                stmt = stmt.where(
                    coalesce(
                        History.received_timestamp,
                        History.timestamp
                    ) < (
                        select(coalesce(History.received_timestamp, History.timestamp))
                        .filter_by(uid=filters["before_uid"])
                    ).scalar_subquery()
                )
            if 'body' in filters:
                # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html
                stmt = stmt.where(Message.message.like(f"%{filters['body']}%"))
            if 'search' in filters:
                search_term = f"%{filters['search']}%"
                stmt = stmt.where(or_(
                    Message.message.like(search_term),
                    History.source_res.like(search_term)
                ))
            if 'types' in filters:
                types = filters['types'].split()
                stmt = stmt.where(History.type.in_(types))
            if 'not_types' in filters:
                types = filters['not_types'].split()
                stmt = stmt.where(History.type.not_in(types))
            if 'last_stanza_id' in filters:
                # this request get the last message with a "stanza_id" that we
                # have in history. This is mainly used to retrieve messages sent
                # while we were offline, using MAM (XEP-0313).
                if (filters['last_stanza_id'] is not True
                    or limit != 1):
                    raise ValueError("Unexpected values for last_stanza_id filter")
                stmt = stmt.where(History.stanza_id.is_not(None))
            if 'origin_id' in filters:
                stmt = stmt.where(History.origin_id == filters["origin_id"])

        if limit is not None:
            stmt = stmt.limit(limit)

        async with self.session() as session:
            result = await session.execute(stmt)

        result = result.scalars().unique().all()
        result.reverse()
        return [h.as_tuple() for h in result]

    @aio
    async def addToHistory(self, data: dict, profile: str) -> None:
        """Store a new message in history

        @param data: message data as build by SatMessageProtocol.onMessage
        """
        extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA}
        messages = [Message(message=mess, language=lang)
                    for lang, mess in data["message"].items()]
        subjects = [Subject(subject=mess, language=lang)
                    for lang, mess in data["subject"].items()]
        if "thread" in data["extra"]:
            thread = Thread(thread_id=data["extra"]["thread"],
                            parent_id=data["extra"].get["thread_parent"])
        else:
            thread = None
        try:
            async with self.session() as session:
                async with session.begin():
                    session.add(History(
                        uid=data["uid"],
                        origin_id=data["extra"].get("origin_id"),
                        stanza_id=data["extra"].get("stanza_id"),
                        update_uid=data["extra"].get("update_uid"),
                        profile_id=self.profiles[profile],
                        source_jid=data["from"],
                        dest_jid=data["to"],
                        timestamp=data["timestamp"],
                        received_timestamp=data.get("received_timestamp"),
                        type=data["type"],
                        extra=extra,
                        messages=messages,
                        subjects=subjects,
                        thread=thread,
                    ))
        except IntegrityError as e:
            if "unique" in str(e.orig).lower():
                log.debug(
                    f"message {data['uid']!r} is already in history, not storing it again"
                )
            else:
                log.error(f"Can't store message {data['uid']!r} in history: {e}")
        except Exception as e:
            log.critical(
                f"Can't store message, unexpected exception (uid: {data['uid']}): {e}"
            )

    ## Private values

    def _getPrivateClass(self, binary, profile):
        """Get ORM class to use for private values"""
        if profile is None:
            return PrivateGenBin if binary else PrivateGen
        else:
            return PrivateIndBin if binary else PrivateInd


    @aio
    async def getPrivates(
        self,
        namespace:str,
        keys: Optional[Iterable[str]] = None,
        binary: bool = False,
        profile: Optional[str] = None
    ) -> Dict[str, Any]:
        """Get private value(s) from databases

        @param namespace: namespace of the values
        @param keys: keys of the values to get None to get all keys/values
        @param binary: True to deserialise binary values
        @param profile: profile to use for individual values
            None to use general values
        @return: gotten keys/values
        """
        if keys is not None:
            keys = list(keys)
        log.debug(
            f"getting {'general' if profile is None else 'individual'}"
            f"{' binary' if binary else ''} private values from database for namespace "
            f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}"
        )
        cls = self._getPrivateClass(binary, profile)
        stmt = select(cls).filter_by(namespace=namespace)
        if keys:
            stmt = stmt.where(cls.key.in_(list(keys)))
        if profile is not None:
            stmt = stmt.filter_by(profile_id=self.profiles[profile])
        async with self.session() as session:
            result = await session.execute(stmt)
        return {p.key: p.value for p in result.scalars()}

    @aio
    async def setPrivateValue(
        self,
        namespace: str,
        key:str,
        value: Any,
        binary: bool = False,
        profile: Optional[str] = None
    ) -> None:
        """Set a private value in database

        @param namespace: namespace of the values
        @param key: key of the value to set
        @param value: value to set
        @param binary: True if it's a binary values
            binary values need to be serialised, used for everything but strings
        @param profile: profile to use for individual value
            if None, it's a general value
        """
        cls = self._getPrivateClass(binary, profile)

        values = {
            "namespace": namespace,
            "key": key,
            "value": value
        }
        index_elements = [cls.namespace, cls.key]

        if profile is not None:
            values["profile_id"] = self.profiles[profile]
            index_elements.append(cls.profile_id)

        async with self.session() as session:
            await session.execute(
                insert(cls).values(**values).on_conflict_do_update(
                    index_elements=index_elements,
                    set_={
                        cls.value: value
                    }
                )
            )
            await session.commit()

    @aio
    async def delPrivateValue(
        self,
        namespace: str,
        key: str,
        binary: bool = False,
        profile: Optional[str] = None
    ) -> None:
        """Delete private value from database

        @param category: category of the privateeter
        @param key: key of the private value
        @param binary: True if it's a binary values
        @param profile: profile to use for individual value
            if None, it's a general value
        """
        cls = self._getPrivateClass(binary, profile)

        stmt = delete(cls).filter_by(namespace=namespace, key=key)

        if profile is not None:
            stmt = stmt.filter_by(profile_id=self.profiles[profile])

        async with self.session() as session:
            await session.execute(stmt)
            await session.commit()

    @aio
    async def delPrivateNamespace(
        self,
        namespace: str,
        binary: bool = False,
        profile: Optional[str] = None
    ) -> None:
        """Delete all data from a private namespace

        Be really cautious when you use this method, as all data with given namespace are
        removed.
        Params are the same as for delPrivateValue
        """
        cls = self._getPrivateClass(binary, profile)

        stmt = delete(cls).filter_by(namespace=namespace)

        if profile is not None:
            stmt = stmt.filter_by(profile_id=self.profiles[profile])

        async with self.session() as session:
            await session.execute(stmt)
            await session.commit()

    ## Files

    @aio
    async def getFiles(
        self,
        client: Optional[SatXMPPEntity],
        file_id: Optional[str] = None,
        version: Optional[str] = '',
        parent: Optional[str] = None,
        type_: Optional[str] = None,
        file_hash: Optional[str] = None,
        hash_algo: Optional[str] = None,
        name: Optional[str] = None,
        namespace: Optional[str] = None,
        mime_type: Optional[str] = None,
        public_id: Optional[str] = None,
        owner: Optional[jid.JID] = None,
        access: Optional[dict] = None,
        projection: Optional[List[str]] = None,
        unique: bool = False
    ) -> List[dict]:
        """Retrieve files with with given filters

        @param file_id: id of the file
            None to ignore
        @param version: version of the file
            None to ignore
            empty string to look for current version
        @param parent: id of the directory containing the files
            None to ignore
            empty string to look for root files/directories
        @param projection: name of columns to retrieve
            None to retrieve all
        @param unique: if True will remove duplicates
        other params are the same as for [setFile]
        @return: files corresponding to filters
        """
        if projection is None:
            projection = [
                'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name',
                'size', 'namespace', 'media_type', 'media_subtype', 'public_id',
                'created', 'modified', 'owner', 'access', 'extra'
            ]

        stmt = select(*[getattr(File, f) for f in projection])

        if unique:
            stmt = stmt.distinct()

        if client is not None:
            stmt = stmt.filter_by(profile_id=self.profiles[client.profile])
        else:
            if public_id is None:
                raise exceptions.InternalError(
                    "client can only be omitted when public_id is set"
                )
        if file_id is not None:
            stmt = stmt.filter_by(id=file_id)
        if version is not None:
            stmt = stmt.filter_by(version=version)
        if parent is not None:
            stmt = stmt.filter_by(parent=parent)
        if type_ is not None:
            stmt = stmt.filter_by(type=type_)
        if file_hash is not None:
            stmt = stmt.filter_by(file_hash=file_hash)
        if hash_algo is not None:
            stmt = stmt.filter_by(hash_algo=hash_algo)
        if name is not None:
            stmt = stmt.filter_by(name=name)
        if namespace is not None:
            stmt = stmt.filter_by(namespace=namespace)
        if mime_type is not None:
            if '/' in mime_type:
                media_type, media_subtype = mime_type.split("/", 1)
                stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype)
            else:
                stmt = stmt.filter_by(media_type=mime_type)
        if public_id is not None:
            stmt = stmt.filter_by(public_id=public_id)
        if owner is not None:
            stmt = stmt.filter_by(owner=owner)
        if access is not None:
            raise NotImplementedError('Access check is not implemented yet')
            # a JSON comparison is needed here

        async with self.session() as session:
            result = await session.execute(stmt)

        return [dict(r) for r in result]

    @aio
    async def setFile(
        self,
        client: SatXMPPEntity,
        name: str,
        file_id: str,
        version: str = "",
        parent: str = "",
        type_: str = C.FILE_TYPE_FILE,
        file_hash: Optional[str] = None,
        hash_algo: Optional[str] = None,
        size: int = None,
        namespace: Optional[str] = None,
        mime_type: Optional[str] = None,
        public_id: Optional[str] = None,
        created: Optional[float] = None,
        modified: Optional[float] = None,
        owner: Optional[jid.JID] = None,
        access: Optional[dict] = None,
        extra: Optional[dict] = None
    ) -> None:
        """Set a file metadata

        @param client: client owning the file
        @param name: name of the file (must not contain "/")
        @param file_id: unique id of the file
        @param version: version of this file
        @param parent: id of the directory containing this file
            Empty string if it is a root file/directory
        @param type_: one of:
            - file
            - directory
        @param file_hash: unique hash of the payload
        @param hash_algo: algorithm used for hashing the file (usually sha-256)
        @param size: size in bytes
        @param namespace: identifier (human readable is better) to group files
            for instance, namespace could be used to group files in a specific photo album
        @param mime_type: media type of the file, or None if not known/guessed
        @param public_id: ID used to server the file publicly via HTTP
        @param created: UNIX time of creation
        @param modified: UNIX time of last modification, or None to use created date
        @param owner: jid of the owner of the file (mainly useful for component)
        @param access: serialisable dictionary with access rules. See [memory.memory] for
            details
        @param extra: serialisable dictionary of any extra data
            will be encoded to json in database
        """
        if mime_type is None:
            media_type = media_subtype = None
        elif '/' in mime_type:
            media_type, media_subtype = mime_type.split('/', 1)
        else:
            media_type, media_subtype = mime_type, None

        async with self.session() as session:
            async with session.begin():
                session.add(File(
                    id=file_id,
                    version=version.strip(),
                    parent=parent,
                    type=type_,
                    file_hash=file_hash,
                    hash_algo=hash_algo,
                    name=name,
                    size=size,
                    namespace=namespace,
                    media_type=media_type,
                    media_subtype=media_subtype,
                    public_id=public_id,
                    created=time.time() if created is None else created,
                    modified=modified,
                    owner=owner,
                    access=access,
                    extra=extra,
                    profile_id=self.profiles[client.profile]
                ))

    @aio
    async def fileGetUsedSpace(self, client: SatXMPPEntity, owner: jid.JID) -> int:
        async with self.session() as session:
            result = await session.execute(
                select(sum_(File.size)).filter_by(
                    owner=owner,
                    type=C.FILE_TYPE_FILE,
                    profile_id=self.profiles[client.profile]
                ))
        return result.scalar_one_or_none() or 0

    @aio
    async def fileDelete(self, file_id: str) -> None:
        """Delete file metadata from the database

        @param file_id: id of the file to delete
        NOTE: file itself must still be removed, this method only handle metadata in
            database
        """
        async with self.session() as session:
            await session.execute(delete(File).filter_by(id=file_id))
            await session.commit()

    @aio
    async def fileUpdate(
        self,
        file_id: str,
        column: str,
        update_cb: Callable[[dict], None]
    ) -> None:
        """Update a column value using a method to avoid race conditions

        the older value will be retrieved from database, then update_cb will be applied to
        update it, and file will be updated checking that older value has not been changed
        meanwhile by an other user. If it has changed, it tries again a couple of times
        before failing
        @param column: column name (only "access" or "extra" are allowed)
        @param update_cb: method to update the value of the colum
            the method will take older value as argument, and must update it in place
            update_cb must not care about serialization,
            it get the deserialized data (i.e. a Python object) directly
        @raise exceptions.NotFound: there is not file with this id
        """
        if column not in ('access', 'extra'):
            raise exceptions.InternalError('bad column name')
        orm_col = getattr(File, column)

        for i in range(5):
            async with self.session() as session:
                try:
                    value = (await session.execute(
                        select(orm_col).filter_by(id=file_id)
                    )).scalar_one()
                except NoResultFound:
                    raise exceptions.NotFound
                old_value = copy.deepcopy(value)
                update_cb(value)
                stmt = update(File).filter_by(id=file_id).values({column: value})
                if not old_value:
                    # because JsonDefaultDict convert NULL to an empty dict, we have to
                    # test both for empty dict and None when we have an empty dict
                    stmt = stmt.where((orm_col == None) | (orm_col == old_value))
                else:
                    stmt = stmt.where(orm_col == old_value)
                result = await session.execute(stmt)
                await session.commit()

            if result.rowcount == 1:
                break

            log.warning(
                _("table not updated, probably due to race condition, trying again "
                  "({tries})").format(tries=i+1)
            )

        else:
            raise exceptions.DatabaseError(
                _("Can't update file {file_id} due to race condition")
                .format(file_id=file_id)
            )

    @aio
    async def getPubsubNode(
        self,
        client: SatXMPPEntity,
        service: jid.JID,
        name: str,
        with_items: bool = False,
        with_subscriptions: bool = False,
    ) -> Optional[PubsubNode]:
        """
        """
        async with self.session() as session:
            stmt = (
                select(PubsubNode)
                .filter_by(
                    service=service,
                    name=name,
                    profile_id=self.profiles[client.profile],
                )
            )
            if with_items:
                stmt = stmt.options(
                    joinedload(PubsubNode.items)
                )
            if with_subscriptions:
                stmt = stmt.options(
                    joinedload(PubsubNode.subscriptions)
                )
            result = await session.execute(stmt)
        return result.unique().scalar_one_or_none()

    @aio
    async def setPubsubNode(
        self,
        client: SatXMPPEntity,
        service: jid.JID,
        name: str,
        analyser: Optional[str] = None,
        type_: Optional[str] = None,
        subtype: Optional[str] = None,
        subscribed: bool = False,
    ) -> PubsubNode:
        node = PubsubNode(
            profile_id=self.profiles[client.profile],
            service=service,
            name=name,
            subscribed=subscribed,
            analyser=analyser,
            type_=type_,
            subtype=subtype,
            subscriptions=[],
        )
        async with self.session() as session:
            async with session.begin():
                session.add(node)
        return node

    @aio
    async def updatePubsubNodeSyncState(
        self,
        node: PubsubNode,
        state: SyncState
    ) -> None:
        async with self.session() as session:
            async with session.begin():
                await session.execute(
                    update(PubsubNode)
                    .filter_by(id=node.id)
                    .values(
                        sync_state=state,
                        sync_state_updated=time.time(),
                    )
                )

    @aio
    async def deletePubsubNode(
        self,
        profiles: Optional[List[str]],
        services: Optional[List[jid.JID]],
        names: Optional[List[str]]
    ) -> None:
        """Delete items cached for a node

        @param profiles: profile names from which nodes must be deleted.
            None to remove nodes from ALL profiles
        @param services: JIDs of pubsub services from which nodes must be deleted.
            None to remove nodes from ALL services
        @param names: names of nodes which must be deleted.
            None to remove ALL nodes whatever is their names
        """
        stmt = delete(PubsubNode)
        if profiles is not None:
            stmt = stmt.where(
                PubsubNode.profile.in_(
                    [self.profiles[p] for p in profiles]
                )
            )
        if services is not None:
            stmt = stmt.where(PubsubNode.service.in_(services))
        if names is not None:
            stmt = stmt.where(PubsubNode.name.in_(names))
        async with self.session() as session:
            await session.execute(stmt)
            await session.commit()

    @aio
    async def cachePubsubItems(
        self,
        client: SatXMPPEntity,
        node: PubsubNode,
        items: List[domish.Element],
        parsed_items: Optional[List[dict]] = None,
    ) -> None:
        """Add items to database, using an upsert taking care of "updated" field"""
        if parsed_items is not None and len(items) != len(parsed_items):
            raise exceptions.InternalError(
                "parsed_items must have the same lenght as items"
            )
        async with self.session() as session:
            async with session.begin():
                for idx, item in enumerate(items):
                    parsed = parsed_items[idx] if parsed_items else None
                    stmt = insert(PubsubItem).values(
                        node_id = node.id,
                        name = item["id"],
                        data = item,
                        parsed = parsed,
                    ).on_conflict_do_update(
                        index_elements=(PubsubItem.node_id, PubsubItem.name),
                        set_={
                            PubsubItem.data: item,
                            PubsubItem.parsed: parsed,
                            PubsubItem.updated: now()
                        }
                    )
                    await session.execute(stmt)
                await session.commit()

    @aio
    async def deletePubsubItems(
        self,
        node: PubsubNode,
        items_names: Optional[List[str]] = None
    ) -> None:
        """Delete items cached for a node

        @param node: node from which items must be deleted
        @param items_names: names of items to delete
            if None, ALL items will be deleted
        """
        stmt = delete(PubsubItem)
        if node is not None:
            if isinstance(node, list):
                stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node]))
            else:
                stmt = stmt.filter_by(node_id=node.id)
        if items_names is not None:
            stmt = stmt.where(PubsubItem.name.in_(items_names))
        async with self.session() as session:
            await session.execute(stmt)
            await session.commit()

    @aio
    async def purgePubsubItems(
        self,
        services: Optional[List[jid.JID]] = None,
        names: Optional[List[str]] = None,
        types: Optional[List[str]] = None,
        subtypes: Optional[List[str]] = None,
        profiles: Optional[List[str]] = None,
        created_before: Optional[datetime] = None,
        updated_before: Optional[datetime] = None,
    ) -> None:
        """Delete items cached for a node

        @param node: node from which items must be deleted
        @param items_names: names of items to delete
            if None, ALL items will be deleted
        """
        stmt = delete(PubsubItem)
        node_fields = {
            "service": services,
            "name": names,
            "type_": types,
            "subtype": subtypes,
        }
        if profiles is not None:
            node_fields["profile_id"] = [self.profiles[p] for p in profiles]

        if any(x is not None for x in node_fields.values()):
            sub_q = select(PubsubNode.id)
            for col, values in node_fields.items():
                if values is None:
                    continue
                sub_q = sub_q.where(getattr(PubsubNode, col).in_(values))
            stmt = (
                stmt
                .where(PubsubItem.node_id.in_(sub_q))
                .execution_options(synchronize_session=False)
            )

        if created_before is not None:
            stmt = stmt.where(PubsubItem.created < created_before)

        if updated_before is not None:
            stmt = stmt.where(PubsubItem.updated < updated_before)

        async with self.session() as session:
            await session.execute(stmt)
            await session.commit()

    @aio
    async def getItems(
        self,
        node: PubsubNode,
        max_items: Optional[int] = None,
        item_ids: Optional[list[str]] = None,
        before: Optional[str] = None,
        after: Optional[str] = None,
        from_index: Optional[int] = None,
        order_by: Optional[List[str]] = None,
        desc: bool = True,
        force_rsm: bool = False,
    ) -> Tuple[List[PubsubItem], dict]:
        """Get Pubsub Items from cache

        @param node: retrieve items from this node (must be synchronised)
        @param max_items: maximum number of items to retrieve
        @param before: get items which are before the item with this name in given order
            empty string is not managed here, use desc order to reproduce RSM
            behaviour.
        @param after: get items which are after the item with this name in given order
        @param from_index: get items with item index (as defined in RSM spec)
            starting from this number
        @param order_by: sorting order of items (one of C.ORDER_BY_*)
        @param desc: direction or ordering
        @param force_rsm: if True, force the use of RSM worklow.
            RSM workflow is automatically used if any of before, after or
            from_index is used, but if only RSM max_items is used, it won't be
            used by default. This parameter let's use RSM workflow in this
            case. Note that in addition to RSM metadata, the result will not be
            the same (max_items without RSM will returns most recent items,
            i.e. last items in modification order, while max_items with RSM
            will return the oldest ones (i.e. first items in modification
            order).
            to be used when max_items is used from RSM
        """

        metadata = {
            "service": node.service,
            "node": node.name,
            "uri": uri.buildXMPPUri(
                "pubsub",
                path=node.service.full(),
                node=node.name,
            ),
        }
        if max_items is None:
            max_items = 20

        use_rsm = any((before, after, from_index is not None))
        if force_rsm and not use_rsm:
            #
            use_rsm = True
            from_index = 0

        stmt = (
            select(PubsubItem)
            .filter_by(node_id=node.id)
            .limit(max_items)
        )

        if item_ids is not None:
            stmt = stmt.where(PubsubItem.name.in_(item_ids))

        if not order_by:
            order_by = [C.ORDER_BY_MODIFICATION]

        order = []
        for order_type in order_by:
            if order_type == C.ORDER_BY_MODIFICATION:
                if desc:
                    order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc()))
                else:
                    order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc()))
            elif order_type == C.ORDER_BY_CREATION:
                if desc:
                    order.append(PubsubItem.id.desc())
                else:
                    order.append(PubsubItem.id.asc())
            else:
                raise exceptions.InternalError(f"Unknown order type {order_type!r}")

        stmt = stmt.order_by(*order)

        if use_rsm:
            # CTE to have result row numbers
            row_num_q = select(
                PubsubItem.id,
                PubsubItem.name,
                # row_number starts from 1, but RSM index must start from 0
                (func.row_number().over(order_by=order)-1).label("item_index")
            ).filter_by(node_id=node.id)

            row_num_cte = row_num_q.cte()

            if max_items > 0:
                # as we can't simply use PubsubItem.id when we order by modification,
                # we need to use row number
                item_name = before or after
                row_num_limit_q = (
                    select(row_num_cte.c.item_index)
                    .where(row_num_cte.c.name==item_name)
                ).scalar_subquery()

                stmt = (
                    select(row_num_cte.c.item_index, PubsubItem)
                    .join(row_num_cte, PubsubItem.id == row_num_cte.c.id)
                    .limit(max_items)
                )
                if before:
                    stmt = (
                        stmt
                        .where(row_num_cte.c.item_index<row_num_limit_q)
                        .order_by(row_num_cte.c.item_index.desc())
                    )
                elif after:
                    stmt = (
                        stmt
                        .where(row_num_cte.c.item_index>row_num_limit_q)
                        .order_by(row_num_cte.c.item_index.asc())
                    )
                else:
                    stmt = (
                        stmt
                        .where(row_num_cte.c.item_index>=from_index)
                        .order_by(row_num_cte.c.item_index.asc())
                    )
                    # from_index is used

            async with self.session() as session:
                if max_items == 0:
                    items = result = []
                else:
                    result = await session.execute(stmt)
                    result = result.all()
                    if before:
                        result.reverse()
                    items = [row[-1] for row in result]
                rows_count = (
                    await session.execute(row_num_q.with_only_columns(count()))
                ).scalar_one()

            try:
                index = result[0][0]
            except IndexError:
                index = None

            try:
                first = result[0][1].name
            except IndexError:
                first = None
                last = None
            else:
                last = result[-1][1].name

            metadata["rsm"] = {
                k: v for k, v in {
                    "index": index,
                    "count": rows_count,
                    "first": first,
                    "last": last,
                }.items() if v is not None
            }
            metadata["complete"] = (index or 0) + len(result) == rows_count

            return items, metadata

        async with self.session() as session:
            result = await session.execute(stmt)

        result = result.scalars().all()
        if desc:
            result.reverse()
        return result, metadata

    def _getSqlitePath(
        self,
        path: List[Union[str, int]]
    ) -> str:
        """generate path suitable to query JSON element with SQLite"""
        return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}"

    @aio
    async def searchPubsubItems(
        self,
        query: dict,
    ) -> Tuple[List[PubsubItem]]:
        """Search for pubsub items in cache

        @param query: search terms. Keys can be:
            :fts (str):
                Full-Text Search query. Currently SQLite FT5 engine is used, its query
                syntax can be used, see `FTS5 Query documentation
                <https://sqlite.org/fts5.html#full_text_query_syntax>`_
            :profiles (list[str]):
                filter on nodes linked to those profiles
            :nodes (list[str]):
                filter on nodes with those names
            :services (list[jid.JID]):
                filter on nodes from those services
            :types (list[str|None]):
                filter on nodes with those types. None can be used to filter on nodes with
                no type set
            :subtypes (list[str|None]):
                filter on nodes with those subtypes. None can be used to filter on nodes with
                no subtype set
            :names (list[str]):
                filter on items with those names
            :parsed (list[dict]):
                Filter on a parsed data field. The dict must contain 3 keys: ``path``
                which is a list of str or int giving the path to the field of interest
                (str for a dict key, int for a list index), ``operator`` with indicate the
                operator to use to check the condition, and ``value`` which depends of
                field type and operator.

                See documentation for details on operators (it's currently explained at
                ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command
                documentation).

            :order-by (list[dict]):
                Indicates how to order results. The dict can contain either a ``order``
                for a well-know order or a ``path`` for a parsed data field path
                (``order`` and ``path`` can't be used at the same time), an an optional
                ``direction`` which can be ``asc`` or ``desc``. See documentation for
                details on well-known orders (it's currently explained at
                ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command
                documentation).

            :index (int):
                starting index of items to return from the query result. It's translated
                to SQL's OFFSET

            :limit (int):
                maximum number of items to return. It's translated to SQL's LIMIT.

        @result: found items (the ``node`` attribute will be filled with suitable
            PubsubNode)
        """
        # TODO: FTS and parsed data filters use SQLite specific syntax
        #   when other DB engines will be used, this will have to be adapted
        stmt = select(PubsubItem)

        # Full-Text Search
        fts = query.get("fts")
        if fts:
            fts_select = text(
                "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)"
            ).bindparams(fts_query=fts).columns(rowid=Integer).subquery()
            stmt = (
                stmt
                .select_from(fts_select)
                .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id)
            )

        # node related filters
        profiles = query.get("profiles")
        if (profiles
            or any(query.get(k) for k in ("nodes", "services", "types", "subtypes"))
        ):
            stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node))
            if profiles:
                try:
                    stmt = stmt.where(
                        PubsubNode.profile_id.in_(self.profiles[p] for p in profiles)
                    )
                except KeyError as e:
                    raise exceptions.ProfileUnknownError(
                        f"This profile doesn't exist: {e.args[0]!r}"
                    )
            for key, attr in (
                ("nodes", "name"),
                ("services", "service"),
                ("types", "type_"),
                ("subtypes", "subtype")
            ):
                value = query.get(key)
                if not value:
                    continue
                if key in ("types", "subtypes") and None in value:
                    # NULL can't be used with SQL's IN, so we have to add a condition with
                    # IS NULL, and use a OR if there are other values to check
                    value.remove(None)
                    condition = getattr(PubsubNode, attr).is_(None)
                    if value:
                        condition = or_(
                            getattr(PubsubNode, attr).in_(value),
                            condition
                        )
                else:
                    condition = getattr(PubsubNode, attr).in_(value)
                stmt = stmt.where(condition)
        else:
            stmt = stmt.options(selectinload(PubsubItem.node))

        # names
        names = query.get("names")
        if names:
            stmt = stmt.where(PubsubItem.name.in_(names))

        # parsed data filters
        parsed = query.get("parsed", [])
        for filter_ in parsed:
            try:
                path = filter_["path"]
                operator = filter_["op"]
                value = filter_["value"]
            except KeyError as e:
                raise ValueError(
                    f'missing mandatory key {e.args[0]!r} in "parsed" filter'
                )
            try:
                op_attr = OP_MAP[operator]
            except KeyError:
                raise ValueError(f"invalid operator: {operator!r}")
            sqlite_path = self._getSqlitePath(path)
            if operator in ("overlap", "ioverlap", "disjoint", "idisjoint"):
                col = literal_column("json_each.value")
                if operator[0] == "i":
                    col = func.lower(col)
                    value = [str(v).lower() for v in value]
                condition = (
                    select(1)
                    .select_from(func.json_each(PubsubItem.parsed, sqlite_path))
                    .where(col.in_(value))
                ).scalar_subquery()
                if operator in ("disjoint", "idisjoint"):
                    condition = condition.is_(None)
                stmt = stmt.where(condition)
            elif operator == "between":
                try:
                    left, right = value
                except (ValueError, TypeError):
                    raise ValueError(_(
                        "invalid value for \"between\" filter, you must use a 2 items "
                        "array: {value!r}"
                    ).format(value=value))
                col = func.json_extract(PubsubItem.parsed, sqlite_path)
                stmt = stmt.where(col.between(left, right))
            else:
                # we use func.json_extract instead of generic JSON way because SQLAlchemy
                # add a JSON_QUOTE to the value, and we want SQL value
                col = func.json_extract(PubsubItem.parsed, sqlite_path)
                stmt = stmt.where(getattr(col, op_attr)(value))

        # order
        order_by = query.get("order-by") or [{"order": "creation"}]

        for order_data in order_by:
            order, path = order_data.get("order"), order_data.get("path")
            if order and path:
                raise ValueError(_(
                    '"order" and "path" can\'t be used at the same time in '
                    '"order-by" data'
                ))
            if order:
                if order == "creation":
                    col = PubsubItem.id
                elif order == "modification":
                    col = PubsubItem.updated
                elif order == "item_id":
                    col = PubsubItem.name
                elif order == "rank":
                    if not fts:
                        raise ValueError(
                            "'rank' order can only be used with Full-Text Search (fts)"
                        )
                    col = literal_column("rank")
                else:
                    raise NotImplementedError(f"Unknown {order!r} order")
            else:
                # we have a JSON path
                # sqlite_path = self._getSqlitePath(path)
                col = PubsubItem.parsed[path]
            direction = order_data.get("direction", "ASC").lower()
            if not direction in ("asc", "desc"):
                raise ValueError(f"Invalid order-by direction: {direction!r}")
            stmt = stmt.order_by(getattr(col, direction)())

        # offset, limit
        index = query.get("index")
        if index:
            stmt = stmt.offset(index)
        limit = query.get("limit")
        if limit:
            stmt = stmt.limit(limit)

        async with self.session() as session:
            result = await session.execute(stmt)

        return result.scalars().all()