view libervia/backend/plugins/plugin_xep_0384.py @ 4332:71c939e34ca6

XEP-0373 (OX): Adjust to gpgme updates: generate with explicit algorithm and subkeys
author Syndace <me@syndace.dev>
date Sat, 13 Jul 2024 18:28:28 +0200 (5 months ago)
parents 23842a63ea00
children
line wrap: on
line source
#!/usr/bin/env python3

# Libervia plugin for OMEMO encryption
# Copyright (C) 2022-2022 Tim Henkes (me@syndace.dev)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import base64
from datetime import datetime
import enum
import logging
import time
from typing import (
    Any,
    Dict,
    FrozenSet,
    Iterable,
    List,
    Literal,
    NamedTuple,
    Optional,
    Set,
    Type,
    Union,
    cast,
)
import uuid
import xml.etree.ElementTree as ET
from xml.sax.saxutils import quoteattr

from typing_extensions import Final, Never, assert_never
from wokkel import muc, pubsub  # type: ignore[import]
import xmlschema

from libervia.backend.core import exceptions
from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import MessageData, SatXMPPEntity
from libervia.backend.core.i18n import _, D_
from libervia.backend.core.log import getLogger, Logger
from libervia.backend.core.main import LiberviaBackend
from libervia.backend.core.xmpp import SatXMPPClient
from libervia.backend.memory import persistent
from libervia.backend.plugins.plugin_misc_text_commands import TextCommands
from libervia.backend.plugins.plugin_xep_0045 import XEP_0045
from libervia.backend.plugins.plugin_xep_0060 import XEP_0060
from libervia.backend.plugins.plugin_xep_0163 import XEP_0163
from libervia.backend.plugins.plugin_xep_0334 import XEP_0334
from libervia.backend.plugins.plugin_xep_0359 import XEP_0359
from libervia.backend.plugins.plugin_xep_0420 import (
    XEP_0420,
    SCEAffixPolicy,
    SCEAffixValues,
    SCEProfile,
)
from libervia.backend.tools import xml_tools
from twisted.internet import defer
from twisted.words.protocols.jabber import error, jid
from twisted.words.xish import domish

try:
    import omemo
    import omemo.identity_key_pair
    import twomemo
    import twomemo.etree
    import oldmemo
    import oldmemo.etree
    import oldmemo.migrations
    from xmlschema import XMLSchemaValidationError

    # An explicit version check of the OMEMO libraries should not be required here, since
    # the stored data is fully versioned and the library will complain if a downgrade is
    # attempted.
except ImportError as import_error:
    raise exceptions.MissingModule(
        "You are missing one or more package required by the OMEMO plugin. Please"
        " download/install the pip packages 'omemo', 'twomemo', 'oldmemo' and"
        f" 'xmlschema'.\nexception: {import_error}"
    ) from import_error


__all__ = ["PLUGIN_INFO", "OMEMO"]  # pylint: disable=unused-variable

log = cast(Logger, getLogger(__name__))  # type: ignore[no-untyped-call]


PLUGIN_INFO = {
    C.PI_NAME: "OMEMO",
    C.PI_IMPORT_NAME: "XEP-0384",
    C.PI_TYPE: "SEC",
    C.PI_PROTOCOLS: ["XEP-0384"],
    C.PI_DEPENDENCIES: ["XEP-0163", "XEP-0280", "XEP-0334", "XEP-0060", "XEP-0420"],
    C.PI_RECOMMENDATIONS: ["XEP-0045", "XEP-0359", C.TEXT_CMDS],
    C.PI_MAIN: "OMEMO",
    C.PI_HANDLER: "no",
    C.PI_DESCRIPTION: _("""Implementation of OMEMO"""),
}


PARAM_CATEGORY = "Security"
PARAM_NAME = "omemo_policy"

NamespaceType = Literal["urn:xmpp:omemo:2", "eu.siacs.conversations.axolotl"]


class LogHandler(logging.Handler):
    """
    Redirect python-omemo's log output to Libervia's log system.
    """

    def emit(self, record: logging.LogRecord) -> None:
        log.log(record.levelname, record.getMessage())


sm_logger = logging.getLogger(omemo.SessionManager.LOG_TAG)
sm_logger.setLevel(logging.DEBUG)
sm_logger.propagate = False
sm_logger.addHandler(LogHandler())


ikp_logger = logging.getLogger(omemo.identity_key_pair.IdentityKeyPair.LOG_TAG)
ikp_logger.setLevel(logging.DEBUG)
ikp_logger.propagate = False
ikp_logger.addHandler(LogHandler())


# TODO: Add handling for device labels, i.e. show device labels in the trust UI and give
# the user a way to change their own device label.


class MUCPlaintextCacheKey(NamedTuple):
    # pylint: disable=invalid-name
    """
    Structure identifying an encrypted message sent to a MUC.
    """

    client: SatXMPPClient
    room_jid: jid.JID
    message_uid: str


@enum.unique
class TrustLevel(enum.Enum):
    """
    The trust levels required for ATM and BTBV.
    """

    TRUSTED: str = "TRUSTED"
    BLINDLY_TRUSTED: str = "BLINDLY_TRUSTED"
    UNDECIDED: str = "UNDECIDED"
    DISTRUSTED: str = "DISTRUSTED"


TWOMEMO_DEVICE_LIST_NODE = "urn:xmpp:omemo:2:devices"
OLDMEMO_DEVICE_LIST_NODE = "eu.siacs.conversations.axolotl.devicelist"


class StorageImpl(omemo.Storage):
    """
    Storage implementation for OMEMO based on :class:`persistent.LazyPersistentBinaryDict`
    """

    def __init__(self, profile: str) -> None:
        """
        @param profile: The profile this OMEMO data belongs to.
        """

        # persistent.LazyPersistentBinaryDict does not cache at all, so keep the caching
        # option of omemo.Storage enabled.
        super().__init__()

        self.__storage = persistent.LazyPersistentBinaryDict("XEP-0384", profile)

    async def _load(self, key: str) -> omemo.Maybe[omemo.JSONType]:
        try:
            return omemo.Just(await self.__storage[key])
        except KeyError:
            return omemo.Nothing()
        except Exception as e:
            raise omemo.StorageException(f"Error while loading key {key}") from e

    async def _store(self, key: str, value: omemo.JSONType) -> None:
        try:
            await self.__storage.force(key, value)
        except Exception as e:
            raise omemo.StorageException(f"Error while storing key {key}: {value}") from e

    async def _delete(self, key: str) -> None:
        try:
            await self.__storage.remove(key)
        except KeyError:
            pass
        except Exception as e:
            raise omemo.StorageException(f"Error while deleting key {key}") from e


class LegacyStorageImpl(oldmemo.migrations.LegacyStorage):
    """
    Legacy storage implementation to migrate data from the old XEP-0384 plugin.
    """

    KEY_DEVICE_ID = "DEVICE_ID"
    KEY_STATE = "STATE"
    KEY_SESSION = "SESSION"
    KEY_ACTIVE_DEVICES = "DEVICES"
    KEY_INACTIVE_DEVICES = "INACTIVE_DEVICES"
    KEY_TRUST = "TRUST"
    KEY_ALL_JIDS = "ALL_JIDS"

    def __init__(self, profile: str, own_bare_jid: str) -> None:
        """
        @param profile: The profile this OMEMO data belongs to.
        @param own_bare_jid: The own bare JID, to return by the :meth:`load_own_data` call.
        """

        self.__storage = persistent.LazyPersistentBinaryDict("XEP-0384", profile)
        self.__own_bare_jid = own_bare_jid

    async def loadOwnData(self) -> Optional[oldmemo.migrations.OwnData]:
        own_device_id = await self.__storage.get(LegacyStorageImpl.KEY_DEVICE_ID, None)
        if own_device_id is None:
            return None

        return oldmemo.migrations.OwnData(
            own_bare_jid=self.__own_bare_jid, own_device_id=own_device_id
        )

    async def deleteOwnData(self) -> None:
        try:
            await self.__storage.remove(LegacyStorageImpl.KEY_DEVICE_ID)
        except KeyError:
            pass

    async def loadState(self) -> Optional[oldmemo.migrations.State]:
        return cast(
            Optional[oldmemo.migrations.State],
            await self.__storage.get(LegacyStorageImpl.KEY_STATE, None),
        )

    async def deleteState(self) -> None:
        try:
            await self.__storage.remove(LegacyStorageImpl.KEY_STATE)
        except KeyError:
            pass

    async def loadSession(
        self, bare_jid: str, device_id: int
    ) -> Optional[oldmemo.migrations.Session]:
        key = "\n".join([LegacyStorageImpl.KEY_SESSION, bare_jid, str(device_id)])

        return cast(
            Optional[oldmemo.migrations.Session], await self.__storage.get(key, None)
        )

    async def deleteSession(self, bare_jid: str, device_id: int) -> None:
        key = "\n".join([LegacyStorageImpl.KEY_SESSION, bare_jid, str(device_id)])

        try:
            await self.__storage.remove(key)
        except KeyError:
            pass

    async def loadActiveDevices(self, bare_jid: str) -> Optional[List[int]]:
        key = "\n".join([LegacyStorageImpl.KEY_ACTIVE_DEVICES, bare_jid])

        return cast(Optional[List[int]], await self.__storage.get(key, None))

    async def loadInactiveDevices(self, bare_jid: str) -> Optional[Dict[int, int]]:
        key = "\n".join([LegacyStorageImpl.KEY_INACTIVE_DEVICES, bare_jid])

        return cast(Optional[Dict[int, int]], await self.__storage.get(key, None))

    async def deleteActiveDevices(self, bare_jid: str) -> None:
        key = "\n".join([LegacyStorageImpl.KEY_ACTIVE_DEVICES, bare_jid])

        try:
            await self.__storage.remove(key)
        except KeyError:
            pass

    async def deleteInactiveDevices(self, bare_jid: str) -> None:
        key = "\n".join([LegacyStorageImpl.KEY_INACTIVE_DEVICES, bare_jid])

        try:
            await self.__storage.remove(key)
        except KeyError:
            pass

    async def loadTrust(
        self, bare_jid: str, device_id: int
    ) -> Optional[oldmemo.migrations.Trust]:
        key = "\n".join([LegacyStorageImpl.KEY_TRUST, bare_jid, str(device_id)])

        return cast(
            Optional[oldmemo.migrations.Trust], await self.__storage.get(key, None)
        )

    async def deleteTrust(self, bare_jid: str, device_id: int) -> None:
        key = "\n".join([LegacyStorageImpl.KEY_TRUST, bare_jid, str(device_id)])

        try:
            await self.__storage.remove(key)
        except KeyError:
            pass

    async def listJIDs(self) -> Optional[List[str]]:
        bare_jids = await self.__storage.get(LegacyStorageImpl.KEY_ALL_JIDS, None)

        return None if bare_jids is None else list(bare_jids)

    async def deleteJIDList(self) -> None:
        try:
            await self.__storage.remove(LegacyStorageImpl.KEY_ALL_JIDS)
        except KeyError:
            pass


