Mercurial > libervia-backend
diff sat/memory/sqla.py @ 3537:f9a5b810f14d
core (memory/storage): backend storage is now based on SQLAlchemy
author | Goffi <goffi@goffi.org> |
---|---|
date | Thu, 03 Jun 2021 15:20:47 +0200 |
parents | |
children | 71516731d0aa |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/sqla.py Thu Jun 03 15:20:47 2021 +0200 @@ -0,0 +1,881 @@ +#!/usr/bin/env python3 + +# Libervia: an XMPP client +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import time +from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional +from urllib.parse import quote +from pathlib import Path +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.exc import IntegrityError, NoResultFound +from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager +from sqlalchemy.future import select +from sqlalchemy.engine import Engine +from sqlalchemy import update, delete, and_, or_, event +from sqlalchemy.sql.functions import coalesce, sum as sum_ +from sqlalchemy.dialects.sqlite import insert +from twisted.internet import defer +from twisted.words.protocols.jabber import jid +from sat.core.i18n import _ +from sat.core import exceptions +from sat.core.log import getLogger +from sat.core.constants import Const as C +from sat.core.core_types import SatXMPPEntity +from sat.tools.utils import aio +from sat.memory.sqla_mapping import ( + NOT_IN_EXTRA, + Base, + Profile, + Component, + History, + Message, + Subject, + Thread, + ParamGen, + ParamInd, + PrivateGen, + PrivateInd, + PrivateGenBin, + PrivateIndBin, + File +) + + +log = getLogger(__name__) + + +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + +class Storage: + + def __init__(self, db_filename, sat_version): + self.initialized = defer.Deferred() + self.filename = Path(db_filename) + # we keep cache for the profiles (key: profile name, value: profile id) + # profile id to name + self.profiles: Dict[int, str] = {} + # profile id to component entry point + self.components: Dict[int, str] = {} + + @aio + async def initialise(self): + log.info(_("Connecting database")) + engine = create_async_engine( + f"sqlite+aiosqlite:///{quote(str(self.filename))}", + future=True + ) + self.session = sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) + new_base = not self.filename.exists() + if new_base: + log.info(_("The database is new, creating the tables")) + # the dir may not exist if it's not the XDG recommended one + self.filename.parent.mkdir(0o700, True, True) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async with self.session() as session: + result = await session.execute(select(Profile)) + for p in result.scalars(): + self.profiles[p.name] = p.id + result = await session.execute(select(Component)) + for c in result.scalars(): + self.components[c.profile_id] = c.entry_point + + self.initialized.callback(None) + + ## Profiles + + def getProfilesList(self) -> List[str]: + """"Return list of all registered profiles""" + return list(self.profiles.keys()) + + def hasProfile(self, profile_name: str) -> bool: + """return True if profile_name exists + + @param profile_name: name of the profile to check + """ + return profile_name in self.profiles + + def profileIsComponent(self, profile_name: str) -> bool: + try: + return self.profiles[profile_name] in self.components + except KeyError: + raise exceptions.NotFound("the requested profile doesn't exists") + + def getEntryPoint(self, profile_name: str) -> str: + try: + return self.components[self.profiles[profile_name]] + except KeyError: + raise exceptions.NotFound("the requested profile doesn't exists or is not a component") + + @aio + async def createProfile(self, name: str, component_ep: Optional[str] = None) -> None: + """Create a new profile + + @param name: name of the profile + @param component: if not None, must point to a component entry point + """ + async with self.session() as session: + profile = Profile(name=name) + async with session.begin(): + session.add(profile) + self.profiles[profile.id] = profile.name + if component_ep is not None: + async with session.begin(): + component = Component(profile=profile, entry_point=component_ep) + session.add(component) + self.components[profile.id] = component_ep + return profile + + @aio + async def deleteProfile(self, name: str) -> None: + """Delete profile + + @param name: name of the profile + """ + async with self.session() as session: + result = await session.execute(select(Profile).where(Profile.name == name)) + profile = result.scalar() + await session.delete(profile) + await session.commit() + del self.profiles[profile.id] + if profile.id in self.components: + del self.components[profile.id] + log.info(_("Profile {name!r} deleted").format(name = name)) + + ## Params + + @aio + async def loadGenParams(self, params_gen: dict) -> None: + """Load general parameters + + @param params_gen: dictionary to fill + """ + log.debug(_("loading general parameters from database")) + async with self.session() as session: + result = await session.execute(select(ParamGen)) + for p in result.scalars(): + params_gen[(p.category, p.name)] = p.value + + @aio + async def loadIndParams(self, params_ind: dict, profile: str) -> None: + """Load individual parameters + + @param params_ind: dictionary to fill + @param profile: a profile which *must* exist + """ + log.debug(_("loading individual parameters from database")) + async with self.session() as session: + result = await session.execute( + select(ParamInd).where(ParamInd.profile_id == self.profiles[profile]) + ) + for p in result.scalars(): + params_ind[(p.category, p.name)] = p.value + + @aio + async def getIndParam(self, category: str, name: str, profile: str) -> Optional[str]: + """Ask database for the value of one specific individual parameter + + @param category: category of the parameter + @param name: name of the parameter + @param profile: %(doc_profile)s + """ + async with self.session() as session: + result = await session.execute( + select(ParamInd.value) + .filter_by( + category=category, + name=name, + profile_id=self.profiles[profile] + ) + ) + return result.scalar_one_or_none() + + @aio + async def getIndParamValues(self, category: str, name: str) -> Dict[str, str]: + """Ask database for the individual values of a parameter for all profiles + + @param category: category of the parameter + @param name: name of the parameter + @return dict: profile => value map + """ + async with self.session() as session: + result = await session.execute( + select(ParamInd) + .filter_by( + category=category, + name=name + ) + .options(subqueryload(ParamInd.profile)) + ) + return {param.profile.name: param.value for param in result.scalars()} + + @aio + async def setGenParam(self, category: str, name: str, value: Optional[str]) -> None: + """Save the general parameters in database + + @param category: category of the parameter + @param name: name of the parameter + @param value: value to set + """ + async with self.session() as session: + stmt = insert(ParamGen).values( + category=category, + name=name, + value=value + ).on_conflict_do_update( + index_elements=(ParamGen.category, ParamGen.name), + set_={ + ParamGen.value: value + } + ) + await session.execute(stmt) + await session.commit() + + @aio + async def setIndParam( + self, + category:str, + name: str, + value: Optional[str], + profile: str + ) -> None: + """Save the individual parameters in database + + @param category: category of the parameter + @param name: name of the parameter + @param value: value to set + @param profile: a profile which *must* exist + """ + async with self.session() as session: + stmt = insert(ParamInd).values( + category=category, + name=name, + profile_id=self.profiles[profile], + value=value + ).on_conflict_do_update( + index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), + set_={ + ParamInd.value: value + } + ) + await session.execute(stmt) + await session.commit() + + def _jid_filter(self, jid_: jid.JID, dest: bool = False): + """Generate condition to filter on a JID, using relevant columns + + @param dest: True if it's the destinee JID, otherwise it's the source one + @param jid_: JID to filter by + """ + if jid_.resource: + if dest: + return and_( + History.dest == jid_.userhost(), + History.dest_res == jid_.resource + ) + else: + return and_( + History.source == jid_.userhost(), + History.source_res == jid_.resource + ) + else: + if dest: + return History.dest == jid_.userhost() + else: + return History.source == jid_.userhost() + + @aio + async def historyGet( + self, + from_jid: Optional[jid.JID], + to_jid: Optional[jid.JID], + limit: Optional[int] = None, + between: bool = True, + filters: Optional[Dict[str, str]] = None, + profile: Optional[str] = None, + ) -> List[Tuple[ + str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] + ]: + """Retrieve messages in history + + @param from_jid: source JID (full, or bare for catchall) + @param to_jid: dest JID (full, or bare for catchall) + @param limit: maximum number of messages to get: + - 0 for no message (returns the empty list) + - None for unlimited + @param between: confound source and dest (ignore the direction) + @param filters: pattern to filter the history results + @return: list of messages as in [messageNew], minus the profile which is already + known. + """ + # we have to set a default value to profile because it's last argument + # and thus follow other keyword arguments with default values + # but None should not be used for it + assert profile is not None + if limit == 0: + return [] + if filters is None: + filters = {} + + stmt = ( + select(History) + .filter_by( + profile_id=self.profiles[profile] + ) + .outerjoin(History.messages) + .outerjoin(History.subjects) + .outerjoin(History.thread) + .options( + contains_eager(History.messages), + contains_eager(History.subjects), + contains_eager(History.thread), + ) + .order_by( + # timestamp may be identical for 2 close messages (specially when delay is + # used) that's why we order ties by received_timestamp. We'll reverse the + # order when returning the result. We use DESC here so LIMIT keep the last + # messages + History.timestamp.desc(), + History.received_timestamp.desc() + ) + ) + + + if not from_jid and not to_jid: + # no jid specified, we want all one2one communications + pass + elif between: + if not from_jid or not to_jid: + # we only have one jid specified, we check all messages + # from or to this jid + jid_ = from_jid or to_jid + stmt = stmt.where( + or_( + self._jid_filter(jid_), + self._jid_filter(jid_, dest=True) + ) + ) + else: + # we have 2 jids specified, we check all communications between + # those 2 jids + stmt = stmt.where( + or_( + and_( + self._jid_filter(from_jid), + self._jid_filter(to_jid, dest=True), + ), + and_( + self._jid_filter(to_jid), + self._jid_filter(from_jid, dest=True), + ) + ) + ) + else: + # we want one communication in specific direction (from somebody or + # to somebody). + if from_jid is not None: + stmt = stmt.where(self._jid_filter(from_jid)) + if to_jid is not None: + stmt = stmt.where(self._jid_filter(to_jid, dest=True)) + + if filters: + if 'timestamp_start' in filters: + stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) + if 'before_uid' in filters: + # orignially this query was using SQLITE's rowid. This has been changed + # to use coalesce(received_timestamp, timestamp) to be SQL engine independant + stmt = stmt.where( + coalesce( + History.received_timestamp, + History.timestamp + ) < ( + select(coalesce(History.received_timestamp, History.timestamp)) + .filter_by(uid=filters["before_uid"]) + ).scalar_subquery() + ) + if 'body' in filters: + # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html + stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) + if 'search' in filters: + search_term = f"%{filters['search']}%" + stmt = stmt.where(or_( + Message.message.like(search_term), + History.source_res.like(search_term) + )) + if 'types' in filters: + types = filters['types'].split() + stmt = stmt.where(History.type.in_(types)) + if 'not_types' in filters: + types = filters['not_types'].split() + stmt = stmt.where(History.type.not_in(types)) + if 'last_stanza_id' in filters: + # this request get the last message with a "stanza_id" that we + # have in history. This is mainly used to retrieve messages sent + # while we were offline, using MAM (XEP-0313). + if (filters['last_stanza_id'] is not True + or limit != 1): + raise ValueError("Unexpected values for last_stanza_id filter") + stmt = stmt.where(History.stanza_id.is_not(None)) + + if limit is not None: + stmt = stmt.limit(limit) + + async with self.session() as session: + result = await session.execute(stmt) + + result = result.scalars().unique().all() + result.reverse() + return [h.as_tuple() for h in result] + + @aio + async def addToHistory(self, data: dict, profile: str) -> None: + """Store a new message in history + + @param data: message data as build by SatMessageProtocol.onMessage + """ + extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} + messages = [Message(message=mess, language=lang) + for lang, mess in data["message"].items()] + subjects = [Subject(subject=mess, language=lang) + for lang, mess in data["subject"].items()] + if "thread" in data["extra"]: + thread = Thread(thread_id=data["extra"]["thread"], + parent_id=data["extra"].get["thread_parent"]) + else: + thread = None + try: + async with self.session() as session: + async with session.begin(): + session.add(History( + uid=data["uid"], + stanza_id=data["extra"].get("stanza_id"), + update_uid=data["extra"].get("update_uid"), + profile_id=self.profiles[profile], + source_jid=data["from"], + dest_jid=data["to"], + timestamp=data["timestamp"], + received_timestamp=data.get("received_timestamp"), + type=data["type"], + extra=extra, + messages=messages, + subjects=subjects, + thread=thread, + )) + except IntegrityError as e: + if "unique" in str(e.orig).lower(): + log.debug( + f"message {data['uid']!r} is already in history, not storing it again" + ) + else: + log.error(f"Can't store message {data['uid']!r} in history: {e}") + except Exception as e: + log.critical( + f"Can't store message, unexpected exception (uid: {data['uid']}): {e}" + ) + + ## Private values + + def _getPrivateClass(self, binary, profile): + """Get ORM class to use for private values""" + if profile is None: + return PrivateGenBin if binary else PrivateGen + else: + return PrivateIndBin if binary else PrivateInd + + + @aio + async def getPrivates( + self, + namespace:str, + keys: Optional[Iterable[str]] = None, + binary: bool = False, + profile: Optional[str] = None + ) -> Dict[str, Any]: + """Get private value(s) from databases + + @param namespace: namespace of the values + @param keys: keys of the values to get None to get all keys/values + @param binary: True to deserialise binary values + @param profile: profile to use for individual values + None to use general values + @return: gotten keys/values + """ + if keys is not None: + keys = list(keys) + log.debug( + f"getting {'general' if profile is None else 'individual'}" + f"{' binary' if binary else ''} private values from database for namespace " + f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}" + ) + cls = self._getPrivateClass(binary, profile) + stmt = select(cls).filter_by(namespace=namespace) + if keys: + stmt = stmt.where(cls.key.in_(list(keys))) + if profile is not None: + stmt = stmt.filter_by(profile_id=self.profiles[profile]) + async with self.session() as session: + result = await session.execute(stmt) + return {p.key: p.value for p in result.scalars()} + + @aio + async def setPrivateValue( + self, + namespace: str, + key:str, + value: Any, + binary: bool = False, + profile: Optional[str] = None + ) -> None: + """Set a private value in database + + @param namespace: namespace of the values + @param key: key of the value to set + @param value: value to set + @param binary: True if it's a binary values + binary values need to be serialised, used for everything but strings + @param profile: profile to use for individual value + if None, it's a general value + """ + cls = self._getPrivateClass(binary, profile) + + values = { + "namespace": namespace, + "key": key, + "value": value + } + index_elements = [cls.namespace, cls.key] + + if profile is not None: + values["profile_id"] = self.profiles[profile] + index_elements.append(cls.profile_id) + + async with self.session() as session: + await session.execute( + insert(cls).values(**values).on_conflict_do_update( + index_elements=index_elements, + set_={ + cls.value: value + } + ) + ) + await session.commit() + + @aio + async def delPrivateValue( + self, + namespace: str, + key: str, + binary: bool = False, + profile: Optional[str] = None + ) -> None: + """Delete private value from database + + @param category: category of the privateeter + @param key: key of the private value + @param binary: True if it's a binary values + @param profile: profile to use for individual value + if None, it's a general value + """ + cls = self._getPrivateClass(binary, profile) + + stmt = delete(cls).filter_by(namespace=namespace, key=key) + + if profile is not None: + stmt = stmt.filter_by(profile_id=self.profiles[profile]) + + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def delPrivateNamespace( + self, + namespace: str, + binary: bool = False, + profile: Optional[str] = None + ) -> None: + """Delete all data from a private namespace + + Be really cautious when you use this method, as all data with given namespace are + removed. + Params are the same as for delPrivateValue + """ + cls = self._getPrivateClass(binary, profile) + + stmt = delete(cls).filter_by(namespace=namespace) + + if profile is not None: + stmt = stmt.filter_by(profile_id=self.profiles[profile]) + + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + ## Files + + @aio + async def getFiles( + self, + client: Optional[SatXMPPEntity], + file_id: Optional[str] = None, + version: Optional[str] = '', + parent: Optional[str] = None, + type_: Optional[str] = None, + file_hash: Optional[str] = None, + hash_algo: Optional[str] = None, + name: Optional[str] = None, + namespace: Optional[str] = None, + mime_type: Optional[str] = None, + public_id: Optional[str] = None, + owner: Optional[jid.JID] = None, + access: Optional[dict] = None, + projection: Optional[List[str]] = None, + unique: bool = False + ) -> List[dict]: + """Retrieve files with with given filters + + @param file_id: id of the file + None to ignore + @param version: version of the file + None to ignore + empty string to look for current version + @param parent: id of the directory containing the files + None to ignore + empty string to look for root files/directories + @param projection: name of columns to retrieve + None to retrieve all + @param unique: if True will remove duplicates + other params are the same as for [setFile] + @return: files corresponding to filters + """ + if projection is None: + projection = [ + 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', + 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', + 'created', 'modified', 'owner', 'access', 'extra' + ] + + stmt = select(*[getattr(File, f) for f in projection]) + + if unique: + stmt = stmt.distinct() + + if client is not None: + stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) + else: + if public_id is None: + raise exceptions.InternalError( + "client can only be omitted when public_id is set" + ) + if file_id is not None: + stmt = stmt.filter_by(id=file_id) + if version is not None: + stmt = stmt.filter_by(version=version) + if parent is not None: + stmt = stmt.filter_by(parent=parent) + if type_ is not None: + stmt = stmt.filter_by(type=type_) + if file_hash is not None: + stmt = stmt.filter_by(file_hash=file_hash) + if hash_algo is not None: + stmt = stmt.filter_by(hash_algo=hash_algo) + if name is not None: + stmt = stmt.filter_by(name=name) + if namespace is not None: + stmt = stmt.filter_by(namespace=namespace) + if mime_type is not None: + if '/' in mime_type: + media_type, media_subtype = mime_type.split("/", 1) + stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) + else: + stmt = stmt.filter_by(media_type=mime_type) + if public_id is not None: + stmt = stmt.filter_by(public_id=public_id) + if owner is not None: + stmt = stmt.filter_by(owner=owner) + if access is not None: + raise NotImplementedError('Access check is not implemented yet') + # a JSON comparison is needed here + + async with self.session() as session: + result = await session.execute(stmt) + + return [dict(r) for r in result] + + @aio + async def setFile( + self, + client: SatXMPPEntity, + name: str, + file_id: str, + version: str = "", + parent: str = "", + type_: str = C.FILE_TYPE_FILE, + file_hash: Optional[str] = None, + hash_algo: Optional[str] = None, + size: int = None, + namespace: Optional[str] = None, + mime_type: Optional[str] = None, + public_id: Optional[str] = None, + created: Optional[float] = None, + modified: Optional[float] = None, + owner: Optional[jid.JID] = None, + access: Optional[dict] = None, + extra: Optional[dict] = None + ) -> None: + """Set a file metadata + + @param client: client owning the file + @param name: name of the file (must not contain "/") + @param file_id: unique id of the file + @param version: version of this file + @param parent: id of the directory containing this file + Empty string if it is a root file/directory + @param type_: one of: + - file + - directory + @param file_hash: unique hash of the payload + @param hash_algo: algorithm used for hashing the file (usually sha-256) + @param size: size in bytes + @param namespace: identifier (human readable is better) to group files + for instance, namespace could be used to group files in a specific photo album + @param mime_type: media type of the file, or None if not known/guessed + @param public_id: ID used to server the file publicly via HTTP + @param created: UNIX time of creation + @param modified: UNIX time of last modification, or None to use created date + @param owner: jid of the owner of the file (mainly useful for component) + @param access: serialisable dictionary with access rules. See [memory.memory] for details + @param extra: serialisable dictionary of any extra data + will be encoded to json in database + """ + if mime_type is None: + media_type = media_subtype = None + elif '/' in mime_type: + media_type, media_subtype = mime_type.split('/', 1) + else: + media_type, media_subtype = mime_type, None + + async with self.session() as session: + async with session.begin(): + session.add(File( + id=file_id, + version=version.strip(), + parent=parent, + type=type_, + file_hash=file_hash, + hash_algo=hash_algo, + name=name, + size=size, + namespace=namespace, + media_type=media_type, + media_subtype=media_subtype, + public_id=public_id, + created=time.time() if created is None else created, + modified=modified, + owner=owner, + access=access, + extra=extra, + profile_id=self.profiles[client.profile] + )) + + @aio + async def fileGetUsedSpace(self, client: SatXMPPEntity, owner: jid.JID) -> int: + async with self.session() as session: + result = await session.execute( + select(sum_(File.size)).filter_by( + owner=owner, + type=C.FILE_TYPE_FILE, + profile_id=self.profiles[client.profile] + )) + return result.scalar_one_or_none() or 0 + + @aio + async def fileDelete(self, file_id: str) -> None: + """Delete file metadata from the database + + @param file_id: id of the file to delete + NOTE: file itself must still be removed, this method only handle metadata in + database + """ + async with self.session() as session: + await session.execute(delete(File).filter_by(id=file_id)) + await session.commit() + + @aio + async def fileUpdate( + self, + file_id: str, + column: str, + update_cb: Callable[[dict], None] + ) -> None: + """Update a column value using a method to avoid race conditions + + the older value will be retrieved from database, then update_cb will be applied to + update it, and file will be updated checking that older value has not been changed + meanwhile by an other user. If it has changed, it tries again a couple of times + before failing + @param column: column name (only "access" or "extra" are allowed) + @param update_cb: method to update the value of the colum + the method will take older value as argument, and must update it in place + update_cb must not care about serialization, + it get the deserialized data (i.e. a Python object) directly + @raise exceptions.NotFound: there is not file with this id + """ + if column not in ('access', 'extra'): + raise exceptions.InternalError('bad column name') + orm_col = getattr(File, column) + + for i in range(5): + async with self.session() as session: + try: + value = (await session.execute( + select(orm_col).filter_by(id=file_id) + )).scalar_one() + except NoResultFound: + raise exceptions.NotFound + update_cb(value) + stmt = update(orm_col).filter_by(id=file_id) + if not value: + # because JsonDefaultDict convert NULL to an empty dict, we have to + # test both for empty dict and None when we have and empty dict + stmt = stmt.where((orm_col == None) | (orm_col == value)) + else: + stmt = stmt.where(orm_col == value) + result = await session.execute(stmt) + await session.commit() + + if result.rowcount == 1: + break + + log.warning( + _("table not updated, probably due to race condition, trying again " + "({tries})").format(tries=i+1) + ) + + else: + raise exceptions.DatabaseError( + _("Can't update file {file_id} due to race condition") + .format(file_id=file_id) + )