view sat/plugins/plugin_xep_0215.py @ 4052:2ced30f6d5de

plugin XEP-0166, 0176, 0234: minor renaming + type hints
author Goffi <goffi@goffi.org>
date Mon, 29 May 2023 13:32:19 +0200
parents 3900626bc100
children
line wrap: on
line source

#!/usr/bin/env python3

# Libervia plugin
# Copyright (C) 2009-2023 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from typing import Dict, Final, List, Optional, Optional

from twisted.internet import defer
from twisted.words.protocols.jabber import error, jid
from twisted.words.protocols.jabber.xmlstream import XMPPHandler
from twisted.words.xish import domish
from wokkel import data_form, disco, iwokkel
from zope.interface import implementer

from sat.core import exceptions
from sat.core.constants import Const as C
from sat.core.core_types import SatXMPPEntity
from sat.core.i18n import _
from sat.core.log import getLogger
from sat.tools import xml_tools
from sat.tools import utils
from sat.tools.common import data_format

log = getLogger(__name__)


PLUGIN_INFO = {
    C.PI_NAME: "External Service Discovery",
    C.PI_IMPORT_NAME: "XEP-0215",
    C.PI_TYPE: "XEP",
    C.PI_MODES: C.PLUG_MODE_BOTH,
    C.PI_PROTOCOLS: [],
    C.PI_DEPENDENCIES: [],
    C.PI_RECOMMENDATIONS: [],
    C.PI_MAIN: "XEP_0215",
    C.PI_HANDLER: "yes",
    C.PI_DESCRIPTION: _("""Discover services external to the XMPP network"""),
}

NS_EXTDISCO: Final = "urn:xmpp:extdisco:2"
IQ_PUSH: Final = f'{C.IQ_SET}/services[@xmlns="{NS_EXTDISCO}"]'