async def download_oldmemo_bundle(
    client: SatXMPPClient, xep_0060: XEP_0060, bare_jid: str, device_id: int
) -> oldmemo.oldmemo.BundleImpl:
    """Download the oldmemo bundle corresponding to a specific device.

    @param client: The client.
    @param xep_0060: The XEP-0060 plugin instance to use for pubsub interactions.
    @param bare_jid: The bare JID the device belongs to.
    @param device_id: The id of the device.
    @return: The bundle.
    @raise BundleDownloadFailed: if the download failed. Feel free to raise a subclass
        instead.
    """
    # Bundle downloads are needed by the session manager and for migrations from legacy,
    # thus it is made a separate function.

    namespace = oldmemo.oldmemo.NAMESPACE
    node = f"eu.siacs.conversations.axolotl.bundles:{device_id}"

    try:
        items, __ = await xep_0060.get_items(client, jid.JID(bare_jid), node, max_items=1)
    except Exception as e:
        raise omemo.BundleDownloadFailed(
            f"Bundle download failed for {bare_jid}: {device_id} under namespace"
            f" {namespace}"
        ) from e

    if len(items) != 1:
        raise omemo.BundleDownloadFailed(
            f"Bundle download failed for {bare_jid}: {device_id} under namespace"
            f" {namespace}: Unexpected number of items retrieved: {len(items)}."
        )

    element = next(
        iter(xml_tools.domish_elt_2_et_elt(cast(domish.Element, items[0]))), None
    )
    if element is None:
        raise omemo.BundleDownloadFailed(
            f"Bundle download failed for {bare_jid}: {device_id} under namespace"
            f" {namespace}: Item download succeeded but parsing failed: {element}."
        )

    try:
        return oldmemo.etree.parse_bundle(element, bare_jid, device_id)
    except Exception as e:
        raise omemo.BundleDownloadFailed(
            f"Bundle parsing failed for {bare_jid}: {device_id} under namespace"
            f" {namespace}"
        ) from e


# ATM only supports protocols based on SCE, which is currently only omemo:2, and relies on
# so many implementation details of the encryption protocol that it makes more sense to
# add ATM to the OMEMO plugin directly instead of having it a separate Libervia plugin.
NS_TM: Final = "urn:xmpp:tm:1"
NS_ATM: Final = "urn:xmpp:atm:1"


TRUST_MESSAGE_SCHEMA = xmlschema.XMLSchema(
    """<?xml version='1.0' encoding='UTF-8'?>
<xs:schema xmlns:xs='http://www.w3.org/2001/XMLSchema'
           targetNamespace='urn:xmpp:tm:1'
           xmlns='urn:xmpp:tm:1'
           elementFormDefault='qualified'>

  <xs:element name='trust-message'>
    <xs:complexType>
      <xs:sequence>
        <xs:element ref='key-owner' minOccurs='1' maxOccurs='unbounded'/>
      </xs:sequence>
      <xs:attribute name='usage' type='xs:string' use='required'/>
      <xs:attribute name='encryption' type='xs:string' use='required'/>
    </xs:complexType>
  </xs:element>

  <xs:element name='key-owner'>
    <xs:complexType>
      <xs:sequence>
        <xs:element
            name='trust' type='xs:base64Binary' minOccurs='0' maxOccurs='unbounded'/>
        <xs:element
            name='distrust' type='xs:base64Binary' minOccurs='0' maxOccurs='unbounded'/>
      </xs:sequence>
      <xs:attribute name='jid' type='xs:string' use='required'/>
    </xs:complexType>
  </xs:element>
</xs:schema>
"""
)


# This is compatible with omemo:2's SCE profile
TM_SCE_PROFILE = SCEProfile(
    rpad_policy=SCEAffixPolicy.REQUIRED,
    time_policy=SCEAffixPolicy.REQUIRED,
    to_policy=SCEAffixPolicy.OPTIONAL,
    from_policy=SCEAffixPolicy.OPTIONAL,
    custom_policies={},
)


class TrustUpdate(NamedTuple):
    # pylint: disable=invalid-name
    """
    An update to the trust status of an identity key, used by Automatic Trust Management.
    """

    target_jid: jid.JID
    target_key: bytes
    target_trust: bool

    def to_dict(self) -> dict[str, Any]:
        """Convert the instance to a serialised dictionary"""
        data = {
            "target_jid": self.target_jid.full(),
            "target_key": self.target_key.hex(),
            "target_trust": self.target_trust,
        }
        return data

    @staticmethod
    def from_dict(data: dict[str, Any]) -> "TrustUpdate":
        """Load a serialized dictionary"""
        data["target_jid"] = jid.JID(data["target_jid"])
        data["target_key"] = bytes.fromhex(data["target_key"])
        return TrustUpdate(**data)


class TrustMessageCacheEntry(NamedTuple):
    # pylint: disable=invalid-name
    """
    An entry in the trust message cache used by ATM.
    """

    sender_jid: jid.JID
    sender_key: bytes
    timestamp: datetime
    trust_update: TrustUpdate

    def to_dict(self) -> dict[str, Any]:
        """Convert the instance to a serialised dictionary"""
        data = {
            "sender_jid": self.sender_jid.full(),
            "sender_key": self.sender_key.hex(),
            "timestamp": self.timestamp.isoformat(),
            "trust_update": self.trust_update.to_dict(),
        }
        return data

    @staticmethod
    def from_dict(data: dict[str, Any]) -> "TrustMessageCacheEntry":
        """Load a serialized dictionary"""
        data["sender_jid"] = jid.JID(data["sender_jid"])
        data["sender_key"] = bytes.fromhex(data["sender_key"])
        data["timestamp"] = datetime.fromisoformat(data["timestamp"])
        data["trust_update"] = TrustUpdate.from_dict(data["trust_update"])
        return TrustMessageCacheEntry(**data)


class PartialTrustMessage(NamedTuple):
    # pylint: disable=invalid-name
    """
    A structure representing a partial trust message, used by :func:`send_trust_messages`
    to build trust messages.
    """

    recipient_jid: jid.JID
    updated_jid: jid.JID
    trust_updates: FrozenSet[TrustUpdate]


async def manage_trust_message_cache(
    client: SatXMPPClient,
    session_manager: omemo.SessionManager,
    applied_trust_updates: FrozenSet[TrustUpdate],
) -> None:
    """Manage the ATM trust message cache after trust updates have been applied.

    @param client: The client this operation runs under.
    @param session_manager: The session manager to use.
    @param applied_trust_updates: The trust updates that have already been applied,
        triggering this cache management run.
    """

    trust_message_cache = persistent.LazyPersistentBinaryDict(
        "XEP-0384/TM", client.profile
    )

    # Load cache entries
    cache_entries = {
        TrustMessageCacheEntry.from_dict(d)
        for d in await trust_message_cache.get("cache", [])
    }

    # Expire cache entries that were overwritten by the applied trust updates
    cache_entries_by_target = {
        (
            cache_entry.trust_update.target_jid.userhostJID(),
            cache_entry.trust_update.target_key,
        ): cache_entry
        for cache_entry in cache_entries
    }

    for trust_update in applied_trust_updates:
        cache_entry = cache_entries_by_target.get(
            (trust_update.target_jid.userhostJID(), trust_update.target_key), None
        )

        if cache_entry is not None:
            cache_entries.remove(cache_entry)

    # Apply cached Trust Messages by newly trusted devices
    new_trust_updates: Set[TrustUpdate] = set()

    for trust_update in applied_trust_updates:
        if trust_update.target_trust:
            # Iterate over a copy such that cache_entries can be modified
            for cache_entry in set(cache_entries):
                if (
                    cache_entry.sender_jid.userhostJID()
                    == trust_update.target_jid.userhostJID()
                    and cache_entry.sender_key == trust_update.target_key
                ):
                    trust_level = (
                        TrustLevel.TRUSTED
                        if cache_entry.trust_update.target_trust
                        else TrustLevel.DISTRUSTED
                    )

                    # Apply the trust update
                    await session_manager.set_trust(
                        cache_entry.trust_update.target_jid.userhost(),
                        cache_entry.trust_update.target_key,
                        trust_level.name,
                    )

                    # Track the fact that this trust update has been applied
                    new_trust_updates.add(cache_entry.trust_update)

                    # Remove the corresponding cache entry
                    cache_entries.remove(cache_entry)

    # Store the updated cache entries
    await trust_message_cache.force("cache", [tm.to_dict() for tm in cache_entries])

    # TODO: Notify the user ("feedback") about automatically updated trust?

    if len(new_trust_updates) > 0:
        # If any trust has been updated, recursively perform another run of cache
        # management
        await manage_trust_message_cache(
            client, session_manager, frozenset(new_trust_updates)
        )


async def get_trust_as_trust_updates(
    session_manager: omemo.SessionManager, target_jid: jid.JID
) -> FrozenSet[TrustUpdate]:
    """Get the trust status of all known keys of a JID as trust updates for use with ATM.

    @param session_manager: The session manager to load the trust from.
    @param target_jid: The JID to load the trust for.
    @return: The trust updates encoding the trust status of all known keys of the JID that
        are either explicitly trusted or distrusted. Undecided keys are not included in
        the trust updates.
    """

    devices = await session_manager.get_device_information(target_jid.userhost())

    trust_updates: Set[TrustUpdate] = set()

    for device in devices:
        trust_level = TrustLevel(device.trust_level_name)
        target_trust: bool

        if trust_level is TrustLevel.TRUSTED:
            target_trust = True
        elif trust_level is TrustLevel.DISTRUSTED:
            target_trust = False
        else:
            # Skip devices that are not explicitly trusted or distrusted
            continue

        trust_updates.add(
            TrustUpdate(
                target_jid=target_jid.userhostJID(),
                target_key=device.identity_key,
                target_trust=target_trust,
            )
        )

    return frozenset(trust_updates)


