diff libervia/backend/memory/disco.py @ 4071:4b842c1fb686

refactoring: renamed `sat` package to `libervia.backend`
author Goffi <goffi@goffi.org>
date Fri, 02 Jun 2023 11:49:51 +0200
parents sat/memory/disco.py@524856bd7b19
children 0d7bb4df2343
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/libervia/backend/memory/disco.py	Fri Jun 02 11:49:51 2023 +0200
@@ -0,0 +1,499 @@
+#!/usr/bin/env python3
+
+
+# SAT: a jabber client
+# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from typing import Optional
+from libervia.backend.core.i18n import _
+from libervia.backend.core import exceptions
+from libervia.backend.core.log import getLogger
+from libervia.backend.core.core_types import SatXMPPEntity
+
+from twisted.words.protocols.jabber import jid
+from twisted.words.protocols.jabber.error import StanzaError
+from twisted.internet import defer
+from twisted.internet import reactor
+from twisted.python import failure
+from libervia.backend.core.constants import Const as C
+from libervia.backend.tools import xml_tools
+from libervia.backend.memory import persistent
+from wokkel import disco
+from base64 import b64encode
+from hashlib import sha1
+
+
+log = getLogger(__name__)
+
+
+TIMEOUT = 15
+CAP_HASH_ERROR = "ERROR"
+
+
+class HashGenerationError(Exception):
+    pass
+
+
+class ByteIdentity(object):
+    """This class manage identity as bytes (needed for i;octet sort), it is used for the hash generation"""
+
+    def __init__(self, identity, lang=None):
+        assert isinstance(identity, disco.DiscoIdentity)
+        self.category = identity.category.encode("utf-8")
+        self.idType = identity.type.encode("utf-8")
+        self.name = identity.name.encode("utf-8") if identity.name else b""
+        self.lang = lang.encode("utf-8") if lang is not None else b""
+
+    def __bytes__(self):
+        return b"%s/%s/%s/%s" % (self.category, self.idType, self.lang, self.name)
+
+
+class HashManager(object):
+    """map object which manage hashes
+
+    persistent storage is update when a new hash is added
+    """
+
+    def __init__(self, persistent):
+        self.hashes = {
+            CAP_HASH_ERROR: disco.DiscoInfo()  # used when we can't get disco infos
+        }
+        self.persistent = persistent
+
+    def __getitem__(self, key):
+        return self.hashes[key]
+
+    def __setitem__(self, hash_, disco_info):
+        if hash_ in self.hashes:
+            log.debug("ignoring hash set: it is already known")
+            return
+        self.hashes[hash_] = disco_info
+        self.persistent[hash_] = disco_info.toElement().toXml()
+
+    def __contains__(self, hash_):
+        return self.hashes.__contains__(hash_)
+
+    def load(self):
+        def fill_hashes(hashes):
+            for hash_, xml in hashes.items():
+                element = xml_tools.ElementParser()(xml)
+                disco_info = disco.DiscoInfo.fromElement(element)
+                for ext_form in disco_info.extensions.values():
+                    # wokkel doesn't call typeCheck on reception, so we do it here
+                    ext_form.typeCheck()
+                if not disco_info.features and not disco_info.identities:
+                    log.warning(
+                        _(
+                            "no feature/identity found in disco element (hash: {cap_hash}), ignoring: {xml}"
+                        ).format(cap_hash=hash_, xml=xml)
+                    )
+                else:
+                    self.hashes[hash_] = disco_info
+
+            log.info("Disco hashes loaded")
+
+        d = self.persistent.load()
+        d.addCallback(fill_hashes)
+        return d
+
+
+class Discovery(object):
+    """ Manage capabilities of entities """
+
+    def __init__(self, host):
+        self.host = host
+        # TODO: remove legacy hashes
+
+    def load(self):
+        """Load persistent hashes"""
+        self.hashes = HashManager(persistent.PersistentDict("disco"))
+        return self.hashes.load()
+
+    @defer.inlineCallbacks
+    def hasFeature(self, client, feature, jid_=None, node=""):
+        """Tell if an entity has the required feature
+
+        @param feature: feature namespace
+        @param jid_: jid of the target, or None for profile's server
+        @param node(unicode): optional node to use for disco request
+        @return: a Deferred which fire a boolean (True if feature is available)
+        """
+        disco_infos = yield self.get_infos(client, jid_, node)
+        defer.returnValue(feature in disco_infos.features)
+
+    @defer.inlineCallbacks
+    def check_feature(self, client, feature, jid_=None, node=""):
+        """Like hasFeature, but raise an exception is feature is not Found
+
+        @param feature: feature namespace
+        @param jid_: jid of the target, or None for profile's server
+        @param node(unicode): optional node to use for disco request
+
+        @raise: exceptions.FeatureNotFound
+        """
+        disco_infos = yield self.get_infos(client, jid_, node)
+        if not feature in disco_infos.features:
+            raise failure.Failure(exceptions.FeatureNotFound())
+
+    @defer.inlineCallbacks
+    def check_features(self, client, features, jid_=None, identity=None, node=""):
+        """Like check_feature, but check several features at once, and check also identity
+
+        @param features(iterable[unicode]): features to check
+        @param jid_(jid.JID): jid of the target, or None for profile's server
+        @param node(unicode): optional node to use for disco request
+        @param identity(None, tuple(unicode, unicode): if not None, the entity must have an identity with this (category, type) tuple
+
+        @raise: exceptions.FeatureNotFound
+        """
+        disco_infos = yield self.get_infos(client, jid_, node)
+        if not set(features).issubset(disco_infos.features):
+            raise failure.Failure(exceptions.FeatureNotFound())
+
+        if identity is not None and identity not in disco_infos.identities:
+            raise failure.Failure(exceptions.FeatureNotFound())
+
+    async def has_identity(
+        self,
+        client: SatXMPPEntity,
+        category: str,
+        type_: str,
+        jid_: Optional[jid.JID] = None,
+        node: str = ""
+    ) -> bool:
+        """Tell if an entity has the requested identity
+
+        @param category: identity category
+        @param type_: identity type
+        @param jid_: jid of the target, or None for profile's server
+        @param node(unicode): optional node to use for disco request
+        @return: True if the entity has the given identity
+        """
+        disco_infos = await self.get_infos(client, jid_, node)
+        return (category, type_) in disco_infos.identities
+
+    def get_infos(self, client, jid_=None, node="", use_cache=True):
+        """get disco infos from jid_, filling capability hash if needed
+
+        @param jid_: jid of the target, or None for profile's server
+        @param node(unicode): optional node to use for disco request
+        @param use_cache(bool): if True, use cached data if available
+        @return: a Deferred which fire disco.DiscoInfo
+        """
+        if jid_ is None:
+            jid_ = jid.JID(client.jid.host)
+        try:
+            if not use_cache:
+                # we ignore cache, so we pretend we haven't found it
+                raise KeyError
+            cap_hash = self.host.memory.entity_data_get(
+                client, jid_, [C.ENTITY_CAP_HASH]
+            )[C.ENTITY_CAP_HASH]
+        except (KeyError, exceptions.UnknownEntityError):
+            # capability hash is not available, we'll compute one
+            def infos_cb(disco_infos):
+                cap_hash = self.generate_hash(disco_infos)
+                for ext_form in disco_infos.extensions.values():
+                    # wokkel doesn't call typeCheck on reception, so we do it here
+                    # to avoid ending up with incorrect types. We have to do it after
+                    # the hash has been generated (str value is needed to compute the
+                    # hash)
+                    ext_form.typeCheck()
+                self.hashes[cap_hash] = disco_infos
+                self.host.memory.update_entity_data(
+                    client, jid_, C.ENTITY_CAP_HASH, cap_hash
+                )
+                return disco_infos
+
+            def infos_eb(fail):
+                if fail.check(defer.CancelledError):
+                    reason = "request time-out"
+                    fail = failure.Failure(exceptions.TimeOutError(str(fail.value)))
+                else:
+                    try:
+                        reason = str(fail.value)
+                    except AttributeError:
+                        reason = str(fail)
+
+                log.warning(
+                    "can't request disco infos from {jid}: {reason}".format(
+                        jid=jid_.full(), reason=reason
+                    )
+                )
+
+                # XXX we set empty disco in cache, to avoid getting an error or waiting
+                # for a timeout again the next time
+                self.host.memory.update_entity_data(
+                    client, jid_, C.ENTITY_CAP_HASH, CAP_HASH_ERROR
+                )
+                raise fail
+
+            d = client.disco.requestInfo(jid_, nodeIdentifier=node)
+            d.addCallback(infos_cb)
+            d.addErrback(infos_eb)
+            return d
+        else:
+            disco_infos = self.hashes[cap_hash]
+            return defer.succeed(disco_infos)
+
+    @defer.inlineCallbacks
+    def get_items(self, client, jid_=None, node="", use_cache=True):
+        """get disco items from jid_, cache them for our own server
+
+        @param jid_(jid.JID): jid of the target, or None for profile's server
+        @param node(unicode): optional node to use for disco request
+        @param use_cache(bool): if True, use cached data if available
+        @return: a Deferred which fire disco.DiscoItems
+        """
+        if jid_ is None:
+            jid_ = client.server_jid
+
+        if jid_ == client.server_jid and not node:
+            # we cache items only for our own server and if node is not set
+            try:
+                items = self.host.memory.entity_data_get(
+                    client, jid_, ["DISCO_ITEMS"]
+                )["DISCO_ITEMS"]
+                log.debug("[%s] disco items are in cache" % jid_.full())
+                if not use_cache:
+                    # we ignore cache, so we pretend we haven't found it
+                    raise KeyError
+            except (KeyError, exceptions.UnknownEntityError):
+                log.debug("Caching [%s] disco items" % jid_.full())
+                items = yield client.disco.requestItems(jid_, nodeIdentifier=node)
+                self.host.memory.update_entity_data(
+                    client, jid_, "DISCO_ITEMS", items
+                )
+        else:
+            try:
+                items = yield client.disco.requestItems(jid_, nodeIdentifier=node)
+            except StanzaError as e:
+                log.warning(
+                    "Error while requesting items for {jid}: {reason}".format(
+                        jid=jid_.full(), reason=e.condition
+                    )
+                )
+                items = disco.DiscoItems()
+
+        defer.returnValue(items)
+
+    def _infos_eb(self, failure_, entity_jid):
+        failure_.trap(StanzaError)
+        log.warning(
+            _("Error while requesting [%(jid)s]: %(error)s")
+            % {"jid": entity_jid.full(), "error": failure_.getErrorMessage()}
+        )
+
+    def find_service_entity(self, client, category, type_, jid_=None):
+        """Helper method to find first available entity from find_service_entities
+
+        args are the same as for [find_service_entities]
+        @return (jid.JID, None): found entity
+        """
+        d = self.host.find_service_entities(client, category, type_)
+        d.addCallback(lambda entities: entities.pop() if entities else None)
+        return d
+
+    def find_service_entities(self, client, category, type_, jid_=None):
+        """Return all available items of an entity which correspond to (category, type_)
+
+        @param category: identity's category
+        @param type_: identitiy's type
+        @param jid_: the jid of the target server (None for profile's server)
+        @return: a set of found entities
+        @raise defer.CancelledError: the request timed out
+        """
+        found_entities = set()
+
+        def infos_cb(infos, entity_jid):
+            if (category, type_) in infos.identities:
+                found_entities.add(entity_jid)
+
+        def got_items(items):
+            defers_list = []
+            for item in items:
+                info_d = self.get_infos(client, item.entity)
+                info_d.addCallbacks(
+                    infos_cb, self._infos_eb, [item.entity], None, [item.entity]
+                )
+                defers_list.append(info_d)
+            return defer.DeferredList(defers_list)
+
+        d = self.get_items(client, jid_)
+        d.addCallback(got_items)
+        d.addCallback(lambda __: found_entities)
+        reactor.callLater(
+            TIMEOUT, d.cancel
+        )  # FIXME: one bad service make a general timeout
+        return d
+
+    def find_features_set(self, client, features, identity=None, jid_=None):
+        """Return entities (including jid_ and its items) offering features
+
+        @param features: iterable of features which must be present
+        @param identity(None, tuple(unicode, unicode)): if not None, accept only this
+            (category/type) identity
+        @param jid_: the jid of the target server (None for profile's server)
+        @param profile: %(doc_profile)s
+        @return: a set of found entities
+        """
+        if jid_ is None:
+            jid_ = jid.JID(client.jid.host)
+        features = set(features)
+        found_entities = set()
+
+        def infos_cb(infos, entity):
+            if entity is None:
+                log.warning(_("received an item without jid"))
+                return
+            if identity is not None and identity not in infos.identities:
+                return
+            if features.issubset(infos.features):
+                found_entities.add(entity)
+
+        def got_items(items):
+            defer_list = []
+            for entity in [jid_] + [item.entity for item in items]:
+                infos_d = self.get_infos(client, entity)
+                infos_d.addCallbacks(infos_cb, self._infos_eb, [entity], None, [entity])
+                defer_list.append(infos_d)
+            return defer.DeferredList(defer_list)
+
+        d = self.get_items(client, jid_)
+        d.addCallback(got_items)
+        d.addCallback(lambda __: found_entities)
+        reactor.callLater(
+            TIMEOUT, d.cancel
+        )  # FIXME: one bad service make a general timeout
+        return d
+
+    def generate_hash(self, services):
+        """ Generate a unique hash for given service
+
+        hash algorithm is the one described in XEP-0115
+        @param services: iterable of disco.DiscoIdentity/disco.DiscoFeature, as returned by discoHandler.info
+
+        """
+        s = []
+        # identities
+        byte_identities = [
+            ByteIdentity(service)
+            for service in services
+            if isinstance(service, disco.DiscoIdentity)
+        ]  # FIXME: lang must be managed here
+        byte_identities.sort(key=lambda i: i.lang)
+        byte_identities.sort(key=lambda i: i.idType)
+        byte_identities.sort(key=lambda i: i.category)
+        for identity in byte_identities:
+            s.append(bytes(identity))
+            s.append(b"<")
+        # features
+        byte_features = [
+            service.encode("utf-8")
+            for service in services
+            if isinstance(service, disco.DiscoFeature)
+        ]
+        byte_features.sort()  # XXX: the default sort has the same behaviour as the requested RFC 4790 i;octet sort
+        for feature in byte_features:
+            s.append(feature)
+            s.append(b"<")
+
+        # extensions
+        ext = list(services.extensions.values())
+        ext.sort(key=lambda f: f.formNamespace.encode('utf-8'))
+        for extension in ext:
+            s.append(extension.formNamespace.encode('utf-8'))
+            s.append(b"<")
+            fields = extension.fieldList
+            fields.sort(key=lambda f: f.var.encode('utf-8'))
+            for field in fields:
+                s.append(field.var.encode('utf-8'))
+                s.append(b"<")
+                values = [v.encode('utf-8') for v in field.values]
+                values.sort()
+                for value in values:
+                    s.append(value)
+                    s.append(b"<")
+
+        cap_hash = b64encode(sha1(b"".join(s)).digest()).decode('utf-8')
+        log.debug(_("Capability hash generated: [{cap_hash}]").format(cap_hash=cap_hash))
+        return cap_hash
+
+    @defer.inlineCallbacks
+    def _disco_infos(
+        self, entity_jid_s, node="", use_cache=True, profile_key=C.PROF_KEY_NONE
+    ):
+        """Discovery method for the bridge
+        @param entity_jid_s: entity we want to discover
+        @param use_cache(bool): if True, use cached data if available
+        @param node(unicode): optional node to use
+
+        @return: list of tuples
+        """
+        client = self.host.get_client(profile_key)
+        entity = jid.JID(entity_jid_s)
+        disco_infos = yield self.get_infos(client, entity, node, use_cache)
+        extensions = {}
+        # FIXME: should extensions be serialised using tools.common.data_format?
+        for form_type, form in list(disco_infos.extensions.items()):
+            fields = []
+            for field in form.fieldList:
+                data = {"type": field.fieldType}
+                for attr in ("var", "label", "desc"):
+                    value = getattr(field, attr)
+                    if value is not None:
+                        data[attr] = value
+
+                values = [field.value] if field.value is not None else field.values
+                if field.fieldType == "boolean":
+                    values = [C.bool_const(v) for v in values]
+                fields.append((data, values))
+
+            extensions[form_type or ""] = fields
+
+        defer.returnValue((
+            [str(f) for f in disco_infos.features],
+            [(cat, type_, name or "")
+             for (cat, type_), name in list(disco_infos.identities.items())],
+            extensions))
+
+    def items2tuples(self, disco_items):
+        """convert disco items to tuple of strings
+
+        @param disco_items(iterable[disco.DiscoItem]): items
+        @return G(tuple[unicode,unicode,unicode]): serialised items
+        """
+        for item in disco_items:
+            if not item.entity:
+                log.warning(_("invalid item (no jid)"))
+                continue
+            yield (item.entity.full(), item.nodeIdentifier or "", item.name or "")
+
+    @defer.inlineCallbacks
+    def _disco_items(
+        self, entity_jid_s, node="", use_cache=True, profile_key=C.PROF_KEY_NONE
+    ):
+        """ Discovery method for the bridge
+
+        @param entity_jid_s: entity we want to discover
+        @param node(unicode): optional node to use
+        @param use_cache(bool): if True, use cached data if available
+        @return: list of tuples"""
+        client = self.host.get_client(profile_key)
+        entity = jid.JID(entity_jid_s)
+        disco_items = yield self.get_items(client, entity, node, use_cache)
+        ret = list(self.items2tuples(disco_items))
+        defer.returnValue(ret)