class XEP_0215:
    def __init__(self, host):
        log.info(_("External Service Discovery plugin initialization"))
        self.host = host
        host.bridge.add_method(
            "external_disco_get",
            ".plugin",
            in_sign="ss",
            out_sign="s",
            method=self._external_disco_get,
            async_=True,
        )
        host.bridge.add_method(
            "external_disco_credentials_get",
            ".plugin",
            in_sign="ssis",
            out_sign="s",
            method=self._external_disco_credentials_get,
            async_=True,
        )

    def get_handler(self, client):
        return XEP_0215_handler(self)

    async def profile_connecting(self, client: SatXMPPEntity) -> None:
        client._xep_0215_services = {}

    def parse_services(
        self, element: domish.Element, parent_elt_name: str = "services"
    ) -> List[dict]:
        """Retrieve services from element

        @param element: <[parent_elt_name]/> element or its parent
        @param parent_elt_name: name of the parent element
            can be "services" or "credentials"
        @return: list of parsed services
        """
        if parent_elt_name not in ("services", "credentials"):
            raise exceptions.InternalError(
                f"invalid parent_elt_name: {parent_elt_name!r}"
            )
        if element.name == parent_elt_name and element.uri == NS_EXTDISCO:
            services_elt = element
        else:
            try:
                services_elt = next(element.elements(NS_EXTDISCO, parent_elt_name))
            except StopIteration:
                raise exceptions.DataError(
                    f"XEP-0215 response is missing <{parent_elt_name}> element"
                )

        services = []
        for service_elt in services_elt.elements(NS_EXTDISCO, "service"):
            service = {}
            for key in [
                "action",
                "expires",
                "host",
                "name",
                "password",
                "port",
                "restricted",
                "transport",
                "type",
                "username",
            ]:
                value = service_elt.getAttribute(key)
                if value is not None:
                    if key == "expires":
                        try:
                            service[key] = utils.parse_xmpp_date(value)
                        except ValueError:
                            log.warning(f"invalid expiration date: {value!r}")
                            continue
                    elif key == "port":
                        try:
                            service[key] = int(value)
                        except ValueError:
                            log.warning(f"invalid port: {value!r}")
                            continue
                    elif key == "restricted":
                        service[key] = C.bool(value)
                    else:
                        service[key] = value
            if not {"host", "type"}.issubset(service):
                log.warning(
                    'mandatory "host" or "type" are missing in service, ignoring it: '
                    "{service_elt.toXml()}"
                )
                continue
            for x_elt in service_elt.elements(data_form.NS_X_DATA, "x"):
                form = data_form.Form.fromElement(x_elt)
                extended = service.setdefault("extended", [])
                extended.append(xml_tools.data_form_2_data_dict(form))
            services.append(service)

        return services

    def _external_disco_get(self, entity: str, profile_key: str) -> defer.Deferred:
        client = self.host.get_client(profile_key)
        d = defer.ensureDeferred(
            self.get_external_services(client, jid.JID(entity) if entity else None)
        )
        d.addCallback(data_format.serialise)
        return d

    async def get_external_services(
        self, client: SatXMPPEntity, entity: Optional[jid.JID] = None
    ) -> List[Dict]:
        """Get non XMPP service proposed by the entity

        Response is cached after first query

        @param entity: XMPP entity to query. Default to our own server
        @return: found services
        """
        if entity is None:
            entity = client.server_jid

        if entity.resource:
            raise exceptions.DataError("A bare jid was expected for target entity")

        try:
            cached_services = client._xep_0215_services[entity]
        except KeyError:
            if not self.host.hasFeature(client, NS_EXTDISCO, entity):
                cached_services = client._xep_0215_services[entity] = None
            else:
                iq_elt = client.IQ("get")
                iq_elt["to"] = entity.full()
                iq_elt.addElement((NS_EXTDISCO, "services"))
                try:
                    iq_result_elt = await iq_elt.send()
                except error.StanzaError as e:
                    log.warning(f"Can't get external services: {e}")
                    cached_services = client._xep_0215_services[entity] = None
                else:
                    cached_services = self.parse_services(iq_result_elt)
                    client._xep_0215_services[entity] = cached_services

        return cached_services or []

    def _external_disco_credentials_get(
        self,
        entity: str,
        host: str,
        type_: str,
        port: int = 0,
        profile_key=C.PROF_KEY_NONE,
    ) -> defer.Deferred:
        client = self.host.get_client(profile_key)
        d = defer.ensureDeferred(
            self.request_credentials(
                client, host, type_, port or None, jid.JID(entity) if entity else None
            )
        )
        d.addCallback(data_format.serialise)
        return d

    async def request_credentials(
        self,
        client: SatXMPPEntity,
        host: str,
        type_: str,
        port: Optional[int] = None,
        entity: Optional[jid.JID] = None,
    ) -> List[dict]:
        """Request credentials for specified service(s)

        While usually a single service is expected, several may be returned if the same
        service is launched on several ports (cf. XEP-0215 §3.3)
        @param entity: XMPP entity to query. Defaut to our own server
        @param host: service host
        @param type_: service type
        @param port: service port (to be used when several services have same host and
            type but on different ports)
        @return: matching services with filled credentials
        """
        if entity is None:
            entity = client.server_jid

        iq_elt = client.IQ("get")
        iq_elt["to"] = entity.full()
        iq_elt.addElement((NS_EXTDISCO, "credentials"))
        iq_result_elt = await iq_elt.send()
        return self.parse_services(iq_result_elt, parent_elt_name="credentials")

    def get_matching_service(
        self, services: List[dict], host: str, type_: str, port: Optional[int]
    ) -> Optional[dict]:
        """Retrieve service data from its characteristics"""
        try:
            return next(
                s
                for s in services
                if (
                    s["host"] == host
                    and s["type"] == type_
                    and (port is None or s.get("port") == port)
                )
            )
        except StopIteration:
            return None

    def on_services_push(self, iq_elt: domish.Element, client: SatXMPPEntity) -> None:
        iq_elt.handled = True
        entity = jid.JID(iq_elt["from"]).userhostJID()
        cached_services = client._xep_0215_services.get(entity)
        if cached_services is None:
            log.info(f"ignoring services push for uncached entity {entity}")
            return
        try:
            services = self.parse_services(iq_elt)
        except Exception:
            log.exception(f"Can't parse services push: {iq_elt.toXml()}")
            return
        for service in services:
            host = service["host"]
            type_ = service["type"]
            port = service.get("port")

            action = service.pop("action", None)
            if action is None:
                # action is not specified, we first check if the service exists
                found_service = self.get_matching_service(
                    cached_services, host, type_, port
                )
                if found_service is not None:
                    # existing service, we replace by the new one
                    found_service.clear()
                    found_service.update(service)
                else:
                    # new service
                    cached_services.append(service)
            elif action == "add":
                cached_services.append(service)
            elif action in ("modify", "delete"):
                found_service = self.get_matching_service(
                    cached_services, host, type_, port
                )
                if found_service is None:
                    log.warning(
                        f"{entity} want to {action} an unknow service, we ask for the "
                        "full list again"
                    )
                    # we delete cache and request a fresh list to make a new one
                    del client._xep_0215_services[entity]
                    defer.ensureDeferred(self.get_external_services(client, entity))
                elif action == "modify":
                    found_service.clear()
                    found_service.update(service)
                else:
                    cached_services.remove(found_service)
            else:
                log.warning(f"unknown action for services push, ignoring: {action!r}")


@implementer(iwokkel.IDisco)
class XEP_0215_handler(XMPPHandler):
    def __init__(self, plugin_parent):
        self.plugin_parent = plugin_parent

    def connectionInitialized(self):
        self.xmlstream.addObserver(
            IQ_PUSH, self.plugin_parent.on_services_push, client=self.parent
        )

    def getDiscoInfo(self, requestor, target, nodeIdentifier=""):
        return [disco.DiscoFeature(NS_EXTDISCO)]

    def getDiscoItems(self, requestor, target, nodeIdentifier=""):
        return []