async def send_trust_messages(
    client: SatXMPPClient,
    session_manager: omemo.SessionManager,
    applied_trust_updates: FrozenSet[TrustUpdate],
) -> None:
    """Send information about updated trust to peers via ATM (XEP-0450).

    @param client: The client.
    @param session_manager: The session manager.
    @param applied_trust_updates: The trust updates that have already been applied, to
        notify other peers about.
    """
    # NOTE: This currently sends information about oldmemo trust too. This is not
    # specified and experimental, but since twomemo and oldmemo share the same identity
    # keys and trust systems, this could be a cool side effect.

    # Send Trust Messages for newly trusted and distrusted devices
    own_jid = client.jid.userhostJID()
    own_trust_updates = await get_trust_as_trust_updates(session_manager, own_jid)

    # JIDs of which at least one device's trust has been updated
    updated_jids = frozenset(
        {trust_update.target_jid.userhostJID() for trust_update in applied_trust_updates}
    )

    trust_messages: Set[PartialTrustMessage] = set()

    for updated_jid in updated_jids:
        # Get the trust updates for that JID
        trust_updates = frozenset(
            {
                trust_update
                for trust_update in applied_trust_updates
                if trust_update.target_jid.userhostJID() == updated_jid
            }
        )

        if updated_jid == own_jid:
            # If the own JID is updated, _all_ peers have to be notified
            # TODO: Using my author's privilege here to shamelessly access private fields
            # and storage keys until I've added public API to get a list of peers to
            # python-omemo.
            storage: omemo.Storage = getattr(session_manager, "_SessionManager__storage")
            peer_jids = frozenset(
                {
                    jid.JID(bare_jid).userhostJID()
                    for bare_jid in (
                        await storage.load_list(f"/{OMEMO.NS_TWOMEMO}/bare_jids", str)
                    ).maybe([])
                }
            )

            if len(peer_jids) == 0:
                # If there are no peers to notify, notify our other devices about the
                # changes directly
                trust_messages.add(
                    PartialTrustMessage(
                        recipient_jid=own_jid,
                        updated_jid=own_jid,
                        trust_updates=trust_updates,
                    )
                )
            else:
                # Otherwise, notify all peers about the changes in trust and let carbons
                # handle the copy to our own JID
                for peer_jid in peer_jids:
                    trust_messages.add(
                        PartialTrustMessage(
                            recipient_jid=peer_jid,
                            updated_jid=own_jid,
                            trust_updates=trust_updates,
                        )
                    )

                    # Also send full trust information about _every_ peer to our newly
                    # trusted devices
                    peer_trust_updates = await get_trust_as_trust_updates(
                        session_manager, peer_jid
                    )

                    trust_messages.add(
                        PartialTrustMessage(
                            recipient_jid=own_jid,
                            updated_jid=peer_jid,
                            trust_updates=peer_trust_updates,
                        )
                    )

            # Send information about our own devices to our newly trusted devices
            trust_messages.add(
                PartialTrustMessage(
                    recipient_jid=own_jid,
                    updated_jid=own_jid,
                    trust_updates=own_trust_updates,
                )
            )
        else:
            # Notify our other devices about the changes in trust
            trust_messages.add(
                PartialTrustMessage(
                    recipient_jid=own_jid,
                    updated_jid=updated_jid,
                    trust_updates=trust_updates,
                )
            )

            # Send a summary of our own trust to newly trusted devices
            trust_messages.add(
                PartialTrustMessage(
                    recipient_jid=updated_jid,
                    updated_jid=own_jid,
                    trust_updates=own_trust_updates,
                )
            )

    # All trust messages prepared. Merge all trust messages directed at the same
    # recipient.
    recipient_jids = {trust_message.recipient_jid for trust_message in trust_messages}

    for recipient_jid in recipient_jids:
        updated: Dict[jid.JID, Set[TrustUpdate]] = {}

        for trust_message in trust_messages:
            # Merge trust messages directed at that recipient
            if trust_message.recipient_jid == recipient_jid:
                # Merge the trust updates
                updated[trust_message.updated_jid] = updated.get(
                    trust_message.updated_jid, set()
                )

                updated[trust_message.updated_jid] |= trust_message.trust_updates

        # Build the trust message
        trust_message_elt = domish.Element((NS_TM, "trust-message"))
        trust_message_elt["usage"] = NS_ATM
        trust_message_elt["encryption"] = twomemo.twomemo.NAMESPACE

        for updated_jid, trust_updates in updated.items():
            key_owner_elt = trust_message_elt.addElement((NS_TM, "key-owner"))
            key_owner_elt["jid"] = updated_jid.userhost()

            for trust_update in trust_updates:
                serialized_identity_key = base64.b64encode(
                    trust_update.target_key
                ).decode("ASCII")

                if trust_update.target_trust:
                    key_owner_elt.addElement(
                        (NS_TM, "trust"), content=serialized_identity_key
                    )
                else:
                    key_owner_elt.addElement(
                        (NS_TM, "distrust"), content=serialized_identity_key
                    )

        # Finally, encrypt and send the trust message!
        message_data = client.generate_message_xml(
            MessageData(
                {
                    "from": own_jid,
                    "to": recipient_jid,
                    "uid": str(uuid.uuid4()),
                    "message": {},
                    "subject": {},
                    "type": C.MESS_TYPE_CHAT,
                    "extra": {},
                    "timestamp": time.time(),
                }
            )
        )

        message_data["xml"].addChild(trust_message_elt)

        plaintext = XEP_0420.pack_stanza(TM_SCE_PROFILE, message_data["xml"])

        feedback_jid = recipient_jid

        # TODO: The following is mostly duplicate code
        try:
            messages, encryption_errors = await session_manager.encrypt(
                frozenset({own_jid.userhost(), recipient_jid.userhost()}),
                {OMEMO.NS_TWOMEMO: plaintext},
                backend_priority_order=[OMEMO.NS_TWOMEMO],
                identifier=feedback_jid.userhost(),
            )
        except Exception as e:
            msg = _(
                # pylint: disable=consider-using-f-string
                "Can't encrypt message for {entities}: {reason}".format(
                    entities=", ".join({own_jid.userhost(), recipient_jid.userhost()}),
                    reason=e,
                )
            )
            log.warning(msg)
            client.feedback(feedback_jid, msg, {C.MESS_EXTRA_INFO: C.EXTRA_INFO_ENCR_ERR})
            raise e

        if len(encryption_errors) > 0:
            log.warning(
                f"Ignored the following non-critical encryption errors:"
                f" {encryption_errors}"
            )

            encrypted_errors_stringified = ", ".join(
                [
                    f"device {err.device_id} of {err.bare_jid} under namespace"
                    f" {err.namespace}"
                    for err in encryption_errors
                ]
            )

            client.feedback(
                feedback_jid,
                D_(
                    "There were non-critical errors during encryption resulting in some"
                    " of your destinees' devices potentially not receiving the message."
                    " This happens when the encryption data/key material of a device is"
                    " incomplete or broken, which shouldn't happen for actively used"
                    " devices, and can usually be ignored. The following devices are"
                    f" affected: {encrypted_errors_stringified}."
                ),
            )

        message = next(
            message for message in messages if message.namespace == OMEMO.NS_TWOMEMO
        )

        # Add the encrypted element
        message_data["xml"].addChild(
            xml_tools.et_elt_2_domish_elt(twomemo.etree.serialize_message(message))
        )

        await client.a_send(message_data["xml"])


