view libervia/backend/plugins/plugin_xep_0033.py @ 4310:d27228b3c704

test (unit): add test for email gateway: rel 450
author Goffi <goffi@goffi.org>
date Thu, 26 Sep 2024 16:12:01 +0200
parents 94e0968987cd
children 530f86f078cc
line wrap: on
line source

#!/usr/bin/env python3

# Libervia plugin for Extended Stanza Addressing (XEP-0033)
# Copyright (C) 2009-2024 Jérôme Poisson (goffi@goffi.org)
# Copyright (C) 2013-2016 Adrien Cossa (souliane@mailoo.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from typing import Iterator, Literal, Self

from pydantic import BaseModel, model_validator
from twisted.internet import defer
from twisted.words.protocols.jabber import jid
from twisted.words.protocols.jabber.xmlstream import XMPPHandler
from twisted.words.xish import domish
from wokkel import disco, iwokkel
from zope.interface import implementer

from libervia.backend.core import exceptions
from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import SatXMPPEntity
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.models.core import MessageData
from libervia.backend.models.types import JIDType
from libervia.backend.tools import trigger
from libervia.backend.tools.xml_tools import element_copy

log = getLogger(__name__)


# TODO: fix Prosody "addressing" plugin to leave the concerned bcc according to the spec:
#   http://xmpp.org/extensions/xep-0033.html#addr-type-bcc "This means that the server
#   MUST remove these addresses before the stanza is delivered to anyone other than the
#   given bcc addressee or the multicast service of the bcc addressee."
#
#   http://xmpp.org/extensions/xep-0033.html#multicast "Each 'bcc' recipient MUST receive
#   only the <address type='bcc'/> associated with that addressee."

# TODO: fix Prosody "addressing" plugin to determine itself if remote servers supports
#   this XEP


PLUGIN_INFO = {
    C.PI_NAME: "Extended Stanza Addressing Protocol Plugin",
    C.PI_IMPORT_NAME: "XEP-0033",
    C.PI_TYPE: "XEP",
    C.PI_MODES: C.PLUG_MODE_BOTH,
    C.PI_PROTOCOLS: ["XEP-0033"],
    C.PI_DEPENDENCIES: [],
    C.PI_MAIN: "XEP_0033",
    C.PI_HANDLER: "yes",
    C.PI_DESCRIPTION: _(
        "Efficiently send messages to several recipients, using metadata to transmit "
        "them with main recipients (to), carbon copies (cc), and blind carbon copies "
        "(bcc) fields in a similar manner as for email."
    ),
}

NS_ADDRESS = "http://jabber.org/protocol/address"
RECIPIENT_FIELDS = ("to", "cc", "bcc")


class AddressType(BaseModel):
    jid: JIDType | None = None
    desc: str | None = None
    delivered: bool | None = None

    def set_attribute(self, address_elt: domish.Element) -> None:
        """Set <address> element attribute from this instance's data."""
        if self.jid:
            address_elt["jid"] = str(self.jid)
        if self.desc:
            address_elt["desc"] = self.desc
        if self.delivered is not None:
            address_elt["delivered"] = "true" if self.delivered else "false"

    @classmethod
    def from_element(cls, address_elt: domish.Element) -> Self:
        """Create an AddressType instance from an <address> element.

        @param address_elt: The <address> element.
        @return: AddressType instance.
        """
        if address_elt.uri != NS_ADDRESS or address_elt.name != "address":
            raise ValueError("Element is not an <address> element")

        kwargs = {}
        if address_elt.hasAttribute("jid"):
            kwargs["jid"] = jid.JID(address_elt["jid"])
        if address_elt.hasAttribute("desc"):
            kwargs["desc"] = address_elt["desc"]
        if address_elt.hasAttribute("delivered"):
            kwargs["delivered"] = address_elt["delivered"] == "true"
        return cls(**kwargs)

    def to_element(self) -> domish.Element:
        """Build the <address> element from this instance's data.

        @return: <address> element.
        """
        address_elt = domish.Element((NS_ADDRESS, "address"))
        self.set_attribute(address_elt)
        return address_elt


