view libervia/backend/memory/encryption.py @ 4202:b26339343076

core: use a user specific directory for PID file: default location of pid file is now specific to logged user, this allow to run several instances of Libervia by different users on the same machine without PID conflicts.
author Goffi <goffi@goffi.org>
date Sun, 14 Jan 2024 17:48:02 +0100
parents 4b842c1fb686
children 0d7bb4df2343
line wrap: on
line source

#!/usr/bin/env python3


# SAT: a jabber client
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import copy
from functools import partial
from typing import Optional
from twisted.words.protocols.jabber import jid
from twisted.internet import defer
from twisted.python import failure
from libervia.backend.core.core_types import EncryptionPlugin, EncryptionSession, MessageData
from libervia.backend.core.i18n import D_, _
from libervia.backend.core.constants import Const as C
from libervia.backend.core import exceptions
from libervia.backend.core.log import getLogger
from libervia.backend.tools.common import data_format
from libervia.backend.tools import utils
from libervia.backend.memory import persistent


log = getLogger(__name__)


class EncryptionHandler:
    """Class to handle encryption sessions for a client"""
    plugins = []  # plugin able to encrypt messages

    def __init__(self, client):
        self.client = client
        self._sessions = {}  # bare_jid ==> encryption_data
        self._stored_session = persistent.PersistentDict(
            "core:encryption", profile=client.profile)

    @property
    def host(self):
        return self.client.host_app

    async def load_sessions(self):
        """Load persistent sessions"""
        await self._stored_session.load()
        start_d_list = []
        for entity_jid_s, namespace in self._stored_session.items():
            entity = jid.JID(entity_jid_s)
            start_d_list.append(defer.ensureDeferred(self.start(entity, namespace)))

        if start_d_list:
            result = await defer.DeferredList(start_d_list)
            for idx, (success, err) in enumerate(result):
                if not success:
                    entity_jid_s, namespace = list(self._stored_session.items())[idx]
                    log.warning(_(
                        "Could not restart {namespace!r} encryption with {entity}: {err}"
                        ).format(namespace=namespace, entity=entity_jid_s, err=err))
            log.info(_("encryption sessions restored"))

    @classmethod
    def register_plugin(cls, plg_instance, name, namespace, priority=0, directed=False):
        """Register a plugin handling an encryption algorithm

        @param plg_instance(object): instance of the plugin
            it must have the following methods:
                - get_trust_ui(entity): return a XMLUI for trust management
                    entity(jid.JID): entity to manage
                    The returned XMLUI must be a form
            if may have the following methods:
                - start_encryption(entity): start encrypted session
                    entity(jid.JID): entity to start encrypted session with
                - stop_encryption(entity): start encrypted session
                    entity(jid.JID): entity to stop encrypted session with
            if they don't exists, those 2 methods will be ignored.

        @param name(unicode): human readable name of the encryption algorithm
        @param namespace(unicode): namespace of the encryption algorithm
        @param priority(int): priority of this plugin to encrypt an message when not
            selected manually
        @param directed(bool): True if this plugin is directed (if it works with one
                               device only at a time)
        """
        existing_ns = set()
        existing_names = set()
        for p in cls.plugins:
            existing_ns.add(p.namespace.lower())
            existing_names.add(p.name.lower())
        if namespace.lower() in existing_ns:
            raise exceptions.ConflictError("A plugin with this namespace already exists!")
        if name.lower() in existing_names:
            raise exceptions.ConflictError("A plugin with this name already exists!")
        plugin = EncryptionPlugin(
            instance=plg_instance,
            name=name,
            namespace=namespace,
            priority=priority,
            directed=directed)
        cls.plugins.append(plugin)
        cls.plugins.sort(key=lambda p: p.priority)
        log.info(_("Encryption plugin registered: {name}").format(name=name))

    @classmethod
    def getPlugins(cls):
        return cls.plugins

    @classmethod
    def get_plugin(cls, namespace):
        try:
            return next(p for p in cls.plugins if p.namespace == namespace)
        except StopIteration:
            raise exceptions.NotFound(_(
                "Can't find requested encryption plugin: {namespace}").format(
                    namespace=namespace))

    @classmethod
    def get_namespaces(cls):
        """Get available plugin namespaces"""
        return {p.namespace for p in cls.getPlugins()}

    @classmethod
    def get_ns_from_name(cls, name):
        """Retrieve plugin namespace from its name

        @param name(unicode): name of the plugin (case insensitive)
        @return (unicode): namespace of the plugin
        @raise exceptions.NotFound: there is not encryption plugin of this name
        """
        for p in cls.plugins:
            if p.name.lower() == name.lower():
                return p.namespace
        raise exceptions.NotFound(_(
            "Can't find a plugin with the name \"{name}\".".format(
                name=name)))

    def get_bridge_data(self, session):
        """Retrieve session data serialized for bridge.

        @param session(dict): encryption session
        @return (unicode): serialized data for bridge
        """
        if session is None:
            return ''
        plugin = session['plugin']
        bridge_data = {'name': plugin.name,
                       'namespace': plugin.namespace}
        if 'directed_devices' in session:
            bridge_data['directed_devices'] = session['directed_devices']

        return data_format.serialise(bridge_data)

    async def _start_encryption(self, plugin, entity):
        """Start encryption with a plugin

        This method must be called just before adding a plugin session.
        StartEncryptionn method of plugin will be called if it exists.
        """
        if not plugin.directed:
            await self._stored_session.aset(entity.userhost(), plugin.namespace)
        try:
            start_encryption = plugin.instance.start_encryption
        except AttributeError:
            log.debug(f"No start_encryption method found for {plugin.namespace}")
        else:
            # we copy entity to avoid having the resource changed by stop_encryption
            await utils.as_deferred(start_encryption, self.client, copy.copy(entity))

    async def _stop_encryption(self, plugin, entity):
        """Stop encryption with a plugin

        This method must be called just before removing a plugin session.
        StopEncryptionn method of plugin will be called if it exists.
        """
        try:
            await self._stored_session.adel(entity.userhost())
        except KeyError:
            pass
        try:
            stop_encryption = plugin.instance.stop_encryption
        except AttributeError:
            log.debug(f"No stop_encryption method found for {plugin.namespace}")
        else:
            # we copy entity to avoid having the resource changed by stop_encryption
            return utils.as_deferred(stop_encryption, self.client, copy.copy(entity))

    async def start(self, entity, namespace=None, replace=False):
        """Start an encryption session with an entity

        @param entity(jid.JID): entity to start an encryption session with
            must be bare jid is the algorithm encrypt for all devices
        @param namespace(unicode, None): namespace of the encryption algorithm
            to use.
            None to select automatically an algorithm
        @param replace(bool): if True and an encrypted session already exists,
            it will be replaced by the new one
        """
        if not self.plugins:
            raise exceptions.NotFound(_("No encryption plugin is registered, "
                                        "an encryption session can't be started"))

        if namespace is None:
            plugin = self.plugins[0]
        else:
            plugin = self.get_plugin(namespace)

        bare_jid = entity.userhostJID()
        if bare_jid in self._sessions:
            # we have already an encryption session with this contact
            former_plugin = self._sessions[bare_jid]["plugin"]
            if former_plugin.namespace == namespace:
                log.info(_("Session with {bare_jid} is already encrypted with {name}. "
                           "Nothing to do.").format(
                               bare_jid=bare_jid, name=former_plugin.name))
                return

            if replace:
                # there is a conflict, but replacement is requested
                # so we stop previous encryption to use new one
                del self._sessions[bare_jid]
                await self._stop_encryption(former_plugin, entity)
            else:
                msg = (_("Session with {bare_jid} is already encrypted with {name}. "
                         "Please stop encryption session before changing algorithm.")
                       .format(bare_jid=bare_jid, name=plugin.name))
                log.warning(msg)
                raise exceptions.ConflictError(msg)

        data = {"plugin": plugin}
        if plugin.directed:
            if not entity.resource:
                entity.resource = self.host.memory.main_resource_get(self.client, entity)
                if not entity.resource:
                    raise exceptions.NotFound(
                        _("No resource found for {destinee}, can't encrypt with {name}")
                        .format(destinee=entity.full(), name=plugin.name))
                log.info(_("No resource specified to encrypt with {name}, using "
                           "{destinee}.").format(destinee=entity.full(),
                                                  name=plugin.name))
            # indicate that we encrypt only for some devices
            directed_devices = data['directed_devices'] = [entity.resource]
        elif entity.resource:
            raise ValueError(_("{name} encryption must be used with bare jids."))

        await self._start_encryption(plugin, entity)
        self._sessions[entity.userhostJID()] = data
        log.info(_("Encryption session has been set for {entity_jid} with "
                   "{encryption_name}").format(
                   entity_jid=entity.full(), encryption_name=plugin.name))
        self.host.bridge.message_encryption_started(
            entity.full(),
            self.get_bridge_data(data),
            self.client.profile)
        msg = D_("Encryption session started: your messages with {destinee} are "
                 "now end to end encrypted using {name} algorithm.").format(
                 destinee=entity.full(), name=plugin.name)
        directed_devices = data.get('directed_devices')
        if directed_devices:
            msg += "\n" + D_("Message are encrypted only for {nb_devices} device(s): "
                              "{devices_list}.").format(
                              nb_devices=len(directed_devices),
                              devices_list = ', '.join(directed_devices))

        self.client.feedback(bare_jid, msg)

    async def stop(self, entity, namespace=None):
        """Stop an encryption session with an entity

        @param entity(jid.JID): entity with who the encryption session must be stopped
            must be bare jid if the algorithm encrypt for all devices
        @param namespace(unicode): namespace of the session to stop
            when specified, used to check that we stop the right encryption session
        """
        session = self.getSession(entity.userhostJID())
        if not session:
            raise failure.Failure(
                exceptions.NotFound(_("There is no encryption session with this "
                                      "entity.")))
        plugin = session['plugin']
        if namespace is not None and plugin.namespace != namespace:
            raise exceptions.InternalError(_(
                "The encryption session is not run with the expected plugin: encrypted "
                "with {current_name} and was expecting {expected_name}").format(
                current_name=session['plugin'].namespace,
                expected_name=namespace))
        if entity.resource:
            try:
                directed_devices = session['directed_devices']
            except KeyError:
                raise exceptions.NotFound(_(
                    "There is a session for the whole entity (i.e. all devices of the "
                    "entity), not a directed one. Please use bare jid if you want to "
                    "stop the whole encryption with this entity."))

            try:
                directed_devices.remove(entity.resource)
            except ValueError:
                raise exceptions.NotFound(_("There is no directed session with this "
                                            "entity."))
            else:
                if not directed_devices:
                    # if we have no more directed device sessions,
                    # we stop the whole session
                    # see comment below for deleting session before stopping encryption
                    del self._sessions[entity.userhostJID()]
                    await self._stop_encryption(plugin, entity)
        else:
            # plugin's stop_encryption may call stop again (that's the case with OTR)
            # so we need to remove plugin from session before calling self._stop_encryption
            del self._sessions[entity.userhostJID()]
            await self._stop_encryption(plugin, entity)

        log.info(_("encryption session stopped with entity {entity}").format(
            entity=entity.full()))
        self.host.bridge.message_encryption_stopped(
            entity.full(),
            {'name': plugin.name,
             'namespace': plugin.namespace,
            },
            self.client.profile)
        msg = D_("Encryption session finished: your messages with {destinee} are "
                 "NOT end to end encrypted anymore.\nYour server administrators or "
                 "{destinee} server administrators will be able to read them.").format(
                 destinee=entity.full())

        self.client.feedback(entity, msg)

    def getSession(self, entity: jid.JID) -> Optional[EncryptionSession]:
        """Get encryption session for this contact

        @param entity(jid.JID): get the session for this entity
            must be a bare jid
        @return (dict, None): encryption session data
            None if there is not encryption for this session with this jid
        """
        if entity.resource:
            raise ValueError("Full jid given when expecting bare jid")
        return self._sessions.get(entity)

    def get_namespace(self, entity: jid.JID) -> Optional[str]:
        """Helper method to get the current encryption namespace used

        @param entity: get the namespace for this entity must be a bare jid
        @return: the algorithm namespace currently used in this session, or None if no
            e2ee is currently used.
        """
        session = self.getSession(entity)
        if session is None:
            return None
        return session["plugin"].namespace

    def get_trust_ui(self, entity_jid, namespace=None):
        """Retrieve encryption UI

        @param entity_jid(jid.JID): get the UI for this entity
            must be a bare jid
        @param namespace(unicode): namespace of the algorithm to manage
            if None use current algorithm
        @return D(xmlui): XMLUI for trust management
            the xmlui is a form
            None if there is not encryption for this session with this jid
        @raise exceptions.NotFound: no algorithm/plugin found
        @raise NotImplementedError: plugin doesn't handle UI management
        """
        if namespace is None:
            session = self.getSession(entity_jid)
            if not session:
                raise exceptions.NotFound(
                    "No encryption session currently active for {entity_jid}"
                    .format(entity_jid=entity_jid.full()))
            plugin = session['plugin']
        else:
            plugin = self.get_plugin(namespace)
        try:
            get_trust_ui = plugin.instance.get_trust_ui
        except AttributeError:
            raise NotImplementedError(
                "Encryption plugin doesn't handle trust management UI")
        else:
            return utils.as_deferred(get_trust_ui, self.client, entity_jid)

    ## Menus ##

    @classmethod
    def _import_menus(cls, host):
        host.import_menu(
             (D_("Encryption"), D_("unencrypted (plain text)")),
             partial(cls._on_menu_unencrypted, host=host),
             security_limit=0,
             help_string=D_("End encrypted session"),
             type_=C.MENU_SINGLE,
        )
        for plg in cls.getPlugins():
            host.import_menu(
                 (D_("Encryption"), plg.name),
                 partial(cls._on_menu_name, host=host, plg=plg),
                 security_limit=0,
                 help_string=D_("Start {name} session").format(name=plg.name),
                 type_=C.MENU_SINGLE,
            )
            host.import_menu(
                 (D_("Encryption"), D_("⛨ {name} trust").format(name=plg.name)),
                 partial(cls._on_menu_trust, host=host, plg=plg),
                 security_limit=0,
                 help_string=D_("Manage {name} trust").format(name=plg.name),
                 type_=C.MENU_SINGLE,
            )

    @classmethod
    def _on_menu_unencrypted(cls, data, host, profile):
        client = host.get_client(profile)
        peer_jid = jid.JID(data['jid']).userhostJID()
        d = defer.ensureDeferred(client.encryption.stop(peer_jid))
        d.addCallback(lambda __: {})
        return d

    @classmethod
    def _on_menu_name(cls, data, host, plg, profile):
        client = host.get_client(profile)
        peer_jid = jid.JID(data['jid'])
        if not plg.directed:
            peer_jid = peer_jid.userhostJID()
        d = defer.ensureDeferred(
            client.encryption.start(peer_jid, plg.namespace, replace=True))
        d.addCallback(lambda __: {})
        return d

    @classmethod
    @defer.inlineCallbacks
    def _on_menu_trust(cls, data, host, plg, profile):
        client = host.get_client(profile)
        peer_jid = jid.JID(data['jid']).userhostJID()
        ui = yield client.encryption.get_trust_ui(peer_jid, plg.namespace)
        defer.returnValue({'xmlui': ui.toXml()})

    ## Triggers ##

    def set_encryption_flag(self, mess_data):
        """Set "encryption" key in mess_data if session with destinee is encrypted"""
        to_jid = mess_data['to']
        encryption = self._sessions.get(to_jid.userhostJID())
        if encryption is not None:
            plugin = encryption['plugin']
            if mess_data["type"] == "groupchat" and plugin.directed:
                raise exceptions.InternalError(
                f"encryption flag must not be set for groupchat if encryption algorithm "
                f"({encryption['plugin'].name}) is directed!")
            mess_data[C.MESS_KEY_ENCRYPTION] = encryption
            self.mark_as_encrypted(mess_data, plugin.namespace)

    ## Misc ##

    def mark_as_encrypted(self, mess_data, namespace):
        """Helper method to mark a message as having been e2e encrypted.

        This should be used in the post_treat workflow of message_received trigger of
        the plugin
        @param mess_data(dict): message data as used in post treat workflow
        @param namespace(str): namespace of the algorithm used for encrypting the message
        """
        mess_data['extra'][C.MESS_KEY_ENCRYPTED] = True
        from_bare_jid = mess_data['from'].userhostJID()
        if from_bare_jid != self.client.jid.userhostJID():
            session = self.getSession(from_bare_jid)
            if session is None:
                # if we are currently unencrypted, we start a session automatically
                # to avoid sending unencrypted messages in an encrypted context
                log.info(_(
                    "Starting e2e session with {peer_jid} as we receive encrypted "
                    "messages")
                    .format(peer_jid=from_bare_jid)
                )
                defer.ensureDeferred(self.start(from_bare_jid, namespace))

        return mess_data

    def is_encryption_requested(
        self,
        mess_data: MessageData,
        namespace: Optional[str] = None
    ) -> bool:
        """Helper method to check if encryption is requested in an outgoind message

        @param mess_data: message data for outgoing message
        @param namespace: if set, check if encryption is requested for the algorithm
            specified
        @return: True if the encryption flag is present
        """
        encryption = mess_data.get(C.MESS_KEY_ENCRYPTION)
        if encryption is None:
            return False
        # we get plugin even if namespace is None to be sure that the key exists
        plugin = encryption['plugin']
        if namespace is None:
            return True
        return plugin.namespace == namespace

    def isEncrypted(self, mess_data):
        """Helper method to check if a message has the e2e encrypted flag

        @param mess_data(dict): message data
        @return (bool): True if the encrypted flag is present
        """
        return mess_data['extra'].get(C.MESS_KEY_ENCRYPTED, False)


    def mark_as_trusted(self, mess_data):
        """Helper methor to mark a message as sent from a trusted entity.

        This should be used in the post_treat workflow of message_received trigger of
        the plugin
        @param mess_data(dict): message data as used in post treat workflow
        """
        mess_data[C.MESS_KEY_TRUSTED] = True
        return mess_data

    def mark_as_untrusted(self, mess_data):
        """Helper methor to mark a message as sent from an untrusted entity.

        This should be used in the post_treat workflow of message_received trigger of
        the plugin
        @param mess_data(dict): message data as used in post treat workflow
        """
        mess_data['trusted'] = False
        return mess_data