Mercurial > libervia-backend
diff libervia/backend/memory/memory.py @ 4071:4b842c1fb686
refactoring: renamed `sat` package to `libervia.backend`
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 02 Jun 2023 11:49:51 +0200 |
parents | sat/memory/memory.py@524856bd7b19 |
children | 02f0adc745c6 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libervia/backend/memory/memory.py Fri Jun 02 11:49:51 2023 +0200 @@ -0,0 +1,1881 @@ +#!/usr/bin/env python3 + +# Libervia: an XMPP client +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import os.path +import copy +import shortuuid +import mimetypes +import time +from functools import partial +from typing import Optional, Tuple, Dict +from pathlib import Path +from uuid import uuid4 +from collections import namedtuple +from twisted.python import failure +from twisted.internet import defer, reactor, error +from twisted.words.protocols.jabber import jid +from libervia.backend.core.i18n import _ +from libervia.backend.core.log import getLogger +from libervia.backend.core import exceptions +from libervia.backend.core.constants import Const as C +from libervia.backend.memory.sqla import Storage +from libervia.backend.memory.persistent import PersistentDict +from libervia.backend.memory.params import Params +from libervia.backend.memory.disco import Discovery +from libervia.backend.memory.crypto import BlockCipher +from libervia.backend.memory.crypto import PasswordHasher +from libervia.backend.tools import config as tools_config +from libervia.backend.tools.common import data_format +from libervia.backend.tools.common import regex + + +log = getLogger(__name__) + + +PresenceTuple = namedtuple("PresenceTuple", ("show", "priority", "statuses")) +MSG_NO_SESSION = "Session id doesn't exist or is finished" + + +class Sessions(object): + """Sessions are data associated to key used for a temporary moment, with optional profile checking.""" + + DEFAULT_TIMEOUT = 600 + + def __init__(self, timeout=None, resettable_timeout=True): + """ + @param timeout (int): nb of seconds before session destruction + @param resettable_timeout (bool): if True, the timeout is reset on each access + """ + self._sessions = dict() + self.timeout = timeout or Sessions.DEFAULT_TIMEOUT + self.resettable_timeout = resettable_timeout + + def new_session(self, session_data=None, session_id=None, profile=None): + """Create a new session + + @param session_data: mutable data to use, default to a dict + @param session_id (str): force the session_id to the given string + @param profile: if set, the session is owned by the profile, + and profile_get must be used instead of __getitem__ + @return: session_id, session_data + """ + if session_id is None: + session_id = str(uuid4()) + elif session_id in self._sessions: + raise exceptions.ConflictError( + "Session id {} is already used".format(session_id) + ) + timer = reactor.callLater(self.timeout, self._purge_session, session_id) + if session_data is None: + session_data = {} + self._sessions[session_id] = ( + (timer, session_data) if profile is None else (timer, session_data, profile) + ) + return session_id, session_data + + def _purge_session(self, session_id): + try: + timer, session_data, profile = self._sessions[session_id] + except ValueError: + timer, session_data = self._sessions[session_id] + profile = None + try: + timer.cancel() + except error.AlreadyCalled: + # if the session is time-outed, the timer has been called + pass + del self._sessions[session_id] + log.debug( + "Session {} purged{}".format( + session_id, + " (profile {})".format(profile) if profile is not None else "", + ) + ) + + def __len__(self): + return len(self._sessions) + + def __contains__(self, session_id): + return session_id in self._sessions + + def profile_get(self, session_id, profile): + try: + timer, session_data, profile_set = self._sessions[session_id] + except ValueError: + raise exceptions.InternalError( + "You need to use __getitem__ when profile is not set" + ) + except KeyError: + raise failure.Failure(KeyError(MSG_NO_SESSION)) + if profile_set != profile: + raise exceptions.InternalError("current profile differ from set profile !") + if self.resettable_timeout: + timer.reset(self.timeout) + return session_data + + def __getitem__(self, session_id): + try: + timer, session_data = self._sessions[session_id] + except ValueError: + raise exceptions.InternalError( + "You need to use profile_get instead of __getitem__ when profile is set" + ) + except KeyError: + raise failure.Failure(KeyError(MSG_NO_SESSION)) + if self.resettable_timeout: + timer.reset(self.timeout) + return session_data + + def __setitem__(self, key, value): + raise NotImplementedError("You need do use new_session to create a session") + + def __delitem__(self, session_id): + """ delete the session data """ + self._purge_session(session_id) + + def keys(self): + return list(self._sessions.keys()) + + def iterkeys(self): + return iter(self._sessions.keys()) + + +class ProfileSessions(Sessions): + """ProfileSessions extends the Sessions class, but here the profile can be + used as the key to retrieve data or delete a session (instead of session id). + """ + + def _profile_get_all_ids(self, profile): + """Return a list of the sessions ids that are associated to the given profile. + + @param profile: %(doc_profile)s + @return: a list containing the sessions ids + """ + ret = [] + for session_id in self._sessions.keys(): + try: + timer, session_data, profile_set = self._sessions[session_id] + except ValueError: + continue + if profile == profile_set: + ret.append(session_id) + return ret + + def profile_get_unique(self, profile): + """Return the data of the unique session that is associated to the given profile. + + @param profile: %(doc_profile)s + @return: + - mutable data (default: dict) of the unique session + - None if no session is associated to the profile + - raise an error if more than one session are found + """ + ids = self._profile_get_all_ids(profile) + if len(ids) > 1: + raise exceptions.InternalError( + "profile_get_unique has been used but more than one session has been found!" + ) + return ( + self.profile_get(ids[0], profile) if len(ids) == 1 else None + ) # XXX: timeout might be reset + + def profile_del_unique(self, profile): + """Delete the unique session that is associated to the given profile. + + @param profile: %(doc_profile)s + @return: None, but raise an error if more than one session are found + """ + ids = self._profile_get_all_ids(profile) + if len(ids) > 1: + raise exceptions.InternalError( + "profile_del_unique has been used but more than one session has been found!" + ) + if len(ids) == 1: + del self._sessions[ids[0]] + + +class PasswordSessions(ProfileSessions): + + # FIXME: temporary hack for the user personal key not to be lost. The session + # must actually be purged and later, when the personal key is needed, the + # profile password should be asked again in order to decrypt it. + def __init__(self, timeout=None): + ProfileSessions.__init__(self, timeout, resettable_timeout=False) + + def _purge_session(self, session_id): + log.debug( + "FIXME: PasswordSessions should ask for the profile password after the session expired" + ) + + +class Memory: + """This class manage all the persistent information""" + + def __init__(self, host): + log.info(_("Memory manager init")) + self.host = host + self._entities_cache = {} # XXX: keep presence/last resource/other data in cache + # /!\ an entity is not necessarily in roster + # main key is bare jid, value is a dict + # where main key is resource, or None for bare jid + self._key_signals = set() # key which need a signal to frontends when updated + self.subscriptions = {} + self.auth_sessions = PasswordSessions() # remember the authenticated profiles + self.disco = Discovery(host) + self.config = tools_config.parse_main_conf(log_filenames=True) + self._cache_path = Path(self.config_get("", "local_dir"), C.CACHE_DIR) + self.admins = self.config_get("", "admins_list", []) + self.admin_jids = set() + + + async def initialise(self): + self.storage = Storage() + await self.storage.initialise() + PersistentDict.storage = self.storage + self.params = Params(self.host, self.storage) + log.info(_("Loading default params template")) + self.params.load_default_params() + await self.load() + self.memory_data = PersistentDict("memory") + await self.memory_data.load() + await self.disco.load() + for admin in self.admins: + try: + admin_jid_s = await self.param_get_a_async( + "JabberID", "Connection", profile_key=admin + ) + except Exception as e: + log.warning(f"Can't retrieve jid of admin {admin!r}: {e}") + else: + if admin_jid_s is not None: + try: + admin_jid = jid.JID(admin_jid_s).userhostJID() + except RuntimeError: + log.warning(f"Invalid JID for admin {admin}: {admin_jid_s}") + else: + self.admin_jids.add(admin_jid) + + + ## Configuration ## + + def config_get(self, section, name, default=None): + """Get the main configuration option + + @param section: section of the config file (None or '' for DEFAULT) + @param name: name of the option + @param default: value to use if not found + @return: str, list or dict + """ + return tools_config.config_get(self.config, section, name, default) + + def load_xml(self, filename): + """Load parameters template from xml file + + @param filename (str): input file + @return: bool: True in case of success + """ + if not filename: + return False + filename = os.path.expanduser(filename) + if os.path.exists(filename): + try: + self.params.load_xml(filename) + log.debug(_("Parameters loaded from file: %s") % filename) + return True + except Exception as e: + log.error(_("Can't load parameters from file: %s") % e) + return False + + def save_xml(self, filename): + """Save parameters template to xml file + + @param filename (str): output file + @return: bool: True in case of success + """ + if not filename: + return False + # TODO: need to encrypt files (at least passwords !) and set permissions + filename = os.path.expanduser(filename) + try: + self.params.save_xml(filename) + log.debug(_("Parameters saved to file: %s") % filename) + return True + except Exception as e: + log.error(_("Can't save parameters to file: %s") % e) + return False + + def load(self): + """Load parameters and all memory things from db""" + # parameters data + return self.params.load_gen_params() + + def load_individual_params(self, profile): + """Load individual parameters for a profile + @param profile: %(doc_profile)s""" + return self.params.load_ind_params(profile) + + ## Profiles/Sessions management ## + + def start_session(self, password, profile): + """"Iniatialise session for a profile + + @param password(unicode): profile session password + or empty string is no password is set + @param profile: %(doc_profile)s + @raise exceptions.ProfileUnknownError if profile doesn't exists + @raise exceptions.PasswordError: the password does not match + """ + profile = self.get_profile_name(profile) + + def create_session(__): + """Called once params are loaded.""" + self._entities_cache[profile] = {} + log.info("[{}] Profile session started".format(profile)) + return False + + def backend_initialised(__): + def do_start_session(__=None): + if self.is_session_started(profile): + log.info("Session already started!") + return True + try: + # if there is a value at this point in self._entities_cache, + # it is the load_individual_params Deferred, the session is starting + session_d = self._entities_cache[profile] + except KeyError: + # else we do request the params + session_d = self._entities_cache[profile] = self.load_individual_params( + profile + ) + session_d.addCallback(create_session) + finally: + return session_d + + auth_d = defer.ensureDeferred(self.profile_authenticate(password, profile)) + auth_d.addCallback(do_start_session) + return auth_d + + if self.host.initialised.called: + return defer.succeed(None).addCallback(backend_initialised) + else: + return self.host.initialised.addCallback(backend_initialised) + + def stop_session(self, profile): + """Delete a profile session + + @param profile: %(doc_profile)s + """ + if self.host.is_connected(profile): + log.debug("Disconnecting profile because of session stop") + self.host.disconnect(profile) + self.auth_sessions.profile_del_unique(profile) + try: + self._entities_cache[profile] + except KeyError: + log.warning("Profile was not in cache") + + def _is_session_started(self, profile_key): + return self.is_session_started(self.get_profile_name(profile_key)) + + def is_session_started(self, profile): + try: + # XXX: if the value in self._entities_cache is a Deferred, + # the session is starting but not started yet + return not isinstance(self._entities_cache[profile], defer.Deferred) + except KeyError: + return False + + async def profile_authenticate(self, password, profile): + """Authenticate the profile. + + @param password (unicode): the SàT profile password + @return: None in case of success (an exception is raised otherwise) + @raise exceptions.PasswordError: the password does not match + """ + if not password and self.auth_sessions.profile_get_unique(profile): + # XXX: this allows any frontend to connect with the empty password as soon as + # the profile has been authenticated at least once before. It is OK as long as + # submitting a form with empty passwords is restricted to local frontends. + return + + sat_cipher = await self.param_get_a_async( + C.PROFILE_PASS_PATH[1], C.PROFILE_PASS_PATH[0], profile_key=profile + ) + valid = PasswordHasher.verify(password, sat_cipher) + if not valid: + log.warning(_("Authentication failure of profile {profile}").format( + profile=profile)) + raise exceptions.PasswordError("The provided profile password doesn't match.") + return await self.new_auth_session(password, profile) + + async def new_auth_session(self, key, profile): + """Start a new session for the authenticated profile. + + If there is already an existing session, no new one is created + The personal key is loaded encrypted from a PersistentDict before being decrypted. + + @param key: the key to decrypt the personal key + @param profile: %(doc_profile)s + """ + data = await PersistentDict(C.MEMORY_CRYPTO_NAMESPACE, profile).load() + personal_key = BlockCipher.decrypt(key, data[C.MEMORY_CRYPTO_KEY]) + # Create the session for this profile and store the personal key + session_data = self.auth_sessions.profile_get_unique(profile) + if not session_data: + self.auth_sessions.new_session( + {C.MEMORY_CRYPTO_KEY: personal_key}, profile=profile + ) + log.debug("auth session created for profile %s" % profile) + + def purge_profile_session(self, profile): + """Delete cache of data of profile + @param profile: %(doc_profile)s""" + log.info(_("[%s] Profile session purge" % profile)) + self.params.purge_profile(profile) + try: + del self._entities_cache[profile] + except KeyError: + log.error( + _( + "Trying to purge roster status cache for a profile not in memory: [%s]" + ) + % profile + ) + + def get_profiles_list(self, clients=True, components=False): + """retrieve profiles list + + @param clients(bool): if True return clients profiles + @param components(bool): if True return components profiles + @return (list[unicode]): selected profiles + """ + if not clients and not components: + log.warning(_("requesting no profiles at all")) + return [] + profiles = self.storage.get_profiles_list() + if clients and components: + return sorted(profiles) + is_component = self.storage.profile_is_component + if clients: + p_filter = lambda p: not is_component(p) + else: + p_filter = lambda p: is_component(p) + + return sorted(p for p in profiles if p_filter(p)) + + def get_profile_name(self, profile_key, return_profile_keys=False): + """Return name of profile from keyword + + @param profile_key: can be the profile name or a keyword (like @DEFAULT@) + @param return_profile_keys: if True, return unmanaged profile keys (like "@ALL@"). This keys must be managed by the caller + @return: requested profile name + @raise exceptions.ProfileUnknownError if profile doesn't exists + """ + return self.params.get_profile_name(profile_key, return_profile_keys) + + def profile_set_default(self, profile): + """Set default profile + + @param profile: %(doc_profile)s + """ + # we want to be sure that the profile exists + profile = self.get_profile_name(profile) + + self.memory_data["Profile_default"] = profile + + def create_profile(self, name, password, component=None): + """Create a new profile + + @param name(unicode): profile name + @param password(unicode): profile password + Can be empty to disable password + @param component(None, unicode): set to entry point if this is a component + @return: Deferred + @raise exceptions.NotFound: component is not a known plugin import name + """ + if not name: + raise ValueError("Empty profile name") + if name[0] == "@": + raise ValueError("A profile name can't start with a '@'") + if "\n" in name: + raise ValueError("A profile name can't contain line feed ('\\n')") + + if name in self._entities_cache: + raise exceptions.ConflictError("A session for this profile exists") + + if component: + if not component in self.host.plugins: + raise exceptions.NotFound( + _( + "Can't find component {component} entry point".format( + component=component + ) + ) + ) + # FIXME: PLUGIN_INFO is not currently accessible after import, but type shoul be tested here + # if self.host.plugins[component].PLUGIN_INFO[u"type"] != C.PLUG_TYPE_ENTRY_POINT: + # raise ValueError(_(u"Plugin {component} is not an entry point !".format( + # component = component))) + + d = self.params.create_profile(name, component) + + def init_personal_key(__): + # be sure to call this after checking that the profile doesn't exist yet + + # generated once for all and saved in a PersistentDict + personal_key = BlockCipher.get_random_key( + base64=True + ).decode('utf-8') + self.auth_sessions.new_session( + {C.MEMORY_CRYPTO_KEY: personal_key}, profile=name + ) # will be encrypted by param_set + + def start_fake_session(__): + # avoid ProfileNotConnected exception in param_set + self._entities_cache[name] = None + self.params.load_ind_params(name) + + def stop_fake_session(__): + del self._entities_cache[name] + self.params.purge_profile(name) + + d.addCallback(init_personal_key) + d.addCallback(start_fake_session) + d.addCallback( + lambda __: self.param_set( + C.PROFILE_PASS_PATH[1], password, C.PROFILE_PASS_PATH[0], profile_key=name + ) + ) + d.addCallback(stop_fake_session) + d.addCallback(lambda __: self.auth_sessions.profile_del_unique(name)) + return d + + def profile_delete_async(self, name, force=False): + """Delete an existing profile + + @param name: Name of the profile + @param force: force the deletion even if the profile is connected. + To be used for direct calls only (not through the bridge). + @return: a Deferred instance + """ + + def clean_memory(__): + self.auth_sessions.profile_del_unique(name) + try: + del self._entities_cache[name] + except KeyError: + pass + + d = self.params.profile_delete_async(name, force) + d.addCallback(clean_memory) + return d + + def is_component(self, profile_name): + """Tell if a profile is a component + + @param profile_name(unicode): name of the profile + @return (bool): True if profile is a component + @raise exceptions.NotFound: profile doesn't exist + """ + return self.storage.profile_is_component(profile_name) + + def get_entry_point(self, profile_name): + """Get a component entry point + + @param profile_name(unicode): name of the profile + @return (bool): True if profile is a component + @raise exceptions.NotFound: profile doesn't exist + """ + return self.storage.get_entry_point(profile_name) + + ## History ## + + def add_to_history(self, client, data): + return self.storage.add_to_history(data, client.profile) + + def _history_get_serialise(self, history_data): + return [ + (uid, timestamp, from_jid, to_jid, message, subject, mess_type, + data_format.serialise(extra)) for uid, timestamp, from_jid, to_jid, message, + subject, mess_type, extra in history_data + ] + + def _history_get(self, from_jid_s, to_jid_s, limit=C.HISTORY_LIMIT_NONE, between=True, + filters=None, profile=C.PROF_KEY_NONE): + d = self.history_get(jid.JID(from_jid_s), jid.JID(to_jid_s), limit, between, + filters, profile) + d.addCallback(self._history_get_serialise) + return d + + def history_get(self, from_jid, to_jid, limit=C.HISTORY_LIMIT_NONE, between=True, + filters=None, profile=C.PROF_KEY_NONE): + """Retrieve messages in history + + @param from_jid (JID): source JID (full, or bare for catchall) + @param to_jid (JID): dest JID (full, or bare for catchall) + @param limit (int): maximum number of messages to get: + - 0 for no message (returns the empty list) + - C.HISTORY_LIMIT_NONE or None for unlimited + - C.HISTORY_LIMIT_DEFAULT to use the HISTORY_LIMIT parameter value + @param between (bool): confound source and dest (ignore the direction) + @param filters (dict[unicode, unicode]): pattern to filter the history results + (see bridge API for details) + @param profile (str): %(doc_profile)s + @return (D(list)): list of message data as in [message_new] + """ + assert profile != C.PROF_KEY_NONE + if limit == C.HISTORY_LIMIT_DEFAULT: + limit = int(self.param_get_a(C.HISTORY_LIMIT, "General", profile_key=profile)) + elif limit == C.HISTORY_LIMIT_NONE: + limit = None + if limit == 0: + return defer.succeed([]) + return self.storage.history_get(from_jid, to_jid, limit, between, filters, profile) + + ## Statuses ## + + def _get_presence_statuses(self, profile_key): + ret = self.presence_statuses_get(profile_key) + return {entity.full(): data for entity, data in ret.items()} + + def presence_statuses_get(self, profile_key): + """Get all the presence statuses of a profile + + @param profile_key: %(doc_profile_key)s + @return: presence data: key=entity JID, value=presence data for this entity + """ + client = self.host.get_client(profile_key) + profile_cache = self._get_profile_cache(client) + entities_presence = {} + + for entity_jid, entity_data in profile_cache.items(): + for resource, resource_data in entity_data.items(): + full_jid = copy.copy(entity_jid) + full_jid.resource = resource + try: + presence_data = self.get_entity_datum(client, full_jid, "presence") + except KeyError: + continue + entities_presence.setdefault(entity_jid, {})[ + resource or "" + ] = presence_data + + return entities_presence + + def set_presence_status(self, entity_jid, show, priority, statuses, profile_key): + """Change the presence status of an entity + + @param entity_jid: jid.JID of the entity + @param show: show status + @param priority: priority + @param statuses: dictionary of statuses + @param profile_key: %(doc_profile_key)s + """ + client = self.host.get_client(profile_key) + presence_data = PresenceTuple(show, priority, statuses) + self.update_entity_data( + client, entity_jid, "presence", presence_data + ) + if entity_jid.resource and show != C.PRESENCE_UNAVAILABLE: + # If a resource is available, bare jid should not have presence information + try: + self.del_entity_datum(client, entity_jid.userhostJID(), "presence") + except (KeyError, exceptions.UnknownEntityError): + pass + + ## Resources ## + + def _get_all_resource(self, jid_s, profile_key): + client = self.host.get_client(profile_key) + jid_ = jid.JID(jid_s) + return self.get_all_resources(client, jid_) + + def get_all_resources(self, client, entity_jid): + """Return all resource from jid for which we have had data in this session + + @param entity_jid: bare jid of the entity + return (set[unicode]): set of resources + + @raise exceptions.UnknownEntityError: if entity is not in cache + @raise ValueError: entity_jid has a resource + """ + # FIXME: is there a need to keep cache data for resources which are not connected anymore? + if entity_jid.resource: + raise ValueError( + "get_all_resources must be used with a bare jid (got {})".format(entity_jid) + ) + profile_cache = self._get_profile_cache(client) + try: + entity_data = profile_cache[entity_jid.userhostJID()] + except KeyError: + raise exceptions.UnknownEntityError( + "Entity {} not in cache".format(entity_jid) + ) + resources = set(entity_data.keys()) + resources.discard(None) + return resources + + def get_available_resources(self, client, entity_jid): + """Return available resource for entity_jid + + This method differs from get_all_resources by returning only available resources + @param entity_jid: bare jid of the entit + return (list[unicode]): list of available resources + + @raise exceptions.UnknownEntityError: if entity is not in cache + """ + available = [] + for resource in self.get_all_resources(client, entity_jid): + full_jid = copy.copy(entity_jid) + full_jid.resource = resource + try: + presence_data = self.get_entity_datum(client, full_jid, "presence") + except KeyError: + log.debug("Can't get presence data for {}".format(full_jid)) + else: + if presence_data.show != C.PRESENCE_UNAVAILABLE: + available.append(resource) + return available + + def _get_main_resource(self, jid_s, profile_key): + client = self.host.get_client(profile_key) + jid_ = jid.JID(jid_s) + return self.main_resource_get(client, jid_) or "" + + def main_resource_get(self, client, entity_jid): + """Return the main resource used by an entity + + @param entity_jid: bare entity jid + @return (unicode): main resource or None + """ + if entity_jid.resource: + raise ValueError( + "main_resource_get must be used with a bare jid (got {})".format(entity_jid) + ) + try: + if self.host.plugins["XEP-0045"].is_joined_room(client, entity_jid): + return None # MUC rooms have no main resource + except KeyError: # plugin not found + pass + try: + resources = self.get_all_resources(client, entity_jid) + except exceptions.UnknownEntityError: + log.warning("Entity is not in cache, we can't find any resource") + return None + priority_resources = [] + for resource in resources: + full_jid = copy.copy(entity_jid) + full_jid.resource = resource + try: + presence_data = self.get_entity_datum(client, full_jid, "presence") + except KeyError: + log.debug("No presence information for {}".format(full_jid)) + continue + priority_resources.append((resource, presence_data.priority)) + try: + return max(priority_resources, key=lambda res_tuple: res_tuple[1])[0] + except ValueError: + log.warning("No resource found at all for {}".format(entity_jid)) + return None + + ## Entities data ## + + def _get_profile_cache(self, client): + """Check profile validity and return its cache + + @param client: SatXMPPClient + @return (dict): profile cache + """ + return self._entities_cache[client.profile] + + def set_signal_on_update(self, key, signal=True): + """Set a signal flag on the key + + When the key will be updated, a signal will be sent to frontends + @param key: key to signal + @param signal(boolean): if True, do the signal + """ + if signal: + self._key_signals.add(key) + else: + self._key_signals.discard(key) + + def get_all_entities_iter(self, client, with_bare=False): + """Return an iterator of full jids of all entities in cache + + @param with_bare: if True, include bare jids + @return (list[unicode]): list of jids + """ + profile_cache = self._get_profile_cache(client) + # we construct a list of all known full jids (bare jid of entities x resources) + for bare_jid, entity_data in profile_cache.items(): + for resource in entity_data.keys(): + if resource is None: + continue + full_jid = copy.copy(bare_jid) + full_jid.resource = resource + yield full_jid + + def update_entity_data( + self, client, entity_jid, key, value, silent=False + ): + """Set a misc data for an entity + + If key was registered with set_signal_on_update, a signal will be sent to frontends + @param entity_jid: JID of the entity, C.ENTITY_ALL_RESOURCES for all resources of + all entities, C.ENTITY_ALL for all entities (all resources + bare jids) + @param key: key to set (eg: C.ENTITY_TYPE) + @param value: value for this key (eg: C.ENTITY_TYPE_MUC) + @param silent(bool): if True, doesn't send signal to frontend, even if there is a + signal flag (see set_signal_on_update) + """ + profile_cache = self._get_profile_cache(client) + if entity_jid in (C.ENTITY_ALL_RESOURCES, C.ENTITY_ALL): + entities = self.get_all_entities_iter(client, entity_jid == C.ENTITY_ALL) + else: + entities = (entity_jid,) + + for jid_ in entities: + entity_data = profile_cache.setdefault(jid_.userhostJID(), {}).setdefault( + jid_.resource, {} + ) + + entity_data[key] = value + if key in self._key_signals and not silent: + self.host.bridge.entity_data_updated( + jid_.full(), + key, + data_format.serialise(value), + client.profile + ) + + def del_entity_datum(self, client, entity_jid, key): + """Delete a data for an entity + + @param entity_jid: JID of the entity, C.ENTITY_ALL_RESOURCES for all resources of all entities, + C.ENTITY_ALL for all entities (all resources + bare jids) + @param key: key to delete (eg: C.ENTITY_TYPE) + + @raise exceptions.UnknownEntityError: if entity is not in cache + @raise KeyError: key is not in cache + """ + profile_cache = self._get_profile_cache(client) + if entity_jid in (C.ENTITY_ALL_RESOURCES, C.ENTITY_ALL): + entities = self.get_all_entities_iter(client, entity_jid == C.ENTITY_ALL) + else: + entities = (entity_jid,) + + for jid_ in entities: + try: + entity_data = profile_cache[jid_.userhostJID()][jid_.resource] + except KeyError: + raise exceptions.UnknownEntityError( + "Entity {} not in cache".format(jid_) + ) + try: + del entity_data[key] + except KeyError as e: + if entity_jid in (C.ENTITY_ALL_RESOURCES, C.ENTITY_ALL): + continue # we ignore KeyError when deleting keys from several entities + else: + raise e + + def _get_entities_data(self, entities_jids, keys_list, profile_key): + client = self.host.get_client(profile_key) + ret = self.entities_data_get( + client, [jid.JID(jid_) for jid_ in entities_jids], keys_list + ) + return { + jid_.full(): {k: data_format.serialise(v) for k,v in data.items()} + for jid_, data in ret.items() + } + + def entities_data_get(self, client, entities_jids, keys_list=None): + """Get a list of cached values for several entities at once + + @param entities_jids: jids of the entities, or empty list for all entities in cache + @param keys_list (iterable,None): list of keys to get, None for everything + @param profile_key: %(doc_profile_key)s + @return: dict withs values for each key in keys_list. + if there is no value of a given key, resulting dict will + have nothing with that key nether + if an entity doesn't exist in cache, it will not appear + in resulting dict + + @raise exceptions.UnknownEntityError: if entity is not in cache + """ + + def fill_entity_data(entity_cache_data): + entity_data = {} + if keys_list is None: + entity_data = entity_cache_data + else: + for key in keys_list: + try: + entity_data[key] = entity_cache_data[key] + except KeyError: + continue + return entity_data + + profile_cache = self._get_profile_cache(client) + ret_data = {} + if entities_jids: + for entity in entities_jids: + try: + entity_cache_data = profile_cache[entity.userhostJID()][ + entity.resource + ] + except KeyError: + continue + ret_data[entity.full()] = fill_entity_data(entity_cache_data, keys_list) + else: + for bare_jid, data in profile_cache.items(): + for resource, entity_cache_data in data.items(): + full_jid = copy.copy(bare_jid) + full_jid.resource = resource + ret_data[full_jid] = fill_entity_data(entity_cache_data) + + return ret_data + + def _get_entity_data(self, entity_jid_s, keys_list=None, profile=C.PROF_KEY_NONE): + return self.entity_data_get( + self.host.get_client(profile), jid.JID(entity_jid_s), keys_list) + + def entity_data_get(self, client, entity_jid, keys_list=None): + """Get a list of cached values for entity + + @param entity_jid: JID of the entity + @param keys_list (iterable,None): list of keys to get, None for everything + @param profile_key: %(doc_profile_key)s + @return: dict withs values for each key in keys_list. + if there is no value of a given key, resulting dict will + have nothing with that key nether + + @raise exceptions.UnknownEntityError: if entity is not in cache + """ + profile_cache = self._get_profile_cache(client) + try: + entity_data = profile_cache[entity_jid.userhostJID()][entity_jid.resource] + except KeyError: + raise exceptions.UnknownEntityError( + "Entity {} not in cache (was requesting {})".format( + entity_jid, keys_list + ) + ) + if keys_list is None: + return entity_data + + return {key: entity_data[key] for key in keys_list if key in entity_data} + + def get_entity_datum(self, client, entity_jid, key): + """Get a datum from entity + + @param entity_jid: JID of the entity + @param key: key to get + @return: requested value + + @raise exceptions.UnknownEntityError: if entity is not in cache + @raise KeyError: if there is no value for this key and this entity + """ + return self.entity_data_get(client, entity_jid, (key,))[key] + + def del_entity_cache( + self, entity_jid, delete_all_resources=True, profile_key=C.PROF_KEY_NONE + ): + """Remove all cached data for entity + + @param entity_jid: JID of the entity to delete + @param delete_all_resources: if True also delete all known resources from cache (a bare jid must be given in this case) + @param profile_key: %(doc_profile_key)s + + @raise exceptions.UnknownEntityError: if entity is not in cache + """ + client = self.host.get_client(profile_key) + profile_cache = self._get_profile_cache(client) + + if delete_all_resources: + if entity_jid.resource: + raise ValueError(_("Need a bare jid to delete all resources")) + try: + del profile_cache[entity_jid] + except KeyError: + raise exceptions.UnknownEntityError( + "Entity {} not in cache".format(entity_jid) + ) + else: + try: + del profile_cache[entity_jid.userhostJID()][entity_jid.resource] + except KeyError: + raise exceptions.UnknownEntityError( + "Entity {} not in cache".format(entity_jid) + ) + + ## Encryption ## + + def encrypt_value(self, value, profile): + """Encrypt a value for the given profile. The personal key must be loaded + already in the profile session, that should be the case if the profile is + already authenticated. + + @param value (str): the value to encrypt + @param profile (str): %(doc_profile)s + @return: the deferred encrypted value + """ + try: + personal_key = self.auth_sessions.profile_get_unique(profile)[ + C.MEMORY_CRYPTO_KEY + ] + except TypeError: + raise exceptions.InternalError( + _("Trying to encrypt a value for %s while the personal key is undefined!") + % profile + ) + return BlockCipher.encrypt(personal_key, value) + + def decrypt_value(self, value, profile): + """Decrypt a value for the given profile. The personal key must be loaded + already in the profile session, that should be the case if the profile is + already authenticated. + + @param value (str): the value to decrypt + @param profile (str): %(doc_profile)s + @return: the deferred decrypted value + """ + try: + personal_key = self.auth_sessions.profile_get_unique(profile)[ + C.MEMORY_CRYPTO_KEY + ] + except TypeError: + raise exceptions.InternalError( + _("Trying to decrypt a value for %s while the personal key is undefined!") + % profile + ) + return BlockCipher.decrypt(personal_key, value) + + def encrypt_personal_data(self, data_key, data_value, crypto_key, profile): + """Re-encrypt a personal data (saved to a PersistentDict). + + @param data_key: key for the individual PersistentDict instance + @param data_value: the value to be encrypted + @param crypto_key: the key to encrypt the value + @param profile: %(profile_doc)s + @return: a deferred None value + """ + + def got_ind_memory(data): + data[data_key] = BlockCipher.encrypt(crypto_key, data_value) + return data.force(data_key) + + def done(__): + log.debug( + _("Personal data (%(ns)s, %(key)s) has been successfuly encrypted") + % {"ns": C.MEMORY_CRYPTO_NAMESPACE, "key": data_key} + ) + + d = PersistentDict(C.MEMORY_CRYPTO_NAMESPACE, profile).load() + return d.addCallback(got_ind_memory).addCallback(done) + + ## Subscription requests ## + + def add_waiting_sub(self, type_, entity_jid, profile_key): + """Called when a subcription request is received""" + profile = self.get_profile_name(profile_key) + assert profile + if profile not in self.subscriptions: + self.subscriptions[profile] = {} + self.subscriptions[profile][entity_jid] = type_ + + def del_waiting_sub(self, entity_jid, profile_key): + """Called when a subcription request is finished""" + profile = self.get_profile_name(profile_key) + assert profile + if profile in self.subscriptions and entity_jid in self.subscriptions[profile]: + del self.subscriptions[profile][entity_jid] + + def sub_waiting_get(self, profile_key): + """Called to get a list of currently waiting subscription requests""" + profile = self.get_profile_name(profile_key) + if not profile: + log.error(_("Asking waiting subscriptions for a non-existant profile")) + return {} + if profile not in self.subscriptions: + return {} + + return self.subscriptions[profile] + + ## Parameters ## + + def get_string_param_a(self, name, category, attr="value", profile_key=C.PROF_KEY_NONE): + return self.params.get_string_param_a(name, category, attr, profile_key) + + def param_get_a(self, name, category, attr="value", profile_key=C.PROF_KEY_NONE): + return self.params.param_get_a(name, category, attr, profile_key=profile_key) + + def param_get_a_async( + self, + name, + category, + attr="value", + security_limit=C.NO_SECURITY_LIMIT, + profile_key=C.PROF_KEY_NONE, + ): + return self.params.param_get_a_async( + name, category, attr, security_limit, profile_key + ) + + def _get_params_values_from_category( + self, category, security_limit, app, extra_s, profile_key + ): + return self.params._get_params_values_from_category( + category, security_limit, app, extra_s, profile_key + ) + + def async_get_string_param_a( + self, name, category, attribute="value", security_limit=C.NO_SECURITY_LIMIT, + profile_key=C.PROF_KEY_NONE): + + profile = self.get_profile_name(profile_key) + return defer.ensureDeferred(self.params.async_get_string_param_a( + name, category, attribute, security_limit, profile + )) + + def _get_params_ui(self, security_limit, app, extra_s, profile_key): + return self.params._get_params_ui(security_limit, app, extra_s, profile_key) + + def params_categories_get(self): + return self.params.params_categories_get() + + def param_set( + self, + name, + value, + category, + security_limit=C.NO_SECURITY_LIMIT, + profile_key=C.PROF_KEY_NONE, + ): + return self.params.param_set(name, value, category, security_limit, profile_key) + + def update_params(self, xml): + return self.params.update_params(xml) + + def params_register_app(self, xml, security_limit=C.NO_SECURITY_LIMIT, app=""): + return self.params.params_register_app(xml, security_limit, app) + + def set_default(self, name, category, callback, errback=None): + return self.params.set_default(name, category, callback, errback) + + ## Private Data ## + + def _private_data_set(self, namespace, key, data_s, profile_key): + client = self.host.get_client(profile_key) + # we accept any type + data = data_format.deserialise(data_s, type_check=None) + return defer.ensureDeferred(self.storage.set_private_value( + namespace, key, data, binary=True, profile=client.profile)) + + def _private_data_get(self, namespace, key, profile_key): + client = self.host.get_client(profile_key) + d = defer.ensureDeferred( + self.storage.get_privates( + namespace, [key], binary=True, profile=client.profile) + ) + d.addCallback(lambda data_dict: data_format.serialise(data_dict.get(key))) + return d + + def _private_data_delete(self, namespace, key, profile_key): + client = self.host.get_client(profile_key) + return defer.ensureDeferred(self.storage.del_private_value( + namespace, key, binary=True, profile=client.profile)) + + ## Files ## + + def check_file_permission( + self, + file_data: dict, + peer_jid: Optional[jid.JID], + perms_to_check: Optional[Tuple[str]], + set_affiliation: bool = False + ) -> None: + """Check that an entity has the right permission on a file + + @param file_data: data of one file, as returned by get_files + @param peer_jid: entity trying to access the file + @param perms_to_check: permissions to check + tuple of C.ACCESS_PERM_* + @param check_parents: if True, also check all parents until root node + @parma set_affiliation: if True, "affiliation" metadata will be set + @raise exceptions.PermissionError: peer_jid doesn't have all permission + in perms_to_check for file_data + @raise exceptions.InternalError: perms_to_check is invalid + """ + # TODO: knowing if user is owner is not enough, we need to check permission + # to see if user can modify/delete files, and set corresponding affiliation (publisher, member) + if peer_jid is None and perms_to_check is None: + return + peer_jid = peer_jid.userhostJID() + if peer_jid == file_data["owner"]: + if set_affiliation: + file_data['affiliation'] = 'owner' + # the owner has all rights, nothing to check + return + if not C.ACCESS_PERMS.issuperset(perms_to_check): + raise exceptions.InternalError(_("invalid permission")) + + for perm in perms_to_check: + # we check each perm and raise PermissionError as soon as one condition is not valid + # we must never return here, we only return after the loop if nothing was blocking the access + try: + perm_data = file_data["access"][perm] + perm_type = perm_data["type"] + except KeyError: + # No permission is set. + # If we are in a root file/directory, we deny access + # otherwise, we use public permission, as the parent directory will + # block anyway, this avoid to have to recursively change permissions for + # all sub directories/files when modifying a permission + if not file_data.get('parent'): + raise exceptions.PermissionError() + else: + perm_type = C.ACCESS_TYPE_PUBLIC + if perm_type == C.ACCESS_TYPE_PUBLIC: + continue + elif perm_type == C.ACCESS_TYPE_WHITELIST: + try: + jids = perm_data["jids"] + except KeyError: + raise exceptions.PermissionError() + if peer_jid.full() in jids: + continue + else: + raise exceptions.PermissionError() + else: + raise exceptions.InternalError( + _("unknown access type: {type}").format(type=perm_type) + ) + + async def check_permission_to_root(self, client, file_data, peer_jid, perms_to_check): + """do check_file_permission on file_data and all its parents until root""" + current = file_data + while True: + self.check_file_permission(current, peer_jid, perms_to_check) + parent = current["parent"] + if not parent: + break + files_data = await self.get_files( + client, peer_jid=None, file_id=parent, perms_to_check=None + ) + try: + current = files_data[0] + except IndexError: + raise exceptions.DataError("Missing parent") + + async def _get_parent_dir( + self, client, path, parent, namespace, owner, peer_jid, perms_to_check + ): + """Retrieve parent node from a path, or last existing directory + + each directory of the path will be retrieved, until the last existing one + @return (tuple[unicode, list[unicode])): parent, remaining path elements: + - parent is the id of the last retrieved directory (or u'' for root) + - remaining path elements are the directories which have not been retrieved + (i.e. which don't exist) + """ + # if path is set, we have to retrieve parent directory of the file(s) from it + if parent is not None: + raise exceptions.ConflictError( + _("You can't use path and parent at the same time") + ) + path_elts = [_f for _f in path.split("/") if _f] + if {"..", "."}.intersection(path_elts): + raise ValueError(_('".." or "." can\'t be used in path')) + + # we retrieve all directories from path until we get the parent container + # non existing directories will be created + parent = "" + for idx, path_elt in enumerate(path_elts): + directories = await self.storage.get_files( + client, + parent=parent, + type_=C.FILE_TYPE_DIRECTORY, + name=path_elt, + namespace=namespace, + owner=owner, + ) + if not directories: + return (parent, path_elts[idx:]) + # from this point, directories don't exist anymore, we have to create them + elif len(directories) > 1: + raise exceptions.InternalError( + _("Several directories found, this should not happen") + ) + else: + directory = directories[0] + self.check_file_permission(directory, peer_jid, perms_to_check) + parent = directory["id"] + return (parent, []) + + def get_file_affiliations(self, file_data: dict) -> Dict[jid.JID, str]: + """Convert file access to pubsub like affiliations""" + affiliations = {} + access_data = file_data['access'] + + read_data = access_data.get(C.ACCESS_PERM_READ, {}) + if read_data.get('type') == C.ACCESS_TYPE_WHITELIST: + for entity_jid_s in read_data['jids']: + entity_jid = jid.JID(entity_jid_s) + affiliations[entity_jid] = 'member' + + write_data = access_data.get(C.ACCESS_PERM_WRITE, {}) + if write_data.get('type') == C.ACCESS_TYPE_WHITELIST: + for entity_jid_s in write_data['jids']: + entity_jid = jid.JID(entity_jid_s) + affiliations[entity_jid] = 'publisher' + + owner = file_data.get('owner') + if owner: + affiliations[owner] = 'owner' + + return affiliations + + def _set_file_affiliations_update( + self, + access: dict, + file_data: dict, + affiliations: Dict[jid.JID, str] + ) -> None: + read_data = access.setdefault(C.ACCESS_PERM_READ, {}) + if read_data.get('type') != C.ACCESS_TYPE_WHITELIST: + read_data['type'] = C.ACCESS_TYPE_WHITELIST + if 'jids' not in read_data: + read_data['jids'] = [] + read_whitelist = read_data['jids'] + write_data = access.setdefault(C.ACCESS_PERM_WRITE, {}) + if write_data.get('type') != C.ACCESS_TYPE_WHITELIST: + write_data['type'] = C.ACCESS_TYPE_WHITELIST + if 'jids' not in write_data: + write_data['jids'] = [] + write_whitelist = write_data['jids'] + for entity_jid, affiliation in affiliations.items(): + entity_jid_s = entity_jid.full() + if affiliation == "none": + try: + read_whitelist.remove(entity_jid_s) + except ValueError: + log.warning( + "removing affiliation from an entity without read permission: " + f"{entity_jid}" + ) + try: + write_whitelist.remove(entity_jid_s) + except ValueError: + pass + elif affiliation == "publisher": + if entity_jid_s not in read_whitelist: + read_whitelist.append(entity_jid_s) + if entity_jid_s not in write_whitelist: + write_whitelist.append(entity_jid_s) + elif affiliation == "member": + if entity_jid_s not in read_whitelist: + read_whitelist.append(entity_jid_s) + try: + write_whitelist.remove(entity_jid_s) + except ValueError: + pass + elif affiliation == "owner": + raise NotImplementedError('"owner" affiliation can\'t be set') + else: + raise ValueError(f"unknown affiliation: {affiliation!r}") + + async def set_file_affiliations( + self, + client, + file_data: dict, + affiliations: Dict[jid.JID, str] + ) -> None: + """Apply pubsub like affiliation to file_data + + Affiliations are converted to access types, then set in a whitelist. + Affiliation are mapped as follow: + - "owner" can't be set (for now) + - "publisher" gives read and write permissions + - "member" gives read permission only + - "none" removes both read and write permissions + """ + file_id = file_data['id'] + await self.file_update( + file_id, + 'access', + update_cb=partial( + self._set_file_affiliations_update, + file_data=file_data, + affiliations=affiliations + ), + ) + + def _set_file_access_model_update( + self, + access: dict, + file_data: dict, + access_model: str + ) -> None: + read_data = access.setdefault(C.ACCESS_PERM_READ, {}) + if access_model == "open": + requested_type = C.ACCESS_TYPE_PUBLIC + elif access_model == "whitelist": + requested_type = C.ACCESS_TYPE_WHITELIST + else: + raise ValueError(f"unknown access model: {access_model}") + + read_data['type'] = requested_type + if requested_type == C.ACCESS_TYPE_WHITELIST and 'jids' not in read_data: + read_data['jids'] = [] + + async def set_file_access_model( + self, + client, + file_data: dict, + access_model: str, + ) -> None: + """Apply pubsub like access_model to file_data + + Only 2 access models are supported so far: + - "open": set public access to file/dir + - "whitelist": set whitelist to file/dir + """ + file_id = file_data['id'] + await self.file_update( + file_id, + 'access', + update_cb=partial( + self._set_file_access_model_update, + file_data=file_data, + access_model=access_model + ), + ) + + def get_files_owner( + self, + client, + owner: Optional[jid.JID], + peer_jid: Optional[jid.JID], + file_id: Optional[str] = None, + parent: Optional[str] = None + ) -> jid.JID: + """Get owner to use for a file operation + + if owner is not explicitely set, a suitable one will be used (client.jid for + clients, peer_jid for components). + @raise exception.InternalError: we are one a component, and neither owner nor + peer_jid are set + """ + if owner is not None: + return owner.userhostJID() + if client is None: + # client may be None when looking for file with public_id + return None + if file_id or parent: + # owner has already been filtered on parent file + return None + if not client.is_component: + return client.jid.userhostJID() + if peer_jid is None: + raise exceptions.InternalError( + "Owner must be set for component if peer_jid is None" + ) + return peer_jid.userhostJID() + + async def get_files( + self, client, peer_jid, file_id=None, version=None, parent=None, path=None, + type_=None, file_hash=None, hash_algo=None, name=None, namespace=None, + mime_type=None, public_id=None, owner=None, access=None, projection=None, + unique=False, perms_to_check=(C.ACCESS_PERM_READ,)): + """Retrieve files with with given filters + + @param peer_jid(jid.JID, None): jid trying to access the file + needed to check permission. + Use None to ignore permission (perms_to_check must be None too) + @param file_id(unicode, None): id of the file + None to ignore + @param version(unicode, None): version of the file + None to ignore + empty string to look for current version + @param parent(unicode, None): id of the directory containing the files + None to ignore + empty string to look for root files/directories + @param path(Path, unicode, None): path to the directory containing the files + @param type_(unicode, None): type of file filter, can be one of C.FILE_TYPE_* + @param file_hash(unicode, None): hash of the file to retrieve + @param hash_algo(unicode, None): algorithm use for file_hash + @param name(unicode, None): name of the file to retrieve + @param namespace(unicode, None): namespace of the files to retrieve + @param mime_type(unicode, None): filter on this mime type + @param public_id(unicode, None): filter on this public id + @param owner(jid.JID, None): if not None, only get files from this owner + @param access(dict, None): get file with given access (see [set_file]) + @param projection(list[unicode], None): name of columns to retrieve + None to retrieve all + @param unique(bool): if True will remove duplicates + @param perms_to_check(tuple[unicode],None): permission to check + must be a tuple of C.ACCESS_PERM_* or None + if None, permission will no be checked (peer_jid must be None too in this + case) + other params are the same as for [set_file] + @return (list[dict]): files corresponding to filters + @raise exceptions.NotFound: parent directory not found (when path is specified) + @raise exceptions.PermissionError: peer_jid can't use perms_to_check for one of + the file + on the path + """ + if peer_jid is None and perms_to_check or perms_to_check is None and peer_jid: + raise exceptions.InternalError( + "if you want to disable permission check, both peer_jid and " + "perms_to_check must be None" + ) + owner = self.get_files_owner(client, owner, peer_jid, file_id, parent) + if path is not None: + path = str(path) + # permission are checked by _get_parent_dir + parent, remaining_path_elts = await self._get_parent_dir( + client, path, parent, namespace, owner, peer_jid, perms_to_check + ) + if remaining_path_elts: + # if we have remaining path elements, + # the parent directory is not found + raise failure.Failure(exceptions.NotFound()) + if parent and peer_jid: + # if parent is given directly and permission check is requested, + # we need to check all the parents + parent_data = await self.storage.get_files(client, file_id=parent) + try: + parent_data = parent_data[0] + except IndexError: + raise exceptions.DataError("mising parent") + await self.check_permission_to_root( + client, parent_data, peer_jid, perms_to_check + ) + + files = await self.storage.get_files( + client, + file_id=file_id, + version=version, + parent=parent, + type_=type_, + file_hash=file_hash, + hash_algo=hash_algo, + name=name, + namespace=namespace, + mime_type=mime_type, + public_id=public_id, + owner=owner, + access=access, + projection=projection, + unique=unique, + ) + + if peer_jid: + # if permission are checked, we must remove all file that user can't access + to_remove = [] + for file_data in files: + try: + self.check_file_permission( + file_data, peer_jid, perms_to_check, set_affiliation=True + ) + except exceptions.PermissionError: + to_remove.append(file_data) + for file_data in to_remove: + files.remove(file_data) + return files + + async def set_file( + self, client, name, file_id=None, version="", parent=None, path=None, + type_=C.FILE_TYPE_FILE, file_hash=None, hash_algo=None, size=None, + namespace=None, mime_type=None, public_id=None, created=None, modified=None, + owner=None, access=None, extra=None, peer_jid=None, + perms_to_check=(C.ACCESS_PERM_WRITE,) + ): + """Set a file metadata + + @param name(unicode): basename of the file + @param file_id(unicode): unique id of the file + @param version(unicode): version of this file + empty string for current version or when there is no versioning + @param parent(unicode, None): id of the directory containing the files + @param path(unicode, None): virtual path of the file in the namespace + if set, parent must be None. All intermediate directories will be created + if needed, using current access. + @param type_(str, None): type of file filter, can be one of C.FILE_TYPE_* + @param file_hash(unicode): unique hash of the payload + @param hash_algo(unicode): algorithm used for hashing the file (usually sha-256) + @param size(int): size in bytes + @param namespace(unicode, None): identifier (human readable is better) to group + files + For instance, namespace could be used to group files in a specific photo album + @param mime_type(unicode): MIME type of the file, or None if not known/guessed + @param public_id(unicode): id used to share publicly the file via HTTP + @param created(int): UNIX time of creation + @param modified(int,None): UNIX time of last modification, or None to use + created date + @param owner(jid.JID, None): jid of the owner of the file (mainly useful for + component) + will be used to check permission (only bare jid is used, don't use with MUC). + Use None to ignore permission (perms_to_check must be None too) + @param access(dict, None): serialisable dictionary with access rules. + None (or empty dict) to use private access, i.e. allow only profile's jid to + access the file + key can be on on C.ACCESS_PERM_*, + then a sub dictionary with a type key is used (one of C.ACCESS_TYPE_*). + According to type, extra keys can be used: + - C.ACCESS_TYPE_PUBLIC: the permission is granted for everybody + - C.ACCESS_TYPE_WHITELIST: the permission is granted for jids (as unicode) + in the 'jids' key + will be encoded to json in database + @param extra(dict, None): serialisable dictionary of any extra data + will be encoded to json in database + @param perms_to_check(tuple[unicode],None): permission to check + must be a tuple of C.ACCESS_PERM_* or None + if None, permission will not be checked (peer_jid must be None too in this + case) + @param profile(unicode): profile owning the file + """ + if "/" in name: + raise ValueError('name must not contain a slash ("/")') + if file_id is None: + file_id = shortuuid.uuid() + if ( + file_hash is not None + and hash_algo is None + or hash_algo is not None + and file_hash is None + ): + raise ValueError("file_hash and hash_algo must be set at the same time") + if mime_type is None: + mime_type, __ = mimetypes.guess_type(name) + else: + mime_type = mime_type.lower() + if public_id is not None: + assert len(public_id)>0 + if created is None: + created = time.time() + if namespace is not None: + namespace = namespace.strip() or None + if type_ == C.FILE_TYPE_DIRECTORY: + if any((version, file_hash, size, mime_type)): + raise ValueError( + "version, file_hash, size and mime_type can't be set for a directory" + ) + owner = self.get_files_owner(client, owner, peer_jid, file_id, parent) + + if path is not None: + path = str(path) + # _get_parent_dir will check permissions if peer_jid is set, so we use owner + parent, remaining_path_elts = await self._get_parent_dir( + client, path, parent, namespace, owner, owner, perms_to_check + ) + # if remaining directories don't exist, we have to create them + for new_dir in remaining_path_elts: + new_dir_id = shortuuid.uuid() + await self.storage.set_file( + client, + name=new_dir, + file_id=new_dir_id, + version="", + parent=parent, + type_=C.FILE_TYPE_DIRECTORY, + namespace=namespace, + created=time.time(), + owner=owner, + access=access, + extra={}, + ) + parent = new_dir_id + elif parent is None: + parent = "" + + await self.storage.set_file( + client, + file_id=file_id, + version=version, + parent=parent, + type_=type_, + file_hash=file_hash, + hash_algo=hash_algo, + name=name, + size=size, + namespace=namespace, + mime_type=mime_type, + public_id=public_id, + created=created, + modified=modified, + owner=owner, + access=access, + extra=extra, + ) + + async def file_get_used_space( + self, + client, + peer_jid: jid.JID, + owner: Optional[jid.JID] = None + ) -> int: + """Get space taken by all files owned by an entity + + @param peer_jid: entity requesting the size + @param owner: entity owning the file to check. If None, will be determined by + get_files_owner + @return: size of total space used by files of this owner + """ + owner = self.get_files_owner(client, owner, peer_jid) + if peer_jid.userhostJID() != owner and client.profile not in self.admins: + raise exceptions.PermissionError("You are not allowed to check this size") + return await self.storage.file_get_used_space(client, owner) + + def file_update(self, file_id, column, update_cb): + """Update a file column taking care of race condition + + access is NOT checked in this method, it must be checked beforehand + @param file_id(unicode): id of the file to update + @param column(unicode): one of "access" or "extra" + @param update_cb(callable): method to update the value of the colum + the method will take older value as argument, and must update it in place + Note that the callable must be thread-safe + """ + return self.storage.file_update(file_id, column, update_cb) + + @defer.inlineCallbacks + def _delete_file( + self, + client, + peer_jid: jid.JID, + recursive: bool, + files_path: Path, + file_data: dict + ): + """Internal method to delete files/directories recursively + + @param peer_jid(jid.JID): entity requesting the deletion (must be owner of files + to delete) + @param recursive(boolean): True if recursive deletion is needed + @param files_path(unicode): path of the directory containing the actual files + @param file_data(dict): data of the file to delete + """ + if file_data['owner'] != peer_jid: + raise exceptions.PermissionError( + "file {file_name} can't be deleted, {peer_jid} is not the owner" + .format(file_name=file_data['name'], peer_jid=peer_jid.full())) + if file_data['type'] == C.FILE_TYPE_DIRECTORY: + sub_files = yield self.get_files(client, peer_jid, parent=file_data['id']) + if sub_files and not recursive: + raise exceptions.DataError(_("Can't delete directory, it is not empty")) + # we first delete the sub-files + for sub_file_data in sub_files: + if sub_file_data['type'] == C.FILE_TYPE_DIRECTORY: + sub_file_path = files_path / sub_file_data['name'] + else: + sub_file_path = files_path + yield self._delete_file( + client, peer_jid, recursive, sub_file_path, sub_file_data) + # then the directory itself + yield self.storage.file_delete(file_data['id']) + elif file_data['type'] == C.FILE_TYPE_FILE: + log.info(_("deleting file {name} with hash {file_hash}").format( + name=file_data['name'], file_hash=file_data['file_hash'])) + yield self.storage.file_delete(file_data['id']) + references = yield self.get_files( + client, peer_jid, file_hash=file_data['file_hash']) + if references: + log.debug("there are still references to the file, we keep it") + else: + file_path = os.path.join(files_path, file_data['file_hash']) + log.info(_("no reference left to {file_path}, deleting").format( + file_path=file_path)) + try: + os.unlink(file_path) + except FileNotFoundError: + log.error(f"file at {file_path!r} doesn't exist but it was referenced in files database") + else: + raise exceptions.InternalError('Unexpected file type: {file_type}' + .format(file_type=file_data['type'])) + + async def file_delete(self, client, peer_jid, file_id, recursive=False): + """Delete a single file or a directory and all its sub-files + + @param file_id(unicode): id of the file to delete + @param peer_jid(jid.JID): entity requesting the deletion, + must be owner of all files to delete + @param recursive(boolean): must be True to delete a directory and all sub-files + """ + # FIXME: we only allow owner of file to delete files for now, but WRITE access + # should be checked too + files_data = await self.get_files(client, peer_jid, file_id) + if not files_data: + raise exceptions.NotFound("Can't find the file with id {file_id}".format( + file_id=file_id)) + file_data = files_data[0] + if file_data["type"] != C.FILE_TYPE_DIRECTORY and recursive: + raise ValueError("recursive can only be set for directories") + files_path = self.host.get_local_path(None, C.FILES_DIR) + await self._delete_file(client, peer_jid, recursive, files_path, file_data) + + ## Cache ## + + def get_cache_path(self, namespace: str, *args: str) -> Path: + """Get path to use to get a common path for a namespace + + This can be used by plugins to manage permanent data. It's the responsability + of plugins to clean this directory from unused data. + @param namespace: unique namespace to use + @param args: extra identifier which will be added to the path + """ + namespace = namespace.strip().lower() + return Path( + self._cache_path, + regex.path_escape(namespace), + *(regex.path_escape(a) for a in args) + ) + + ## Misc ## + + def is_entity_available(self, client, entity_jid): + """Tell from the presence information if the given entity is available. + + @param entity_jid (JID): the entity to check (if bare jid is used, all resources are tested) + @return (bool): True if entity is available + """ + if not entity_jid.resource: + return bool( + self.get_available_resources(client, entity_jid) + ) # is any resource is available, entity is available + try: + presence_data = self.get_entity_datum(client, entity_jid, "presence") + except KeyError: + log.debug("No presence information for {}".format(entity_jid)) + return False + return presence_data.show != C.PRESENCE_UNAVAILABLE + + def is_admin(self, profile: str) -> bool: + """Tell if given profile has administrator privileges""" + return profile in self.admins + + def is_admin_jid(self, entity: jid.JID) -> bool: + """Tells if an entity jid correspond to an admin one + + It is sometime not possible to use the profile alone to check if an entity is an + admin (e.g. a request managed by a component). In this case we check if the JID + correspond to an admin profile + """ + return entity.userhostJID() in self.admin_jids