class AddressesData(BaseModel):
    to: list[AddressType] | None = None
    cc: list[AddressType] | None = None
    bcc: list[AddressType] | None = None
    replyto: list[AddressType] | None = None
    replyroom: list[AddressType] | None = None
    noreply: bool | None = None
    ofrom: JIDType | None = None

    @model_validator(mode="after")
    def check_minimal_data(self) -> Self:
        assert self.to or self.cc or self.bcc, "At least one recipent must be set"
        if self.noreply and (self.replyto is not None or self.replyroom is not None):
            log.warning(
                '"noreply" can\'t be used with "replyto" or "replyroom". Ignoring reply '
                f'fields ({self.replyto=}, {self.replyroom=}).'
            )
            # We reset instead of raising a ValueError, because this can happen in
            # incoming messages and we should not discard them.
            self.replyto = self.replyroom = None
        return self

    @property
    def addresses(self) -> Iterator[AddressType]:
        """Iterator over all recipient addresses."""
        for field in RECIPIENT_FIELDS:
            addresses = getattr(self, field)
            if not addresses:
                continue
            yield from addresses

    @staticmethod
    def add_address_element(
        addresses_elt: domish.Element, type_: str, address: AddressType | None
    ) -> None:
        """Add <address> element to parent <addresses> element.

        @param addresses_elt: Parent <addresses> element.
        @param type_: Value of "type" attribute.
        @param address: Address data.
        """

        address_elt = addresses_elt.addElement("address")
        address_elt["type"] = type_
        if address is not None:
            address.set_attribute(address_elt)

    @classmethod
    def from_element(cls, addresses_elt: domish.Element) -> Self:
        """Create an AddressesData instance from an <addresses> element.

        @param addresses_elt: The <addresses> element or its direct parent.
        @return: AddressesData instance.
        @raise NotFound: No <addresses> element found.
        """
        if addresses_elt.uri != NS_ADDRESS or addresses_elt.name != "addresses":
            child_addresses_elt = next(
                addresses_elt.elements(NS_ADDRESS, "addresses"), None
            )
            if child_addresses_elt is None:
                raise exceptions.NotFound("<addresses> element not found")
            else:
                addresses_elt = child_addresses_elt

        kwargs = {}
        for address_elt in addresses_elt.elements(NS_ADDRESS, "address"):
            address_type = address_elt.getAttribute("type")
            if address_type in ("to", "cc", "bcc", "replyto", "replyroom"):
                try:
                    address = AddressType.from_element(address_elt)
                except Exception as e:
                    log.warning(f"Invalid <address> element: {e}\n{address_elt.toXml()}")
                else:
                    kwargs.setdefault(address_type, []).append(address)
            elif address_type == "noreply":
                kwargs["noreply"] = True
            elif address_type == "ofrom":
                kwargs["ofrom"] = jid.JID(address_elt["jid"])
            else:
                log.warning(
                    f"Invalid <address> element: unknonwn type {address_type!r}\n"
                    f"{address_elt.toXml()}"
                )
        return cls(**kwargs)

    def to_element(self) -> domish.Element:
        """Build the <addresses> element from this instance's data.

        @return: <addresses> element.
        """
        addresses_elt = domish.Element((NS_ADDRESS, "addresses"))

        if self.to:
            for address in self.to:
                self.add_address_element(addresses_elt, "to", address)
        if self.cc:
            for address in self.cc:
                self.add_address_element(addresses_elt, "cc", address)
        if self.bcc:
            for address in self.bcc:
                self.add_address_element(addresses_elt, "bcc", address)
        if self.replyto:
            for address in self.replyto:
                self.add_address_element(addresses_elt, "replyto", address)
        if self.replyroom:
            for address in self.replyroom:
                self.add_address_element(addresses_elt, "replyroom", address)
        if self.noreply:
            self.add_address_element(addresses_elt, "noreply", None)
        if self.ofrom is not None:
            address_elt = addresses_elt.addElement("address")
            address_elt["type"] = "ofrom"
            address_elt["jid"] = self.ofrom.full()

        return addresses_elt