def make_session_manager(
    sat: LiberviaBackend, profile: str
) -> Type[omemo.SessionManager]:
    """
    @param sat: The SAT instance.
    @param profile: The profile.
    @return: A non-abstract subclass of :class:`~omemo.session_manager.SessionManager`
        with XMPP interactions and trust handled via the SAT instance.
    """

    client = sat.get_client(profile)
    xep_0060 = cast(XEP_0060, sat.plugins["XEP-0060"])

    class SessionManagerImpl(omemo.SessionManager):
        """
        Session manager implementation handling XMPP interactions and trust via an
        instance of :class:`~sat.core.sat_main.SAT`.
        """

        @staticmethod
        async def _upload_bundle(bundle: omemo.Bundle) -> None:
            if isinstance(bundle, twomemo.twomemo.BundleImpl):
                element = twomemo.etree.serialize_bundle(bundle)

                node = "urn:xmpp:omemo:2:bundles"
                try:
                    await xep_0060.send_item(
                        client,
                        client.jid.userhostJID(),
                        node,
                        xml_tools.et_elt_2_domish_elt(element),
                        item_id=str(bundle.device_id),
                        extra={
                            XEP_0060.EXTRA_PUBLISH_OPTIONS: {
                                XEP_0060.OPT_ACCESS_MODEL: "open",
                                XEP_0060.OPT_MAX_ITEMS: "max",
                            },
                            XEP_0060.EXTRA_ON_PRECOND_NOT_MET: "force",
                        },
                    )
                except (error.StanzaError, Exception) as e:
                    if (
                        isinstance(e, error.StanzaError)
                        and e.condition == "conflict"
                        and e.appCondition is not None
                        # pylint: disable=no-member
                        and e.appCondition.name == "precondition-not-met"
                    ):
                        # publish options couldn't be set on the fly, manually reconfigure
                        # the node and publish again
                        raise omemo.BundleUploadFailed(
                            f"precondition-not-met: {bundle}"
                        ) from e
                        # TODO: What can I do here? The correct node configuration is a
                        # MUST in the XEP.

                    raise omemo.BundleUploadFailed(
                        f"Bundle upload failed: {bundle}"
                    ) from e

                return

            if isinstance(bundle, oldmemo.oldmemo.BundleImpl):
                element = oldmemo.etree.serialize_bundle(bundle)

                node = f"eu.siacs.conversations.axolotl.bundles:{bundle.device_id}"
                try:
                    await xep_0060.send_item(
                        client,
                        client.jid.userhostJID(),
                        node,
                        xml_tools.et_elt_2_domish_elt(element),
                        item_id=xep_0060.ID_SINGLETON,
                        extra={
                            XEP_0060.EXTRA_PUBLISH_OPTIONS: {
                                XEP_0060.OPT_ACCESS_MODEL: "open",
                                XEP_0060.OPT_MAX_ITEMS: 1,
                            },
                            XEP_0060.EXTRA_ON_PRECOND_NOT_MET: "publish_without_options",
                        },
                    )
                except Exception as e:
                    raise omemo.BundleUploadFailed(
                        f"Bundle upload failed: {bundle}"
                    ) from e

                return

            raise omemo.UnknownNamespace(f"Unknown namespace: {bundle.namespace}")

        @staticmethod
        async def _download_bundle(
            namespace: str, bare_jid: str, device_id: int
        ) -> omemo.Bundle:
            if namespace == twomemo.twomemo.NAMESPACE:
                node = "urn:xmpp:omemo:2:bundles"

                try:
                    items, __ = await xep_0060.get_items(
                        client, jid.JID(bare_jid), node, item_ids=[str(device_id)]
                    )
                except Exception as e:
                    raise omemo.BundleDownloadFailed(
                        f"Bundle download failed for {bare_jid}: {device_id} under"
                        f" namespace {namespace}"
                    ) from e

                if len(items) != 1:
                    raise omemo.BundleDownloadFailed(
                        f"Bundle download failed for {bare_jid}: {device_id} under"
                        f" namespace {namespace}: Unexpected number of items retrieved:"
                        f" {len(items)}."
                    )

                element = next(
                    iter(xml_tools.domish_elt_2_et_elt(cast(domish.Element, items[0]))),
                    None,
                )
                if element is None:
                    raise omemo.BundleDownloadFailed(
                        f"Bundle download failed for {bare_jid}: {device_id} under"
                        f" namespace {namespace}: Item download succeeded but parsing"
                        f" failed: {element}."
                    )

                try:
                    return twomemo.etree.parse_bundle(element, bare_jid, device_id)
                except Exception as e:
                    raise omemo.BundleDownloadFailed(
                        f"Bundle parsing failed for {bare_jid}: {device_id} under"
                        f" namespace {namespace}"
                    ) from e

            if namespace == oldmemo.oldmemo.NAMESPACE:
                return await download_oldmemo_bundle(
                    client, xep_0060, bare_jid, device_id
                )

            raise omemo.UnknownNamespace(f"Unknown namespace: {namespace}")

        @staticmethod
        async def _delete_bundle(namespace: str, device_id: int) -> None:
            if namespace == twomemo.twomemo.NAMESPACE:
                node = "urn:xmpp:omemo:2:bundles"

                try:
                    await xep_0060.retract_items(
                        client,
                        client.jid.userhostJID(),
                        node,
                        [str(device_id)],
                        notify=False,
                    )
                except Exception as e:
                    raise omemo.BundleDeletionFailed(
                        f"Bundle deletion failed for {device_id} under namespace"
                        f" {namespace}"
                    ) from e

                return

            if namespace == oldmemo.oldmemo.NAMESPACE:
                node = f"eu.siacs.conversations.axolotl.bundles:{device_id}"

                try:
                    await xep_0060.deleteNode(client, client.jid.userhostJID(), node)
                except Exception as e:
                    raise omemo.BundleDeletionFailed(
                        f"Bundle deletion failed for {device_id} under namespace"
                        f" {namespace}"
                    ) from e

                return

            raise omemo.UnknownNamespace(f"Unknown namespace: {namespace}")

        @staticmethod
        async def _upload_device_list(
            namespace: str, device_list: Dict[int, Optional[str]]
        ) -> None:
            element: Optional[ET.Element] = None
            node: Optional[str] = None

            if namespace == twomemo.twomemo.NAMESPACE:
                element = twomemo.etree.serialize_device_list(device_list)
                node = TWOMEMO_DEVICE_LIST_NODE
            if namespace == oldmemo.oldmemo.NAMESPACE:
                element = oldmemo.etree.serialize_device_list(device_list)
                node = OLDMEMO_DEVICE_LIST_NODE

            if element is None or node is None:
                raise omemo.UnknownNamespace(f"Unknown namespace: {namespace}")

            try:
                await xep_0060.send_item(
                    client,
                    client.jid.userhostJID(),
                    node,
                    xml_tools.et_elt_2_domish_elt(element),
                    item_id=xep_0060.ID_SINGLETON,
                    extra={
                        XEP_0060.EXTRA_PUBLISH_OPTIONS: {
                            XEP_0060.OPT_MAX_ITEMS: 1,
                            XEP_0060.OPT_ACCESS_MODEL: "open",
                        },
                        XEP_0060.EXTRA_ON_PRECOND_NOT_MET: "force",
                    },
                )
            except (error.StanzaError, Exception) as e:
                if (
                    isinstance(e, error.StanzaError)
                    and e.condition == "conflict"
                    and e.appCondition is not None
                    # pylint: disable=no-member
                    and e.appCondition.name == "precondition-not-met"
                ):
                    # publish options couldn't be set on the fly, manually reconfigure the
                    # node and publish again
                    raise omemo.DeviceListUploadFailed(
                        f"precondition-not-met for namespace {namespace}"
                    ) from e
                    # TODO: What can I do here? The correct node configuration is a MUST
                    # in the XEP.

                raise omemo.DeviceListUploadFailed(
                    f"Device list upload failed for namespace {namespace}"
                ) from e

        @staticmethod
        async def _download_device_list(
            namespace: str, bare_jid: str
        ) -> Dict[int, Optional[str]]:
            node: Optional[str] = None

            if namespace == twomemo.twomemo.NAMESPACE:
                node = TWOMEMO_DEVICE_LIST_NODE
            if namespace == oldmemo.oldmemo.NAMESPACE:
                node = OLDMEMO_DEVICE_LIST_NODE

            if node is None:
                raise omemo.UnknownNamespace(f"Unknown namespace: {namespace}")

            try:
                items, __ = await xep_0060.get_items(client, jid.JID(bare_jid), node)
            except exceptions.NotFound:
                return {}
            except Exception as e:
                raise omemo.DeviceListDownloadFailed(
                    f"Device list download failed for {bare_jid} under namespace"
                    f" {namespace}"
                ) from e

            if len(items) == 0:
                return {}

            if len(items) != 1:
                raise omemo.DeviceListDownloadFailed(
                    f"Device list download failed for {bare_jid} under namespace"
                    f" {namespace}: Unexpected number of items retrieved: {len(items)}."
                )

            element = next(
                iter(xml_tools.domish_elt_2_et_elt(cast(domish.Element, items[0]))), None
            )

            if element is None:
                raise omemo.DeviceListDownloadFailed(
                    f"Device list download failed for {bare_jid} under namespace"
                    f" {namespace}: Item download succeeded but parsing failed:"
                    f" {element}."
                )

            try:
                if namespace == twomemo.twomemo.NAMESPACE:
                    return twomemo.etree.parse_device_list(element)
                if namespace == oldmemo.oldmemo.NAMESPACE:
                    return oldmemo.etree.parse_device_list(element)
            except Exception as e:
                raise omemo.DeviceListDownloadFailed(
                    f"Device list download failed for {bare_jid} under namespace"
                    f" {namespace}"
                ) from e

            raise omemo.UnknownNamespace(f"Unknown namespace: {namespace}")

        async def _evaluate_custom_trust_level(
            self, device: omemo.DeviceInformation
        ) -> omemo.TrustLevel:
            # Get the custom trust level
            try:
                trust_level = TrustLevel(device.trust_level_name)
            except ValueError as e:
                raise omemo.UnknownTrustLevel(
                    f"Unknown trust level name {device.trust_level_name}"
                ) from e

            # The first three cases are a straight-forward mapping
            if trust_level is TrustLevel.TRUSTED:
                return omemo.TrustLevel.TRUSTED
            if trust_level is TrustLevel.UNDECIDED:
                return omemo.TrustLevel.UNDECIDED
            if trust_level is TrustLevel.DISTRUSTED:
                return omemo.TrustLevel.DISTRUSTED

            # The blindly trusted case is more complicated, since its evaluation depends
            # on the trust system and phase
            if trust_level is TrustLevel.BLINDLY_TRUSTED:
                # Get the name of the active trust system
                trust_system = cast(
                    str,
                    sat.memory.param_get_a(
                        PARAM_NAME, PARAM_CATEGORY, profile_key=profile
                    ),
                )

                # If the trust model is BTBV, blind trust is always enabled
                if trust_system == "btbv":
                    return omemo.TrustLevel.TRUSTED

                # If the trust model is ATM, blind trust is disabled in the second phase
                # and counts as undecided
                if trust_system == "atm":
                    # Find out whether we are in phase one or two
                    devices = await self.get_device_information(device.bare_jid)

                    phase_one = all(
                        TrustLevel(device.trust_level_name)
                        in {TrustLevel.UNDECIDED, TrustLevel.BLINDLY_TRUSTED}
                        for device in devices
                    )

                    if phase_one:
                        return omemo.TrustLevel.TRUSTED

                    return omemo.TrustLevel.UNDECIDED

                raise exceptions.InternalError(
                    f"Unknown trust system active: {trust_system}"
                )

            assert_never(trust_level)

        async def _make_trust_decision(
            self, undecided: FrozenSet[omemo.DeviceInformation], identifier: Optional[str]
        ) -> None:
            if identifier is None:
                raise omemo.TrustDecisionFailed(
                    "The identifier must contain the feedback JID."
                )

            # The feedback JID is transferred via the identifier
            feedback_jid = jid.JID(identifier).userhostJID()

            # Both the ATM and the BTBV trust models work with blind trust before the
            # first manual verification is performed. Thus, we can separate bare JIDs into
            # two pools here, one pool of bare JIDs for which blind trust is active, and
            # one pool of bare JIDs for which manual trust is used instead.
            bare_jids = {device.bare_jid for device in undecided}

            blind_trust_bare_jids: Set[str] = set()
            manual_trust_bare_jids: Set[str] = set()

            # For each bare JID, decide whether blind trust applies
            for bare_jid in bare_jids:
                # Get all known devices belonging to the bare JID
                devices = await self.get_device_information(bare_jid)

                # If the trust levels of all devices correspond to those used by blind
                # trust, blind trust applies. Otherwise, fall back to manual trust.
                if all(
                    TrustLevel(device.trust_level_name)
                    in {TrustLevel.UNDECIDED, TrustLevel.BLINDLY_TRUSTED}
                    for device in devices
                ):
                    blind_trust_bare_jids.add(bare_jid)
                else:
                    manual_trust_bare_jids.add(bare_jid)

            # With the JIDs sorted into their respective pools, the undecided devices can
            # be categorized too
            blindly_trusted_devices = {
                dev for dev in undecided if dev.bare_jid in blind_trust_bare_jids
            }
            manually_trusted_devices = {
                dev for dev in undecided if dev.bare_jid in manual_trust_bare_jids
            }

            # Blindly trust devices handled by blind trust
            if len(blindly_trusted_devices) > 0:
                for device in blindly_trusted_devices:
                    await self.set_trust(
                        device.bare_jid,
                        device.identity_key,
                        TrustLevel.BLINDLY_TRUSTED.name,
                    )

                blindly_trusted_devices_stringified = ", ".join(
                    [
                        f"device {device.device_id} of {device.bare_jid} under namespace"
                        f" {device.namespaces}"
                        for device in blindly_trusted_devices
                    ]
                )

                client.feedback(
                    feedback_jid,
                    D_(
                        "Not all destination devices are trusted, unknown devices will be"
                        " blindly trusted.\nFollowing devices have been automatically"
                        f" trusted: {blindly_trusted_devices_stringified}."
                    ),
                )

            # Prompt the user for manual trust decisions on the devices handled by manual
            # trust
            if len(manually_trusted_devices) > 0:
                client.feedback(
                    feedback_jid,
                    D_(
                        "Not all destination devices are trusted, we can't encrypt"
                        " message in such a situation. Please indicate if you trust"
                        " those devices or not in the trust manager before we can"
                        " send this message."
                    ),
                )
                await self.__prompt_manual_trust(
                    frozenset(manually_trusted_devices), feedback_jid
                )

        @staticmethod
        async def _send_message(message: omemo.Message, bare_jid: str) -> None:
            element: Optional[ET.Element] = None

            if message.namespace == twomemo.twomemo.NAMESPACE:
                element = twomemo.etree.serialize_message(message)
            if message.namespace == oldmemo.oldmemo.NAMESPACE:
                element = oldmemo.etree.serialize_message(message)

            if element is None:
                raise omemo.UnknownNamespace(f"Unknown namespace: {message.namespace}")

            message_data = client.generate_message_xml(
                MessageData(
                    {
                        "from": client.jid,
                        "to": jid.JID(bare_jid),
                        "uid": str(uuid.uuid4()),
                        "message": {},
                        "subject": {},
                        "type": C.MESS_TYPE_CHAT,
                        "extra": {},
                        "timestamp": time.time(),
                    }
                )
            )

            message_data["xml"].addChild(xml_tools.et_elt_2_domish_elt(element))

            try:
                await client.a_send(message_data["xml"])
            except Exception as e:
                raise omemo.MessageSendingFailed() from e

        async def __prompt_manual_trust(
            self, undecided: FrozenSet[omemo.DeviceInformation], feedback_jid: jid.JID
        ) -> None:
            """Asks the user to decide on the manual trust level of a set of devices.

            Blocks until the user has made a decision and updates the trust levels of all
            devices using :meth:`set_trust`.

            @param undecided: The set of devices to prompt manual trust for.
            @param feedback_jid: The bare JID to redirect feedback to. In case of a one to
                one message, the recipient JID. In case of a MUC message, the room JID.
            @raise TrustDecisionFailed: if the user cancels the prompt.
            """

            # This session manager handles encryption with both twomemo and oldmemo, but
            # both are currently registered as different plugins and the `defer_xmlui`
            # below requires a single namespace identifying the encryption plugin. Thus,
            # get the namespace of the requested encryption method from the encryption
            # session using the feedback JID.
            encryption = client.encryption.getSession(feedback_jid)
            if encryption is None:
                raise omemo.TrustDecisionFailed(
                    f"Encryption not requested for {feedback_jid.userhost()}."
                )

            namespace = encryption["plugin"].namespace

            # Casting this to Any, otherwise all calls on the variable cause type errors
            # pylint: disable=no-member
            trust_ui = cast(
                Any,
                xml_tools.XMLUI(
                    panel_type=C.XMLUI_FORM,
                    title=D_("OMEMO trust management"),
                    submit_id="",
                ),
            )
            trust_ui.addText(
                D_(
                    "This is OMEMO trusting system. You'll see below the devices of your "
                    "contacts, and a checkbox to trust them or not. A trusted device "
                    "can read your messages in plain text, so be sure to only validate "
                    "devices that you are sure are belonging to your contact. It's better "
                    "to do this when you are next to your contact and their device, so "
                    'you can check the "fingerprint" (the number next to the device) '
                    "yourself. Do *not* validate a device if the fingerprint is wrong!"
                )
            )

            own_device, __ = await self.get_own_device_information()

            trust_ui.change_container("label")
            trust_ui.addLabel(D_("This device ID"))
            trust_ui.addText(str(own_device.device_id))
            trust_ui.addLabel(D_("This device's fingerprint"))
            trust_ui.addText(" ".join(self.format_identity_key(own_device.identity_key)))
            trust_ui.addEmpty()
            trust_ui.addEmpty()

            # At least sort the devices by bare JID such that they aren't listed
            # completely random
            undecided_ordered = sorted(undecided, key=lambda device: device.bare_jid)

            for index, device in enumerate(undecided_ordered):
                trust_ui.addLabel(D_("Contact"))
                trust_ui.addJid(jid.JID(device.bare_jid))
                trust_ui.addLabel(D_("Device ID"))
                trust_ui.addText(str(device.device_id))
                trust_ui.addLabel(D_("Fingerprint"))
                trust_ui.addText(" ".join(self.format_identity_key(device.identity_key)))
                trust_ui.addLabel(D_("Trust this device?"))
                trust_ui.addBool(f"trust_{index}", value=C.bool_const(False))
                trust_ui.addEmpty()
                trust_ui.addEmpty()

            trust_ui_result = await xml_tools.defer_xmlui(
                sat,
                trust_ui,
                action_extra={"meta_encryption_trust": namespace},
                profile=profile,
            )

            if C.bool(trust_ui_result.get("cancelled", "false")):
                raise omemo.TrustDecisionFailed("Trust UI cancelled.")

            data_form_result = cast(
                Dict[str, str], xml_tools.xmlui_result_2_data_form_result(trust_ui_result)
            )

            trust_updates: Set[TrustUpdate] = set()

            for key, value in data_form_result.items():
                if not key.startswith("trust_"):
                    continue

                device = undecided_ordered[int(key[len("trust_") :])]
                target_trust = C.bool(value)
                trust_level = (
                    TrustLevel.TRUSTED if target_trust else TrustLevel.DISTRUSTED
                )

                await self.set_trust(
                    device.bare_jid, device.identity_key, trust_level.name
                )

                trust_updates.add(
                    TrustUpdate(
                        target_jid=jid.JID(device.bare_jid).userhostJID(),
                        target_key=device.identity_key,
                        target_trust=target_trust,
                    )
                )

            # Check whether ATM is enabled and handle everything in case it is
            trust_system = cast(
                str,
                sat.memory.param_get_a(PARAM_NAME, PARAM_CATEGORY, profile_key=profile),
            )

            if trust_system == "atm":
                await manage_trust_message_cache(client, self, frozenset(trust_updates))
                await send_trust_messages(client, self, frozenset(trust_updates))

    return SessionManagerImpl


