diff libervia/backend/memory/encryption.py @ 4071:4b842c1fb686

refactoring: renamed `sat` package to `libervia.backend`
author Goffi <goffi@goffi.org>
date Fri, 02 Jun 2023 11:49:51 +0200
parents sat/memory/encryption.py@c23cad65ae99
children 0d7bb4df2343
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/libervia/backend/memory/encryption.py	Fri Jun 02 11:49:51 2023 +0200
@@ -0,0 +1,534 @@
+#!/usr/bin/env python3
+
+
+# SAT: a jabber client
+# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+import copy
+from functools import partial
+from typing import Optional
+from twisted.words.protocols.jabber import jid
+from twisted.internet import defer
+from twisted.python import failure
+from libervia.backend.core.core_types import EncryptionPlugin, EncryptionSession, MessageData
+from libervia.backend.core.i18n import D_, _
+from libervia.backend.core.constants import Const as C
+from libervia.backend.core import exceptions
+from libervia.backend.core.log import getLogger
+from libervia.backend.tools.common import data_format
+from libervia.backend.tools import utils
+from libervia.backend.memory import persistent
+
+
+log = getLogger(__name__)
+
+
+class EncryptionHandler:
+    """Class to handle encryption sessions for a client"""
+    plugins = []  # plugin able to encrypt messages
+
+    def __init__(self, client):
+        self.client = client
+        self._sessions = {}  # bare_jid ==> encryption_data
+        self._stored_session = persistent.PersistentDict(
+            "core:encryption", profile=client.profile)
+
+    @property
+    def host(self):
+        return self.client.host_app
+
+    async def load_sessions(self):
+        """Load persistent sessions"""
+        await self._stored_session.load()
+        start_d_list = []
+        for entity_jid_s, namespace in self._stored_session.items():
+            entity = jid.JID(entity_jid_s)
+            start_d_list.append(defer.ensureDeferred(self.start(entity, namespace)))
+
+        if start_d_list:
+            result = await defer.DeferredList(start_d_list)
+            for idx, (success, err) in enumerate(result):
+                if not success:
+                    entity_jid_s, namespace = list(self._stored_session.items())[idx]
+                    log.warning(_(
+                        "Could not restart {namespace!r} encryption with {entity}: {err}"
+                        ).format(namespace=namespace, entity=entity_jid_s, err=err))
+            log.info(_("encryption sessions restored"))
+
+    @classmethod
+    def register_plugin(cls, plg_instance, name, namespace, priority=0, directed=False):
+        """Register a plugin handling an encryption algorithm
+
+        @param plg_instance(object): instance of the plugin
+            it must have the following methods:
+                - get_trust_ui(entity): return a XMLUI for trust management
+                    entity(jid.JID): entity to manage
+                    The returned XMLUI must be a form
+            if may have the following methods:
+                - start_encryption(entity): start encrypted session
+                    entity(jid.JID): entity to start encrypted session with
+                - stop_encryption(entity): start encrypted session
+                    entity(jid.JID): entity to stop encrypted session with
+            if they don't exists, those 2 methods will be ignored.
+
+        @param name(unicode): human readable name of the encryption algorithm
+        @param namespace(unicode): namespace of the encryption algorithm
+        @param priority(int): priority of this plugin to encrypt an message when not
+            selected manually
+        @param directed(bool): True if this plugin is directed (if it works with one
+                               device only at a time)
+        """
+        existing_ns = set()
+        existing_names = set()
+        for p in cls.plugins:
+            existing_ns.add(p.namespace.lower())
+            existing_names.add(p.name.lower())
+        if namespace.lower() in existing_ns:
+            raise exceptions.ConflictError("A plugin with this namespace already exists!")
+        if name.lower() in existing_names:
+            raise exceptions.ConflictError("A plugin with this name already exists!")
+        plugin = EncryptionPlugin(
+            instance=plg_instance,
+            name=name,
+            namespace=namespace,
+            priority=priority,
+            directed=directed)
+        cls.plugins.append(plugin)
+        cls.plugins.sort(key=lambda p: p.priority)
+        log.info(_("Encryption plugin registered: {name}").format(name=name))
+
+    @classmethod
+    def getPlugins(cls):
+        return cls.plugins
+
+    @classmethod
+    def get_plugin(cls, namespace):
+        try:
+            return next(p for p in cls.plugins if p.namespace == namespace)
+        except StopIteration:
+            raise exceptions.NotFound(_(
+                "Can't find requested encryption plugin: {namespace}").format(
+                    namespace=namespace))
+
+    @classmethod
+    def get_namespaces(cls):
+        """Get available plugin namespaces"""
+        return {p.namespace for p in cls.getPlugins()}
+
+    @classmethod
+    def get_ns_from_name(cls, name):
+        """Retrieve plugin namespace from its name
+
+        @param name(unicode): name of the plugin (case insensitive)
+        @return (unicode): namespace of the plugin
+        @raise exceptions.NotFound: there is not encryption plugin of this name
+        """
+        for p in cls.plugins:
+            if p.name.lower() == name.lower():
+                return p.namespace
+        raise exceptions.NotFound(_(
+            "Can't find a plugin with the name \"{name}\".".format(
+                name=name)))
+
+    def get_bridge_data(self, session):
+        """Retrieve session data serialized for bridge.
+
+        @param session(dict): encryption session
+        @return (unicode): serialized data for bridge
+        """
+        if session is None:
+            return ''
+        plugin = session['plugin']
+        bridge_data = {'name': plugin.name,
+                       'namespace': plugin.namespace}
+        if 'directed_devices' in session:
+            bridge_data['directed_devices'] = session['directed_devices']
+
+        return data_format.serialise(bridge_data)
+
+    async def _start_encryption(self, plugin, entity):
+        """Start encryption with a plugin
+
+        This method must be called just before adding a plugin session.
+        StartEncryptionn method of plugin will be called if it exists.
+        """
+        if not plugin.directed:
+            await self._stored_session.aset(entity.userhost(), plugin.namespace)
+        try:
+            start_encryption = plugin.instance.start_encryption
+        except AttributeError:
+            log.debug(f"No start_encryption method found for {plugin.namespace}")
+        else:
+            # we copy entity to avoid having the resource changed by stop_encryption
+            await utils.as_deferred(start_encryption, self.client, copy.copy(entity))
+
+    async def _stop_encryption(self, plugin, entity):
+        """Stop encryption with a plugin
+
+        This method must be called just before removing a plugin session.
+        StopEncryptionn method of plugin will be called if it exists.
+        """
+        try:
+            await self._stored_session.adel(entity.userhost())
+        except KeyError:
+            pass
+        try:
+            stop_encryption = plugin.instance.stop_encryption
+        except AttributeError:
+            log.debug(f"No stop_encryption method found for {plugin.namespace}")
+        else:
+            # we copy entity to avoid having the resource changed by stop_encryption
+            return utils.as_deferred(stop_encryption, self.client, copy.copy(entity))
+
+    async def start(self, entity, namespace=None, replace=False):
+        """Start an encryption session with an entity
+
+        @param entity(jid.JID): entity to start an encryption session with
+            must be bare jid is the algorithm encrypt for all devices
+        @param namespace(unicode, None): namespace of the encryption algorithm
+            to use.
+            None to select automatically an algorithm
+        @param replace(bool): if True and an encrypted session already exists,
+            it will be replaced by the new one
+        """
+        if not self.plugins:
+            raise exceptions.NotFound(_("No encryption plugin is registered, "
+                                        "an encryption session can't be started"))
+
+        if namespace is None:
+            plugin = self.plugins[0]
+        else:
+            plugin = self.get_plugin(namespace)
+
+        bare_jid = entity.userhostJID()
+        if bare_jid in self._sessions:
+            # we have already an encryption session with this contact
+            former_plugin = self._sessions[bare_jid]["plugin"]
+            if former_plugin.namespace == namespace:
+                log.info(_("Session with {bare_jid} is already encrypted with {name}. "
+                           "Nothing to do.").format(
+                               bare_jid=bare_jid, name=former_plugin.name))
+                return
+
+            if replace:
+                # there is a conflict, but replacement is requested
+                # so we stop previous encryption to use new one
+                del self._sessions[bare_jid]
+                await self._stop_encryption(former_plugin, entity)
+            else:
+                msg = (_("Session with {bare_jid} is already encrypted with {name}. "
+                         "Please stop encryption session before changing algorithm.")
+                       .format(bare_jid=bare_jid, name=plugin.name))
+                log.warning(msg)
+                raise exceptions.ConflictError(msg)
+
+        data = {"plugin": plugin}
+        if plugin.directed:
+            if not entity.resource:
+                entity.resource = self.host.memory.main_resource_get(self.client, entity)
+                if not entity.resource:
+                    raise exceptions.NotFound(
+                        _("No resource found for {destinee}, can't encrypt with {name}")
+                        .format(destinee=entity.full(), name=plugin.name))
+                log.info(_("No resource specified to encrypt with {name}, using "
+                           "{destinee}.").format(destinee=entity.full(),
+                                                  name=plugin.name))
+            # indicate that we encrypt only for some devices
+            directed_devices = data['directed_devices'] = [entity.resource]
+        elif entity.resource:
+            raise ValueError(_("{name} encryption must be used with bare jids."))
+
+        await self._start_encryption(plugin, entity)
+        self._sessions[entity.userhostJID()] = data
+        log.info(_("Encryption session has been set for {entity_jid} with "
+                   "{encryption_name}").format(
+                   entity_jid=entity.full(), encryption_name=plugin.name))
+        self.host.bridge.message_encryption_started(
+            entity.full(),
+            self.get_bridge_data(data),
+            self.client.profile)
+        msg = D_("Encryption session started: your messages with {destinee} are "
+                 "now end to end encrypted using {name} algorithm.").format(
+                 destinee=entity.full(), name=plugin.name)
+        directed_devices = data.get('directed_devices')
+        if directed_devices:
+            msg += "\n" + D_("Message are encrypted only for {nb_devices} device(s): "
+                              "{devices_list}.").format(
+                              nb_devices=len(directed_devices),
+                              devices_list = ', '.join(directed_devices))
+
+        self.client.feedback(bare_jid, msg)
+
+    async def stop(self, entity, namespace=None):
+        """Stop an encryption session with an entity
+
+        @param entity(jid.JID): entity with who the encryption session must be stopped
+            must be bare jid if the algorithm encrypt for all devices
+        @param namespace(unicode): namespace of the session to stop
+            when specified, used to check that we stop the right encryption session
+        """
+        session = self.getSession(entity.userhostJID())
+        if not session:
+            raise failure.Failure(
+                exceptions.NotFound(_("There is no encryption session with this "
+                                      "entity.")))
+        plugin = session['plugin']
+        if namespace is not None and plugin.namespace != namespace:
+            raise exceptions.InternalError(_(
+                "The encryption session is not run with the expected plugin: encrypted "
+                "with {current_name} and was expecting {expected_name}").format(
+                current_name=session['plugin'].namespace,
+                expected_name=namespace))
+        if entity.resource:
+            try:
+                directed_devices = session['directed_devices']
+            except KeyError:
+                raise exceptions.NotFound(_(
+                    "There is a session for the whole entity (i.e. all devices of the "
+                    "entity), not a directed one. Please use bare jid if you want to "
+                    "stop the whole encryption with this entity."))
+
+            try:
+                directed_devices.remove(entity.resource)
+            except ValueError:
+                raise exceptions.NotFound(_("There is no directed session with this "
+                                            "entity."))
+            else:
+                if not directed_devices:
+                    # if we have no more directed device sessions,
+                    # we stop the whole session
+                    # see comment below for deleting session before stopping encryption
+                    del self._sessions[entity.userhostJID()]
+                    await self._stop_encryption(plugin, entity)
+        else:
+            # plugin's stop_encryption may call stop again (that's the case with OTR)
+            # so we need to remove plugin from session before calling self._stop_encryption
+            del self._sessions[entity.userhostJID()]
+            await self._stop_encryption(plugin, entity)
+
+        log.info(_("encryption session stopped with entity {entity}").format(
+            entity=entity.full()))
+        self.host.bridge.message_encryption_stopped(
+            entity.full(),
+            {'name': plugin.name,
+             'namespace': plugin.namespace,
+            },
+            self.client.profile)
+        msg = D_("Encryption session finished: your messages with {destinee} are "
+                 "NOT end to end encrypted anymore.\nYour server administrators or "
+                 "{destinee} server administrators will be able to read them.").format(
+                 destinee=entity.full())
+
+        self.client.feedback(entity, msg)
+
+    def getSession(self, entity: jid.JID) -> Optional[EncryptionSession]:
+        """Get encryption session for this contact
+
+        @param entity(jid.JID): get the session for this entity
+            must be a bare jid
+        @return (dict, None): encryption session data
+            None if there is not encryption for this session with this jid
+        """
+        if entity.resource:
+            raise ValueError("Full jid given when expecting bare jid")
+        return self._sessions.get(entity)
+
+    def get_namespace(self, entity: jid.JID) -> Optional[str]:
+        """Helper method to get the current encryption namespace used
+
+        @param entity: get the namespace for this entity must be a bare jid
+        @return: the algorithm namespace currently used in this session, or None if no
+            e2ee is currently used.
+        """
+        session = self.getSession(entity)
+        if session is None:
+            return None
+        return session["plugin"].namespace
+
+    def get_trust_ui(self, entity_jid, namespace=None):
+        """Retrieve encryption UI
+
+        @param entity_jid(jid.JID): get the UI for this entity
+            must be a bare jid
+        @param namespace(unicode): namespace of the algorithm to manage
+            if None use current algorithm
+        @return D(xmlui): XMLUI for trust management
+            the xmlui is a form
+            None if there is not encryption for this session with this jid
+        @raise exceptions.NotFound: no algorithm/plugin found
+        @raise NotImplementedError: plugin doesn't handle UI management
+        """
+        if namespace is None:
+            session = self.getSession(entity_jid)
+            if not session:
+                raise exceptions.NotFound(
+                    "No encryption session currently active for {entity_jid}"
+                    .format(entity_jid=entity_jid.full()))
+            plugin = session['plugin']
+        else:
+            plugin = self.get_plugin(namespace)
+        try:
+            get_trust_ui = plugin.instance.get_trust_ui
+        except AttributeError:
+            raise NotImplementedError(
+                "Encryption plugin doesn't handle trust management UI")
+        else:
+            return utils.as_deferred(get_trust_ui, self.client, entity_jid)
+
+    ## Menus ##
+
+    @classmethod
+    def _import_menus(cls, host):
+        host.import_menu(
+             (D_("Encryption"), D_("unencrypted (plain text)")),
+             partial(cls._on_menu_unencrypted, host=host),
+             security_limit=0,
+             help_string=D_("End encrypted session"),
+             type_=C.MENU_SINGLE,
+        )
+        for plg in cls.getPlugins():
+            host.import_menu(
+                 (D_("Encryption"), plg.name),
+                 partial(cls._on_menu_name, host=host, plg=plg),
+                 security_limit=0,
+                 help_string=D_("Start {name} session").format(name=plg.name),
+                 type_=C.MENU_SINGLE,
+            )
+            host.import_menu(
+                 (D_("Encryption"), D_("⛨ {name} trust").format(name=plg.name)),
+                 partial(cls._on_menu_trust, host=host, plg=plg),
+                 security_limit=0,
+                 help_string=D_("Manage {name} trust").format(name=plg.name),
+                 type_=C.MENU_SINGLE,
+            )
+
+    @classmethod
+    def _on_menu_unencrypted(cls, data, host, profile):
+        client = host.get_client(profile)
+        peer_jid = jid.JID(data['jid']).userhostJID()
+        d = defer.ensureDeferred(client.encryption.stop(peer_jid))
+        d.addCallback(lambda __: {})
+        return d
+
+    @classmethod
+    def _on_menu_name(cls, data, host, plg, profile):
+        client = host.get_client(profile)
+        peer_jid = jid.JID(data['jid'])
+        if not plg.directed:
+            peer_jid = peer_jid.userhostJID()
+        d = defer.ensureDeferred(
+            client.encryption.start(peer_jid, plg.namespace, replace=True))
+        d.addCallback(lambda __: {})
+        return d
+
+    @classmethod
+    @defer.inlineCallbacks
+    def _on_menu_trust(cls, data, host, plg, profile):
+        client = host.get_client(profile)
+        peer_jid = jid.JID(data['jid']).userhostJID()
+        ui = yield client.encryption.get_trust_ui(peer_jid, plg.namespace)
+        defer.returnValue({'xmlui': ui.toXml()})
+
+    ## Triggers ##
+
+    def set_encryption_flag(self, mess_data):
+        """Set "encryption" key in mess_data if session with destinee is encrypted"""
+        to_jid = mess_data['to']
+        encryption = self._sessions.get(to_jid.userhostJID())
+        if encryption is not None:
+            plugin = encryption['plugin']
+            if mess_data["type"] == "groupchat" and plugin.directed:
+                raise exceptions.InternalError(
+                f"encryption flag must not be set for groupchat if encryption algorithm "
+                f"({encryption['plugin'].name}) is directed!")
+            mess_data[C.MESS_KEY_ENCRYPTION] = encryption
+            self.mark_as_encrypted(mess_data, plugin.namespace)
+
+    ## Misc ##
+
+    def mark_as_encrypted(self, mess_data, namespace):
+        """Helper method to mark a message as having been e2e encrypted.
+
+        This should be used in the post_treat workflow of message_received trigger of
+        the plugin
+        @param mess_data(dict): message data as used in post treat workflow
+        @param namespace(str): namespace of the algorithm used for encrypting the message
+        """
+        mess_data['extra'][C.MESS_KEY_ENCRYPTED] = True
+        from_bare_jid = mess_data['from'].userhostJID()
+        if from_bare_jid != self.client.jid.userhostJID():
+            session = self.getSession(from_bare_jid)
+            if session is None:
+                # if we are currently unencrypted, we start a session automatically
+                # to avoid sending unencrypted messages in an encrypted context
+                log.info(_(
+                    "Starting e2e session with {peer_jid} as we receive encrypted "
+                    "messages")
+                    .format(peer_jid=from_bare_jid)
+                )
+                defer.ensureDeferred(self.start(from_bare_jid, namespace))
+
+        return mess_data
+
+    def is_encryption_requested(
+        self,
+        mess_data: MessageData,
+        namespace: Optional[str] = None
+    ) -> bool:
+        """Helper method to check if encryption is requested in an outgoind message
+
+        @param mess_data: message data for outgoing message
+        @param namespace: if set, check if encryption is requested for the algorithm
+            specified
+        @return: True if the encryption flag is present
+        """
+        encryption = mess_data.get(C.MESS_KEY_ENCRYPTION)
+        if encryption is None:
+            return False
+        # we get plugin even if namespace is None to be sure that the key exists
+        plugin = encryption['plugin']
+        if namespace is None:
+            return True
+        return plugin.namespace == namespace
+
+    def isEncrypted(self, mess_data):
+        """Helper method to check if a message has the e2e encrypted flag
+
+        @param mess_data(dict): message data
+        @return (bool): True if the encrypted flag is present
+        """
+        return mess_data['extra'].get(C.MESS_KEY_ENCRYPTED, False)
+
+
+    def mark_as_trusted(self, mess_data):
+        """Helper methor to mark a message as sent from a trusted entity.
+
+        This should be used in the post_treat workflow of message_received trigger of
+        the plugin
+        @param mess_data(dict): message data as used in post treat workflow
+        """
+        mess_data[C.MESS_KEY_TRUSTED] = True
+        return mess_data
+
+    def mark_as_untrusted(self, mess_data):
+        """Helper methor to mark a message as sent from an untrusted entity.
+
+        This should be used in the post_treat workflow of message_received trigger of
+        the plugin
+        @param mess_data(dict): message data as used in post treat workflow
+        """
+        mess_data['trusted'] = False
+        return mess_data