Mercurial > libervia-backend
diff libervia/backend/memory/sqla.py @ 4071:4b842c1fb686
refactoring: renamed `sat` package to `libervia.backend`
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 02 Jun 2023 11:49:51 +0200 |
parents | sat/memory/sqla.py@524856bd7b19 |
children | 74c66c0d93f3 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libervia/backend/memory/sqla.py Fri Jun 02 11:49:51 2023 +0200 @@ -0,0 +1,1704 @@ +#!/usr/bin/env python3 + +# Libervia: an XMPP client +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import asyncio +from asyncio.subprocess import PIPE +import copy +from datetime import datetime +from pathlib import Path +import sys +import time +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +from alembic import config as al_config, script as al_script +from alembic.runtime import migration as al_migration +from sqlalchemy import and_, delete, event, func, or_, update +from sqlalchemy import Integer, literal_column, text +from sqlalchemy.dialects.sqlite import insert +from sqlalchemy.engine import Connection, Engine +from sqlalchemy.exc import IntegrityError, NoResultFound +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy.future import select +from sqlalchemy.orm import ( + contains_eager, + joinedload, + selectinload, + sessionmaker, + subqueryload, +) +from sqlalchemy.orm.attributes import Mapped +from sqlalchemy.orm.decl_api import DeclarativeMeta +from sqlalchemy.sql.functions import coalesce, count, now, sum as sum_ +from twisted.internet import defer +from twisted.words.protocols.jabber import jid +from twisted.words.xish import domish + +from libervia.backend.core import exceptions +from libervia.backend.core.constants import Const as C +from libervia.backend.core.core_types import SatXMPPEntity +from libervia.backend.core.i18n import _ +from libervia.backend.core.log import getLogger +from libervia.backend.memory import migration +from libervia.backend.memory import sqla_config +from libervia.backend.memory.sqla_mapping import ( + Base, + Component, + File, + History, + Message, + NOT_IN_EXTRA, + ParamGen, + ParamInd, + PrivateGen, + PrivateGenBin, + PrivateInd, + PrivateIndBin, + Profile, + PubsubItem, + PubsubNode, + Subject, + SyncState, + Thread, +) +from libervia.backend.tools.common import uri +from libervia.backend.tools.utils import aio, as_future + + +log = getLogger(__name__) +migration_path = Path(migration.__file__).parent +#: mapping of Libervia search query operators to SQLAlchemy method name +OP_MAP = { + "==": "__eq__", + "eq": "__eq__", + "!=": "__ne__", + "ne": "__ne__", + ">": "__gt__", + "gt": "__gt__", + "<": "__le__", + "le": "__le__", + "between": "between", + "in": "in_", + "not_in": "not_in", + "overlap": "in_", + "ioverlap": "in_", + "disjoint": "in_", + "idisjoint": "in_", + "like": "like", + "ilike": "ilike", + "not_like": "notlike", + "not_ilike": "notilike", +} + + +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + +class Storage: + + def __init__(self): + self.initialized = defer.Deferred() + # we keep cache for the profiles (key: profile name, value: profile id) + # profile id to name + self.profiles: Dict[str, int] = {} + # profile id to component entry point + self.components: Dict[int, str] = {} + + def get_profile_by_id(self, profile_id): + return self.profiles.get(profile_id) + + async def migrate_apply(self, *args: str, log_output: bool = False) -> None: + """Do a migration command + + Commands are applied by running Alembic in a subprocess. + Arguments are alembic executables commands + + @param log_output: manage stdout and stderr: + - if False, stdout and stderr are buffered, and logged only in case of error + - if True, stdout and stderr will be logged during the command execution + @raise exceptions.DatabaseError: something went wrong while running the + process + """ + stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) + proc = await asyncio.create_subprocess_exec( + sys.executable, "-m", "alembic", *args, + stdout=stdout, stderr=stderr, cwd=migration_path + ) + log_out, log_err = await proc.communicate() + if proc.returncode != 0: + msg = _( + "Can't {operation} database (exit code {exit_code})" + ).format( + operation=args[0], + exit_code=proc.returncode + ) + if log_out or log_err: + msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" + log.error(msg) + + raise exceptions.DatabaseError(msg) + + async def create_db(self, engine: AsyncEngine, db_config: dict) -> None: + """Create a new database + + The database is generated from SQLAlchemy model, then stamped by Alembic + """ + # the dir may not exist if it's not the XDG recommended one + db_config["path"].parent.mkdir(0o700, True, True) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + log.debug("stamping the database") + await self.migrate_apply("stamp", "head") + log.debug("stamping done") + + def _check_db_is_up_to_date(self, conn: Connection) -> bool: + al_ini_path = migration_path / "alembic.ini" + al_cfg = al_config.Config(al_ini_path) + directory = al_script.ScriptDirectory.from_config(al_cfg) + context = al_migration.MigrationContext.configure(conn) + return set(context.get_current_heads()) == set(directory.get_heads()) + + def _sqlite_set_journal_mode_wal(self, conn: Connection) -> None: + """Check if journal mode is WAL, and set it if necesssary""" + result = conn.execute(text("PRAGMA journal_mode")) + if result.scalar() != "wal": + log.info("WAL mode not activated, activating it") + conn.execute(text("PRAGMA journal_mode=WAL")) + + async def check_and_update_db(self, engine: AsyncEngine, db_config: dict) -> None: + """Check that database is up-to-date, and update if necessary""" + async with engine.connect() as conn: + up_to_date = await conn.run_sync(self._check_db_is_up_to_date) + if up_to_date: + log.debug("Database is up-to-date") + else: + log.info("Database needs to be updated") + log.info("updating…") + await self.migrate_apply("upgrade", "head", log_output=True) + log.info("Database is now up-to-date") + + @aio + async def initialise(self) -> None: + log.info(_("Connecting database")) + + db_config = sqla_config.get_db_config() + engine = create_async_engine( + db_config["url"], + future=True, + ) + + new_base = not db_config["path"].exists() + if new_base: + log.info(_("The database is new, creating the tables")) + await self.create_db(engine, db_config) + else: + await self.check_and_update_db(engine, db_config) + + async with engine.connect() as conn: + await conn.run_sync(self._sqlite_set_journal_mode_wal) + + self.session = sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) + + async with self.session() as session: + result = await session.execute(select(Profile)) + for p in result.scalars(): + self.profiles[p.name] = p.id + result = await session.execute(select(Component)) + for c in result.scalars(): + self.components[c.profile_id] = c.entry_point + + self.initialized.callback(None) + + ## Generic + + @aio + async def get( + self, + client: SatXMPPEntity, + db_cls: DeclarativeMeta, + db_id_col: Mapped, + id_value: Any, + joined_loads = None + ) -> Optional[DeclarativeMeta]: + stmt = select(db_cls).where(db_id_col==id_value) + if client is not None: + stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) + if joined_loads is not None: + for joined_load in joined_loads: + stmt = stmt.options(joinedload(joined_load)) + async with self.session() as session: + result = await session.execute(stmt) + if joined_loads is not None: + result = result.unique() + return result.scalar_one_or_none() + + @aio + async def add(self, db_obj: DeclarativeMeta) -> None: + """Add an object to database""" + async with self.session() as session: + async with session.begin(): + session.add(db_obj) + + @aio + async def delete( + self, + db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]], + session_add: Optional[List[DeclarativeMeta]] = None + ) -> None: + """Delete an object from database + + @param db_obj: object to delete or list of objects to delete + @param session_add: other objects to add to session. + This is useful when parents of deleted objects needs to be updated too, or if + other objects needs to be updated in the same transaction. + """ + if not db_obj: + return + if not isinstance(db_obj, list): + db_obj = [db_obj] + async with self.session() as session: + async with session.begin(): + if session_add is not None: + for obj in session_add: + session.add(obj) + for obj in db_obj: + await session.delete(obj) + await session.commit() + + ## Profiles + + def get_profiles_list(self) -> List[str]: + """"Return list of all registered profiles""" + return list(self.profiles.keys()) + + def has_profile(self, profile_name: str) -> bool: + """return True if profile_name exists + + @param profile_name: name of the profile to check + """ + return profile_name in self.profiles + + def profile_is_component(self, profile_name: str) -> bool: + try: + return self.profiles[profile_name] in self.components + except KeyError: + raise exceptions.NotFound("the requested profile doesn't exists") + + def get_entry_point(self, profile_name: str) -> str: + try: + return self.components[self.profiles[profile_name]] + except KeyError: + raise exceptions.NotFound("the requested profile doesn't exists or is not a component") + + @aio + async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None: + """Create a new profile + + @param name: name of the profile + @param component: if not None, must point to a component entry point + """ + async with self.session() as session: + profile = Profile(name=name) + async with session.begin(): + session.add(profile) + self.profiles[profile.name] = profile.id + if component_ep is not None: + async with session.begin(): + component = Component(profile=profile, entry_point=component_ep) + session.add(component) + self.components[profile.id] = component_ep + return profile + + @aio + async def delete_profile(self, name: str) -> None: + """Delete profile + + @param name: name of the profile + """ + async with self.session() as session: + result = await session.execute(select(Profile).where(Profile.name == name)) + profile = result.scalar() + await session.delete(profile) + await session.commit() + del self.profiles[profile.name] + if profile.id in self.components: + del self.components[profile.id] + log.info(_("Profile {name!r} deleted").format(name = name)) + + ## Params + + @aio + async def load_gen_params(self, params_gen: dict) -> None: + """Load general parameters + + @param params_gen: dictionary to fill + """ + log.debug(_("loading general parameters from database")) + async with self.session() as session: + result = await session.execute(select(ParamGen)) + for p in result.scalars(): + params_gen[(p.category, p.name)] = p.value + + @aio + async def load_ind_params(self, params_ind: dict, profile: str) -> None: + """Load individual parameters + + @param params_ind: dictionary to fill + @param profile: a profile which *must* exist + """ + log.debug(_("loading individual parameters from database")) + async with self.session() as session: + result = await session.execute( + select(ParamInd).where(ParamInd.profile_id == self.profiles[profile]) + ) + for p in result.scalars(): + params_ind[(p.category, p.name)] = p.value + + @aio + async def get_ind_param(self, category: str, name: str, profile: str) -> Optional[str]: + """Ask database for the value of one specific individual parameter + + @param category: category of the parameter + @param name: name of the parameter + @param profile: %(doc_profile)s + """ + async with self.session() as session: + result = await session.execute( + select(ParamInd.value) + .filter_by( + category=category, + name=name, + profile_id=self.profiles[profile] + ) + ) + return result.scalar_one_or_none() + + @aio + async def get_ind_param_values(self, category: str, name: str) -> Dict[str, str]: + """Ask database for the individual values of a parameter for all profiles + + @param category: category of the parameter + @param name: name of the parameter + @return dict: profile => value map + """ + async with self.session() as session: + result = await session.execute( + select(ParamInd) + .filter_by( + category=category, + name=name + ) + .options(subqueryload(ParamInd.profile)) + ) + return {param.profile.name: param.value for param in result.scalars()} + + @aio + async def set_gen_param(self, category: str, name: str, value: Optional[str]) -> None: + """Save the general parameters in database + + @param category: category of the parameter + @param name: name of the parameter + @param value: value to set + """ + async with self.session() as session: + stmt = insert(ParamGen).values( + category=category, + name=name, + value=value + ).on_conflict_do_update( + index_elements=(ParamGen.category, ParamGen.name), + set_={ + ParamGen.value: value + } + ) + await session.execute(stmt) + await session.commit() + + @aio + async def set_ind_param( + self, + category:str, + name: str, + value: Optional[str], + profile: str + ) -> None: + """Save the individual parameters in database + + @param category: category of the parameter + @param name: name of the parameter + @param value: value to set + @param profile: a profile which *must* exist + """ + async with self.session() as session: + stmt = insert(ParamInd).values( + category=category, + name=name, + profile_id=self.profiles[profile], + value=value + ).on_conflict_do_update( + index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), + set_={ + ParamInd.value: value + } + ) + await session.execute(stmt) + await session.commit() + + def _jid_filter(self, jid_: jid.JID, dest: bool = False): + """Generate condition to filter on a JID, using relevant columns + + @param dest: True if it's the destinee JID, otherwise it's the source one + @param jid_: JID to filter by + """ + if jid_.resource: + if dest: + return and_( + History.dest == jid_.userhost(), + History.dest_res == jid_.resource + ) + else: + return and_( + History.source == jid_.userhost(), + History.source_res == jid_.resource + ) + else: + if dest: + return History.dest == jid_.userhost() + else: + return History.source == jid_.userhost() + + @aio + async def history_get( + self, + from_jid: Optional[jid.JID], + to_jid: Optional[jid.JID], + limit: Optional[int] = None, + between: bool = True, + filters: Optional[Dict[str, str]] = None, + profile: Optional[str] = None, + ) -> List[Tuple[ + str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] + ]: + """Retrieve messages in history + + @param from_jid: source JID (full, or bare for catchall) + @param to_jid: dest JID (full, or bare for catchall) + @param limit: maximum number of messages to get: + - 0 for no message (returns the empty list) + - None for unlimited + @param between: confound source and dest (ignore the direction) + @param filters: pattern to filter the history results + @return: list of messages as in [message_new], minus the profile which is already + known. + """ + # we have to set a default value to profile because it's last argument + # and thus follow other keyword arguments with default values + # but None should not be used for it + assert profile is not None + if limit == 0: + return [] + if filters is None: + filters = {} + + stmt = ( + select(History) + .filter_by( + profile_id=self.profiles[profile] + ) + .outerjoin(History.messages) + .outerjoin(History.subjects) + .outerjoin(History.thread) + .options( + contains_eager(History.messages), + contains_eager(History.subjects), + contains_eager(History.thread), + ) + .order_by( + # timestamp may be identical for 2 close messages (specially when delay is + # used) that's why we order ties by received_timestamp. We'll reverse the + # order when returning the result. We use DESC here so LIMIT keep the last + # messages + History.timestamp.desc(), + History.received_timestamp.desc() + ) + ) + + + if not from_jid and not to_jid: + # no jid specified, we want all one2one communications + pass + elif between: + if not from_jid or not to_jid: + # we only have one jid specified, we check all messages + # from or to this jid + jid_ = from_jid or to_jid + stmt = stmt.where( + or_( + self._jid_filter(jid_), + self._jid_filter(jid_, dest=True) + ) + ) + else: + # we have 2 jids specified, we check all communications between + # those 2 jids + stmt = stmt.where( + or_( + and_( + self._jid_filter(from_jid), + self._jid_filter(to_jid, dest=True), + ), + and_( + self._jid_filter(to_jid), + self._jid_filter(from_jid, dest=True), + ) + ) + ) + else: + # we want one communication in specific direction (from somebody or + # to somebody). + if from_jid is not None: + stmt = stmt.where(self._jid_filter(from_jid)) + if to_jid is not None: + stmt = stmt.where(self._jid_filter(to_jid, dest=True)) + + if filters: + if 'timestamp_start' in filters: + stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) + if 'before_uid' in filters: + # orignially this query was using SQLITE's rowid. This has been changed + # to use coalesce(received_timestamp, timestamp) to be SQL engine independant + stmt = stmt.where( + coalesce( + History.received_timestamp, + History.timestamp + ) < ( + select(coalesce(History.received_timestamp, History.timestamp)) + .filter_by(uid=filters["before_uid"]) + ).scalar_subquery() + ) + if 'body' in filters: + # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html + stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) + if 'search' in filters: + search_term = f"%{filters['search']}%" + stmt = stmt.where(or_( + Message.message.like(search_term), + History.source_res.like(search_term) + )) + if 'types' in filters: + types = filters['types'].split() + stmt = stmt.where(History.type.in_(types)) + if 'not_types' in filters: + types = filters['not_types'].split() + stmt = stmt.where(History.type.not_in(types)) + if 'last_stanza_id' in filters: + # this request get the last message with a "stanza_id" that we + # have in history. This is mainly used to retrieve messages sent + # while we were offline, using MAM (XEP-0313). + if (filters['last_stanza_id'] is not True + or limit != 1): + raise ValueError("Unexpected values for last_stanza_id filter") + stmt = stmt.where(History.stanza_id.is_not(None)) + if 'origin_id' in filters: + stmt = stmt.where(History.origin_id == filters["origin_id"]) + + if limit is not None: + stmt = stmt.limit(limit) + + async with self.session() as session: + result = await session.execute(stmt) + + result = result.scalars().unique().all() + result.reverse() + return [h.as_tuple() for h in result] + + @aio + async def add_to_history(self, data: dict, profile: str) -> None: + """Store a new message in history + + @param data: message data as build by SatMessageProtocol.onMessage + """ + extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} + messages = [Message(message=mess, language=lang) + for lang, mess in data["message"].items()] + subjects = [Subject(subject=mess, language=lang) + for lang, mess in data["subject"].items()] + if "thread" in data["extra"]: + thread = Thread(thread_id=data["extra"]["thread"], + parent_id=data["extra"].get["thread_parent"]) + else: + thread = None + try: + async with self.session() as session: + async with session.begin(): + session.add(History( + uid=data["uid"], + origin_id=data["extra"].get("origin_id"), + stanza_id=data["extra"].get("stanza_id"), + update_uid=data["extra"].get("update_uid"), + profile_id=self.profiles[profile], + source_jid=data["from"], + dest_jid=data["to"], + timestamp=data["timestamp"], + received_timestamp=data.get("received_timestamp"), + type=data["type"], + extra=extra, + messages=messages, + subjects=subjects, + thread=thread, + )) + except IntegrityError as e: + if "unique" in str(e.orig).lower(): + log.debug( + f"message {data['uid']!r} is already in history, not storing it again" + ) + else: + log.error(f"Can't store message {data['uid']!r} in history: {e}") + except Exception as e: + log.critical( + f"Can't store message, unexpected exception (uid: {data['uid']}): {e}" + ) + + ## Private values + + def _get_private_class(self, binary, profile): + """Get ORM class to use for private values""" + if profile is None: + return PrivateGenBin if binary else PrivateGen + else: + return PrivateIndBin if binary else PrivateInd + + + @aio + async def get_privates( + self, + namespace:str, + keys: Optional[Iterable[str]] = None, + binary: bool = False, + profile: Optional[str] = None + ) -> Dict[str, Any]: + """Get private value(s) from databases + + @param namespace: namespace of the values + @param keys: keys of the values to get None to get all keys/values + @param binary: True to deserialise binary values + @param profile: profile to use for individual values + None to use general values + @return: gotten keys/values + """ + if keys is not None: + keys = list(keys) + log.debug( + f"getting {'general' if profile is None else 'individual'}" + f"{' binary' if binary else ''} private values from database for namespace " + f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}" + ) + cls = self._get_private_class(binary, profile) + stmt = select(cls).filter_by(namespace=namespace) + if keys: + stmt = stmt.where(cls.key.in_(list(keys))) + if profile is not None: + stmt = stmt.filter_by(profile_id=self.profiles[profile]) + async with self.session() as session: + result = await session.execute(stmt) + return {p.key: p.value for p in result.scalars()} + + @aio + async def set_private_value( + self, + namespace: str, + key:str, + value: Any, + binary: bool = False, + profile: Optional[str] = None + ) -> None: + """Set a private value in database + + @param namespace: namespace of the values + @param key: key of the value to set + @param value: value to set + @param binary: True if it's a binary values + binary values need to be serialised, used for everything but strings + @param profile: profile to use for individual value + if None, it's a general value + """ + cls = self._get_private_class(binary, profile) + + values = { + "namespace": namespace, + "key": key, + "value": value + } + index_elements = [cls.namespace, cls.key] + + if profile is not None: + values["profile_id"] = self.profiles[profile] + index_elements.append(cls.profile_id) + + async with self.session() as session: + await session.execute( + insert(cls).values(**values).on_conflict_do_update( + index_elements=index_elements, + set_={ + cls.value: value + } + ) + ) + await session.commit() + + @aio + async def del_private_value( + self, + namespace: str, + key: str, + binary: bool = False, + profile: Optional[str] = None + ) -> None: + """Delete private value from database + + @param category: category of the privateeter + @param key: key of the private value + @param binary: True if it's a binary values + @param profile: profile to use for individual value + if None, it's a general value + """ + cls = self._get_private_class(binary, profile) + + stmt = delete(cls).filter_by(namespace=namespace, key=key) + + if profile is not None: + stmt = stmt.filter_by(profile_id=self.profiles[profile]) + + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def del_private_namespace( + self, + namespace: str, + binary: bool = False, + profile: Optional[str] = None + ) -> None: + """Delete all data from a private namespace + + Be really cautious when you use this method, as all data with given namespace are + removed. + Params are the same as for del_private_value + """ + cls = self._get_private_class(binary, profile) + + stmt = delete(cls).filter_by(namespace=namespace) + + if profile is not None: + stmt = stmt.filter_by(profile_id=self.profiles[profile]) + + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + ## Files + + @aio + async def get_files( + self, + client: Optional[SatXMPPEntity], + file_id: Optional[str] = None, + version: Optional[str] = '', + parent: Optional[str] = None, + type_: Optional[str] = None, + file_hash: Optional[str] = None, + hash_algo: Optional[str] = None, + name: Optional[str] = None, + namespace: Optional[str] = None, + mime_type: Optional[str] = None, + public_id: Optional[str] = None, + owner: Optional[jid.JID] = None, + access: Optional[dict] = None, + projection: Optional[List[str]] = None, + unique: bool = False + ) -> List[dict]: + """Retrieve files with with given filters + + @param file_id: id of the file + None to ignore + @param version: version of the file + None to ignore + empty string to look for current version + @param parent: id of the directory containing the files + None to ignore + empty string to look for root files/directories + @param projection: name of columns to retrieve + None to retrieve all + @param unique: if True will remove duplicates + other params are the same as for [set_file] + @return: files corresponding to filters + """ + if projection is None: + projection = [ + 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', + 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', + 'created', 'modified', 'owner', 'access', 'extra' + ] + + stmt = select(*[getattr(File, f) for f in projection]) + + if unique: + stmt = stmt.distinct() + + if client is not None: + stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) + else: + if public_id is None: + raise exceptions.InternalError( + "client can only be omitted when public_id is set" + ) + if file_id is not None: + stmt = stmt.filter_by(id=file_id) + if version is not None: + stmt = stmt.filter_by(version=version) + if parent is not None: + stmt = stmt.filter_by(parent=parent) + if type_ is not None: + stmt = stmt.filter_by(type=type_) + if file_hash is not None: + stmt = stmt.filter_by(file_hash=file_hash) + if hash_algo is not None: + stmt = stmt.filter_by(hash_algo=hash_algo) + if name is not None: + stmt = stmt.filter_by(name=name) + if namespace is not None: + stmt = stmt.filter_by(namespace=namespace) + if mime_type is not None: + if '/' in mime_type: + media_type, media_subtype = mime_type.split("/", 1) + stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) + else: + stmt = stmt.filter_by(media_type=mime_type) + if public_id is not None: + stmt = stmt.filter_by(public_id=public_id) + if owner is not None: + stmt = stmt.filter_by(owner=owner) + if access is not None: + raise NotImplementedError('Access check is not implemented yet') + # a JSON comparison is needed here + + async with self.session() as session: + result = await session.execute(stmt) + + return [dict(r) for r in result] + + @aio + async def set_file( + self, + client: SatXMPPEntity, + name: str, + file_id: str, + version: str = "", + parent: str = "", + type_: str = C.FILE_TYPE_FILE, + file_hash: Optional[str] = None, + hash_algo: Optional[str] = None, + size: int = None, + namespace: Optional[str] = None, + mime_type: Optional[str] = None, + public_id: Optional[str] = None, + created: Optional[float] = None, + modified: Optional[float] = None, + owner: Optional[jid.JID] = None, + access: Optional[dict] = None, + extra: Optional[dict] = None + ) -> None: + """Set a file metadata + + @param client: client owning the file + @param name: name of the file (must not contain "/") + @param file_id: unique id of the file + @param version: version of this file + @param parent: id of the directory containing this file + Empty string if it is a root file/directory + @param type_: one of: + - file + - directory + @param file_hash: unique hash of the payload + @param hash_algo: algorithm used for hashing the file (usually sha-256) + @param size: size in bytes + @param namespace: identifier (human readable is better) to group files + for instance, namespace could be used to group files in a specific photo album + @param mime_type: media type of the file, or None if not known/guessed + @param public_id: ID used to server the file publicly via HTTP + @param created: UNIX time of creation + @param modified: UNIX time of last modification, or None to use created date + @param owner: jid of the owner of the file (mainly useful for component) + @param access: serialisable dictionary with access rules. See [memory.memory] for + details + @param extra: serialisable dictionary of any extra data + will be encoded to json in database + """ + if mime_type is None: + media_type = media_subtype = None + elif '/' in mime_type: + media_type, media_subtype = mime_type.split('/', 1) + else: + media_type, media_subtype = mime_type, None + + async with self.session() as session: + async with session.begin(): + session.add(File( + id=file_id, + version=version.strip(), + parent=parent, + type=type_, + file_hash=file_hash, + hash_algo=hash_algo, + name=name, + size=size, + namespace=namespace, + media_type=media_type, + media_subtype=media_subtype, + public_id=public_id, + created=time.time() if created is None else created, + modified=modified, + owner=owner, + access=access, + extra=extra, + profile_id=self.profiles[client.profile] + )) + + @aio + async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int: + async with self.session() as session: + result = await session.execute( + select(sum_(File.size)).filter_by( + owner=owner, + type=C.FILE_TYPE_FILE, + profile_id=self.profiles[client.profile] + )) + return result.scalar_one_or_none() or 0 + + @aio + async def file_delete(self, file_id: str) -> None: + """Delete file metadata from the database + + @param file_id: id of the file to delete + NOTE: file itself must still be removed, this method only handle metadata in + database + """ + async with self.session() as session: + await session.execute(delete(File).filter_by(id=file_id)) + await session.commit() + + @aio + async def file_update( + self, + file_id: str, + column: str, + update_cb: Callable[[dict], None] + ) -> None: + """Update a column value using a method to avoid race conditions + + the older value will be retrieved from database, then update_cb will be applied to + update it, and file will be updated checking that older value has not been changed + meanwhile by an other user. If it has changed, it tries again a couple of times + before failing + @param column: column name (only "access" or "extra" are allowed) + @param update_cb: method to update the value of the colum + the method will take older value as argument, and must update it in place + update_cb must not care about serialization, + it get the deserialized data (i.e. a Python object) directly + @raise exceptions.NotFound: there is not file with this id + """ + if column not in ('access', 'extra'): + raise exceptions.InternalError('bad column name') + orm_col = getattr(File, column) + + for i in range(5): + async with self.session() as session: + try: + value = (await session.execute( + select(orm_col).filter_by(id=file_id) + )).scalar_one() + except NoResultFound: + raise exceptions.NotFound + old_value = copy.deepcopy(value) + update_cb(value) + stmt = update(File).filter_by(id=file_id).values({column: value}) + if not old_value: + # because JsonDefaultDict convert NULL to an empty dict, we have to + # test both for empty dict and None when we have an empty dict + stmt = stmt.where((orm_col == None) | (orm_col == old_value)) + else: + stmt = stmt.where(orm_col == old_value) + result = await session.execute(stmt) + await session.commit() + + if result.rowcount == 1: + break + + log.warning( + _("table not updated, probably due to race condition, trying again " + "({tries})").format(tries=i+1) + ) + + else: + raise exceptions.DatabaseError( + _("Can't update file {file_id} due to race condition") + .format(file_id=file_id) + ) + + @aio + async def get_pubsub_node( + self, + client: SatXMPPEntity, + service: jid.JID, + name: str, + with_items: bool = False, + with_subscriptions: bool = False, + create: bool = False, + create_kwargs: Optional[dict] = None + ) -> Optional[PubsubNode]: + """Retrieve a PubsubNode from DB + + @param service: service hosting the node + @param name: node's name + @param with_items: retrieve items in the same query + @param with_subscriptions: retrieve subscriptions in the same query + @param create: if the node doesn't exist in DB, create it + @param create_kwargs: keyword arguments to use with ``set_pubsub_node`` if the node + needs to be created. + """ + async with self.session() as session: + stmt = ( + select(PubsubNode) + .filter_by( + service=service, + name=name, + profile_id=self.profiles[client.profile], + ) + ) + if with_items: + stmt = stmt.options( + joinedload(PubsubNode.items) + ) + if with_subscriptions: + stmt = stmt.options( + joinedload(PubsubNode.subscriptions) + ) + result = await session.execute(stmt) + ret = result.unique().scalar_one_or_none() + if ret is None and create: + # we auto-create the node + if create_kwargs is None: + create_kwargs = {} + try: + return await as_future(self.set_pubsub_node( + client, service, name, **create_kwargs + )) + except IntegrityError as e: + if "unique" in str(e.orig).lower(): + # the node may already exist, if it has been created just after + # get_pubsub_node above + log.debug("ignoring UNIQUE constraint error") + cached_node = await as_future(self.get_pubsub_node( + client, + service, + name, + with_items=with_items, + with_subscriptions=with_subscriptions + )) + else: + raise e + else: + return ret + + @aio + async def set_pubsub_node( + self, + client: SatXMPPEntity, + service: jid.JID, + name: str, + analyser: Optional[str] = None, + type_: Optional[str] = None, + subtype: Optional[str] = None, + subscribed: bool = False, + ) -> PubsubNode: + node = PubsubNode( + profile_id=self.profiles[client.profile], + service=service, + name=name, + subscribed=subscribed, + analyser=analyser, + type_=type_, + subtype=subtype, + subscriptions=[], + ) + async with self.session() as session: + async with session.begin(): + session.add(node) + return node + + @aio + async def update_pubsub_node_sync_state( + self, + node: PubsubNode, + state: SyncState + ) -> None: + async with self.session() as session: + async with session.begin(): + await session.execute( + update(PubsubNode) + .filter_by(id=node.id) + .values( + sync_state=state, + sync_state_updated=time.time(), + ) + ) + + @aio + async def delete_pubsub_node( + self, + profiles: Optional[List[str]], + services: Optional[List[jid.JID]], + names: Optional[List[str]] + ) -> None: + """Delete items cached for a node + + @param profiles: profile names from which nodes must be deleted. + None to remove nodes from ALL profiles + @param services: JIDs of pubsub services from which nodes must be deleted. + None to remove nodes from ALL services + @param names: names of nodes which must be deleted. + None to remove ALL nodes whatever is their names + """ + stmt = delete(PubsubNode) + if profiles is not None: + stmt = stmt.where( + PubsubNode.profile.in_( + [self.profiles[p] for p in profiles] + ) + ) + if services is not None: + stmt = stmt.where(PubsubNode.service.in_(services)) + if names is not None: + stmt = stmt.where(PubsubNode.name.in_(names)) + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def cache_pubsub_items( + self, + client: SatXMPPEntity, + node: PubsubNode, + items: List[domish.Element], + parsed_items: Optional[List[dict]] = None, + ) -> None: + """Add items to database, using an upsert taking care of "updated" field""" + if parsed_items is not None and len(items) != len(parsed_items): + raise exceptions.InternalError( + "parsed_items must have the same lenght as items" + ) + async with self.session() as session: + async with session.begin(): + for idx, item in enumerate(items): + parsed = parsed_items[idx] if parsed_items else None + stmt = insert(PubsubItem).values( + node_id = node.id, + name = item["id"], + data = item, + parsed = parsed, + ).on_conflict_do_update( + index_elements=(PubsubItem.node_id, PubsubItem.name), + set_={ + PubsubItem.data: item, + PubsubItem.parsed: parsed, + PubsubItem.updated: now() + } + ) + await session.execute(stmt) + await session.commit() + + @aio + async def delete_pubsub_items( + self, + node: PubsubNode, + items_names: Optional[List[str]] = None + ) -> None: + """Delete items cached for a node + + @param node: node from which items must be deleted + @param items_names: names of items to delete + if None, ALL items will be deleted + """ + stmt = delete(PubsubItem) + if node is not None: + if isinstance(node, list): + stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node])) + else: + stmt = stmt.filter_by(node_id=node.id) + if items_names is not None: + stmt = stmt.where(PubsubItem.name.in_(items_names)) + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def purge_pubsub_items( + self, + services: Optional[List[jid.JID]] = None, + names: Optional[List[str]] = None, + types: Optional[List[str]] = None, + subtypes: Optional[List[str]] = None, + profiles: Optional[List[str]] = None, + created_before: Optional[datetime] = None, + updated_before: Optional[datetime] = None, + ) -> None: + """Delete items cached for a node + + @param node: node from which items must be deleted + @param items_names: names of items to delete + if None, ALL items will be deleted + """ + stmt = delete(PubsubItem) + node_fields = { + "service": services, + "name": names, + "type_": types, + "subtype": subtypes, + } + if profiles is not None: + node_fields["profile_id"] = [self.profiles[p] for p in profiles] + + if any(x is not None for x in node_fields.values()): + sub_q = select(PubsubNode.id) + for col, values in node_fields.items(): + if values is None: + continue + sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) + stmt = ( + stmt + .where(PubsubItem.node_id.in_(sub_q)) + .execution_options(synchronize_session=False) + ) + + if created_before is not None: + stmt = stmt.where(PubsubItem.created < created_before) + + if updated_before is not None: + stmt = stmt.where(PubsubItem.updated < updated_before) + + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def get_items( + self, + node: PubsubNode, + max_items: Optional[int] = None, + item_ids: Optional[list[str]] = None, + before: Optional[str] = None, + after: Optional[str] = None, + from_index: Optional[int] = None, + order_by: Optional[List[str]] = None, + desc: bool = True, + force_rsm: bool = False, + ) -> Tuple[List[PubsubItem], dict]: + """Get Pubsub Items from cache + + @param node: retrieve items from this node (must be synchronised) + @param max_items: maximum number of items to retrieve + @param before: get items which are before the item with this name in given order + empty string is not managed here, use desc order to reproduce RSM + behaviour. + @param after: get items which are after the item with this name in given order + @param from_index: get items with item index (as defined in RSM spec) + starting from this number + @param order_by: sorting order of items (one of C.ORDER_BY_*) + @param desc: direction or ordering + @param force_rsm: if True, force the use of RSM worklow. + RSM workflow is automatically used if any of before, after or + from_index is used, but if only RSM max_items is used, it won't be + used by default. This parameter let's use RSM workflow in this + case. Note that in addition to RSM metadata, the result will not be + the same (max_items without RSM will returns most recent items, + i.e. last items in modification order, while max_items with RSM + will return the oldest ones (i.e. first items in modification + order). + to be used when max_items is used from RSM + """ + + metadata = { + "service": node.service, + "node": node.name, + "uri": uri.build_xmpp_uri( + "pubsub", + path=node.service.full(), + node=node.name, + ), + } + if max_items is None: + max_items = 20 + + use_rsm = any((before, after, from_index is not None)) + if force_rsm and not use_rsm: + # + use_rsm = True + from_index = 0 + + stmt = ( + select(PubsubItem) + .filter_by(node_id=node.id) + .limit(max_items) + ) + + if item_ids is not None: + stmt = stmt.where(PubsubItem.name.in_(item_ids)) + + if not order_by: + order_by = [C.ORDER_BY_MODIFICATION] + + order = [] + for order_type in order_by: + if order_type == C.ORDER_BY_MODIFICATION: + if desc: + order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc())) + else: + order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc())) + elif order_type == C.ORDER_BY_CREATION: + if desc: + order.append(PubsubItem.id.desc()) + else: + order.append(PubsubItem.id.asc()) + else: + raise exceptions.InternalError(f"Unknown order type {order_type!r}") + + stmt = stmt.order_by(*order) + + if use_rsm: + # CTE to have result row numbers + row_num_q = select( + PubsubItem.id, + PubsubItem.name, + # row_number starts from 1, but RSM index must start from 0 + (func.row_number().over(order_by=order)-1).label("item_index") + ).filter_by(node_id=node.id) + + row_num_cte = row_num_q.cte() + + if max_items > 0: + # as we can't simply use PubsubItem.id when we order by modification, + # we need to use row number + item_name = before or after + row_num_limit_q = ( + select(row_num_cte.c.item_index) + .where(row_num_cte.c.name==item_name) + ).scalar_subquery() + + stmt = ( + select(row_num_cte.c.item_index, PubsubItem) + .join(row_num_cte, PubsubItem.id == row_num_cte.c.id) + .limit(max_items) + ) + if before: + stmt = ( + stmt + .where(row_num_cte.c.item_index<row_num_limit_q) + .order_by(row_num_cte.c.item_index.desc()) + ) + elif after: + stmt = ( + stmt + .where(row_num_cte.c.item_index>row_num_limit_q) + .order_by(row_num_cte.c.item_index.asc()) + ) + else: + stmt = ( + stmt + .where(row_num_cte.c.item_index>=from_index) + .order_by(row_num_cte.c.item_index.asc()) + ) + # from_index is used + + async with self.session() as session: + if max_items == 0: + items = result = [] + else: + result = await session.execute(stmt) + result = result.all() + if before: + result.reverse() + items = [row[-1] for row in result] + rows_count = ( + await session.execute(row_num_q.with_only_columns(count())) + ).scalar_one() + + try: + index = result[0][0] + except IndexError: + index = None + + try: + first = result[0][1].name + except IndexError: + first = None + last = None + else: + last = result[-1][1].name + + metadata["rsm"] = { + k: v for k, v in { + "index": index, + "count": rows_count, + "first": first, + "last": last, + }.items() if v is not None + } + metadata["complete"] = (index or 0) + len(result) == rows_count + + return items, metadata + + async with self.session() as session: + result = await session.execute(stmt) + + result = result.scalars().all() + if desc: + result.reverse() + return result, metadata + + def _get_sqlite_path( + self, + path: List[Union[str, int]] + ) -> str: + """generate path suitable to query JSON element with SQLite""" + return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}" + + @aio + async def search_pubsub_items( + self, + query: dict, + ) -> Tuple[List[PubsubItem]]: + """Search for pubsub items in cache + + @param query: search terms. Keys can be: + :fts (str): + Full-Text Search query. Currently SQLite FT5 engine is used, its query + syntax can be used, see `FTS5 Query documentation + <https://sqlite.org/fts5.html#full_text_query_syntax>`_ + :profiles (list[str]): + filter on nodes linked to those profiles + :nodes (list[str]): + filter on nodes with those names + :services (list[jid.JID]): + filter on nodes from those services + :types (list[str|None]): + filter on nodes with those types. None can be used to filter on nodes with + no type set + :subtypes (list[str|None]): + filter on nodes with those subtypes. None can be used to filter on nodes with + no subtype set + :names (list[str]): + filter on items with those names + :parsed (list[dict]): + Filter on a parsed data field. The dict must contain 3 keys: ``path`` + which is a list of str or int giving the path to the field of interest + (str for a dict key, int for a list index), ``operator`` with indicate the + operator to use to check the condition, and ``value`` which depends of + field type and operator. + + See documentation for details on operators (it's currently explained at + ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command + documentation). + + :order-by (list[dict]): + Indicates how to order results. The dict can contain either a ``order`` + for a well-know order or a ``path`` for a parsed data field path + (``order`` and ``path`` can't be used at the same time), an an optional + ``direction`` which can be ``asc`` or ``desc``. See documentation for + details on well-known orders (it's currently explained at + ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command + documentation). + + :index (int): + starting index of items to return from the query result. It's translated + to SQL's OFFSET + + :limit (int): + maximum number of items to return. It's translated to SQL's LIMIT. + + @result: found items (the ``node`` attribute will be filled with suitable + PubsubNode) + """ + # TODO: FTS and parsed data filters use SQLite specific syntax + # when other DB engines will be used, this will have to be adapted + stmt = select(PubsubItem) + + # Full-Text Search + fts = query.get("fts") + if fts: + fts_select = text( + "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)" + ).bindparams(fts_query=fts).columns(rowid=Integer).subquery() + stmt = ( + stmt + .select_from(fts_select) + .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id) + ) + + # node related filters + profiles = query.get("profiles") + if (profiles + or any(query.get(k) for k in ("nodes", "services", "types", "subtypes")) + ): + stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node)) + if profiles: + try: + stmt = stmt.where( + PubsubNode.profile_id.in_(self.profiles[p] for p in profiles) + ) + except KeyError as e: + raise exceptions.ProfileUnknownError( + f"This profile doesn't exist: {e.args[0]!r}" + ) + for key, attr in ( + ("nodes", "name"), + ("services", "service"), + ("types", "type_"), + ("subtypes", "subtype") + ): + value = query.get(key) + if not value: + continue + if key in ("types", "subtypes") and None in value: + # NULL can't be used with SQL's IN, so we have to add a condition with + # IS NULL, and use a OR if there are other values to check + value.remove(None) + condition = getattr(PubsubNode, attr).is_(None) + if value: + condition = or_( + getattr(PubsubNode, attr).in_(value), + condition + ) + else: + condition = getattr(PubsubNode, attr).in_(value) + stmt = stmt.where(condition) + else: + stmt = stmt.options(selectinload(PubsubItem.node)) + + # names + names = query.get("names") + if names: + stmt = stmt.where(PubsubItem.name.in_(names)) + + # parsed data filters + parsed = query.get("parsed", []) + for filter_ in parsed: + try: + path = filter_["path"] + operator = filter_["op"] + value = filter_["value"] + except KeyError as e: + raise ValueError( + f'missing mandatory key {e.args[0]!r} in "parsed" filter' + ) + try: + op_attr = OP_MAP[operator] + except KeyError: + raise ValueError(f"invalid operator: {operator!r}") + sqlite_path = self._get_sqlite_path(path) + if operator in ("overlap", "ioverlap", "disjoint", "idisjoint"): + col = literal_column("json_each.value") + if operator[0] == "i": + col = func.lower(col) + value = [str(v).lower() for v in value] + condition = ( + select(1) + .select_from(func.json_each(PubsubItem.parsed, sqlite_path)) + .where(col.in_(value)) + ).scalar_subquery() + if operator in ("disjoint", "idisjoint"): + condition = condition.is_(None) + stmt = stmt.where(condition) + elif operator == "between": + try: + left, right = value + except (ValueError, TypeError): + raise ValueError(_( + "invalid value for \"between\" filter, you must use a 2 items " + "array: {value!r}" + ).format(value=value)) + col = func.json_extract(PubsubItem.parsed, sqlite_path) + stmt = stmt.where(col.between(left, right)) + else: + # we use func.json_extract instead of generic JSON way because SQLAlchemy + # add a JSON_QUOTE to the value, and we want SQL value + col = func.json_extract(PubsubItem.parsed, sqlite_path) + stmt = stmt.where(getattr(col, op_attr)(value)) + + # order + order_by = query.get("order-by") or [{"order": "creation"}] + + for order_data in order_by: + order, path = order_data.get("order"), order_data.get("path") + if order and path: + raise ValueError(_( + '"order" and "path" can\'t be used at the same time in ' + '"order-by" data' + )) + if order: + if order == "creation": + col = PubsubItem.id + elif order == "modification": + col = PubsubItem.updated + elif order == "item_id": + col = PubsubItem.name + elif order == "rank": + if not fts: + raise ValueError( + "'rank' order can only be used with Full-Text Search (fts)" + ) + col = literal_column("rank") + else: + raise NotImplementedError(f"Unknown {order!r} order") + else: + # we have a JSON path + # sqlite_path = self._get_sqlite_path(path) + col = PubsubItem.parsed[path] + direction = order_data.get("direction", "ASC").lower() + if not direction in ("asc", "desc"): + raise ValueError(f"Invalid order-by direction: {direction!r}") + stmt = stmt.order_by(getattr(col, direction)()) + + # offset, limit + index = query.get("index") + if index: + stmt = stmt.offset(index) + limit = query.get("limit") + if limit: + stmt = stmt.limit(limit) + + async with self.session() as session: + result = await session.execute(stmt) + + return result.scalars().all()