async def prepare_for_profile(
    sat: LiberviaBackend,
    profile: str,
    initial_own_label: Optional[str],
    signed_pre_key_rotation_period: int = 7 * 24 * 60 * 60,
    pre_key_refill_threshold: int = 99,
    max_num_per_session_skipped_keys: int = 1000,
    max_num_per_message_skipped_keys: Optional[int] = None,
) -> omemo.SessionManager:
    """Prepare the OMEMO library (storage, backends, core) for a specific profile.

    @param sat: The SAT instance.
    @param profile: The profile.
    @param initial_own_label: The initial (optional) label to assign to this device if
        supported by any of the backends.
    @param signed_pre_key_rotation_period: The rotation period for the signed pre key, in
        seconds. The rotation period is recommended to be between one week (the default)
        and one month.
    @param pre_key_refill_threshold: The number of pre keys that triggers a refill to 100.
        Defaults to 99, which means that each pre key gets replaced with a new one right
        away. The threshold can not be configured to lower than 25.
    @param max_num_per_session_skipped_keys: The maximum number of skipped message keys to
        keep around per session. Once the maximum is reached, old message keys are deleted
        to make space for newer ones. Accessible via
        :attr:`max_num_per_session_skipped_keys`.
    @param max_num_per_message_skipped_keys: The maximum number of skipped message keys to
        accept in a single message. When set to ``None`` (the default), this parameter
        defaults to the per-session maximum (i.e. the value of the
        ``max_num_per_session_skipped_keys`` parameter). This parameter may only be 0 if
        the per-session maximum is 0, otherwise it must be a number between 1 and the
        per-session maximum. Accessible via :attr:`max_num_per_message_skipped_keys`.
    @return: A session manager with ``urn:xmpp:omemo:2`` and
        ``eu.siacs.conversations.axolotl`` capabilities, specifically for the given
        profile.
    @raise BundleUploadFailed: if a bundle upload failed. Forwarded from
        :meth:`~omemo.session_manager.SessionManager.create`.
    @raise BundleDownloadFailed: if a bundle download failed. Forwarded from
        :meth:`~omemo.session_manager.SessionManager.create`.
    @raise BundleDeletionFailed: if a bundle deletion failed. Forwarded from
        :meth:`~omemo.session_manager.SessionManager.create`.
    @raise DeviceListUploadFailed: if a device list upload failed. Forwarded from
        :meth:`~omemo.session_manager.SessionManager.create`.
    @raise DeviceListDownloadFailed: if a device list download failed. Forwarded from
        :meth:`~omemo.session_manager.SessionManager.create`.
    """

    client = sat.get_client(profile)
    xep_0060 = cast(XEP_0060, sat.plugins["XEP-0060"])

    storage = StorageImpl(profile)

    # TODO: Untested
    await oldmemo.migrations.migrate(
        LegacyStorageImpl(profile, client.jid.userhost()),
        storage,
        # TODO: Do we want BLINDLY_TRUSTED or TRUSTED here?
        TrustLevel.BLINDLY_TRUSTED.name,
        TrustLevel.UNDECIDED.name,
        TrustLevel.DISTRUSTED.name,
        lambda bare_jid, device_id: download_oldmemo_bundle(
            client, xep_0060, bare_jid, device_id
        ),
    )

    session_manager = await make_session_manager(sat, profile).create(
        [
            twomemo.Twomemo(
                storage,
                max_num_per_session_skipped_keys,
                max_num_per_message_skipped_keys,
            ),
            oldmemo.Oldmemo(
                storage,
                max_num_per_session_skipped_keys,
                max_num_per_message_skipped_keys,
            ),
        ],
        storage,
        client.jid.userhost(),
        initial_own_label,
        TrustLevel.UNDECIDED.value,
        signed_pre_key_rotation_period,
        pre_key_refill_threshold,
        omemo.AsyncFramework.TWISTED,
    )

    # This shouldn't hurt here since we're not running on overly constrainted devices.
    # TODO: Consider ensuring data consistency regularly/in response to certain events
    await session_manager.ensure_data_consistency()

    # TODO: Correct entering/leaving of the history synchronization mode isn't terribly
    # important for now, since it only prevents an extremely unlikely race condition of
    # multiple devices choosing the same pre key for new sessions while the device was
    # offline. I don't believe other clients seriously defend against that race condition
    # either. In the long run, it might still be cool to have triggers for when history
    # sync starts and ends (MAM, MUC catch-up, etc.) and to react to those triggers.
    await session_manager.after_history_sync()

    return session_manager


DEFAULT_TRUST_MODEL_PARAM = f"""
<params>
<individual>
<category name="{PARAM_CATEGORY}" label={quoteattr(D_('Security'))}>
    <param name="{PARAM_NAME}"
        label={quoteattr(D_('OMEMO default trust policy'))}
        type="list" security="3">
        <option value="atm"
            label={quoteattr(D_('Automatic Trust Management (more secure)'))} />
        <option value="btbv"
            label={quoteattr(D_('Blind Trust Before Verification (more user friendly)'))}
            selected="true" />
    </param>
</category>
</individual>
</params>
"""


