view libervia/backend/memory/encryption.py @ 4270:0d7bb4df2343

Reformatted code base using black.
author Goffi <goffi@goffi.org>
date Wed, 19 Jun 2024 18:44:57 +0200
parents 4b842c1fb686
children
line wrap: on
line source

#!/usr/bin/env python3


# SAT: a jabber client
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import copy
from functools import partial
from typing import Optional
from twisted.words.protocols.jabber import jid
from twisted.internet import defer
from twisted.python import failure
from libervia.backend.core.core_types import (
    EncryptionPlugin,
    EncryptionSession,
    MessageData,
)
from libervia.backend.core.i18n import D_, _
from libervia.backend.core.constants import Const as C
from libervia.backend.core import exceptions
from libervia.backend.core.log import getLogger
from libervia.backend.tools.common import data_format
from libervia.backend.tools import utils
from libervia.backend.memory import persistent


log = getLogger(__name__)


class EncryptionHandler:
    """Class to handle encryption sessions for a client"""

    plugins = []  # plugin able to encrypt messages

    def __init__(self, client):
        self.client = client
        self._sessions = {}  # bare_jid ==> encryption_data
        self._stored_session = persistent.PersistentDict(
            "core:encryption", profile=client.profile
        )

    @property
    def host(self):
        return self.client.host_app

    async def load_sessions(self):
        """Load persistent sessions"""
        await self._stored_session.load()
        start_d_list = []
        for entity_jid_s, namespace in self._stored_session.items():
            entity = jid.JID(entity_jid_s)
            start_d_list.append(defer.ensureDeferred(self.start(entity, namespace)))

        if start_d_list:
            result = await defer.DeferredList(start_d_list)
            for idx, (success, err) in enumerate(result):
                if not success:
                    entity_jid_s, namespace = list(self._stored_session.items())[idx]
                    log.warning(
                        _(
                            "Could not restart {namespace!r} encryption with {entity}: {err}"
                        ).format(namespace=namespace, entity=entity_jid_s, err=err)
                    )
            log.info(_("encryption sessions restored"))

    @classmethod
    def register_plugin(cls, plg_instance, name, namespace, priority=0, directed=False):
        """Register a plugin handling an encryption algorithm

        @param plg_instance(object): instance of the plugin
            it must have the following methods:
                - get_trust_ui(entity): return a XMLUI for trust management
                    entity(jid.JID): entity to manage
                    The returned XMLUI must be a form
            if may have the following methods:
                - start_encryption(entity): start encrypted session
                    entity(jid.JID): entity to start encrypted session with
                - stop_encryption(entity): start encrypted session
                    entity(jid.JID): entity to stop encrypted session with
            if they don't exists, those 2 methods will be ignored.

        @param name(unicode): human readable name of the encryption algorithm
        @param namespace(unicode): namespace of the encryption algorithm
        @param priority(int): priority of this plugin to encrypt an message when not
            selected manually
        @param directed(bool): True if this plugin is directed (if it works with one
                               device only at a time)
        """
        existing_ns = set()
        existing_names = set()
        for p in cls.plugins:
            existing_ns.add(p.namespace.lower())
            existing_names.add(p.name.lower())
        if namespace.lower() in existing_ns:
            raise exceptions.ConflictError("A plugin with this namespace already exists!")
        if name.lower() in existing_names:
            raise exceptions.ConflictError("A plugin with this name already exists!")
        plugin = EncryptionPlugin(
            instance=plg_instance,
            name=name,
            namespace=namespace,
            priority=priority,
            directed=directed,
        )
        cls.plugins.append(plugin)
        cls.plugins.sort(key=lambda p: p.priority)
        log.info(_("Encryption plugin registered: {name}").format(name=name))

    @classmethod
    def getPlugins(cls):
        return cls.plugins

    @classmethod
    def get_plugin(cls, namespace):
        try:
            return next(p for p in cls.plugins if p.namespace == namespace)
        except StopIteration:
            raise exceptions.NotFound(
                _("Can't find requested encryption plugin: {namespace}").format(
                    namespace=namespace
                )
            )

    @classmethod
    def get_namespaces(cls):
        """Get available plugin namespaces"""
        return {p.namespace for p in cls.getPlugins()}

    @classmethod
    def get_ns_from_name(cls, name):
        """Retrieve plugin namespace from its name

        @param name(unicode): name of the plugin (case insensitive)
        @return (unicode): namespace of the plugin
        @raise exceptions.NotFound: there is not encryption plugin of this name
        """
        for p in cls.plugins:
            if p.name.lower() == name.lower():
                return p.namespace
        raise exceptions.NotFound(
            _('Can\'t find a plugin with the name "{name}".'.format(name=name))
        )

    def get_bridge_data(self, session):
        """Retrieve session data serialized for bridge.

        @param session(dict): encryption session
        @return (unicode): serialized data for bridge
        """
        if session is None:
            return ""
        plugin = session["plugin"]
        bridge_data = {"name": plugin.name, "namespace": plugin.namespace}
        if "directed_devices" in session:
            bridge_data["directed_devices"] = session["directed_devices"]

        return data_format.serialise(bridge_data)

    async def _start_encryption(self, plugin, entity):
        """Start encryption with a plugin

        This method must be called just before adding a plugin session.
        StartEncryptionn method of plugin will be called if it exists.
        """
        if not plugin.directed:
            await self._stored_session.aset(entity.userhost(), plugin.namespace)
        try:
            start_encryption = plugin.instance.start_encryption
        except AttributeError:
            log.debug(f"No start_encryption method found for {plugin.namespace}")
        else:
            # we copy entity to avoid having the resource changed by stop_encryption
            await utils.as_deferred(start_encryption, self.client, copy.copy(entity))

    async def _stop_encryption(self, plugin, entity):
        """Stop encryption with a plugin

        This method must be called just before removing a plugin session.
        StopEncryptionn method of plugin will be called if it exists.
        """
        try:
            await self._stored_session.adel(entity.userhost())
        except KeyError:
            pass
        try:
            stop_encryption = plugin.instance.stop_encryption
        except AttributeError:
            log.debug(f"No stop_encryption method found for {plugin.namespace}")
        else:
            # we copy entity to avoid having the resource changed by stop_encryption
            return utils.as_deferred(stop_encryption, self.client, copy.copy(entity))

    async def start(self, entity, namespace=None, replace=False):
        """Start an encryption session with an entity

        @param entity(jid.JID): entity to start an encryption session with
            must be bare jid is the algorithm encrypt for all devices
        @param namespace(unicode, None): namespace of the encryption algorithm
            to use.
            None to select automatically an algorithm
        @param replace(bool): if True and an encrypted session already exists,
            it will be replaced by the new one
        """
        if not self.plugins:
            raise exceptions.NotFound(
                _(
                    "No encryption plugin is registered, "
                    "an encryption session can't be started"
                )
            )

        if namespace is None:
            plugin = self.plugins[0]
        else:
            plugin = self.get_plugin(namespace)

        bare_jid = entity.userhostJID()
        if bare_jid in self._sessions:
            # we have already an encryption session with this contact
            former_plugin = self._sessions[bare_jid]["plugin"]
            if former_plugin.namespace == namespace:
                log.info(
                    _(
                        "Session with {bare_jid} is already encrypted with {name}. "
                        "Nothing to do."
                    ).format(bare_jid=bare_jid, name=former_plugin.name)
                )
                return

            if replace:
                # there is a conflict, but replacement is requested
                # so we stop previous encryption to use new one
                del self._sessions[bare_jid]
                await self._stop_encryption(former_plugin, entity)
            else:
                msg = _(
                    "Session with {bare_jid} is already encrypted with {name}. "
                    "Please stop encryption session before changing algorithm."
                ).format(bare_jid=bare_jid, name=plugin.name)
                log.warning(msg)
                raise exceptions.ConflictError(msg)

        data = {"plugin": plugin}
        if plugin.directed:
            if not entity.resource:
                entity.resource = self.host.memory.main_resource_get(self.client, entity)
                if not entity.resource:
                    raise exceptions.NotFound(
                        _(
                            "No resource found for {destinee}, can't encrypt with {name}"
                        ).format(destinee=entity.full(), name=plugin.name)
                    )
                log.info(
                    _(
                        "No resource specified to encrypt with {name}, using "
                        "{destinee}."
                    ).format(destinee=entity.full(), name=plugin.name)
                )
            # indicate that we encrypt only for some devices
            directed_devices = data["directed_devices"] = [entity.resource]
        elif entity.resource:
            raise ValueError(_("{name} encryption must be used with bare jids."))

        await self._start_encryption(plugin, entity)
        self._sessions[entity.userhostJID()] = data
        log.info(
            _(
                "Encryption session has been set for {entity_jid} with "
                "{encryption_name}"
            ).format(entity_jid=entity.full(), encryption_name=plugin.name)
        )
        self.host.bridge.message_encryption_started(
            entity.full(), self.get_bridge_data(data), self.client.profile
        )
        msg = D_(
            "Encryption session started: your messages with {destinee} are "
            "now end to end encrypted using {name} algorithm."
        ).format(destinee=entity.full(), name=plugin.name)
        directed_devices = data.get("directed_devices")
        if directed_devices:
            msg += "\n" + D_(
                "Message are encrypted only for {nb_devices} device(s): "
                "{devices_list}."
            ).format(
                nb_devices=len(directed_devices), devices_list=", ".join(directed_devices)
            )

        self.client.feedback(bare_jid, msg)

    async def stop(self, entity, namespace=None):
        """Stop an encryption session with an entity

        @param entity(jid.JID): entity with who the encryption session must be stopped
            must be bare jid if the algorithm encrypt for all devices
        @param namespace(unicode): namespace of the session to stop
            when specified, used to check that we stop the right encryption session
        """
        session = self.getSession(entity.userhostJID())
        if not session:
            raise failure.Failure(
                exceptions.NotFound(
                    _("There is no encryption session with this " "entity.")
                )
            )
        plugin = session["plugin"]
        if namespace is not None and plugin.namespace != namespace:
            raise exceptions.InternalError(
                _(
                    "The encryption session is not run with the expected plugin: encrypted "
                    "with {current_name} and was expecting {expected_name}"
                ).format(
                    current_name=session["plugin"].namespace, expected_name=namespace
                )
            )
        if entity.resource:
            try:
                directed_devices = session["directed_devices"]
            except KeyError:
                raise exceptions.NotFound(
                    _(
                        "There is a session for the whole entity (i.e. all devices of the "
                        "entity), not a directed one. Please use bare jid if you want to "
                        "stop the whole encryption with this entity."
                    )
                )

            try:
                directed_devices.remove(entity.resource)
            except ValueError:
                raise exceptions.NotFound(
                    _("There is no directed session with this " "entity.")
                )
            else:
                if not directed_devices:
                    # if we have no more directed device sessions,
                    # we stop the whole session
                    # see comment below for deleting session before stopping encryption
                    del self._sessions[entity.userhostJID()]
                    await self._stop_encryption(plugin, entity)
        else:
            # plugin's stop_encryption may call stop again (that's the case with OTR)
            # so we need to remove plugin from session before calling self._stop_encryption
            del self._sessions[entity.userhostJID()]
            await self._stop_encryption(plugin, entity)

        log.info(
            _("encryption session stopped with entity {entity}").format(
                entity=entity.full()
            )
        )
        self.host.bridge.message_encryption_stopped(
            entity.full(),
            {
                "name": plugin.name,
                "namespace": plugin.namespace,
            },
            self.client.profile,
        )
        msg = D_(
            "Encryption session finished: your messages with {destinee} are "
            "NOT end to end encrypted anymore.\nYour server administrators or "
            "{destinee} server administrators will be able to read them."
        ).format(destinee=entity.full())

        self.client.feedback(entity, msg)

    def getSession(self, entity: jid.JID) -> Optional[EncryptionSession]:
        """Get encryption session for this contact

        @param entity(jid.JID): get the session for this entity
            must be a bare jid
        @return (dict, None): encryption session data
            None if there is not encryption for this session with this jid
        """
        if entity.resource:
            raise ValueError("Full jid given when expecting bare jid")
        return self._sessions.get(entity)

    def get_namespace(self, entity: jid.JID) -> Optional[str]:
        """Helper method to get the current encryption namespace used

        @param entity: get the namespace for this entity must be a bare jid
        @return: the algorithm namespace currently used in this session, or None if no
            e2ee is currently used.
        """
        session = self.getSession(entity)
        if session is None:
            return None
        return session["plugin"].namespace

    def get_trust_ui(self, entity_jid, namespace=None):
        """Retrieve encryption UI

        @param entity_jid(jid.JID): get the UI for this entity
            must be a bare jid
        @param namespace(unicode): namespace of the algorithm to manage
            if None use current algorithm
        @return D(xmlui): XMLUI for trust management
            the xmlui is a form
            None if there is not encryption for this session with this jid
        @raise exceptions.NotFound: no algorithm/plugin found
        @raise NotImplementedError: plugin doesn't handle UI management
        """
        if namespace is None:
            session = self.getSession(entity_jid)
            if not session:
                raise exceptions.NotFound(
                    "No encryption session currently active for {entity_jid}".format(
                        entity_jid=entity_jid.full()
                    )
                )
            plugin = session["plugin"]
        else:
            plugin = self.get_plugin(namespace)
        try:
            get_trust_ui = plugin.instance.get_trust_ui
        except AttributeError:
            raise NotImplementedError(
                "Encryption plugin doesn't handle trust management UI"
            )
        else:
            return utils.as_deferred(get_trust_ui, self.client, entity_jid)

    ## Menus ##

    @classmethod
    def _import_menus(cls, host):
        host.import_menu(
            (D_("Encryption"), D_("unencrypted (plain text)")),
            partial(cls._on_menu_unencrypted, host=host),
            security_limit=0,
            help_string=D_("End encrypted session"),
            type_=C.MENU_SINGLE,
        )
        for plg in cls.getPlugins():
            host.import_menu(
                (D_("Encryption"), plg.name),
                partial(cls._on_menu_name, host=host, plg=plg),
                security_limit=0,
                help_string=D_("Start {name} session").format(name=plg.name),
                type_=C.MENU_SINGLE,
            )
            host.import_menu(
                (D_("Encryption"), D_("⛨ {name} trust").format(name=plg.name)),
                partial(cls._on_menu_trust, host=host, plg=plg),
                security_limit=0,
                help_string=D_("Manage {name} trust").format(name=plg.name),
                type_=C.MENU_SINGLE,
            )

    @classmethod
    def _on_menu_unencrypted(cls, data, host, profile):
        client = host.get_client(profile)
        peer_jid = jid.JID(data["jid"]).userhostJID()
        d = defer.ensureDeferred(client.encryption.stop(peer_jid))
        d.addCallback(lambda __: {})
        return d

    @classmethod
    def _on_menu_name(cls, data, host, plg, profile):
        client = host.get_client(profile)
        peer_jid = jid.JID(data["jid"])
        if not plg.directed:
            peer_jid = peer_jid.userhostJID()
        d = defer.ensureDeferred(
            client.encryption.start(peer_jid, plg.namespace, replace=True)
        )
        d.addCallback(lambda __: {})
        return d

    @classmethod
    @defer.inlineCallbacks
    def _on_menu_trust(cls, data, host, plg, profile):
        client = host.get_client(profile)
        peer_jid = jid.JID(data["jid"]).userhostJID()
        ui = yield client.encryption.get_trust_ui(peer_jid, plg.namespace)
        defer.returnValue({"xmlui": ui.toXml()})

    ## Triggers ##

    def set_encryption_flag(self, mess_data):
        """Set "encryption" key in mess_data if session with destinee is encrypted"""
        to_jid = mess_data["to"]
        encryption = self._sessions.get(to_jid.userhostJID())
        if encryption is not None:
            plugin = encryption["plugin"]
            if mess_data["type"] == "groupchat" and plugin.directed:
                raise exceptions.InternalError(
                    f"encryption flag must not be set for groupchat if encryption algorithm "
                    f"({encryption['plugin'].name}) is directed!"
                )
            mess_data[C.MESS_KEY_ENCRYPTION] = encryption
            self.mark_as_encrypted(mess_data, plugin.namespace)

    ## Misc ##

    def mark_as_encrypted(self, mess_data, namespace):
        """Helper method to mark a message as having been e2e encrypted.

        This should be used in the post_treat workflow of message_received trigger of
        the plugin
        @param mess_data(dict): message data as used in post treat workflow
        @param namespace(str): namespace of the algorithm used for encrypting the message
        """
        mess_data["extra"][C.MESS_KEY_ENCRYPTED] = True
        from_bare_jid = mess_data["from"].userhostJID()
        if from_bare_jid != self.client.jid.userhostJID():
            session = self.getSession(from_bare_jid)
            if session is None:
                # if we are currently unencrypted, we start a session automatically
                # to avoid sending unencrypted messages in an encrypted context
                log.info(
                    _(
                        "Starting e2e session with {peer_jid} as we receive encrypted "
                        "messages"
                    ).format(peer_jid=from_bare_jid)
                )
                defer.ensureDeferred(self.start(from_bare_jid, namespace))

        return mess_data

    def is_encryption_requested(
        self, mess_data: MessageData, namespace: Optional[str] = None
    ) -> bool:
        """Helper method to check if encryption is requested in an outgoind message

        @param mess_data: message data for outgoing message
        @param namespace: if set, check if encryption is requested for the algorithm
            specified
        @return: True if the encryption flag is present
        """
        encryption = mess_data.get(C.MESS_KEY_ENCRYPTION)
        if encryption is None:
            return False
        # we get plugin even if namespace is None to be sure that the key exists
        plugin = encryption["plugin"]
        if namespace is None:
            return True
        return plugin.namespace == namespace

    def isEncrypted(self, mess_data):
        """Helper method to check if a message has the e2e encrypted flag

        @param mess_data(dict): message data
        @return (bool): True if the encrypted flag is present
        """
        return mess_data["extra"].get(C.MESS_KEY_ENCRYPTED, False)

    def mark_as_trusted(self, mess_data):
        """Helper methor to mark a message as sent from a trusted entity.

        This should be used in the post_treat workflow of message_received trigger of
        the plugin
        @param mess_data(dict): message data as used in post treat workflow
        """
        mess_data[C.MESS_KEY_TRUSTED] = True
        return mess_data

    def mark_as_untrusted(self, mess_data):
        """Helper methor to mark a message as sent from an untrusted entity.

        This should be used in the post_treat workflow of message_received trigger of
        the plugin
        @param mess_data(dict): message data as used in post treat workflow
        """
        mess_data["trusted"] = False
        return mess_data