view libervia/backend/memory/disco.py @ 4183:6784d07b99c8

plugin XEP-053, component AP gateway: use the new `trigger.add_with_check` method
author Goffi <goffi@goffi.org>
date Sat, 09 Dec 2023 19:20:13 +0100
parents 4b842c1fb686
children 0d7bb4df2343
line wrap: on
line source

#!/usr/bin/env python3


# SAT: a jabber client
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from typing import Optional
from libervia.backend.core.i18n import _
from libervia.backend.core import exceptions
from libervia.backend.core.log import getLogger
from libervia.backend.core.core_types import SatXMPPEntity

from twisted.words.protocols.jabber import jid
from twisted.words.protocols.jabber.error import StanzaError
from twisted.internet import defer
from twisted.internet import reactor
from twisted.python import failure
from libervia.backend.core.constants import Const as C
from libervia.backend.tools import xml_tools
from libervia.backend.memory import persistent
from wokkel import disco
from base64 import b64encode
from hashlib import sha1


log = getLogger(__name__)


TIMEOUT = 15
CAP_HASH_ERROR = "ERROR"


class HashGenerationError(Exception):
    pass


class ByteIdentity(object):
    """This class manage identity as bytes (needed for i;octet sort), it is used for the hash generation"""

    def __init__(self, identity, lang=None):
        assert isinstance(identity, disco.DiscoIdentity)
        self.category = identity.category.encode("utf-8")
        self.idType = identity.type.encode("utf-8")
        self.name = identity.name.encode("utf-8") if identity.name else b""
        self.lang = lang.encode("utf-8") if lang is not None else b""

    def __bytes__(self):
        return b"%s/%s/%s/%s" % (self.category, self.idType, self.lang, self.name)


class HashManager(object):
    """map object which manage hashes

    persistent storage is update when a new hash is added
    """

    def __init__(self, persistent):
        self.hashes = {
            CAP_HASH_ERROR: disco.DiscoInfo()  # used when we can't get disco infos
        }
        self.persistent = persistent

    def __getitem__(self, key):
        return self.hashes[key]

    def __setitem__(self, hash_, disco_info):
        if hash_ in self.hashes:
            log.debug("ignoring hash set: it is already known")
            return
        self.hashes[hash_] = disco_info
        self.persistent[hash_] = disco_info.toElement().toXml()

    def __contains__(self, hash_):
        return self.hashes.__contains__(hash_)

    def load(self):
        def fill_hashes(hashes):
            for hash_, xml in hashes.items():
                element = xml_tools.ElementParser()(xml)
                disco_info = disco.DiscoInfo.fromElement(element)
                for ext_form in disco_info.extensions.values():
                    # wokkel doesn't call typeCheck on reception, so we do it here
                    ext_form.typeCheck()
                if not disco_info.features and not disco_info.identities:
                    log.warning(
                        _(
                            "no feature/identity found in disco element (hash: {cap_hash}), ignoring: {xml}"
                        ).format(cap_hash=hash_, xml=xml)
                    )
                else:
                    self.hashes[hash_] = disco_info

            log.info("Disco hashes loaded")

        d = self.persistent.load()
        d.addCallback(fill_hashes)
        return d


