Mercurial > libervia-backend
view libervia/backend/memory/sqla.py @ 4071:4b842c1fb686
refactoring: renamed `sat` package to `libervia.backend`
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 02 Jun 2023 11:49:51 +0200 |
parents | sat/memory/sqla.py@524856bd7b19 |
children | 74c66c0d93f3 |
line wrap: on
line source
#!/usr/bin/env python3 # Libervia: an XMPP client # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import asyncio from asyncio.subprocess import PIPE import copy from datetime import datetime from pathlib import Path import sys import time from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from alembic import config as al_config, script as al_script from alembic.runtime import migration as al_migration from sqlalchemy import and_, delete, event, func, or_, update from sqlalchemy import Integer, literal_column, text from sqlalchemy.dialects.sqlite import insert from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.future import select from sqlalchemy.orm import ( contains_eager, joinedload, selectinload, sessionmaker, subqueryload, ) from sqlalchemy.orm.attributes import Mapped from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.sql.functions import coalesce, count, now, sum as sum_ from twisted.internet import defer from twisted.words.protocols.jabber import jid from twisted.words.xish import domish from libervia.backend.core import exceptions from libervia.backend.core.constants import Const as C from libervia.backend.core.core_types import SatXMPPEntity from libervia.backend.core.i18n import _ from libervia.backend.core.log import getLogger from libervia.backend.memory import migration from libervia.backend.memory import sqla_config from libervia.backend.memory.sqla_mapping import ( Base, Component, File, History, Message, NOT_IN_EXTRA, ParamGen, ParamInd, PrivateGen, PrivateGenBin, PrivateInd, PrivateIndBin, Profile, PubsubItem, PubsubNode, Subject, SyncState, Thread, ) from libervia.backend.tools.common import uri from libervia.backend.tools.utils import aio, as_future log = getLogger(__name__) migration_path = Path(migration.__file__).parent #: mapping of Libervia search query operators to SQLAlchemy method name OP_MAP = { "==": "__eq__", "eq": "__eq__", "!=": "__ne__", "ne": "__ne__", ">": "__gt__", "gt": "__gt__", "<": "__le__", "le": "__le__", "between": "between", "in": "in_", "not_in": "not_in", "overlap": "in_", "ioverlap": "in_", "disjoint": "in_", "idisjoint": "in_", "like": "like", "ilike": "ilike", "not_like": "notlike", "not_ilike": "notilike", } @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() class Storage: def __init__(self): self.initialized = defer.Deferred() # we keep cache for the profiles (key: profile name, value: profile id) # profile id to name self.profiles: Dict[str, int] = {} # profile id to component entry point self.components: Dict[int, str] = {} def get_profile_by_id(self, profile_id): return self.profiles.get(profile_id) async def migrate_apply(self, *args: str, log_output: bool = False) -> None: """Do a migration command Commands are applied by running Alembic in a subprocess. Arguments are alembic executables commands @param log_output: manage stdout and stderr: - if False, stdout and stderr are buffered, and logged only in case of error - if True, stdout and stderr will be logged during the command execution @raise exceptions.DatabaseError: something went wrong while running the process """ stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) proc = await asyncio.create_subprocess_exec( sys.executable, "-m", "alembic", *args, stdout=stdout, stderr=stderr, cwd=migration_path ) log_out, log_err = await proc.communicate() if proc.returncode != 0: msg = _( "Can't {operation} database (exit code {exit_code})" ).format( operation=args[0], exit_code=proc.returncode ) if log_out or log_err: msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" log.error(msg) raise exceptions.DatabaseError(msg) async def create_db(self, engine: AsyncEngine, db_config: dict) -> None: """Create a new database The database is generated from SQLAlchemy model, then stamped by Alembic """ # the dir may not exist if it's not the XDG recommended one db_config["path"].parent.mkdir(0o700, True, True) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) log.debug("stamping the database") await self.migrate_apply("stamp", "head") log.debug("stamping done") def _check_db_is_up_to_date(self, conn: Connection) -> bool: al_ini_path = migration_path / "alembic.ini" al_cfg = al_config.Config(al_ini_path) directory = al_script.ScriptDirectory.from_config(al_cfg) context = al_migration.MigrationContext.configure(conn) return set(context.get_current_heads()) == set(directory.get_heads()) def _sqlite_set_journal_mode_wal(self, conn: Connection) -> None: """Check if journal mode is WAL, and set it if necesssary""" result = conn.execute(text("PRAGMA journal_mode")) if result.scalar() != "wal": log.info("WAL mode not activated, activating it") conn.execute(text("PRAGMA journal_mode=WAL")) async def check_and_update_db(self, engine: AsyncEngine, db_config: dict) -> None: """Check that database is up-to-date, and update if necessary""" async with engine.connect() as conn: up_to_date = await conn.run_sync(self._check_db_is_up_to_date) if up_to_date: log.debug("Database is up-to-date") else: log.info("Database needs to be updated") log.info("updating…") await self.migrate_apply("upgrade", "head", log_output=True) log.info("Database is now up-to-date") @aio async def initialise(self) -> None: log.info(_("Connecting database")) db_config = sqla_config.get_db_config() engine = create_async_engine( db_config["url"], future=True, ) new_base = not db_config["path"].exists() if new_base: log.info(_("The database is new, creating the tables")) await self.create_db(engine, db_config) else: await self.check_and_update_db(engine, db_config) async with engine.connect() as conn: await conn.run_sync(self._sqlite_set_journal_mode_wal) self.session = sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) async with self.session() as session: result = await session.execute(select(Profile)) for p in result.scalars(): self.profiles[p.name] = p.id result = await session.execute(select(Component)) for c in result.scalars(): self.components[c.profile_id] = c.entry_point self.initialized.callback(None) ## Generic @aio async def get( self, client: SatXMPPEntity, db_cls: DeclarativeMeta, db_id_col: Mapped, id_value: Any, joined_loads = None ) -> Optional[DeclarativeMeta]: stmt = select(db_cls).where(db_id_col==id_value) if client is not None: stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) if joined_loads is not None: for joined_load in joined_loads: stmt = stmt.options(joinedload(joined_load)) async with self.session() as session: result = await session.execute(stmt) if joined_loads is not None: result = result.unique() return result.scalar_one_or_none() @aio async def add(self, db_obj: DeclarativeMeta) -> None: """Add an object to database""" async with self.session() as session: async with session.begin(): session.add(db_obj) @aio async def delete( self, db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]], session_add: Optional[List[DeclarativeMeta]] = None ) -> None: """Delete an object from database @param db_obj: object to delete or list of objects to delete @param session_add: other objects to add to session. This is useful when parents of deleted objects needs to be updated too, or if other objects needs to be updated in the same transaction. """ if not db_obj: return if not isinstance(db_obj, list): db_obj = [db_obj] async with self.session() as session: async with session.begin(): if session_add is not None: for obj in session_add: session.add(obj) for obj in db_obj: await session.delete(obj) await session.commit() ## Profiles def get_profiles_list(self) -> List[str]: """"Return list of all registered profiles""" return list(self.profiles.keys()) def has_profile(self, profile_name: str) -> bool: """return True if profile_name exists @param profile_name: name of the profile to check """ return profile_name in self.profiles def profile_is_component(self, profile_name: str) -> bool: try: return self.profiles[profile_name] in self.components except KeyError: raise exceptions.NotFound("the requested profile doesn't exists") def get_entry_point(self, profile_name: str) -> str: try: return self.components[self.profiles[profile_name]] except KeyError: raise exceptions.NotFound("the requested profile doesn't exists or is not a component") @aio async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None: """Create a new profile @param name: name of the profile @param component: if not None, must point to a component entry point """ async with self.session() as session: profile = Profile(name=name) async with session.begin(): session.add(profile) self.profiles[profile.name] = profile.id if component_ep is not None: async with session.begin(): component = Component(profile=profile, entry_point=component_ep) session.add(component) self.components[profile.id] = component_ep return profile @aio async def delete_profile(self, name: str) -> None: """Delete profile @param name: name of the profile """ async with self.session() as session: result = await session.execute(select(Profile).where(Profile.name == name)) profile = result.scalar() await session.delete(profile) await session.commit() del self.profiles[profile.name] if profile.id in self.components: del self.components[profile.id] log.info(_("Profile {name!r} deleted").format(name = name)) ## Params @aio async def load_gen_params(self, params_gen: dict) -> None: """Load general parameters @param params_gen: dictionary to fill """ log.debug(_("loading general parameters from database")) async with self.session() as session: result = await session.execute(select(ParamGen)) for p in result.scalars(): params_gen[(p.category, p.name)] = p.value @aio async def load_ind_params(self, params_ind: dict, profile: str) -> None: """Load individual parameters @param params_ind: dictionary to fill @param profile: a profile which *must* exist """ log.debug(_("loading individual parameters from database")) async with self.session() as session: result = await session.execute( select(ParamInd).where(ParamInd.profile_id == self.profiles[profile]) ) for p in result.scalars(): params_ind[(p.category, p.name)] = p.value @aio async def get_ind_param(self, category: str, name: str, profile: str) -> Optional[str]: """Ask database for the value of one specific individual parameter @param category: category of the parameter @param name: name of the parameter @param profile: %(doc_profile)s """ async with self.session() as session: result = await session.execute( select(ParamInd.value) .filter_by( category=category, name=name, profile_id=self.profiles[profile] ) ) return result.scalar_one_or_none() @aio async def get_ind_param_values(self, category: str, name: str) -> Dict[str, str]: """Ask database for the individual values of a parameter for all profiles @param category: category of the parameter @param name: name of the parameter @return dict: profile => value map """ async with self.session() as session: result = await session.execute( select(ParamInd) .filter_by( category=category, name=name ) .options(subqueryload(ParamInd.profile)) ) return {param.profile.name: param.value for param in result.scalars()} @aio async def set_gen_param(self, category: str, name: str, value: Optional[str]) -> None: """Save the general parameters in database @param category: category of the parameter @param name: name of the parameter @param value: value to set """ async with self.session() as session: stmt = insert(ParamGen).values( category=category, name=name, value=value ).on_conflict_do_update( index_elements=(ParamGen.category, ParamGen.name), set_={ ParamGen.value: value } ) await session.execute(stmt) await session.commit() @aio async def set_ind_param( self, category:str, name: str, value: Optional[str], profile: str ) -> None: """Save the individual parameters in database @param category: category of the parameter @param name: name of the parameter @param value: value to set @param profile: a profile which *must* exist """ async with self.session() as session: stmt = insert(ParamInd).values( category=category, name=name, profile_id=self.profiles[profile], value=value ).on_conflict_do_update( index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), set_={ ParamInd.value: value } ) await session.execute(stmt) await session.commit() def _jid_filter(self, jid_: jid.JID, dest: bool = False): """Generate condition to filter on a JID, using relevant columns @param dest: True if it's the destinee JID, otherwise it's the source one @param jid_: JID to filter by """ if jid_.resource: if dest: return and_( History.dest == jid_.userhost(), History.dest_res == jid_.resource ) else: return and_( History.source == jid_.userhost(), History.source_res == jid_.resource ) else: if dest: return History.dest == jid_.userhost() else: return History.source == jid_.userhost() @aio async def history_get( self, from_jid: Optional[jid.JID], to_jid: Optional[jid.JID], limit: Optional[int] = None, between: bool = True, filters: Optional[Dict[str, str]] = None, profile: Optional[str] = None, ) -> List[Tuple[ str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] ]: """Retrieve messages in history @param from_jid: source JID (full, or bare for catchall) @param to_jid: dest JID (full, or bare for catchall) @param limit: maximum number of messages to get: - 0 for no message (returns the empty list) - None for unlimited @param between: confound source and dest (ignore the direction) @param filters: pattern to filter the history results @return: list of messages as in [message_new], minus the profile which is already known. """ # we have to set a default value to profile because it's last argument # and thus follow other keyword arguments with default values # but None should not be used for it assert profile is not None if limit == 0: return [] if filters is None: filters = {} stmt = ( select(History) .filter_by( profile_id=self.profiles[profile] ) .outerjoin(History.messages) .outerjoin(History.subjects) .outerjoin(History.thread) .options( contains_eager(History.messages), contains_eager(History.subjects), contains_eager(History.thread), ) .order_by( # timestamp may be identical for 2 close messages (specially when delay is # used) that's why we order ties by received_timestamp. We'll reverse the # order when returning the result. We use DESC here so LIMIT keep the last # messages History.timestamp.desc(), History.received_timestamp.desc() ) ) if not from_jid and not to_jid: # no jid specified, we want all one2one communications pass elif between: if not from_jid or not to_jid: # we only have one jid specified, we check all messages # from or to this jid jid_ = from_jid or to_jid stmt = stmt.where( or_( self._jid_filter(jid_), self._jid_filter(jid_, dest=True) ) ) else: # we have 2 jids specified, we check all communications between # those 2 jids stmt = stmt.where( or_( and_( self._jid_filter(from_jid), self._jid_filter(to_jid, dest=True), ), and_( self._jid_filter(to_jid), self._jid_filter(from_jid, dest=True), ) ) ) else: # we want one communication in specific direction (from somebody or # to somebody). if from_jid is not None: stmt = stmt.where(self._jid_filter(from_jid)) if to_jid is not None: stmt = stmt.where(self._jid_filter(to_jid, dest=True)) if filters: if 'timestamp_start' in filters: stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) if 'before_uid' in filters: # orignially this query was using SQLITE's rowid. This has been changed # to use coalesce(received_timestamp, timestamp) to be SQL engine independant stmt = stmt.where( coalesce( History.received_timestamp, History.timestamp ) < ( select(coalesce(History.received_timestamp, History.timestamp)) .filter_by(uid=filters["before_uid"]) ).scalar_subquery() ) if 'body' in filters: # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) if 'search' in filters: search_term = f"%{filters['search']}%" stmt = stmt.where(or_( Message.message.like(search_term), History.source_res.like(search_term) )) if 'types' in filters: types = filters['types'].split() stmt = stmt.where(History.type.in_(types)) if 'not_types' in filters: types = filters['not_types'].split() stmt = stmt.where(History.type.not_in(types)) if 'last_stanza_id' in filters: # this request get the last message with a "stanza_id" that we # have in history. This is mainly used to retrieve messages sent # while we were offline, using MAM (XEP-0313). if (filters['last_stanza_id'] is not True or limit != 1): raise ValueError("Unexpected values for last_stanza_id filter") stmt = stmt.where(History.stanza_id.is_not(None)) if 'origin_id' in filters: stmt = stmt.where(History.origin_id == filters["origin_id"]) if limit is not None: stmt = stmt.limit(limit) async with self.session() as session: result = await session.execute(stmt) result = result.scalars().unique().all() result.reverse() return [h.as_tuple() for h in result] @aio async def add_to_history(self, data: dict, profile: str) -> None: """Store a new message in history @param data: message data as build by SatMessageProtocol.onMessage """ extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} messages = [Message(message=mess, language=lang) for lang, mess in data["message"].items()] subjects = [Subject(subject=mess, language=lang) for lang, mess in data["subject"].items()] if "thread" in data["extra"]: thread = Thread(thread_id=data["extra"]["thread"], parent_id=data["extra"].get["thread_parent"]) else: thread = None try: async with self.session() as session: async with session.begin(): session.add(History( uid=data["uid"], origin_id=data["extra"].get("origin_id"), stanza_id=data["extra"].get("stanza_id"), update_uid=data["extra"].get("update_uid"), profile_id=self.profiles[profile], source_jid=data["from"], dest_jid=data["to"], timestamp=data["timestamp"], received_timestamp=data.get("received_timestamp"), type=data["type"], extra=extra, messages=messages, subjects=subjects, thread=thread, )) except IntegrityError as e: if "unique" in str(e.orig).lower(): log.debug( f"message {data['uid']!r} is already in history, not storing it again" ) else: log.error(f"Can't store message {data['uid']!r} in history: {e}") except Exception as e: log.critical( f"Can't store message, unexpected exception (uid: {data['uid']}): {e}" ) ## Private values def _get_private_class(self, binary, profile): """Get ORM class to use for private values""" if profile is None: return PrivateGenBin if binary else PrivateGen else: return PrivateIndBin if binary else PrivateInd @aio async def get_privates( self, namespace:str, keys: Optional[Iterable[str]] = None, binary: bool = False, profile: Optional[str] = None ) -> Dict[str, Any]: """Get private value(s) from databases @param namespace: namespace of the values @param keys: keys of the values to get None to get all keys/values @param binary: True to deserialise binary values @param profile: profile to use for individual values None to use general values @return: gotten keys/values """ if keys is not None: keys = list(keys) log.debug( f"getting {'general' if profile is None else 'individual'}" f"{' binary' if binary else ''} private values from database for namespace " f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}" ) cls = self._get_private_class(binary, profile) stmt = select(cls).filter_by(namespace=namespace) if keys: stmt = stmt.where(cls.key.in_(list(keys))) if profile is not None: stmt = stmt.filter_by(profile_id=self.profiles[profile]) async with self.session() as session: result = await session.execute(stmt) return {p.key: p.value for p in result.scalars()} @aio async def set_private_value( self, namespace: str, key:str, value: Any, binary: bool = False, profile: Optional[str] = None ) -> None: """Set a private value in database @param namespace: namespace of the values @param key: key of the value to set @param value: value to set @param binary: True if it's a binary values binary values need to be serialised, used for everything but strings @param profile: profile to use for individual value if None, it's a general value """ cls = self._get_private_class(binary, profile) values = { "namespace": namespace, "key": key, "value": value } index_elements = [cls.namespace, cls.key] if profile is not None: values["profile_id"] = self.profiles[profile] index_elements.append(cls.profile_id) async with self.session() as session: await session.execute( insert(cls).values(**values).on_conflict_do_update( index_elements=index_elements, set_={ cls.value: value } ) ) await session.commit() @aio async def del_private_value( self, namespace: str, key: str, binary: bool = False, profile: Optional[str] = None ) -> None: """Delete private value from database @param category: category of the privateeter @param key: key of the private value @param binary: True if it's a binary values @param profile: profile to use for individual value if None, it's a general value """ cls = self._get_private_class(binary, profile) stmt = delete(cls).filter_by(namespace=namespace, key=key) if profile is not None: stmt = stmt.filter_by(profile_id=self.profiles[profile]) async with self.session() as session: await session.execute(stmt) await session.commit() @aio async def del_private_namespace( self, namespace: str, binary: bool = False, profile: Optional[str] = None ) -> None: """Delete all data from a private namespace Be really cautious when you use this method, as all data with given namespace are removed. Params are the same as for del_private_value """ cls = self._get_private_class(binary, profile) stmt = delete(cls).filter_by(namespace=namespace) if profile is not None: stmt = stmt.filter_by(profile_id=self.profiles[profile]) async with self.session() as session: await session.execute(stmt) await session.commit() ## Files @aio async def get_files( self, client: Optional[SatXMPPEntity], file_id: Optional[str] = None, version: Optional[str] = '', parent: Optional[str] = None, type_: Optional[str] = None, file_hash: Optional[str] = None, hash_algo: Optional[str] = None, name: Optional[str] = None, namespace: Optional[str] = None, mime_type: Optional[str] = None, public_id: Optional[str] = None, owner: Optional[jid.JID] = None, access: Optional[dict] = None, projection: Optional[List[str]] = None, unique: bool = False ) -> List[dict]: """Retrieve files with with given filters @param file_id: id of the file None to ignore @param version: version of the file None to ignore empty string to look for current version @param parent: id of the directory containing the files None to ignore empty string to look for root files/directories @param projection: name of columns to retrieve None to retrieve all @param unique: if True will remove duplicates other params are the same as for [set_file] @return: files corresponding to filters """ if projection is None: projection = [ 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', 'created', 'modified', 'owner', 'access', 'extra' ] stmt = select(*[getattr(File, f) for f in projection]) if unique: stmt = stmt.distinct() if client is not None: stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) else: if public_id is None: raise exceptions.InternalError( "client can only be omitted when public_id is set" ) if file_id is not None: stmt = stmt.filter_by(id=file_id) if version is not None: stmt = stmt.filter_by(version=version) if parent is not None: stmt = stmt.filter_by(parent=parent) if type_ is not None: stmt = stmt.filter_by(type=type_) if file_hash is not None: stmt = stmt.filter_by(file_hash=file_hash) if hash_algo is not None: stmt = stmt.filter_by(hash_algo=hash_algo) if name is not None: stmt = stmt.filter_by(name=name) if namespace is not None: stmt = stmt.filter_by(namespace=namespace) if mime_type is not None: if '/' in mime_type: media_type, media_subtype = mime_type.split("/", 1) stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) else: stmt = stmt.filter_by(media_type=mime_type) if public_id is not None: stmt = stmt.filter_by(public_id=public_id) if owner is not None: stmt = stmt.filter_by(owner=owner) if access is not None: raise NotImplementedError('Access check is not implemented yet') # a JSON comparison is needed here async with self.session() as session: result = await session.execute(stmt) return [dict(r) for r in result] @aio async def set_file( self, client: SatXMPPEntity, name: str, file_id: str, version: str = "", parent: str = "", type_: str = C.FILE_TYPE_FILE, file_hash: Optional[str] = None, hash_algo: Optional[str] = None, size: int = None, namespace: Optional[str] = None, mime_type: Optional[str] = None, public_id: Optional[str] = None, created: Optional[float] = None, modified: Optional[float] = None, owner: Optional[jid.JID] = None, access: Optional[dict] = None, extra: Optional[dict] = None ) -> None: """Set a file metadata @param client: client owning the file @param name: name of the file (must not contain "/") @param file_id: unique id of the file @param version: version of this file @param parent: id of the directory containing this file Empty string if it is a root file/directory @param type_: one of: - file - directory @param file_hash: unique hash of the payload @param hash_algo: algorithm used for hashing the file (usually sha-256) @param size: size in bytes @param namespace: identifier (human readable is better) to group files for instance, namespace could be used to group files in a specific photo album @param mime_type: media type of the file, or None if not known/guessed @param public_id: ID used to server the file publicly via HTTP @param created: UNIX time of creation @param modified: UNIX time of last modification, or None to use created date @param owner: jid of the owner of the file (mainly useful for component) @param access: serialisable dictionary with access rules. See [memory.memory] for details @param extra: serialisable dictionary of any extra data will be encoded to json in database """ if mime_type is None: media_type = media_subtype = None elif '/' in mime_type: media_type, media_subtype = mime_type.split('/', 1) else: media_type, media_subtype = mime_type, None async with self.session() as session: async with session.begin(): session.add(File( id=file_id, version=version.strip(), parent=parent, type=type_, file_hash=file_hash, hash_algo=hash_algo, name=name, size=size, namespace=namespace, media_type=media_type, media_subtype=media_subtype, public_id=public_id, created=time.time() if created is None else created, modified=modified, owner=owner, access=access, extra=extra, profile_id=self.profiles[client.profile] )) @aio async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int: async with self.session() as session: result = await session.execute( select(sum_(File.size)).filter_by( owner=owner, type=C.FILE_TYPE_FILE, profile_id=self.profiles[client.profile] )) return result.scalar_one_or_none() or 0 @aio async def file_delete(self, file_id: str) -> None: """Delete file metadata from the database @param file_id: id of the file to delete NOTE: file itself must still be removed, this method only handle metadata in database """ async with self.session() as session: await session.execute(delete(File).filter_by(id=file_id)) await session.commit() @aio async def file_update( self, file_id: str, column: str, update_cb: Callable[[dict], None] ) -> None: """Update a column value using a method to avoid race conditions the older value will be retrieved from database, then update_cb will be applied to update it, and file will be updated checking that older value has not been changed meanwhile by an other user. If it has changed, it tries again a couple of times before failing @param column: column name (only "access" or "extra" are allowed) @param update_cb: method to update the value of the colum the method will take older value as argument, and must update it in place update_cb must not care about serialization, it get the deserialized data (i.e. a Python object) directly @raise exceptions.NotFound: there is not file with this id """ if column not in ('access', 'extra'): raise exceptions.InternalError('bad column name') orm_col = getattr(File, column) for i in range(5): async with self.session() as session: try: value = (await session.execute( select(orm_col).filter_by(id=file_id) )).scalar_one() except NoResultFound: raise exceptions.NotFound old_value = copy.deepcopy(value) update_cb(value) stmt = update(File).filter_by(id=file_id).values({column: value}) if not old_value: # because JsonDefaultDict convert NULL to an empty dict, we have to # test both for empty dict and None when we have an empty dict stmt = stmt.where((orm_col == None) | (orm_col == old_value)) else: stmt = stmt.where(orm_col == old_value) result = await session.execute(stmt) await session.commit() if result.rowcount == 1: break log.warning( _("table not updated, probably due to race condition, trying again " "({tries})").format(tries=i+1) ) else: raise exceptions.DatabaseError( _("Can't update file {file_id} due to race condition") .format(file_id=file_id) ) @aio async def get_pubsub_node( self, client: SatXMPPEntity, service: jid.JID, name: str, with_items: bool = False, with_subscriptions: bool = False, create: bool = False, create_kwargs: Optional[dict] = None ) -> Optional[PubsubNode]: """Retrieve a PubsubNode from DB @param service: service hosting the node @param name: node's name @param with_items: retrieve items in the same query @param with_subscriptions: retrieve subscriptions in the same query @param create: if the node doesn't exist in DB, create it @param create_kwargs: keyword arguments to use with ``set_pubsub_node`` if the node needs to be created. """ async with self.session() as session: stmt = ( select(PubsubNode) .filter_by( service=service, name=name, profile_id=self.profiles[client.profile], ) ) if with_items: stmt = stmt.options( joinedload(PubsubNode.items) ) if with_subscriptions: stmt = stmt.options( joinedload(PubsubNode.subscriptions) ) result = await session.execute(stmt) ret = result.unique().scalar_one_or_none() if ret is None and create: # we auto-create the node if create_kwargs is None: create_kwargs = {} try: return await as_future(self.set_pubsub_node( client, service, name, **create_kwargs )) except IntegrityError as e: if "unique" in str(e.orig).lower(): # the node may already exist, if it has been created just after # get_pubsub_node above log.debug("ignoring UNIQUE constraint error") cached_node = await as_future(self.get_pubsub_node( client, service, name, with_items=with_items, with_subscriptions=with_subscriptions )) else: raise e else: return ret @aio async def set_pubsub_node( self, client: SatXMPPEntity, service: jid.JID, name: str, analyser: Optional[str] = None, type_: Optional[str] = None, subtype: Optional[str] = None, subscribed: bool = False, ) -> PubsubNode: node = PubsubNode( profile_id=self.profiles[client.profile], service=service, name=name, subscribed=subscribed, analyser=analyser, type_=type_, subtype=subtype, subscriptions=[], ) async with self.session() as session: async with session.begin(): session.add(node) return node @aio async def update_pubsub_node_sync_state( self, node: PubsubNode, state: SyncState ) -> None: async with self.session() as session: async with session.begin(): await session.execute( update(PubsubNode) .filter_by(id=node.id) .values( sync_state=state, sync_state_updated=time.time(), ) ) @aio async def delete_pubsub_node( self, profiles: Optional[List[str]], services: Optional[List[jid.JID]], names: Optional[List[str]] ) -> None: """Delete items cached for a node @param profiles: profile names from which nodes must be deleted. None to remove nodes from ALL profiles @param services: JIDs of pubsub services from which nodes must be deleted. None to remove nodes from ALL services @param names: names of nodes which must be deleted. None to remove ALL nodes whatever is their names """ stmt = delete(PubsubNode) if profiles is not None: stmt = stmt.where( PubsubNode.profile.in_( [self.profiles[p] for p in profiles] ) ) if services is not None: stmt = stmt.where(PubsubNode.service.in_(services)) if names is not None: stmt = stmt.where(PubsubNode.name.in_(names)) async with self.session() as session: await session.execute(stmt) await session.commit() @aio async def cache_pubsub_items( self, client: SatXMPPEntity, node: PubsubNode, items: List[domish.Element], parsed_items: Optional[List[dict]] = None, ) -> None: """Add items to database, using an upsert taking care of "updated" field""" if parsed_items is not None and len(items) != len(parsed_items): raise exceptions.InternalError( "parsed_items must have the same lenght as items" ) async with self.session() as session: async with session.begin(): for idx, item in enumerate(items): parsed = parsed_items[idx] if parsed_items else None stmt = insert(PubsubItem).values( node_id = node.id, name = item["id"], data = item, parsed = parsed, ).on_conflict_do_update( index_elements=(PubsubItem.node_id, PubsubItem.name), set_={ PubsubItem.data: item, PubsubItem.parsed: parsed, PubsubItem.updated: now() } ) await session.execute(stmt) await session.commit() @aio async def delete_pubsub_items( self, node: PubsubNode, items_names: Optional[List[str]] = None ) -> None: """Delete items cached for a node @param node: node from which items must be deleted @param items_names: names of items to delete if None, ALL items will be deleted """ stmt = delete(PubsubItem) if node is not None: if isinstance(node, list): stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node])) else: stmt = stmt.filter_by(node_id=node.id) if items_names is not None: stmt = stmt.where(PubsubItem.name.in_(items_names)) async with self.session() as session: await session.execute(stmt) await session.commit() @aio async def purge_pubsub_items( self, services: Optional[List[jid.JID]] = None, names: Optional[List[str]] = None, types: Optional[List[str]] = None, subtypes: Optional[List[str]] = None, profiles: Optional[List[str]] = None, created_before: Optional[datetime] = None, updated_before: Optional[datetime] = None, ) -> None: """Delete items cached for a node @param node: node from which items must be deleted @param items_names: names of items to delete if None, ALL items will be deleted """ stmt = delete(PubsubItem) node_fields = { "service": services, "name": names, "type_": types, "subtype": subtypes, } if profiles is not None: node_fields["profile_id"] = [self.profiles[p] for p in profiles] if any(x is not None for x in node_fields.values()): sub_q = select(PubsubNode.id) for col, values in node_fields.items(): if values is None: continue sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) stmt = ( stmt .where(PubsubItem.node_id.in_(sub_q)) .execution_options(synchronize_session=False) ) if created_before is not None: stmt = stmt.where(PubsubItem.created < created_before) if updated_before is not None: stmt = stmt.where(PubsubItem.updated < updated_before) async with self.session() as session: await session.execute(stmt) await session.commit() @aio async def get_items( self, node: PubsubNode, max_items: Optional[int] = None, item_ids: Optional[list[str]] = None, before: Optional[str] = None, after: Optional[str] = None, from_index: Optional[int] = None, order_by: Optional[List[str]] = None, desc: bool = True, force_rsm: bool = False, ) -> Tuple[List[PubsubItem], dict]: """Get Pubsub Items from cache @param node: retrieve items from this node (must be synchronised) @param max_items: maximum number of items to retrieve @param before: get items which are before the item with this name in given order empty string is not managed here, use desc order to reproduce RSM behaviour. @param after: get items which are after the item with this name in given order @param from_index: get items with item index (as defined in RSM spec) starting from this number @param order_by: sorting order of items (one of C.ORDER_BY_*) @param desc: direction or ordering @param force_rsm: if True, force the use of RSM worklow. RSM workflow is automatically used if any of before, after or from_index is used, but if only RSM max_items is used, it won't be used by default. This parameter let's use RSM workflow in this case. Note that in addition to RSM metadata, the result will not be the same (max_items without RSM will returns most recent items, i.e. last items in modification order, while max_items with RSM will return the oldest ones (i.e. first items in modification order). to be used when max_items is used from RSM """ metadata = { "service": node.service, "node": node.name, "uri": uri.build_xmpp_uri( "pubsub", path=node.service.full(), node=node.name, ), } if max_items is None: max_items = 20 use_rsm = any((before, after, from_index is not None)) if force_rsm and not use_rsm: # use_rsm = True from_index = 0 stmt = ( select(PubsubItem) .filter_by(node_id=node.id) .limit(max_items) ) if item_ids is not None: stmt = stmt.where(PubsubItem.name.in_(item_ids)) if not order_by: order_by = [C.ORDER_BY_MODIFICATION] order = [] for order_type in order_by: if order_type == C.ORDER_BY_MODIFICATION: if desc: order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc())) else: order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc())) elif order_type == C.ORDER_BY_CREATION: if desc: order.append(PubsubItem.id.desc()) else: order.append(PubsubItem.id.asc()) else: raise exceptions.InternalError(f"Unknown order type {order_type!r}") stmt = stmt.order_by(*order) if use_rsm: # CTE to have result row numbers row_num_q = select( PubsubItem.id, PubsubItem.name, # row_number starts from 1, but RSM index must start from 0 (func.row_number().over(order_by=order)-1).label("item_index") ).filter_by(node_id=node.id) row_num_cte = row_num_q.cte() if max_items > 0: # as we can't simply use PubsubItem.id when we order by modification, # we need to use row number item_name = before or after row_num_limit_q = ( select(row_num_cte.c.item_index) .where(row_num_cte.c.name==item_name) ).scalar_subquery() stmt = ( select(row_num_cte.c.item_index, PubsubItem) .join(row_num_cte, PubsubItem.id == row_num_cte.c.id) .limit(max_items) ) if before: stmt = ( stmt .where(row_num_cte.c.item_index<row_num_limit_q) .order_by(row_num_cte.c.item_index.desc()) ) elif after: stmt = ( stmt .where(row_num_cte.c.item_index>row_num_limit_q) .order_by(row_num_cte.c.item_index.asc()) ) else: stmt = ( stmt .where(row_num_cte.c.item_index>=from_index) .order_by(row_num_cte.c.item_index.asc()) ) # from_index is used async with self.session() as session: if max_items == 0: items = result = [] else: result = await session.execute(stmt) result = result.all() if before: result.reverse() items = [row[-1] for row in result] rows_count = ( await session.execute(row_num_q.with_only_columns(count())) ).scalar_one() try: index = result[0][0] except IndexError: index = None try: first = result[0][1].name except IndexError: first = None last = None else: last = result[-1][1].name metadata["rsm"] = { k: v for k, v in { "index": index, "count": rows_count, "first": first, "last": last, }.items() if v is not None } metadata["complete"] = (index or 0) + len(result) == rows_count return items, metadata async with self.session() as session: result = await session.execute(stmt) result = result.scalars().all() if desc: result.reverse() return result, metadata def _get_sqlite_path( self, path: List[Union[str, int]] ) -> str: """generate path suitable to query JSON element with SQLite""" return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}" @aio async def search_pubsub_items( self, query: dict, ) -> Tuple[List[PubsubItem]]: """Search for pubsub items in cache @param query: search terms. Keys can be: :fts (str): Full-Text Search query. Currently SQLite FT5 engine is used, its query syntax can be used, see `FTS5 Query documentation <https://sqlite.org/fts5.html#full_text_query_syntax>`_ :profiles (list[str]): filter on nodes linked to those profiles :nodes (list[str]): filter on nodes with those names :services (list[jid.JID]): filter on nodes from those services :types (list[str|None]): filter on nodes with those types. None can be used to filter on nodes with no type set :subtypes (list[str|None]): filter on nodes with those subtypes. None can be used to filter on nodes with no subtype set :names (list[str]): filter on items with those names :parsed (list[dict]): Filter on a parsed data field. The dict must contain 3 keys: ``path`` which is a list of str or int giving the path to the field of interest (str for a dict key, int for a list index), ``operator`` with indicate the operator to use to check the condition, and ``value`` which depends of field type and operator. See documentation for details on operators (it's currently explained at ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command documentation). :order-by (list[dict]): Indicates how to order results. The dict can contain either a ``order`` for a well-know order or a ``path`` for a parsed data field path (``order`` and ``path`` can't be used at the same time), an an optional ``direction`` which can be ``asc`` or ``desc``. See documentation for details on well-known orders (it's currently explained at ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command documentation). :index (int): starting index of items to return from the query result. It's translated to SQL's OFFSET :limit (int): maximum number of items to return. It's translated to SQL's LIMIT. @result: found items (the ``node`` attribute will be filled with suitable PubsubNode) """ # TODO: FTS and parsed data filters use SQLite specific syntax # when other DB engines will be used, this will have to be adapted stmt = select(PubsubItem) # Full-Text Search fts = query.get("fts") if fts: fts_select = text( "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)" ).bindparams(fts_query=fts).columns(rowid=Integer).subquery() stmt = ( stmt .select_from(fts_select) .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id) ) # node related filters profiles = query.get("profiles") if (profiles or any(query.get(k) for k in ("nodes", "services", "types", "subtypes")) ): stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node)) if profiles: try: stmt = stmt.where( PubsubNode.profile_id.in_(self.profiles[p] for p in profiles) ) except KeyError as e: raise exceptions.ProfileUnknownError( f"This profile doesn't exist: {e.args[0]!r}" ) for key, attr in ( ("nodes", "name"), ("services", "service"), ("types", "type_"), ("subtypes", "subtype") ): value = query.get(key) if not value: continue if key in ("types", "subtypes") and None in value: # NULL can't be used with SQL's IN, so we have to add a condition with # IS NULL, and use a OR if there are other values to check value.remove(None) condition = getattr(PubsubNode, attr).is_(None) if value: condition = or_( getattr(PubsubNode, attr).in_(value), condition ) else: condition = getattr(PubsubNode, attr).in_(value) stmt = stmt.where(condition) else: stmt = stmt.options(selectinload(PubsubItem.node)) # names names = query.get("names") if names: stmt = stmt.where(PubsubItem.name.in_(names)) # parsed data filters parsed = query.get("parsed", []) for filter_ in parsed: try: path = filter_["path"] operator = filter_["op"] value = filter_["value"] except KeyError as e: raise ValueError( f'missing mandatory key {e.args[0]!r} in "parsed" filter' ) try: op_attr = OP_MAP[operator] except KeyError: raise ValueError(f"invalid operator: {operator!r}") sqlite_path = self._get_sqlite_path(path) if operator in ("overlap", "ioverlap", "disjoint", "idisjoint"): col = literal_column("json_each.value") if operator[0] == "i": col = func.lower(col) value = [str(v).lower() for v in value] condition = ( select(1) .select_from(func.json_each(PubsubItem.parsed, sqlite_path)) .where(col.in_(value)) ).scalar_subquery() if operator in ("disjoint", "idisjoint"): condition = condition.is_(None) stmt = stmt.where(condition) elif operator == "between": try: left, right = value except (ValueError, TypeError): raise ValueError(_( "invalid value for \"between\" filter, you must use a 2 items " "array: {value!r}" ).format(value=value)) col = func.json_extract(PubsubItem.parsed, sqlite_path) stmt = stmt.where(col.between(left, right)) else: # we use func.json_extract instead of generic JSON way because SQLAlchemy # add a JSON_QUOTE to the value, and we want SQL value col = func.json_extract(PubsubItem.parsed, sqlite_path) stmt = stmt.where(getattr(col, op_attr)(value)) # order order_by = query.get("order-by") or [{"order": "creation"}] for order_data in order_by: order, path = order_data.get("order"), order_data.get("path") if order and path: raise ValueError(_( '"order" and "path" can\'t be used at the same time in ' '"order-by" data' )) if order: if order == "creation": col = PubsubItem.id elif order == "modification": col = PubsubItem.updated elif order == "item_id": col = PubsubItem.name elif order == "rank": if not fts: raise ValueError( "'rank' order can only be used with Full-Text Search (fts)" ) col = literal_column("rank") else: raise NotImplementedError(f"Unknown {order!r} order") else: # we have a JSON path # sqlite_path = self._get_sqlite_path(path) col = PubsubItem.parsed[path] direction = order_data.get("direction", "ASC").lower() if not direction in ("asc", "desc"): raise ValueError(f"Invalid order-by direction: {direction!r}") stmt = stmt.order_by(getattr(col, direction)()) # offset, limit index = query.get("index") if index: stmt = stmt.offset(index) limit = query.get("limit") if limit: stmt = stmt.limit(limit) async with self.session() as session: result = await session.execute(stmt) return result.scalars().all()