class XEP_0033:
    """
    Implementation for XEP-0033
    """

    def __init__(self, host):
        log.info(_("Extended Stanza Addressing plugin initialization"))
        self.host = host
        host.register_namespace("address", NS_ADDRESS)
        self.internal_data = {}
        host.trigger.add(
            "sendMessage",
            self.send_message_trigger,
            # We want this trigger to be the last one, as it may send messages.
            trigger.TriggerManager.MIN_PRIORITY,
        )
        host.trigger.add(
            "sendMessageComponent",
            self.send_message_trigger,
            # We want this trigger to be the last one, as it may send messages.
            trigger.TriggerManager.MIN_PRIORITY,
        )
        host.trigger.add("message_received", self.message_received_trigger)

    async def _stop_if_all_delivered(
        self, client: SatXMPPEntity, mess_data: MessageData, addr_data: AddressesData
    ) -> None:
        """Check if all message have been delivered, and stop workflow in this case.

        If workflow is stopped, message will be added to history and a signal will be sent
        to bridge.
        @param client: Client session.
        @param mess_data: Message data.
        @param addr_data: Addresses data.

        @raise exceptions.CancelError: All message have been delivered and workflow is
            terminated.
        """
        if all(a.delivered for a in addr_data.addresses):
            await client.message_add_to_history(mess_data)
            await client.message_send_to_bridge(mess_data)
            raise exceptions.CancelError(
                f"Message has been delivered by {PLUGIN_INFO['C.PI_NAME']}."
            )

    async def _handle_addresses(self, client, mess_data: MessageData) -> MessageData:
        """Handle Extended Stanza Addressing metadata for outgoing messages."""
        if not "addresses" in mess_data["extra"]:
            return mess_data

        if mess_data["extra"].get(C.MESS_KEY_ENCRYPTED, False):
            # TODO: Message must be encrypted for all recipients, and "to" correspond to
            #   multicast service in this case.
            raise NotImplementedError(
                "End-to-end encryption is not supported yet with multicast addressing."
            )

        data = AddressesData(**mess_data["extra"]["addresses"])
        recipients = set()
        domains: dict[str, list[AddressType]] = {}
        for address in data.addresses:
            if address.jid is None:
                raise NotImplementedError("Non JID addresses are not supported yet.")
            recipients.add(address.jid)
            try:
                domains[address.jid.host].append(address)
            except KeyError:
                domains[address.jid.host] = [address]

        to_recipient_jid = mess_data["to"]

        if to_recipient_jid.user and to_recipient_jid not in recipients:
            # If the main recipient is not a service (i.e. it has a "user" part), we want
            # to move it to the XEP-0033's "to" addresses, so we can use the multicast
            # service for <message> "to" attribute.
            to_recipient_addr = AddressType(jid=to_recipient_jid)
            if data.to is None:
                data.to = [to_recipient_addr]
            else:
                data.to.insert(0, to_recipient_addr)
            recipients.add(to_recipient_jid)
            domains.setdefault(to_recipient_jid.host, []).append(to_recipient_addr)

        # XXX: If our server doesn't handle multicast, we don't check sub-services as
        #   requested in §2.2, because except if there is a special arrangement with the
        #   server, a service at a sub-domain can't send message in the name of the main
        #   domain (e.g. "multicast.example.org" can't send message from
        #   "juliet@example.org"). So the specification is a bit dubious here, and we only
        #   use the main server multicast feature if it's present.
        if not await self.host.memory.disco.has_feature(
            client, NS_ADDRESS, client.server_jid
        ):
            # No multicast service
            log.warning(
                _(
                    f"Server of {client.profile} does not support XEP-0033 "
                    f"({PLUGIN_INFO[C.PI_IMPORT_NAME]}). We will send all messages ourselves."
                )
            )
            await self.deliver_messages(client, mess_data, data, domains)
            await self._stop_if_all_delivered(client, mess_data, data)
        else:
            # XXX: We delived ourself to multicast services because it's not correctly
            #     handled by some multicast services, notably by Prosody mod_addresses.
            # FIXME: Only do this workaround for known incomplete implementations.
            # TODO: remove this workaround when known implementations have been completed.
            if mess_data["to"] != client.server_jid:
                # We send the message to our server which will distribute it to the right
                # locations. The initial ``to`` jid has been moved to ``data.to`` above.
                # FIXME: When sub-services issue is properly handler, a sub-service JID
                #     supporting multicast should be allowed here.
                mess_data["to"] = client.server_jid
            await self.deliver_messages(
                client, mess_data, data, domains, multicast_only=True
            )
            await self._stop_if_all_delivered(client, mess_data, data)

        message_elt = mess_data["xml"]
        message_elt["to"] = str(mess_data["to"])
        message_elt.addChild(data.to_element())
        return mess_data

    async def deliver_messages(
        self,
        client,
        mess_data: MessageData,
        addr_data: AddressesData,
        domains: dict[str, list[AddressType]],
        multicast_only: bool = False,
    ) -> None:
        """Send messages to requested recipients.

        If a domain handles multicast, a single message will be send there.
        @param client: Client session.
        @param mess_data: Messsa data.
        @param addr_data: XEP-0033 addresses data.
        @param domains: Domain to addresses map.
            Note that that the addresses instances in this argument must be the same as in
            ``addr_data`` (There ``delivered`` status will be manipulated).
        @param multicast_only: if True, only multicast domain will be delivered.
        """
        # We'll modify delivered status, so we keep track here of addresses which have
        # already be delivered.
        already_delivered = [a for a in addr_data.addresses if a.delivered]
        multicast_domains = set()
        for domain, domain_addresses in domains.items():
            if domain == client.server_jid.host:

                # ``client.server_jid`` is discarded to avoid sending twice the same
                # message. ``multicast_only`` flag is set when the server supports
                # multicast, so the message will be sent to it at the end of the workflow.
                continue
            if len(domain_addresses) > 1:
                # For domains with multiple recipients, we check if we they support
                # multicast and so if we can deliver to them directly.
                if await self.host.memory.disco.has_feature(
                    client, NS_ADDRESS, jid.JID(domain)
                ):
                    multicast_domains.add(domain)

        # We remove bcc, they have a special handling.
        bcc = addr_data.bcc or []
        addr_data.bcc = None

        # Mark all addresses as "delivered" upfront, even if some won't actually be sent
        # by us (when multicast_only is set). This flag signals to multicast services that
        # they shouldn't handle these addresses. We'll remove the "delivered" status from
        # undelivered addresses post-delivery.
        for address in addr_data.addresses:
            address.delivered = True

        # First, we send multicast messages.
        for domain in multicast_domains:
            something_to_deliver = False
            for address in domains[domain]:
                if address in already_delivered:
                    continue
                # We need to mark as non delivered, so the multicast service will deliver
                # itself.
                address.delivered = False
                something_to_deliver = True

            if not something_to_deliver:
                continue

            domain_bcc = [a for a in bcc if a.jid and a.jid.host == domain]
            message_elt = element_copy(mess_data["xml"])
            # The service must only see BCC from its own domain.
            addr_data.bcc = domain_bcc
            message_elt.addChild(addr_data.to_element())
            message_elt["to"] = domain
            await client.a_send(message_elt)
            for address in domains[domain] + domain_bcc:
                # Those addresses have now been delivered.
                address.delivered = True

        if multicast_only:
            # Only addresses from multicast domains must be marked as delivered.
            for address in addr_data.addresses:
                if (
                    address.jid is not None
                    and address.jid.host not in multicast_domains
                    and address not in already_delivered
                ):
                    address.delivered = None

            # We have delivered to all multicast services, we stop here.
            # But first we need to restore BCC, without the delivered ones.
            addr_data.bcc = [a for a in bcc if not a.delivered]
            return

        # Then BCC
        for address in bcc:
            if address in already_delivered:
                continue
            if address.jid is None:
                raise NotImplementedError(
                    f"Sending to non JID address is not supported yet"
                )
            if address.jid.host in multicast_domains:
                # Address has already be handled by a multicast domain
                continue
            message_elt = element_copy(mess_data["xml"])
            # The recipient must only get its own BCC
            addr_data.bcc = [address]
            message_elt.addChild(addr_data.to_element())
            message_elt["to"] = address.jid.full()
            await client.a_send(message_elt)

        # BCC address must be removed.
        addr_data.bcc = None

        # and finally, other ones.
        message_elt = mess_data["xml"]
        message_elt.addChild(addr_data.to_element())
        non_bcc_addresses = (addr_data.to or []) + (addr_data.cc or [])
        for address in non_bcc_addresses:
            if address in already_delivered:
                continue
            if address.jid is None:
                raise NotImplementedError(
                    f"Sending to non JID address is not supported yet"
                )
            if address.jid.host in multicast_domains:
                # Multicast domains have already been delivered.
                continue
            message_elt["to"] = address.jid.full()
            await client.a_send(message_elt)

    def send_message_trigger(
        self, client, mess_data, pre_xml_treatments, post_xml_treatments
    ) -> Literal[True]:
        """Process the XEP-0033 related data to be sent"""
        post_xml_treatments.addCallback(
            lambda mess_data: defer.ensureDeferred(
                self._handle_addresses(client, mess_data)
            )
        )
        return True

    def message_received_trigger(
        self,
        client: SatXMPPEntity,
        message_elt: domish.Element,
        post_treat: defer.Deferred,
    ) -> Literal[True]:
        """Parse addresses information and add them to message data."""

        try:
            addresses = AddressesData.from_element(message_elt)
        except exceptions.NotFound:
            pass
        else:

            def post_treat_addr(mess_data: MessageData):
                mess_data["extra"]["addresses"] = addresses.model_dump(
                    mode="json", exclude_none=True
                )
                return mess_data

            post_treat.addCallback(post_treat_addr)
        return True

    def get_handler(self, client):
        return XEP_0033_handler(self, client.profile)


@implementer(iwokkel.IDisco)
class XEP_0033_handler(XMPPHandler):

    def __init__(self, plugin_parent, profile):
        self.plugin_parent = plugin_parent
        self.host = plugin_parent.host
        self.profile = profile

    def getDiscoInfo(
        self, requestor: jid.JID, target: jid.JID, nodeIdentifier: str = ""
    ) -> list[disco.DiscoFeature]:
        return [disco.DiscoFeature(NS_ADDRESS)]

    def getDiscoItems(
        self, requestor: jid.JID, target: jid.JID, nodeIdentifier: str = ""
    ) -> list[disco.DiscoItem]:
        return []