class Discovery(object):
    """ Manage capabilities of entities """

    def __init__(self, host):
        self.host = host
        # TODO: remove legacy hashes

    def load(self):
        """Load persistent hashes"""
        self.hashes = HashManager(persistent.PersistentDict("disco"))
        return self.hashes.load()

    @defer.inlineCallbacks
    def hasFeature(self, client, feature, jid_=None, node=""):
        """Tell if an entity has the required feature

        @param feature: feature namespace
        @param jid_: jid of the target, or None for profile's server
        @param node(unicode): optional node to use for disco request
        @return: a Deferred which fire a boolean (True if feature is available)
        """
        disco_infos = yield self.get_infos(client, jid_, node)
        defer.returnValue(feature in disco_infos.features)

    @defer.inlineCallbacks
    def check_feature(self, client, feature, jid_=None, node=""):
        """Like hasFeature, but raise an exception is feature is not Found

        @param feature: feature namespace
        @param jid_: jid of the target, or None for profile's server
        @param node(unicode): optional node to use for disco request

        @raise: exceptions.FeatureNotFound
        """
        disco_infos = yield self.get_infos(client, jid_, node)
        if not feature in disco_infos.features:
            raise failure.Failure(exceptions.FeatureNotFound())

    @defer.inlineCallbacks
    def check_features(self, client, features, jid_=None, identity=None, node=""):
        """Like check_feature, but check several features at once, and check also identity

        @param features(iterable[unicode]): features to check
        @param jid_(jid.JID): jid of the target, or None for profile's server
        @param node(unicode): optional node to use for disco request
        @param identity(None, tuple(unicode, unicode): if not None, the entity must have an identity with this (category, type) tuple

        @raise: exceptions.FeatureNotFound
        """
        disco_infos = yield self.get_infos(client, jid_, node)
        if not set(features).issubset(disco_infos.features):
            raise failure.Failure(exceptions.FeatureNotFound())

        if identity is not None and identity not in disco_infos.identities:
            raise failure.Failure(exceptions.FeatureNotFound())

    async def has_identity(
        self,
        client: SatXMPPEntity,
        category: str,
        type_: str,
        jid_: Optional[jid.JID] = None,
        node: str = ""
    ) -> bool:
        """Tell if an entity has the requested identity

        @param category: identity category
        @param type_: identity type
        @param jid_: jid of the target, or None for profile's server
        @param node(unicode): optional node to use for disco request
        @return: True if the entity has the given identity
        """
        disco_infos = await self.get_infos(client, jid_, node)
        return (category, type_) in disco_infos.identities

    def get_infos(self, client, jid_=None, node="", use_cache=True):
        """get disco infos from jid_, filling capability hash if needed

        @param jid_: jid of the target, or None for profile's server
        @param node(unicode): optional node to use for disco request
        @param use_cache(bool): if True, use cached data if available
        @return: a Deferred which fire disco.DiscoInfo
        """
        if jid_ is None:
            jid_ = jid.JID(client.jid.host)
        try:
            if not use_cache:
                # we ignore cache, so we pretend we haven't found it
                raise KeyError
            cap_hash = self.host.memory.entity_data_get(
                client, jid_, [C.ENTITY_CAP_HASH]
            )[C.ENTITY_CAP_HASH]
        except (KeyError, exceptions.UnknownEntityError):
            # capability hash is not available, we'll compute one
            def infos_cb(disco_infos):
                cap_hash = self.generate_hash(disco_infos)
                for ext_form in disco_infos.extensions.values():
                    # wokkel doesn't call typeCheck on reception, so we do it here
                    # to avoid ending up with incorrect types. We have to do it after
                    # the hash has been generated (str value is needed to compute the
                    # hash)
                    ext_form.typeCheck()
                self.hashes[cap_hash] = disco_infos
                self.host.memory.update_entity_data(
                    client, jid_, C.ENTITY_CAP_HASH, cap_hash
                )
                return disco_infos

            def infos_eb(fail):
                if fail.check(defer.CancelledError):
                    reason = "request time-out"
                    fail = failure.Failure(exceptions.TimeOutError(str(fail.value)))
                else:
                    try:
                        reason = str(fail.value)
                    except AttributeError:
                        reason = str(fail)

                log.warning(
                    "can't request disco infos from {jid}: {reason}".format(
                        jid=jid_.full(), reason=reason
                    )
                )

                # XXX we set empty disco in cache, to avoid getting an error or waiting
                # for a timeout again the next time
                self.host.memory.update_entity_data(
                    client, jid_, C.ENTITY_CAP_HASH, CAP_HASH_ERROR
                )
                raise fail

            d = client.disco.requestInfo(jid_, nodeIdentifier=node)
            d.addCallback(infos_cb)
            d.addErrback(infos_eb)
            return d
        else:
            disco_infos = self.hashes[cap_hash]
            return defer.succeed(disco_infos)

    @defer.inlineCallbacks
    def get_items(self, client, jid_=None, node="", use_cache=True):
        """get disco items from jid_, cache them for our own server

        @param jid_(jid.JID): jid of the target, or None for profile's server
        @param node(unicode): optional node to use for disco request
        @param use_cache(bool): if True, use cached data if available
        @return: a Deferred which fire disco.DiscoItems
        """
        if jid_ is None:
            jid_ = client.server_jid

        if jid_ == client.server_jid and not node:
            # we cache items only for our own server and if node is not set
            try:
                items = self.host.memory.entity_data_get(
                    client, jid_, ["DISCO_ITEMS"]
                )["DISCO_ITEMS"]
                log.debug("[%s] disco items are in cache" % jid_.full())
                if not use_cache:
                    # we ignore cache, so we pretend we haven't found it
                    raise KeyError
            except (KeyError, exceptions.UnknownEntityError):
                log.debug("Caching [%s] disco items" % jid_.full())
                items = yield client.disco.requestItems(jid_, nodeIdentifier=node)
                self.host.memory.update_entity_data(
                    client, jid_, "DISCO_ITEMS", items
                )
        else:
            try:
                items = yield client.disco.requestItems(jid_, nodeIdentifier=node)
            except StanzaError as e:
                log.warning(
                    "Error while requesting items for {jid}: {reason}".format(
                        jid=jid_.full(), reason=e.condition
                    )
                )
                items = disco.DiscoItems()

        defer.returnValue(items)

    def _infos_eb(self, failure_, entity_jid):
        failure_.trap(StanzaError)
        log.warning(
            _("Error while requesting [%(jid)s]: %(error)s")
            % {"jid": entity_jid.full(), "error": failure_.getErrorMessage()}
        )

    def find_service_entity(self, client, category, type_, jid_=None):
        """Helper method to find first available entity from find_service_entities

        args are the same as for [find_service_entities]
        @return (jid.JID, None): found entity
        """
        d = self.host.find_service_entities(client, category, type_)
        d.addCallback(lambda entities: entities.pop() if entities else None)
        return d

    def find_service_entities(self, client, category, type_, jid_=None):
        """Return all available items of an entity which correspond to (category, type_)

        @param category: identity's category
        @param type_: identitiy's type
        @param jid_: the jid of the target server (None for profile's server)
        @return: a set of found entities
        @raise defer.CancelledError: the request timed out
        """
        found_entities = set()

        def infos_cb(infos, entity_jid):
            if (category, type_) in infos.identities:
                found_entities.add(entity_jid)

        def got_items(items):
            defers_list = []
            for item in items:
                info_d = self.get_infos(client, item.entity)
                info_d.addCallbacks(
                    infos_cb, self._infos_eb, [item.entity], None, [item.entity]
                )
                defers_list.append(info_d)
            return defer.DeferredList(defers_list)

        d = self.get_items(client, jid_)
        d.addCallback(got_items)
        d.addCallback(lambda __: found_entities)
        reactor.callLater(
            TIMEOUT, d.cancel
        )  # FIXME: one bad service make a general timeout
        return d

    def find_features_set(self, client, features, identity=None, jid_=None):
        """Return entities (including jid_ and its items) offering features

        @param features: iterable of features which must be present
        @param identity(None, tuple(unicode, unicode)): if not None, accept only this
            (category/type) identity
        @param jid_: the jid of the target server (None for profile's server)
        @param profile: %(doc_profile)s
        @return: a set of found entities
        """
        if jid_ is None:
            jid_ = jid.JID(client.jid.host)
        features = set(features)
        found_entities = set()

        def infos_cb(infos, entity):
            if entity is None:
                log.warning(_("received an item without jid"))
                return
            if identity is not None and identity not in infos.identities:
                return
            if features.issubset(infos.features):
                found_entities.add(entity)

        def got_items(items):
            defer_list = []
            for entity in [jid_] + [item.entity for item in items]:
                infos_d = self.get_infos(client, entity)
                infos_d.addCallbacks(infos_cb, self._infos_eb, [entity], None, [entity])
                defer_list.append(infos_d)
            return defer.DeferredList(defer_list)

        d = self.get_items(client, jid_)
        d.addCallback(got_items)
        d.addCallback(lambda __: found_entities)
        reactor.callLater(
            TIMEOUT, d.cancel
        )  # FIXME: one bad service make a general timeout
        return d

    def generate_hash(self, services):
        """ Generate a unique hash for given service

        hash algorithm is the one described in XEP-0115
        @param services: iterable of disco.DiscoIdentity/disco.DiscoFeature, as returned by discoHandler.info

        """
        s = []
        # identities
        byte_identities = [
            ByteIdentity(service)
            for service in services
            if isinstance(service, disco.DiscoIdentity)
        ]  # FIXME: lang must be managed here
        byte_identities.sort(key=lambda i: i.lang)
        byte_identities.sort(key=lambda i: i.idType)
        byte_identities.sort(key=lambda i: i.category)
        for identity in byte_identities:
            s.append(bytes(identity))
            s.append(b"<")
        # features
        byte_features = [
            service.encode("utf-8")
            for service in services
            if isinstance(service, disco.DiscoFeature)
        ]
        byte_features.sort()  # XXX: the default sort has the same behaviour as the requested RFC 4790 i;octet sort
        for feature in byte_features:
            s.append(feature)
            s.append(b"<")

        # extensions
        ext = list(services.extensions.values())
        ext.sort(key=lambda f: f.formNamespace.encode('utf-8'))
        for extension in ext:
            s.append(extension.formNamespace.encode('utf-8'))
            s.append(b"<")
            fields = extension.fieldList
            fields.sort(key=lambda f: f.var.encode('utf-8'))
            for field in fields:
                s.append(field.var.encode('utf-8'))
                s.append(b"<")
                values = [v.encode('utf-8') for v in field.values]
                values.sort()
                for value in values:
                    s.append(value)
                    s.append(b"<")

        cap_hash = b64encode(sha1(b"".join(s)).digest()).decode('utf-8')
        log.debug(_("Capability hash generated: [{cap_hash}]").format(cap_hash=cap_hash))
        return cap_hash

    @defer.inlineCallbacks
    def _disco_infos(
        self, entity_jid_s, node="", use_cache=True, profile_key=C.PROF_KEY_NONE
    ):
        """Discovery method for the bridge
        @param entity_jid_s: entity we want to discover
        @param use_cache(bool): if True, use cached data if available
        @param node(unicode): optional node to use

        @return: list of tuples
        """
        client = self.host.get_client(profile_key)
        entity = jid.JID(entity_jid_s)
        disco_infos = yield self.get_infos(client, entity, node, use_cache)
        extensions = {}
        # FIXME: should extensions be serialised using tools.common.data_format?
        for form_type, form in list(disco_infos.extensions.items()):
            fields = []
            for field in form.fieldList:
                data = {"type": field.fieldType}
                for attr in ("var", "label", "desc"):
                    value = getattr(field, attr)
                    if value is not None:
                        data[attr] = value

                values = [field.value] if field.value is not None else field.values
                if field.fieldType == "boolean":
                    values = [C.bool_const(v) for v in values]
                fields.append((data, values))

            extensions[form_type or ""] = fields

        defer.returnValue((
            [str(f) for f in disco_infos.features],
            [(cat, type_, name or "")
             for (cat, type_), name in list(disco_infos.identities.items())],
            extensions))

    def items2tuples(self, disco_items):
        """convert disco items to tuple of strings

        @param disco_items(iterable[disco.DiscoItem]): items
        @return G(tuple[unicode,unicode,unicode]): serialised items
        """
        for item in disco_items:
            if not item.entity:
                log.warning(_("invalid item (no jid)"))
                continue
            yield (item.entity.full(), item.nodeIdentifier or "", item.name or "")

    @defer.inlineCallbacks
    def _disco_items(
        self, entity_jid_s, node="", use_cache=True, profile_key=C.PROF_KEY_NONE
    ):
        """ Discovery method for the bridge

        @param entity_jid_s: entity we want to discover
        @param node(unicode): optional node to use
        @param use_cache(bool): if True, use cached data if available
        @return: list of tuples"""
        client = self.host.get_client(profile_key)
        entity = jid.JID(entity_jid_s)
        disco_items = yield self.get_items(client, entity, node, use_cache)
        ret = list(self.items2tuples(disco_items))
        defer.returnValue(ret)