Mercurial > libervia-backend
view sat/memory/sqla.py @ 3591:d830c11eeef3
plugin XEP-0277: ignore `max_items` if `rsm_request` is set
author | Goffi <goffi@goffi.org> |
---|---|
date | Thu, 29 Jul 2021 21:28:48 +0200 |
parents | 16ade4ad63f3 |
children | 7510648e8e3a |
line wrap: on
line source
#!/usr/bin/env python3 # Libervia: an XMPP client # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import sys import time import asyncio from asyncio.subprocess import PIPE from pathlib import Path from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager from sqlalchemy.future import select from sqlalchemy.engine import Engine, Connection from sqlalchemy import update, delete, and_, or_, event from sqlalchemy.sql.functions import coalesce, sum as sum_ from sqlalchemy.dialects.sqlite import insert from alembic import script as al_script, config as al_config from alembic.runtime import migration as al_migration from twisted.internet import defer from twisted.words.protocols.jabber import jid from sat.core.i18n import _ from sat.core import exceptions from sat.core.log import getLogger from sat.core.constants import Const as C from sat.core.core_types import SatXMPPEntity from sat.tools.utils import aio from sat.memory import migration from sat.memory import sqla_config from sat.memory.sqla_mapping import ( NOT_IN_EXTRA, Base, Profile, Component, History, Message, Subject, Thread, ParamGen, ParamInd, PrivateGen, PrivateInd, PrivateGenBin, PrivateIndBin, File ) log = getLogger(__name__) migration_path = Path(migration.__file__).parent @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() class Storage: def __init__(self): self.initialized = defer.Deferred() # we keep cache for the profiles (key: profile name, value: profile id) # profile id to name self.profiles: Dict[int, str] = {} # profile id to component entry point self.components: Dict[int, str] = {} async def migrateApply(self, *args: str, log_output: bool = False) -> None: """Do a migration command Commands are applied by running Alembic in a subprocess. Arguments are alembic executables commands @param log_output: manage stdout and stderr: - if False, stdout and stderr are buffered, and logged only in case of error - if True, stdout and stderr will be logged during the command execution @raise exceptions.DatabaseError: something went wrong while running the process """ stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) proc = await asyncio.create_subprocess_exec( sys.executable, "-m", "alembic", *args, stdout=stdout, stderr=stderr, cwd=migration_path ) log_out, log_err = await proc.communicate() if proc.returncode != 0: msg = _( "Can't {operation} database (exit code {exit_code})" ).format( operation=args[0], exit_code=proc.returncode ) if log_out or log_err: msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" log.error(msg) raise exceptions.DatabaseError(msg) async def createDB(self, engine: AsyncEngine, db_config: dict) -> None: """Create a new database The database is generated from SQLAlchemy model, then stamped by Alembic """ # the dir may not exist if it's not the XDG recommended one db_config["path"].parent.mkdir(0o700, True, True) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) log.debug("stamping the database") await self.migrateApply("stamp", "head") log.debug("stamping done") def _checkDBIsUpToDate(self, conn: Connection) -> bool: al_ini_path = migration_path / "alembic.ini" al_cfg = al_config.Config(al_ini_path) directory = al_script.ScriptDirectory.from_config(al_cfg) context = al_migration.MigrationContext.configure(conn) return set(context.get_current_heads()) == set(directory.get_heads()) async def checkAndUpdateDB(self, engine: AsyncEngine, db_config: dict) -> None: """Check that database is up-to-date, and update if necessary""" async with engine.connect() as conn: up_to_date = await conn.run_sync(self._checkDBIsUpToDate) if up_to_date: log.debug("Database is up-to-date") else: log.info("Database needs to be updated") log.info("updating…") await self.migrateApply("upgrade", "head", log_output=True) log.info("Database is now up-to-date") @aio async def initialise(self) -> None: log.info(_("Connecting database")) db_config = sqla_config.getDbConfig() engine = create_async_engine( db_config["url"], future=True ) new_base = not db_config["path"].exists() if new_base: log.info(_("The database is new, creating the tables")) await self.createDB(engine, db_config) else: await self.checkAndUpdateDB(engine, db_config) self.session = sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) async with self.session() as session: result = await session.execute(select(Profile)) for p in result.scalars(): self.profiles[p.name] = p.id result = await session.execute(select(Component)) for c in result.scalars(): self.components[c.profile_id] = c.entry_point self.initialized.callback(None) ## Profiles def getProfilesList(self) -> List[str]: """"Return list of all registered profiles""" return list(self.profiles.keys()) def hasProfile(self, profile_name: str) -> bool: """return True if profile_name exists @param profile_name: name of the profile to check """ return profile_name in self.profiles def profileIsComponent(self, profile_name: str) -> bool: try: return self.profiles[profile_name] in self.components except KeyError: raise exceptions.NotFound("the requested profile doesn't exists") def getEntryPoint(self, profile_name: str) -> str: try: return self.components[self.profiles[profile_name]] except KeyError: raise exceptions.NotFound("the requested profile doesn't exists or is not a component") @aio async def createProfile(self, name: str, component_ep: Optional[str] = None) -> None: """Create a new profile @param name: name of the profile @param component: if not None, must point to a component entry point """ async with self.session() as session: profile = Profile(name=name) async with session.begin(): session.add(profile) self.profiles[profile.id] = profile.name if component_ep is not None: async with session.begin(): component = Component(profile=profile, entry_point=component_ep) session.add(component) self.components[profile.id] = component_ep return profile @aio async def deleteProfile(self, name: str) -> None: """Delete profile @param name: name of the profile """ async with self.session() as session: result = await session.execute(select(Profile).where(Profile.name == name)) profile = result.scalar() await session.delete(profile) await session.commit() del self.profiles[profile.id] if profile.id in self.components: del self.components[profile.id] log.info(_("Profile {name!r} deleted").format(name = name)) ## Params @aio async def loadGenParams(self, params_gen: dict) -> None: """Load general parameters @param params_gen: dictionary to fill """ log.debug(_("loading general parameters from database")) async with self.session() as session: result = await session.execute(select(ParamGen)) for p in result.scalars(): params_gen[(p.category, p.name)] = p.value @aio async def loadIndParams(self, params_ind: dict, profile: str) -> None: """Load individual parameters @param params_ind: dictionary to fill @param profile: a profile which *must* exist """ log.debug(_("loading individual parameters from database")) async with self.session() as session: result = await session.execute( select(ParamInd).where(ParamInd.profile_id == self.profiles[profile]) ) for p in result.scalars(): params_ind[(p.category, p.name)] = p.value @aio async def getIndParam(self, category: str, name: str, profile: str) -> Optional[str]: """Ask database for the value of one specific individual parameter @param category: category of the parameter @param name: name of the parameter @param profile: %(doc_profile)s """ async with self.session() as session: result = await session.execute( select(ParamInd.value) .filter_by( category=category, name=name, profile_id=self.profiles[profile] ) ) return result.scalar_one_or_none() @aio async def getIndParamValues(self, category: str, name: str) -> Dict[str, str]: """Ask database for the individual values of a parameter for all profiles @param category: category of the parameter @param name: name of the parameter @return dict: profile => value map """ async with self.session() as session: result = await session.execute( select(ParamInd) .filter_by( category=category, name=name ) .options(subqueryload(ParamInd.profile)) ) return {param.profile.name: param.value for param in result.scalars()} @aio async def setGenParam(self, category: str, name: str, value: Optional[str]) -> None: """Save the general parameters in database @param category: category of the parameter @param name: name of the parameter @param value: value to set """ async with self.session() as session: stmt = insert(ParamGen).values( category=category, name=name, value=value ).on_conflict_do_update( index_elements=(ParamGen.category, ParamGen.name), set_={ ParamGen.value: value } ) await session.execute(stmt) await session.commit() @aio async def setIndParam( self, category:str, name: str, value: Optional[str], profile: str ) -> None: """Save the individual parameters in database @param category: category of the parameter @param name: name of the parameter @param value: value to set @param profile: a profile which *must* exist """ async with self.session() as session: stmt = insert(ParamInd).values( category=category, name=name, profile_id=self.profiles[profile], value=value ).on_conflict_do_update( index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), set_={ ParamInd.value: value } ) await session.execute(stmt) await session.commit() def _jid_filter(self, jid_: jid.JID, dest: bool = False): """Generate condition to filter on a JID, using relevant columns @param dest: True if it's the destinee JID, otherwise it's the source one @param jid_: JID to filter by """ if jid_.resource: if dest: return and_( History.dest == jid_.userhost(), History.dest_res == jid_.resource ) else: return and_( History.source == jid_.userhost(), History.source_res == jid_.resource ) else: if dest: return History.dest == jid_.userhost() else: return History.source == jid_.userhost() @aio async def historyGet( self, from_jid: Optional[jid.JID], to_jid: Optional[jid.JID], limit: Optional[int] = None, between: bool = True, filters: Optional[Dict[str, str]] = None, profile: Optional[str] = None, ) -> List[Tuple[ str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] ]: """Retrieve messages in history @param from_jid: source JID (full, or bare for catchall) @param to_jid: dest JID (full, or bare for catchall) @param limit: maximum number of messages to get: - 0 for no message (returns the empty list) - None for unlimited @param between: confound source and dest (ignore the direction) @param filters: pattern to filter the history results @return: list of messages as in [messageNew], minus the profile which is already known. """ # we have to set a default value to profile because it's last argument # and thus follow other keyword arguments with default values # but None should not be used for it assert profile is not None if limit == 0: return [] if filters is None: filters = {} stmt = ( select(History) .filter_by( profile_id=self.profiles[profile] ) .outerjoin(History.messages) .outerjoin(History.subjects) .outerjoin(History.thread) .options( contains_eager(History.messages), contains_eager(History.subjects), contains_eager(History.thread), ) .order_by( # timestamp may be identical for 2 close messages (specially when delay is # used) that's why we order ties by received_timestamp. We'll reverse the # order when returning the result. We use DESC here so LIMIT keep the last # messages History.timestamp.desc(), History.received_timestamp.desc() ) ) if not from_jid and not to_jid: # no jid specified, we want all one2one communications pass elif between: if not from_jid or not to_jid: # we only have one jid specified, we check all messages # from or to this jid jid_ = from_jid or to_jid stmt = stmt.where( or_( self._jid_filter(jid_), self._jid_filter(jid_, dest=True) ) ) else: # we have 2 jids specified, we check all communications between # those 2 jids stmt = stmt.where( or_( and_( self._jid_filter(from_jid), self._jid_filter(to_jid, dest=True), ), and_( self._jid_filter(to_jid), self._jid_filter(from_jid, dest=True), ) ) ) else: # we want one communication in specific direction (from somebody or # to somebody). if from_jid is not None: stmt = stmt.where(self._jid_filter(from_jid)) if to_jid is not None: stmt = stmt.where(self._jid_filter(to_jid, dest=True)) if filters: if 'timestamp_start' in filters: stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) if 'before_uid' in filters: # orignially this query was using SQLITE's rowid. This has been changed # to use coalesce(received_timestamp, timestamp) to be SQL engine independant stmt = stmt.where( coalesce( History.received_timestamp, History.timestamp ) < ( select(coalesce(History.received_timestamp, History.timestamp)) .filter_by(uid=filters["before_uid"]) ).scalar_subquery() ) if 'body' in filters: # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) if 'search' in filters: search_term = f"%{filters['search']}%" stmt = stmt.where(or_( Message.message.like(search_term), History.source_res.like(search_term) )) if 'types' in filters: types = filters['types'].split() stmt = stmt.where(History.type.in_(types)) if 'not_types' in filters: types = filters['not_types'].split() stmt = stmt.where(History.type.not_in(types)) if 'last_stanza_id' in filters: # this request get the last message with a "stanza_id" that we # have in history. This is mainly used to retrieve messages sent # while we were offline, using MAM (XEP-0313). if (filters['last_stanza_id'] is not True or limit != 1): raise ValueError("Unexpected values for last_stanza_id filter") stmt = stmt.where(History.stanza_id.is_not(None)) if limit is not None: stmt = stmt.limit(limit) async with self.session() as session: result = await session.execute(stmt) result = result.scalars().unique().all() result.reverse() return [h.as_tuple() for h in result] @aio async def addToHistory(self, data: dict, profile: str) -> None: """Store a new message in history @param data: message data as build by SatMessageProtocol.onMessage """ extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} messages = [Message(message=mess, language=lang) for lang, mess in data["message"].items()] subjects = [Subject(subject=mess, language=lang) for lang, mess in data["subject"].items()] if "thread" in data["extra"]: thread = Thread(thread_id=data["extra"]["thread"], parent_id=data["extra"].get["thread_parent"]) else: thread = None try: async with self.session() as session: async with session.begin(): session.add(History( uid=data["uid"], stanza_id=data["extra"].get("stanza_id"), update_uid=data["extra"].get("update_uid"), profile_id=self.profiles[profile], source_jid=data["from"], dest_jid=data["to"], timestamp=data["timestamp"], received_timestamp=data.get("received_timestamp"), type=data["type"], extra=extra, messages=messages, subjects=subjects, thread=thread, )) except IntegrityError as e: if "unique" in str(e.orig).lower(): log.debug( f"message {data['uid']!r} is already in history, not storing it again" ) else: log.error(f"Can't store message {data['uid']!r} in history: {e}") except Exception as e: log.critical( f"Can't store message, unexpected exception (uid: {data['uid']}): {e}" ) ## Private values def _getPrivateClass(self, binary, profile): """Get ORM class to use for private values""" if profile is None: return PrivateGenBin if binary else PrivateGen else: return PrivateIndBin if binary else PrivateInd @aio async def getPrivates( self, namespace:str, keys: Optional[Iterable[str]] = None, binary: bool = False, profile: Optional[str] = None ) -> Dict[str, Any]: """Get private value(s) from databases @param namespace: namespace of the values @param keys: keys of the values to get None to get all keys/values @param binary: True to deserialise binary values @param profile: profile to use for individual values None to use general values @return: gotten keys/values """ if keys is not None: keys = list(keys) log.debug( f"getting {'general' if profile is None else 'individual'}" f"{' binary' if binary else ''} private values from database for namespace " f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}" ) cls = self._getPrivateClass(binary, profile) stmt = select(cls).filter_by(namespace=namespace) if keys: stmt = stmt.where(cls.key.in_(list(keys))) if profile is not None: stmt = stmt.filter_by(profile_id=self.profiles[profile]) async with self.session() as session: result = await session.execute(stmt) return {p.key: p.value for p in result.scalars()} @aio async def setPrivateValue( self, namespace: str, key:str, value: Any, binary: bool = False, profile: Optional[str] = None ) -> None: """Set a private value in database @param namespace: namespace of the values @param key: key of the value to set @param value: value to set @param binary: True if it's a binary values binary values need to be serialised, used for everything but strings @param profile: profile to use for individual value if None, it's a general value """ cls = self._getPrivateClass(binary, profile) values = { "namespace": namespace, "key": key, "value": value } index_elements = [cls.namespace, cls.key] if profile is not None: values["profile_id"] = self.profiles[profile] index_elements.append(cls.profile_id) async with self.session() as session: await session.execute( insert(cls).values(**values).on_conflict_do_update( index_elements=index_elements, set_={ cls.value: value } ) ) await session.commit() @aio async def delPrivateValue( self, namespace: str, key: str, binary: bool = False, profile: Optional[str] = None ) -> None: """Delete private value from database @param category: category of the privateeter @param key: key of the private value @param binary: True if it's a binary values @param profile: profile to use for individual value if None, it's a general value """ cls = self._getPrivateClass(binary, profile) stmt = delete(cls).filter_by(namespace=namespace, key=key) if profile is not None: stmt = stmt.filter_by(profile_id=self.profiles[profile]) async with self.session() as session: await session.execute(stmt) await session.commit() @aio async def delPrivateNamespace( self, namespace: str, binary: bool = False, profile: Optional[str] = None ) -> None: """Delete all data from a private namespace Be really cautious when you use this method, as all data with given namespace are removed. Params are the same as for delPrivateValue """ cls = self._getPrivateClass(binary, profile) stmt = delete(cls).filter_by(namespace=namespace) if profile is not None: stmt = stmt.filter_by(profile_id=self.profiles[profile]) async with self.session() as session: await session.execute(stmt) await session.commit() ## Files @aio async def getFiles( self, client: Optional[SatXMPPEntity], file_id: Optional[str] = None, version: Optional[str] = '', parent: Optional[str] = None, type_: Optional[str] = None, file_hash: Optional[str] = None, hash_algo: Optional[str] = None, name: Optional[str] = None, namespace: Optional[str] = None, mime_type: Optional[str] = None, public_id: Optional[str] = None, owner: Optional[jid.JID] = None, access: Optional[dict] = None, projection: Optional[List[str]] = None, unique: bool = False ) -> List[dict]: """Retrieve files with with given filters @param file_id: id of the file None to ignore @param version: version of the file None to ignore empty string to look for current version @param parent: id of the directory containing the files None to ignore empty string to look for root files/directories @param projection: name of columns to retrieve None to retrieve all @param unique: if True will remove duplicates other params are the same as for [setFile] @return: files corresponding to filters """ if projection is None: projection = [ 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', 'created', 'modified', 'owner', 'access', 'extra' ] stmt = select(*[getattr(File, f) for f in projection]) if unique: stmt = stmt.distinct() if client is not None: stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) else: if public_id is None: raise exceptions.InternalError( "client can only be omitted when public_id is set" ) if file_id is not None: stmt = stmt.filter_by(id=file_id) if version is not None: stmt = stmt.filter_by(version=version) if parent is not None: stmt = stmt.filter_by(parent=parent) if type_ is not None: stmt = stmt.filter_by(type=type_) if file_hash is not None: stmt = stmt.filter_by(file_hash=file_hash) if hash_algo is not None: stmt = stmt.filter_by(hash_algo=hash_algo) if name is not None: stmt = stmt.filter_by(name=name) if namespace is not None: stmt = stmt.filter_by(namespace=namespace) if mime_type is not None: if '/' in mime_type: media_type, media_subtype = mime_type.split("/", 1) stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) else: stmt = stmt.filter_by(media_type=mime_type) if public_id is not None: stmt = stmt.filter_by(public_id=public_id) if owner is not None: stmt = stmt.filter_by(owner=owner) if access is not None: raise NotImplementedError('Access check is not implemented yet') # a JSON comparison is needed here async with self.session() as session: result = await session.execute(stmt) return [dict(r) for r in result] @aio async def setFile( self, client: SatXMPPEntity, name: str, file_id: str, version: str = "", parent: str = "", type_: str = C.FILE_TYPE_FILE, file_hash: Optional[str] = None, hash_algo: Optional[str] = None, size: int = None, namespace: Optional[str] = None, mime_type: Optional[str] = None, public_id: Optional[str] = None, created: Optional[float] = None, modified: Optional[float] = None, owner: Optional[jid.JID] = None, access: Optional[dict] = None, extra: Optional[dict] = None ) -> None: """Set a file metadata @param client: client owning the file @param name: name of the file (must not contain "/") @param file_id: unique id of the file @param version: version of this file @param parent: id of the directory containing this file Empty string if it is a root file/directory @param type_: one of: - file - directory @param file_hash: unique hash of the payload @param hash_algo: algorithm used for hashing the file (usually sha-256) @param size: size in bytes @param namespace: identifier (human readable is better) to group files for instance, namespace could be used to group files in a specific photo album @param mime_type: media type of the file, or None if not known/guessed @param public_id: ID used to server the file publicly via HTTP @param created: UNIX time of creation @param modified: UNIX time of last modification, or None to use created date @param owner: jid of the owner of the file (mainly useful for component) @param access: serialisable dictionary with access rules. See [memory.memory] for details @param extra: serialisable dictionary of any extra data will be encoded to json in database """ if mime_type is None: media_type = media_subtype = None elif '/' in mime_type: media_type, media_subtype = mime_type.split('/', 1) else: media_type, media_subtype = mime_type, None async with self.session() as session: async with session.begin(): session.add(File( id=file_id, version=version.strip(), parent=parent, type=type_, file_hash=file_hash, hash_algo=hash_algo, name=name, size=size, namespace=namespace, media_type=media_type, media_subtype=media_subtype, public_id=public_id, created=time.time() if created is None else created, modified=modified, owner=owner, access=access, extra=extra, profile_id=self.profiles[client.profile] )) @aio async def fileGetUsedSpace(self, client: SatXMPPEntity, owner: jid.JID) -> int: async with self.session() as session: result = await session.execute( select(sum_(File.size)).filter_by( owner=owner, type=C.FILE_TYPE_FILE, profile_id=self.profiles[client.profile] )) return result.scalar_one_or_none() or 0 @aio async def fileDelete(self, file_id: str) -> None: """Delete file metadata from the database @param file_id: id of the file to delete NOTE: file itself must still be removed, this method only handle metadata in database """ async with self.session() as session: await session.execute(delete(File).filter_by(id=file_id)) await session.commit() @aio async def fileUpdate( self, file_id: str, column: str, update_cb: Callable[[dict], None] ) -> None: """Update a column value using a method to avoid race conditions the older value will be retrieved from database, then update_cb will be applied to update it, and file will be updated checking that older value has not been changed meanwhile by an other user. If it has changed, it tries again a couple of times before failing @param column: column name (only "access" or "extra" are allowed) @param update_cb: method to update the value of the colum the method will take older value as argument, and must update it in place update_cb must not care about serialization, it get the deserialized data (i.e. a Python object) directly @raise exceptions.NotFound: there is not file with this id """ if column not in ('access', 'extra'): raise exceptions.InternalError('bad column name') orm_col = getattr(File, column) for i in range(5): async with self.session() as session: try: value = (await session.execute( select(orm_col).filter_by(id=file_id) )).scalar_one() except NoResultFound: raise exceptions.NotFound update_cb(value) stmt = update(orm_col).filter_by(id=file_id) if not value: # because JsonDefaultDict convert NULL to an empty dict, we have to # test both for empty dict and None when we have and empty dict stmt = stmt.where((orm_col == None) | (orm_col == value)) else: stmt = stmt.where(orm_col == value) result = await session.execute(stmt) await session.commit() if result.rowcount == 1: break log.warning( _("table not updated, probably due to race condition, trying again " "({tries})").format(tries=i+1) ) else: raise exceptions.DatabaseError( _("Can't update file {file_id} due to race condition") .format(file_id=file_id) )