class OMEMO:
    """
    Plugin equipping Libervia with OMEMO capabilities under the (modern)
    ``urn:xmpp:omemo:2`` namespace and the (legacy) ``eu.siacs.conversations.axolotl``
    namespace. Both versions of the protocol are handled by this plugin and compatibility
    between the two is maintained. MUC messages are supported next to one to one messages.
    For trust management, the two trust models "ATM" and "BTBV" are supported.
    """

    NS_TWOMEMO = twomemo.twomemo.NAMESPACE
    NS_OLDMEMO = oldmemo.oldmemo.NAMESPACE

    # For MUC/MIX message stanzas, the <to/> affix is a MUST
    SCE_PROFILE_GROUPCHAT = SCEProfile(
        rpad_policy=SCEAffixPolicy.REQUIRED,
        time_policy=SCEAffixPolicy.OPTIONAL,
        to_policy=SCEAffixPolicy.REQUIRED,
        from_policy=SCEAffixPolicy.OPTIONAL,
        custom_policies={},
    )

    # For everything but MUC/MIX message stanzas, the <to/> affix is a MAY
    SCE_PROFILE = SCEProfile(
        rpad_policy=SCEAffixPolicy.REQUIRED,
        time_policy=SCEAffixPolicy.OPTIONAL,
        to_policy=SCEAffixPolicy.OPTIONAL,
        from_policy=SCEAffixPolicy.OPTIONAL,
        custom_policies={},
    )

    def __init__(self, host: LiberviaBackend) -> None:
        """
        @param sat: The SAT instance.
        """

        self.host = host

        # Add configuration option to choose between manual trust and BTBV as the trust
        # model
        host.memory.update_params(DEFAULT_TRUST_MODEL_PARAM)

        # Plugins
        self._j = cast(XEP_0060, host.plugins["XEP-0060"])
        self.__xep_0045 = cast(Optional[XEP_0045], host.plugins.get("XEP-0045"))
        self.__xep_0334 = cast(XEP_0334, host.plugins["XEP-0334"])
        self.__xep_0359 = cast(Optional[XEP_0359], host.plugins.get("XEP-0359"))
        self.__xep_0420 = cast(XEP_0420, host.plugins["XEP-0420"])

        # In contrast to one to one messages, MUC messages are reflected to the sender.
        # Thus, the sender does not add messages to their local message log when sending
        # them, but when the reflection is received. This approach does not pair well with
        # OMEMO, since for security reasons it is forbidden to encrypt messages for the
        # own device. Thus, when the reflection of an OMEMO message is received, it can't
        # be decrypted and added to the local message log as usual. To counteract this,
        # the plaintext of encrypted messages sent to MUCs are cached in this field, such
        # that when the reflection is received, the plaintext can be looked up from the
        # cache and added to the local message log.
        # TODO: The old plugin expired this cache after some time. I'm not sure that's
        # really necessary.
        self.__muc_plaintext_cache: Dict[MUCPlaintextCacheKey, bytes] = {}

        # Mapping from profile name to corresponding session manager
        self.__session_managers: Dict[str, omemo.SessionManager] = {}

        # Calls waiting for a specific session manager to be built
        self.__session_manager_waiters: Dict[str, List[defer.Deferred]] = {}

        # These triggers are used by oldmemo, which doesn't do SCE and only applies to
        # messages. Temporarily, until a more fitting trigger for SCE-based encryption is
        # added, the message_received trigger is also used for twomemo.
        host.trigger.add(
            "message_received", self._message_received_trigger, priority=100050
        )

        host.trigger.add("send", self.__send_trigger, priority=0)
        # TODO: Add new triggers here for freshly received and about-to-be-sent stanzas,
        # including IQs.

        # Give twomemo a (slightly) higher priority than oldmemo
        host.register_encryption_plugin(self, "OMEMO", twomemo.twomemo.NAMESPACE, 101)
        host.register_encryption_plugin(
            self, "OMEMO_legacy", oldmemo.oldmemo.NAMESPACE, 100
        )

        xep_0163 = cast(XEP_0163, host.plugins["XEP-0163"])
        xep_0163.add_pep_event(
            "TWOMEMO_DEVICES",
            TWOMEMO_DEVICE_LIST_NODE,
            lambda items_event, profile: defer.ensureDeferred(
                self.__on_device_list_update(items_event, profile)
            ),
        )
        xep_0163.add_pep_event(
            "OLDMEMO_DEVICES",
            OLDMEMO_DEVICE_LIST_NODE,
            lambda items_event, profile: defer.ensureDeferred(
                self.__on_device_list_update(items_event, profile)
            ),
        )

        try:
            self.__text_commands = cast(TextCommands, host.plugins[C.TEXT_CMDS])
        except KeyError:
            log.info(_("Text commands not available"))
        else:
            self.__text_commands.register_text_commands(self)

    def profile_connected(  # pylint: disable=invalid-name
        self, client: SatXMPPClient
    ) -> None:
        """
        @param client: The client.
        """

        defer.ensureDeferred(self.get_session_manager(cast(str, client.profile)))

    async def cmd_omemo_reset(
        self, client: SatXMPPClient, mess_data: MessageData
    ) -> Literal[False]:
        """Reset all sessions of devices that belong to the recipient of ``mess_data``.

        This must only be callable manually by the user. Use this when a session is
        apparently broken, i.e. sending and receiving encrypted messages doesn't work and
        something being wrong has been confirmed manually with the recipient.

        @param client: The client.
        @param mess_data: The message data, whose ``to`` attribute will be the bare JID to
            reset all sessions with.
        @return: The constant value ``False``, indicating to the text commands plugin that
            the message is not supposed to be sent.
        """

        twomemo_requested = client.encryption.is_encryption_requested(
            mess_data, twomemo.twomemo.NAMESPACE
        )
        oldmemo_requested = client.encryption.is_encryption_requested(
            mess_data, oldmemo.oldmemo.NAMESPACE
        )

        if not (twomemo_requested or oldmemo_requested):
            self.__text_commands.feed_back(
                client,
                _("You need to have OMEMO encryption activated to reset the session"),
                mess_data,
            )
            return False

        bare_jid = mess_data["to"].userhost()

        session_manager = await self.get_session_manager(client.profile)
        devices = await session_manager.get_device_information(bare_jid)

        for device in devices:
            log.debug(f"Replacing sessions with device {device}")
            await session_manager.replace_sessions(device)

        self.__text_commands.feed_back(
            client, _("OMEMO session has been reset"), mess_data
        )

        return False

    async def get_trust_ui(  # pylint: disable=invalid-name
        self, client: SatXMPPClient, entity: jid.JID
    ) -> xml_tools.XMLUI:
        """
        @param client: The client.
        @param entity: The entity whose device trust levels to manage.
        @return: An XMLUI instance which opens a form to manage the trust level of all
            devices belonging to the entity.
        """

        if entity.resource:
            raise ValueError("A bare JID is expected.")

        bare_jids: Set[str]
        if self.__xep_0045 is not None and self.__xep_0045.is_joined_room(client, entity):
            bare_jids = self.__get_joined_muc_users(client, self.__xep_0045, entity)
        else:
            bare_jids = {entity.userhost()}

        session_manager = await self.get_session_manager(client.profile)

        # At least sort the devices by bare JID such that they aren't listed completely
        # random
        devices = sorted(
            cast(Set[omemo.DeviceInformation], set()).union(
                *[
                    await session_manager.get_device_information(bare_jid)
                    for bare_jid in bare_jids
                ]
            ),
            key=lambda device: device.bare_jid,
        )

        async def callback(data: Any, profile: str) -> Dict[Never, Never]:
            """
            @param data: The XMLUI result produces by the trust UI form.
            @param profile: The profile.
            @return: An empty dictionary. The type of the return value was chosen
                conservatively since the exact options are neither known not needed here.
            """

            if C.bool(data.get("cancelled", "false")):
                return {}

            data_form_result = cast(
                Dict[str, str], xml_tools.xmlui_result_2_data_form_result(data)
            )

            trust_updates: Set[TrustUpdate] = set()

            for key, value in data_form_result.items():
                if not key.startswith("trust_"):
                    continue

                device = devices[int(key[len("trust_") :])]
                trust_level_name = value

                if device.trust_level_name != trust_level_name:
                    await session_manager.set_trust(
                        device.bare_jid, device.identity_key, trust_level_name
                    )

                    target_trust: Optional[bool] = None

                    if TrustLevel(trust_level_name) is TrustLevel.TRUSTED:
                        target_trust = True
                    if TrustLevel(trust_level_name) is TrustLevel.DISTRUSTED:
                        target_trust = False

                    if target_trust is not None:
                        trust_updates.add(
                            TrustUpdate(
                                target_jid=jid.JID(device.bare_jid).userhostJID(),
                                target_key=device.identity_key,
                                target_trust=target_trust,
                            )
                        )

            # Check whether ATM is enabled and handle everything in case it is
            trust_system = cast(
                str,
                self.host.memory.param_get_a(
                    PARAM_NAME, PARAM_CATEGORY, profile_key=profile
                ),
            )

            if trust_system == "atm":
                if len(trust_updates) > 0:
                    await manage_trust_message_cache(
                        client, session_manager, frozenset(trust_updates)
                    )

                    await send_trust_messages(
                        client, session_manager, frozenset(trust_updates)
                    )

            return {}

        submit_id = self.host.register_callback(callback, with_data=True, one_shot=True)

        result = xml_tools.XMLUI(
            panel_type=C.XMLUI_FORM,
            title=D_("OMEMO trust management"),
            submit_id=submit_id,
        )
        # Casting this to Any, otherwise all calls on the variable cause type errors
        # pylint: disable=no-member
        trust_ui = cast(Any, result)
        trust_ui.addText(
            D_(
                "This is OMEMO trusting system. You'll see below the devices of your"
                " contacts, and a list selection to trust them or not. A trusted device"
                " can read your messages in plain text, so be sure to only validate"
                " devices that you are sure are belonging to your contact. It's better"
                " to do this when you are next to your contact and their device, so"
                ' you can check the "fingerprint" (the number next to the device)'
                " yourself. Do *not* validate a device if the fingerprint is wrong!"
                " Note that manually validating a fingerprint disables any form of automatic"
                " trust."
            )
        )

        own_device, __ = await session_manager.get_own_device_information()

        trust_ui.change_container("label")
        trust_ui.addLabel(D_("This device ID"))
        trust_ui.addText(str(own_device.device_id))
        trust_ui.addLabel(D_("This device's fingerprint"))
        trust_ui.addText(
            " ".join(session_manager.format_identity_key(own_device.identity_key))
        )
        trust_ui.addEmpty()
        trust_ui.addEmpty()

        for index, device in enumerate(devices):
            trust_ui.addLabel(D_("Contact"))
            trust_ui.addJid(jid.JID(device.bare_jid))
            trust_ui.addLabel(D_("Device ID"))
            trust_ui.addText(str(device.device_id))
            trust_ui.addLabel(D_("Fingerprint"))
            trust_ui.addText(
                " ".join(session_manager.format_identity_key(device.identity_key))
            )
            trust_ui.addLabel(D_("Trust this device?"))

            current_trust_level = TrustLevel(device.trust_level_name)
            avaiable_trust_levels = {
                TrustLevel.DISTRUSTED,
                TrustLevel.TRUSTED,
                current_trust_level,
            }

            trust_ui.addList(
                f"trust_{index}",
                options=[trust_level.name for trust_level in avaiable_trust_levels],
                selected=current_trust_level.name,
                styles=["inline"],
            )

            twomemo_active = dict(device.active).get(twomemo.twomemo.NAMESPACE)
            if twomemo_active is None:
                trust_ui.addEmpty()
                trust_ui.addLabel(D_("(not available for Twomemo)"))
            if twomemo_active is False:
                trust_ui.addEmpty()
                trust_ui.addLabel(D_("(inactive for Twomemo)"))

            oldmemo_active = dict(device.active).get(oldmemo.oldmemo.NAMESPACE)
            if oldmemo_active is None:
                trust_ui.addEmpty()
                trust_ui.addLabel(D_("(not available for Oldmemo)"))
            if oldmemo_active is False:
                trust_ui.addEmpty()
                trust_ui.addLabel(D_("(inactive for Oldmemo)"))

            trust_ui.addEmpty()
            trust_ui.addEmpty()

        return result

    @staticmethod
    def __get_joined_muc_users(
        client: SatXMPPClient, xep_0045: XEP_0045, room_jid: jid.JID
    ) -> Set[str]:
        """
        @param client: The client.
        @param xep_0045: A MUC plugin instance.
        @param room_jid: The room JID.
        @return: A set containing the bare JIDs of the MUC participants.
        @raise InternalError: if the MUC is not joined or the entity information of a
            participant isn't available.
        """

        bare_jids: Set[str] = set()

        try:
            room = cast(muc.Room, xep_0045.get_room(client, room_jid))
        except exceptions.NotFound as e:
            raise exceptions.InternalError(
                "Participant list of unjoined MUC requested."
            ) from e

        for user in cast(Dict[str, muc.User], room.roster).values():
            entity = cast(Optional[SatXMPPEntity], user.entity)
            if entity is None:
                raise exceptions.InternalError(
                    f"Participant list of MUC requested, but the entity information of"
                    f" the participant {user} is not available."
                )

            bare_jids.add(entity.jid.userhost())

        return bare_jids

    async def get_session_manager(self, profile: str) -> omemo.SessionManager:
        """
        @param profile: The profile to prepare for.
        @return: A session manager instance for this profile. Creates a new instance if
            none was prepared before.
        """

        try:
            # Try to return the session manager
            return self.__session_managers[profile]
        except KeyError:
            # If a session manager for that profile doesn't exist yet, check whether it is
            # currently being built. A session manager being built is signified by the
            # profile key existing on __session_manager_waiters.
            if profile in self.__session_manager_waiters:
                # If the session manager is being built, add ourselves to the waiting
                # queue
                deferred = defer.Deferred()
                self.__session_manager_waiters[profile].append(deferred)
                return cast(omemo.SessionManager, await deferred)

            # If the session manager is not being built, do so here.
            self.__session_manager_waiters[profile] = []

            # Build and store the session manager
            try:
                session_manager = await prepare_for_profile(
                    self.host, profile, initial_own_label="Libervia"
                )
            except Exception as e:
                # In case of an error during initalization, notify the waiters accordingly
                # and delete them
                for waiter in self.__session_manager_waiters[profile]:
                    waiter.errback(e)
                del self.__session_manager_waiters[profile]

                # Re-raise the exception
                raise

            self.__session_managers[profile] = session_manager

            # Notify the waiters and delete them
            for waiter in self.__session_manager_waiters[profile]:
                waiter.callback(session_manager)
            del self.__session_manager_waiters[profile]

            return session_manager

    async def __message_received_trigger_atm(
        self,
        client: SatXMPPClient,
        message_elt: domish.Element,
        session_manager: omemo.SessionManager,
        sender_device_information: omemo.DeviceInformation,
        timestamp: datetime,
    ) -> None:
        """Check a newly decrypted message stanza for ATM content and perform ATM in case.

        @param client: The client which received the message.
        @param message_elt: The message element. Can be modified.
        @param session_manager: The session manager.
        @param sender_device_information: Information about the device that sent/encrypted
            the message.
        @param timestamp: Timestamp extracted from the SCE time affix.
        """

        trust_message_cache = persistent.LazyPersistentBinaryDict(
            "XEP-0384/TM", client.profile
        )

        new_cache_entries: Set[TrustMessageCacheEntry] = set()

        for trust_message_elt in message_elt.elements(NS_TM, "trust-message"):
            assert isinstance(trust_message_elt, domish.Element)

            try:
                TRUST_MESSAGE_SCHEMA.validate(trust_message_elt.toXml())
            except xmlschema.XMLSchemaValidationError as e:
                raise exceptions.ParsingError(
                    "<trust-message/> element doesn't pass schema validation."
                ) from e

            if trust_message_elt["usage"] != NS_ATM:
                # Skip non-ATM trust message
                continue

            if trust_message_elt["encryption"] != OMEMO.NS_TWOMEMO:
                # Skip non-twomemo trust message
                continue

            for key_owner_elt in trust_message_elt.elements(NS_TM, "key-owner"):
                assert isinstance(key_owner_elt, domish.Element)

                key_owner_jid = jid.JID(key_owner_elt["jid"]).userhostJID()

                for trust_elt in key_owner_elt.elements(NS_TM, "trust"):
                    assert isinstance(trust_elt, domish.Element)

                    new_cache_entries.add(
                        TrustMessageCacheEntry(
                            sender_jid=jid.JID(sender_device_information.bare_jid),
                            sender_key=sender_device_information.identity_key,
                            timestamp=timestamp,
                            trust_update=TrustUpdate(
                                target_jid=key_owner_jid,
                                target_key=base64.b64decode(str(trust_elt)),
                                target_trust=True,
                            ),
                        )
                    )

                for distrust_elt in key_owner_elt.elements(NS_TM, "distrust"):
                    assert isinstance(distrust_elt, domish.Element)

                    new_cache_entries.add(
                        TrustMessageCacheEntry(
                            sender_jid=jid.JID(sender_device_information.bare_jid),
                            sender_key=sender_device_information.identity_key,
                            timestamp=timestamp,
                            trust_update=TrustUpdate(
                                target_jid=key_owner_jid,
                                target_key=base64.b64decode(str(distrust_elt)),
                                target_trust=False,
                            ),
                        )
                    )

        # Load existing cache entries
        existing_cache_entries = {
            TrustMessageCacheEntry.from_dict(d)
            for d in await trust_message_cache.get("cache", [])
        }

        # Discard cache entries by timestamp comparison
        existing_by_target = {
            (
                cache_entry.trust_update.target_jid.userhostJID(),
                cache_entry.trust_update.target_key,
            ): cache_entry
            for cache_entry in existing_cache_entries
        }

        # Iterate over a copy here, such that new_cache_entries can be modified
        for new_cache_entry in set(new_cache_entries):
            existing_cache_entry = existing_by_target.get(
                (
                    new_cache_entry.trust_update.target_jid.userhostJID(),
                    new_cache_entry.trust_update.target_key,
                ),
                None,
            )

            if existing_cache_entry is not None:
                if existing_cache_entry.timestamp > new_cache_entry.timestamp:
                    # If the existing cache entry is newer than the new cache entry,
                    # discard the new one in favor of the existing one
                    new_cache_entries.remove(new_cache_entry)
                else:
                    # Otherwise, discard the existing cache entry. This includes the case
                    # when both cache entries have matching timestamps.
                    existing_cache_entries.remove(existing_cache_entry)

        # If the sending device is trusted, apply the new cache entries
        applied_trust_updates: Set[TrustUpdate] = set()

        if TrustLevel(sender_device_information.trust_level_name) is TrustLevel.TRUSTED:
            # Iterate over a copy such that new_cache_entries can be modified
            for cache_entry in set(new_cache_entries):
                trust_update = cache_entry.trust_update

                trust_level = (
                    TrustLevel.TRUSTED
                    if trust_update.target_trust
                    else TrustLevel.DISTRUSTED
                )

                await session_manager.set_trust(
                    trust_update.target_jid.userhost(),
                    trust_update.target_key,
                    trust_level.name,
                )

                applied_trust_updates.add(trust_update)

                new_cache_entries.remove(cache_entry)

        # Store the remaining existing and new cache entries
        await trust_message_cache.force(
            "cache", [tm.to_dict() for tm in existing_cache_entries | new_cache_entries]
        )

        # If the trust of at least one device was modified, run the ATM cache update logic
        if len(applied_trust_updates) > 0:
            await manage_trust_message_cache(
                client, session_manager, frozenset(applied_trust_updates)
            )

    async def _message_received_trigger(
        self,
        client: SatXMPPClient,
        message_elt: domish.Element,
        post_treat: defer.Deferred,
    ) -> bool:
        """
        @param client: The client which received the message.
        @param message_elt: The message element. Can be modified.
        @param post_treat: A deferred which evaluates to a :class:`MessageData` once the
            message has fully progressed through the message receiving flow. Can be used
            to apply treatments to the fully processed message, like marking it as
            encrypted.
        @return: Whether to continue the message received flow.
        """
        if client.is_component:
            return True
        muc_plaintext_cache_key: Optional[MUCPlaintextCacheKey] = None

        sender_jid = jid.JID(message_elt["from"])
        feedback_jid: jid.JID

        message_type = message_elt.getAttribute("type", C.MESS_TYPE_NORMAL)
        is_muc_message = message_type == C.MESS_TYPE_GROUPCHAT
        if is_muc_message:
            if self.__xep_0045 is None:
                log.warning(
                    "Ignoring MUC message since plugin XEP-0045 is not available."
                )
                # Can't handle a MUC message without XEP-0045, let the flow continue
                # normally
                return True

            room_jid = feedback_jid = sender_jid.userhostJID()

            try:
                room = cast(muc.Room, self.__xep_0045.get_room(client, room_jid))
            except exceptions.NotFound:
                log.warning(
                    f"Ignoring MUC message from a room that has not been joined:"
                    f" {room_jid}"
                )
                # Whatever, let the flow continue
                return True

            sender_user = cast(Optional[muc.User], room.getUser(sender_jid.resource))
            if sender_user is None:
                log.warning(
                    f"Ignoring MUC message from room {room_jid} since the sender's user"
                    f" wasn't found {sender_jid.resource}"
                )
                # Whatever, let the flow continue
                return True

            sender_user_jid = cast(Optional[jid.JID], sender_user.entity)
            if sender_user_jid is None:
                log.warning(
                    f"Ignoring MUC message from room {room_jid} since the sender's bare"
                    f" JID couldn't be found from its user information: {sender_user}"
                )
                # Whatever, let the flow continue
                return True

            sender_jid = sender_user_jid

            message_uid: Optional[str] = None
            if self.__xep_0359 is not None:
                message_uid = self.__xep_0359.get_origin_id(message_elt)
            if message_uid is None:
                message_uid = message_elt.getAttribute("id")
            if message_uid is not None:
                muc_plaintext_cache_key = MUCPlaintextCacheKey(
                    client, room_jid, message_uid
                )
        else:
            # I'm not sure why this check is required, this code is copied from the old
            # plugin.
            if sender_jid.userhostJID() == client.jid.userhostJID():
                try:
                    feedback_jid = jid.JID(message_elt["to"])
                except KeyError:
                    feedback_jid = client.server_jid
            else:
                feedback_jid = sender_jid

        sender_bare_jid = sender_jid.userhost()

        message: Optional[omemo.Message] = None
        encrypted_elt: Optional[domish.Element] = None

        twomemo_encrypted_elt = cast(
            Optional[domish.Element],
            next(message_elt.elements(twomemo.twomemo.NAMESPACE, "encrypted"), None),
        )

        oldmemo_encrypted_elt = cast(
            Optional[domish.Element],
            next(message_elt.elements(oldmemo.oldmemo.NAMESPACE, "encrypted"), None),
        )

        try:
            session_manager = await self.get_session_manager(cast(str, client.profile))
        except Exception as e:
            log.error(f"error while preparing profile for {client.profile}: {e}")
            # we don't want to block the workflow
            return True

        if twomemo_encrypted_elt is not None:
            try:
                message = twomemo.etree.parse_message(
                    xml_tools.domish_elt_2_et_elt(twomemo_encrypted_elt), sender_bare_jid
                )
            except (ValueError, XMLSchemaValidationError):
                log.warning(
                    f"Ingoring malformed encrypted message for namespace"
                    f" {twomemo.twomemo.NAMESPACE}: {twomemo_encrypted_elt.toXml()}"
                )
            else:
                encrypted_elt = twomemo_encrypted_elt

        if oldmemo_encrypted_elt is not None:
            try:
                message = await oldmemo.etree.parse_message(
                    xml_tools.domish_elt_2_et_elt(oldmemo_encrypted_elt),
                    sender_bare_jid,
                    client.jid.userhost(),
                    session_manager,
                )
            except (ValueError, XMLSchemaValidationError):
                log.warning(
                    f"Ingoring malformed encrypted message for namespace"
                    f" {oldmemo.oldmemo.NAMESPACE}: {oldmemo_encrypted_elt.toXml()}"
                )
            except omemo.SenderNotFound:
                log.warning(
                    f"Ingoring encrypted message for namespace"
                    f" {oldmemo.oldmemo.NAMESPACE} by unknown sender:"
                    f" {oldmemo_encrypted_elt.toXml()}"
                )
            else:
                encrypted_elt = oldmemo_encrypted_elt

        if message is None or encrypted_elt is None:
            # None of our business, let the flow continue
            return True

        message_elt.children.remove(encrypted_elt)

        log.debug(
            f"{message.namespace} message of type {message_type} received from"
            f" {sender_bare_jid}"
        )

        plaintext: Optional[bytes]
        device_information: omemo.DeviceInformation

        if (
            muc_plaintext_cache_key is not None
            and muc_plaintext_cache_key in self.__muc_plaintext_cache
        ):
            # Use the cached plaintext
            plaintext = self.__muc_plaintext_cache.pop(muc_plaintext_cache_key)

            # Since this message was sent by us, use the own device information here
            device_information, __ = await session_manager.get_own_device_information()
        else:
            try:
                plaintext, device_information, __ = await session_manager.decrypt(message)
            except omemo.MessageNotForUs:
                # The difference between this being a debug or a warning is whether there
                # is a body included in the message. Without a body, we can assume that
                # it's an empty OMEMO message used for protocol stability reasons, which
                # is not expected to be sent to all devices of all recipients. If a body
                # is included, we can assume that the message carries content and we
                # missed out on something.
                if len(list(message_elt.elements(C.NS_CLIENT, "body"))) > 0:
                    client.feedback(
                        feedback_jid,
                        D_(
                            f"An OMEMO message from {sender_jid.full()} has not been"
                            f" encrypted for our device, we can't decrypt it."
                        ),
                        {C.MESS_EXTRA_INFO: C.EXTRA_INFO_DECR_ERR},
                    )
                    log.warning("Message not encrypted for us.")
                else:
                    log.debug("Message not encrypted for us.")

                # No point in further processing this message.
                return False
            except Exception as e:
                log.warning(
                    _("Can't decrypt message: {reason}\n{xml}").format(
                        reason=e, xml=message_elt.toXml()
                    )
                )
                client.feedback(
                    feedback_jid,
                    D_(
                        f"An OMEMO message from {sender_jid.full()} can't be decrypted:"
                        f" {e}"
                    ),
                    {C.MESS_EXTRA_INFO: C.EXTRA_INFO_DECR_ERR},
                )
                # No point in further processing this message
                return False

        affix_values: Optional[SCEAffixValues] = None

        if message.namespace == twomemo.twomemo.NAMESPACE:
            if plaintext is not None:
                # XEP_0420.unpack_stanza handles the whole unpacking, including the
                # relevant modifications to the element
                sce_profile = (
                    OMEMO.SCE_PROFILE_GROUPCHAT if is_muc_message else OMEMO.SCE_PROFILE
                )
                try:
                    affix_values = self.__xep_0420.unpack_stanza(
                        sce_profile, message_elt, plaintext
                    )
                except Exception as e:
                    log.warning(
                        D_(f"Error unpacking SCE-encrypted message: {e}\n{plaintext}")
                    )
                    client.feedback(
                        feedback_jid,
                        D_(
                            f"An OMEMO message from {sender_jid.full()} was rejected:"
                            f" {e}"
                        ),
                        {C.MESS_EXTRA_INFO: C.EXTRA_INFO_DECR_ERR},
                    )
                    # No point in further processing this message
                    return False
                else:
                    if affix_values.timestamp is not None:
                        # TODO: affix_values.timestamp contains the timestamp included in
                        # the encrypted element here. The XEP says it SHOULD be displayed
                        # with the plaintext by clients.
                        pass

        if message.namespace == oldmemo.oldmemo.NAMESPACE:
            # Remove all body elements from the original element, since those act as
            # fallbacks in case the encryption protocol is not supported
            for child in message_elt.elements():
                if child.name == "body":
                    message_elt.children.remove(child)

            if plaintext is not None:
                # Add the decrypted body
                message_elt.addElement("body", content=plaintext.decode("utf-8"))

        # Mark the message as trusted or untrusted. Undecided counts as untrusted here.
        trust_level = await session_manager._evaluate_custom_trust_level(
            device_information
        )

        if trust_level is omemo.TrustLevel.TRUSTED:
            post_treat.addCallback(client.encryption.mark_as_trusted)
        else:
            post_treat.addCallback(client.encryption.mark_as_untrusted)

        # Mark the message as originally encrypted
        post_treat.addCallback(
            client.encryption.mark_as_encrypted, namespace=message.namespace
        )

        # Handle potential ATM trust updates
        if affix_values is not None and affix_values.timestamp is not None:
            await self.__message_received_trigger_atm(
                client,
                message_elt,
                session_manager,
                device_information,
                affix_values.timestamp,
            )

        # Message processed successfully, continue with the flow
        return True

    async def __send_trigger(self, client: SatXMPPClient, stanza: domish.Element) -> bool:
        """
        @param client: The client sending this message.
        @param stanza: The stanza that is about to be sent. Can be modified.
        @return: Whether the send message flow should continue or not.
        """
        # SCE is only applicable to message and IQ stanzas
        # FIXME: temporary disabling IQ stanza encryption
        if stanza.name not in {"message"}:  # , "iq" }:
            return True

        # Get the intended recipient
        recipient = stanza.getAttribute("to", None)
        if recipient is None:
            if stanza.name == "message":
                # Message stanzas must have a recipient
                raise exceptions.InternalError(
                    f"Message without recipient encountered. Blocking further processing"
                    f" to avoid leaking plaintext data: {stanza.toXml()}"
                )

            # IQs without a recipient are a thing, I believe those simply target the
            # server and are thus not eligible for e2ee anyway.
            return True

        # Parse the JID
        recipient_bare_jid = jid.JID(recipient).userhostJID()

        # Check whether encryption with twomemo is requested
        encryption = client.encryption.getSession(recipient_bare_jid)

        if encryption is None:
            # Encryption is not requested for this recipient
            return True

        encryption_ns = encryption["plugin"].namespace
        # All pre-checks done, we can start encrypting!
        if encryption_ns in (twomemo.twomemo.NAMESPACE, oldmemo.oldmemo.NAMESPACE):
            await self.encrypt(
                client,
                encryption_ns,
                stanza,
                recipient_bare_jid,
                stanza.getAttribute("type", C.MESS_TYPE_NORMAL) == C.MESS_TYPE_GROUPCHAT,
                stanza.getAttribute("id", None),
            )
        else:
            # Encryption is requested for this recipient, but not with twomemo
            return True

        # Add a store hint if this is a message stanza
        if stanza.name == "message":
            self.__xep_0334.add_hint_elements(stanza, ["store"])

        # Let the flow continue.
        return True

    async def download_missing_device_lists(
        self,
        client: SatXMPPClient,
        namespace: NamespaceType,
        recipients: Iterable[jid.JID],
        session_manager: omemo.SessionManager,
    ) -> None:
        """Retrieves missing device lists for recipients outside the profile's roster.

        @param client: XMPP client.
        @param namespace: The namespace of the OMEMO version to use.
        @param recipients: Recipients to verify device list presence.
        @param session_manager: OMEMO session manager.
        """
        recipients = [j.userhostJID() for j in recipients]
        not_in_roster = [j for j in recipients if not client.roster.is_jid_in_roster(j)]
        for bare_jid in not_in_roster:
            device_information = await session_manager.get_device_information(
                bare_jid.userhost()
            )
            if not device_information or not all(
                namespace in di.namespaces for di in device_information
            ):
                if namespace == self.NS_TWOMEMO:
                    algo, node = "OMEMO", TWOMEMO_DEVICE_LIST_NODE
                elif namespace == self.NS_OLDMEMO:
                    algo, node = "OMEMO_legacy", OLDMEMO_DEVICE_LIST_NODE
                else:
                    raise ValueError(f"Invalid namespace: {namespace!r}")

                try:
                    items, __ = await self._j.get_items(client, bare_jid, node, 1)

                except Exception:
                    log.exception(f"Can't find {algo} devices list for {bare_jid}.")
                else:
                    await self._update_device_list(client, bare_jid, items)
                    log.warning(f"{algo} devices list updated for {bare_jid}.")

    async def encrypt(
        self,
        client: SatXMPPClient,
        namespace: NamespaceType,
        stanza: domish.Element,
        recipient_jids: Union[jid.JID, Set[jid.JID]],
        is_muc_message: bool,
        stanza_id: Optional[str],
    ) -> None:
        """
        @param client: The client.
        @param namespace: The namespace of the OMEMO version to use.
        @param stanza: The stanza. Twomemo will encrypt the whole stanza using SCE,
            oldmemo will encrypt only the body. The stanza is modified by this call.
        @param recipient_jid: The JID of the recipients.
            Can be a bare (aka "userhost") JIDs but doesn't have to.
            A single JID can be used.
        @param is_muc_message: Whether the stanza is a message stanza to a MUC room.
        @param stanza_id: The id of this stanza. Especially relevant for message stanzas
            to MUC rooms such that the outgoing plaintext can be cached for MUC message
            reflection handling.

        @warning: The calling code MUST take care of adding the store message processing
            hint to the stanza if applicable! This can be done before or after this call,
            the order doesn't matter.
        """
        if isinstance(recipient_jids, jid.JID):
            recipient_jids = {recipient_jids}
        if not recipient_jids:
            raise exceptions.InternalError("At least one JID must be specified")
        recipient_jid = next(iter(recipient_jids))

        muc_plaintext_cache_key: Optional[MUCPlaintextCacheKey] = None

        recipient_bare_jids: Set[str]
        feedback_jid: jid.JID

        if is_muc_message:
            if len(recipient_jids) != 1:
                raise exceptions.InternalError(
                    'Only one JID can be set when "is_muc_message" is set'
                )
            if self.__xep_0045 is None:
                raise exceptions.InternalError(
                    "Encryption of MUC message requested, but plugin XEP-0045 is not"
                    " available."
                )

            if stanza_id is None:
                raise exceptions.InternalError(
                    "Encryption of MUC message requested, but stanza id not available."
                )

            room_jid = feedback_jid = recipient_jid.userhostJID()

            recipient_bare_jids = self.__get_joined_muc_users(
                client, self.__xep_0045, room_jid
            )

            muc_plaintext_cache_key = MUCPlaintextCacheKey(
                client=client, room_jid=room_jid, message_uid=stanza_id
            )
        else:
            recipient_bare_jids = {r.userhost() for r in recipient_jids}
            feedback_jid = recipient_jid.userhostJID()

        log.debug(
            f"Intercepting message that is to be encrypted by {namespace} for"
            f" {recipient_bare_jids}"
        )

        def prepare_stanza() -> Optional[bytes]:
            """Prepares the stanza for encryption.

            Does so by removing all parts that are not supposed to be sent in plain. Also
            extracts/prepares the plaintext to encrypt.

            @return: The plaintext to encrypt. Returns ``None`` in case body-only
                encryption is requested and no body was found. The function should
                gracefully return in that case, i.e. it's not a critical error that should
                abort the message sending flow.
            """

            if namespace == twomemo.twomemo.NAMESPACE:
                return self.__xep_0420.pack_stanza(
                    OMEMO.SCE_PROFILE_GROUPCHAT if is_muc_message else OMEMO.SCE_PROFILE,
                    stanza,
                )

            if namespace == oldmemo.oldmemo.NAMESPACE:
                plaintext: Optional[bytes] = None

                for child in stanza.elements():
                    if child.name == "body" and plaintext is None:
                        plaintext = str(child).encode("utf-8")

                    # Any other sensitive elements to remove here?
                    if child.name in {"body", "html"}:
                        stanza.children.remove(child)

                if plaintext is None:
                    log.warning(
                        "No body found in intercepted message to be encrypted with"
                        f" oldmemo. [{client.profile}]"
                    )

                return plaintext

            return assert_never(namespace)

        # The stanza/plaintext preparation was moved into its own little function for type
        # safety reasons.
        plaintext = prepare_stanza()
        if plaintext is None:
            return

        log.debug(f"Plaintext to encrypt: {plaintext}")

        session_manager = await self.get_session_manager(client.profile)
        await self.download_missing_device_lists(
            client, namespace, recipient_jids, session_manager
        )

        try:
            messages, encryption_errors = await session_manager.encrypt(
                frozenset(recipient_bare_jids),
                {namespace: plaintext},
                backend_priority_order=[namespace],
                identifier=feedback_jid.userhost(),
            )
        except Exception as e:
            msg = _(
                # pylint: disable=consider-using-f-string
                "Can't encrypt message for {entities}: {reason}".format(
                    entities=", ".join(recipient_bare_jids), reason=e
                )
            )
            log.warning(msg)
            client.feedback(feedback_jid, msg, {C.MESS_EXTRA_INFO: C.EXTRA_INFO_ENCR_ERR})
            raise e

        if len(encryption_errors) > 0:
            log.warning(
                f"Ignored the following non-critical encryption errors:"
                f" {encryption_errors}"
            )

            encrypted_errors_stringified = ", ".join(
                [
                    f"device {err.device_id} of {err.bare_jid} under namespace"
                    f" {err.namespace}"
                    for err in encryption_errors
                ]
            )

            client.feedback(
                feedback_jid,
                D_(
                    "There were non-critical errors during encryption resulting in some"
                    " of your destinees' devices potentially not receiving the message."
                    " This happens when the encryption data/key material of a device is"
                    " incomplete or broken, which shouldn't happen for actively used"
                    " devices, and can usually be ignored. The following devices are"
                    f" affected: {encrypted_errors_stringified}."
                ),
            )

        message = next(message for message in messages if message.namespace == namespace)

        if namespace == twomemo.twomemo.NAMESPACE:
            # Add the encrypted element
            stanza.addChild(
                xml_tools.et_elt_2_domish_elt(twomemo.etree.serialize_message(message))
            )

        if namespace == oldmemo.oldmemo.NAMESPACE:
            # Add the encrypted element
            stanza.addChild(
                xml_tools.et_elt_2_domish_elt(oldmemo.etree.serialize_message(message))
            )

        if muc_plaintext_cache_key is not None:
            self.__muc_plaintext_cache[muc_plaintext_cache_key] = plaintext

    async def __on_device_list_update(
        self, items_event: pubsub.ItemsEvent, profile: str
    ) -> None:
        """Handle device list updates fired by PEP.

        @param items_event: The event.
        @param profile: The profile this event belongs to.
        """

        sender = cast(jid.JID, items_event.sender)
        items = cast(List[domish.Element], items_event.items)
        client = self.host.get_client(profile)
        await self._update_device_list(client, sender, items)

    async def _update_device_list(
        self, client: SatXMPPEntity, sender: jid.JID, items: list[domish.Element]
    ) -> None:

        if len(items) > 1:
            log.warning("Ignoring device list update with more than one element.")
            return

        item = next(iter(items), None)
        if item is None:
            log.debug("Ignoring empty device list update.")
            return

        item_elt = xml_tools.domish_elt_2_et_elt(item)

        device_list: Dict[int, Optional[str]] = {}
        namespace: Optional[str] = None

        list_elt = item_elt.find(f"{{{twomemo.twomemo.NAMESPACE}}}devices")
        if list_elt is not None:
            try:
                device_list = twomemo.etree.parse_device_list(list_elt)
            except XMLSchemaValidationError:
                pass
            else:
                namespace = twomemo.twomemo.NAMESPACE

        list_elt = item_elt.find(f"{{{oldmemo.oldmemo.NAMESPACE}}}list")
        if list_elt is not None:
            try:
                device_list = oldmemo.etree.parse_device_list(list_elt)
            except XMLSchemaValidationError:
                pass
            else:
                namespace = oldmemo.oldmemo.NAMESPACE

        if namespace is None:
            log.warning(
                f"Malformed device list update item:"
                f" {ET.tostring(item_elt, encoding='unicode')}"
            )
            return

        session_manager = await self.get_session_manager(client.profile)

        await session_manager.update_device_list(
            namespace, sender.userhost(), device_list
        )