Mercurial > libervia-backend
diff libervia/backend/memory/disco.py @ 4071:4b842c1fb686
refactoring: renamed `sat` package to `libervia.backend`
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 02 Jun 2023 11:49:51 +0200 |
parents | sat/memory/disco.py@524856bd7b19 |
children | 0d7bb4df2343 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libervia/backend/memory/disco.py Fri Jun 02 11:49:51 2023 +0200 @@ -0,0 +1,499 @@ +#!/usr/bin/env python3 + + +# SAT: a jabber client +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from typing import Optional +from libervia.backend.core.i18n import _ +from libervia.backend.core import exceptions +from libervia.backend.core.log import getLogger +from libervia.backend.core.core_types import SatXMPPEntity + +from twisted.words.protocols.jabber import jid +from twisted.words.protocols.jabber.error import StanzaError +from twisted.internet import defer +from twisted.internet import reactor +from twisted.python import failure +from libervia.backend.core.constants import Const as C +from libervia.backend.tools import xml_tools +from libervia.backend.memory import persistent +from wokkel import disco +from base64 import b64encode +from hashlib import sha1 + + +log = getLogger(__name__) + + +TIMEOUT = 15 +CAP_HASH_ERROR = "ERROR" + + +class HashGenerationError(Exception): + pass + + +class ByteIdentity(object): + """This class manage identity as bytes (needed for i;octet sort), it is used for the hash generation""" + + def __init__(self, identity, lang=None): + assert isinstance(identity, disco.DiscoIdentity) + self.category = identity.category.encode("utf-8") + self.idType = identity.type.encode("utf-8") + self.name = identity.name.encode("utf-8") if identity.name else b"" + self.lang = lang.encode("utf-8") if lang is not None else b"" + + def __bytes__(self): + return b"%s/%s/%s/%s" % (self.category, self.idType, self.lang, self.name) + + +class HashManager(object): + """map object which manage hashes + + persistent storage is update when a new hash is added + """ + + def __init__(self, persistent): + self.hashes = { + CAP_HASH_ERROR: disco.DiscoInfo() # used when we can't get disco infos + } + self.persistent = persistent + + def __getitem__(self, key): + return self.hashes[key] + + def __setitem__(self, hash_, disco_info): + if hash_ in self.hashes: + log.debug("ignoring hash set: it is already known") + return + self.hashes[hash_] = disco_info + self.persistent[hash_] = disco_info.toElement().toXml() + + def __contains__(self, hash_): + return self.hashes.__contains__(hash_) + + def load(self): + def fill_hashes(hashes): + for hash_, xml in hashes.items(): + element = xml_tools.ElementParser()(xml) + disco_info = disco.DiscoInfo.fromElement(element) + for ext_form in disco_info.extensions.values(): + # wokkel doesn't call typeCheck on reception, so we do it here + ext_form.typeCheck() + if not disco_info.features and not disco_info.identities: + log.warning( + _( + "no feature/identity found in disco element (hash: {cap_hash}), ignoring: {xml}" + ).format(cap_hash=hash_, xml=xml) + ) + else: + self.hashes[hash_] = disco_info + + log.info("Disco hashes loaded") + + d = self.persistent.load() + d.addCallback(fill_hashes) + return d + + +class Discovery(object): + """ Manage capabilities of entities """ + + def __init__(self, host): + self.host = host + # TODO: remove legacy hashes + + def load(self): + """Load persistent hashes""" + self.hashes = HashManager(persistent.PersistentDict("disco")) + return self.hashes.load() + + @defer.inlineCallbacks + def hasFeature(self, client, feature, jid_=None, node=""): + """Tell if an entity has the required feature + + @param feature: feature namespace + @param jid_: jid of the target, or None for profile's server + @param node(unicode): optional node to use for disco request + @return: a Deferred which fire a boolean (True if feature is available) + """ + disco_infos = yield self.get_infos(client, jid_, node) + defer.returnValue(feature in disco_infos.features) + + @defer.inlineCallbacks + def check_feature(self, client, feature, jid_=None, node=""): + """Like hasFeature, but raise an exception is feature is not Found + + @param feature: feature namespace + @param jid_: jid of the target, or None for profile's server + @param node(unicode): optional node to use for disco request + + @raise: exceptions.FeatureNotFound + """ + disco_infos = yield self.get_infos(client, jid_, node) + if not feature in disco_infos.features: + raise failure.Failure(exceptions.FeatureNotFound()) + + @defer.inlineCallbacks + def check_features(self, client, features, jid_=None, identity=None, node=""): + """Like check_feature, but check several features at once, and check also identity + + @param features(iterable[unicode]): features to check + @param jid_(jid.JID): jid of the target, or None for profile's server + @param node(unicode): optional node to use for disco request + @param identity(None, tuple(unicode, unicode): if not None, the entity must have an identity with this (category, type) tuple + + @raise: exceptions.FeatureNotFound + """ + disco_infos = yield self.get_infos(client, jid_, node) + if not set(features).issubset(disco_infos.features): + raise failure.Failure(exceptions.FeatureNotFound()) + + if identity is not None and identity not in disco_infos.identities: + raise failure.Failure(exceptions.FeatureNotFound()) + + async def has_identity( + self, + client: SatXMPPEntity, + category: str, + type_: str, + jid_: Optional[jid.JID] = None, + node: str = "" + ) -> bool: + """Tell if an entity has the requested identity + + @param category: identity category + @param type_: identity type + @param jid_: jid of the target, or None for profile's server + @param node(unicode): optional node to use for disco request + @return: True if the entity has the given identity + """ + disco_infos = await self.get_infos(client, jid_, node) + return (category, type_) in disco_infos.identities + + def get_infos(self, client, jid_=None, node="", use_cache=True): + """get disco infos from jid_, filling capability hash if needed + + @param jid_: jid of the target, or None for profile's server + @param node(unicode): optional node to use for disco request + @param use_cache(bool): if True, use cached data if available + @return: a Deferred which fire disco.DiscoInfo + """ + if jid_ is None: + jid_ = jid.JID(client.jid.host) + try: + if not use_cache: + # we ignore cache, so we pretend we haven't found it + raise KeyError + cap_hash = self.host.memory.entity_data_get( + client, jid_, [C.ENTITY_CAP_HASH] + )[C.ENTITY_CAP_HASH] + except (KeyError, exceptions.UnknownEntityError): + # capability hash is not available, we'll compute one + def infos_cb(disco_infos): + cap_hash = self.generate_hash(disco_infos) + for ext_form in disco_infos.extensions.values(): + # wokkel doesn't call typeCheck on reception, so we do it here + # to avoid ending up with incorrect types. We have to do it after + # the hash has been generated (str value is needed to compute the + # hash) + ext_form.typeCheck() + self.hashes[cap_hash] = disco_infos + self.host.memory.update_entity_data( + client, jid_, C.ENTITY_CAP_HASH, cap_hash + ) + return disco_infos + + def infos_eb(fail): + if fail.check(defer.CancelledError): + reason = "request time-out" + fail = failure.Failure(exceptions.TimeOutError(str(fail.value))) + else: + try: + reason = str(fail.value) + except AttributeError: + reason = str(fail) + + log.warning( + "can't request disco infos from {jid}: {reason}".format( + jid=jid_.full(), reason=reason + ) + ) + + # XXX we set empty disco in cache, to avoid getting an error or waiting + # for a timeout again the next time + self.host.memory.update_entity_data( + client, jid_, C.ENTITY_CAP_HASH, CAP_HASH_ERROR + ) + raise fail + + d = client.disco.requestInfo(jid_, nodeIdentifier=node) + d.addCallback(infos_cb) + d.addErrback(infos_eb) + return d + else: + disco_infos = self.hashes[cap_hash] + return defer.succeed(disco_infos) + + @defer.inlineCallbacks + def get_items(self, client, jid_=None, node="", use_cache=True): + """get disco items from jid_, cache them for our own server + + @param jid_(jid.JID): jid of the target, or None for profile's server + @param node(unicode): optional node to use for disco request + @param use_cache(bool): if True, use cached data if available + @return: a Deferred which fire disco.DiscoItems + """ + if jid_ is None: + jid_ = client.server_jid + + if jid_ == client.server_jid and not node: + # we cache items only for our own server and if node is not set + try: + items = self.host.memory.entity_data_get( + client, jid_, ["DISCO_ITEMS"] + )["DISCO_ITEMS"] + log.debug("[%s] disco items are in cache" % jid_.full()) + if not use_cache: + # we ignore cache, so we pretend we haven't found it + raise KeyError + except (KeyError, exceptions.UnknownEntityError): + log.debug("Caching [%s] disco items" % jid_.full()) + items = yield client.disco.requestItems(jid_, nodeIdentifier=node) + self.host.memory.update_entity_data( + client, jid_, "DISCO_ITEMS", items + ) + else: + try: + items = yield client.disco.requestItems(jid_, nodeIdentifier=node) + except StanzaError as e: + log.warning( + "Error while requesting items for {jid}: {reason}".format( + jid=jid_.full(), reason=e.condition + ) + ) + items = disco.DiscoItems() + + defer.returnValue(items) + + def _infos_eb(self, failure_, entity_jid): + failure_.trap(StanzaError) + log.warning( + _("Error while requesting [%(jid)s]: %(error)s") + % {"jid": entity_jid.full(), "error": failure_.getErrorMessage()} + ) + + def find_service_entity(self, client, category, type_, jid_=None): + """Helper method to find first available entity from find_service_entities + + args are the same as for [find_service_entities] + @return (jid.JID, None): found entity + """ + d = self.host.find_service_entities(client, category, type_) + d.addCallback(lambda entities: entities.pop() if entities else None) + return d + + def find_service_entities(self, client, category, type_, jid_=None): + """Return all available items of an entity which correspond to (category, type_) + + @param category: identity's category + @param type_: identitiy's type + @param jid_: the jid of the target server (None for profile's server) + @return: a set of found entities + @raise defer.CancelledError: the request timed out + """ + found_entities = set() + + def infos_cb(infos, entity_jid): + if (category, type_) in infos.identities: + found_entities.add(entity_jid) + + def got_items(items): + defers_list = [] + for item in items: + info_d = self.get_infos(client, item.entity) + info_d.addCallbacks( + infos_cb, self._infos_eb, [item.entity], None, [item.entity] + ) + defers_list.append(info_d) + return defer.DeferredList(defers_list) + + d = self.get_items(client, jid_) + d.addCallback(got_items) + d.addCallback(lambda __: found_entities) + reactor.callLater( + TIMEOUT, d.cancel + ) # FIXME: one bad service make a general timeout + return d + + def find_features_set(self, client, features, identity=None, jid_=None): + """Return entities (including jid_ and its items) offering features + + @param features: iterable of features which must be present + @param identity(None, tuple(unicode, unicode)): if not None, accept only this + (category/type) identity + @param jid_: the jid of the target server (None for profile's server) + @param profile: %(doc_profile)s + @return: a set of found entities + """ + if jid_ is None: + jid_ = jid.JID(client.jid.host) + features = set(features) + found_entities = set() + + def infos_cb(infos, entity): + if entity is None: + log.warning(_("received an item without jid")) + return + if identity is not None and identity not in infos.identities: + return + if features.issubset(infos.features): + found_entities.add(entity) + + def got_items(items): + defer_list = [] + for entity in [jid_] + [item.entity for item in items]: + infos_d = self.get_infos(client, entity) + infos_d.addCallbacks(infos_cb, self._infos_eb, [entity], None, [entity]) + defer_list.append(infos_d) + return defer.DeferredList(defer_list) + + d = self.get_items(client, jid_) + d.addCallback(got_items) + d.addCallback(lambda __: found_entities) + reactor.callLater( + TIMEOUT, d.cancel + ) # FIXME: one bad service make a general timeout + return d + + def generate_hash(self, services): + """ Generate a unique hash for given service + + hash algorithm is the one described in XEP-0115 + @param services: iterable of disco.DiscoIdentity/disco.DiscoFeature, as returned by discoHandler.info + + """ + s = [] + # identities + byte_identities = [ + ByteIdentity(service) + for service in services + if isinstance(service, disco.DiscoIdentity) + ] # FIXME: lang must be managed here + byte_identities.sort(key=lambda i: i.lang) + byte_identities.sort(key=lambda i: i.idType) + byte_identities.sort(key=lambda i: i.category) + for identity in byte_identities: + s.append(bytes(identity)) + s.append(b"<") + # features + byte_features = [ + service.encode("utf-8") + for service in services + if isinstance(service, disco.DiscoFeature) + ] + byte_features.sort() # XXX: the default sort has the same behaviour as the requested RFC 4790 i;octet sort + for feature in byte_features: + s.append(feature) + s.append(b"<") + + # extensions + ext = list(services.extensions.values()) + ext.sort(key=lambda f: f.formNamespace.encode('utf-8')) + for extension in ext: + s.append(extension.formNamespace.encode('utf-8')) + s.append(b"<") + fields = extension.fieldList + fields.sort(key=lambda f: f.var.encode('utf-8')) + for field in fields: + s.append(field.var.encode('utf-8')) + s.append(b"<") + values = [v.encode('utf-8') for v in field.values] + values.sort() + for value in values: + s.append(value) + s.append(b"<") + + cap_hash = b64encode(sha1(b"".join(s)).digest()).decode('utf-8') + log.debug(_("Capability hash generated: [{cap_hash}]").format(cap_hash=cap_hash)) + return cap_hash + + @defer.inlineCallbacks + def _disco_infos( + self, entity_jid_s, node="", use_cache=True, profile_key=C.PROF_KEY_NONE + ): + """Discovery method for the bridge + @param entity_jid_s: entity we want to discover + @param use_cache(bool): if True, use cached data if available + @param node(unicode): optional node to use + + @return: list of tuples + """ + client = self.host.get_client(profile_key) + entity = jid.JID(entity_jid_s) + disco_infos = yield self.get_infos(client, entity, node, use_cache) + extensions = {} + # FIXME: should extensions be serialised using tools.common.data_format? + for form_type, form in list(disco_infos.extensions.items()): + fields = [] + for field in form.fieldList: + data = {"type": field.fieldType} + for attr in ("var", "label", "desc"): + value = getattr(field, attr) + if value is not None: + data[attr] = value + + values = [field.value] if field.value is not None else field.values + if field.fieldType == "boolean": + values = [C.bool_const(v) for v in values] + fields.append((data, values)) + + extensions[form_type or ""] = fields + + defer.returnValue(( + [str(f) for f in disco_infos.features], + [(cat, type_, name or "") + for (cat, type_), name in list(disco_infos.identities.items())], + extensions)) + + def items2tuples(self, disco_items): + """convert disco items to tuple of strings + + @param disco_items(iterable[disco.DiscoItem]): items + @return G(tuple[unicode,unicode,unicode]): serialised items + """ + for item in disco_items: + if not item.entity: + log.warning(_("invalid item (no jid)")) + continue + yield (item.entity.full(), item.nodeIdentifier or "", item.name or "") + + @defer.inlineCallbacks + def _disco_items( + self, entity_jid_s, node="", use_cache=True, profile_key=C.PROF_KEY_NONE + ): + """ Discovery method for the bridge + + @param entity_jid_s: entity we want to discover + @param node(unicode): optional node to use + @param use_cache(bool): if True, use cached data if available + @return: list of tuples""" + client = self.host.get_client(profile_key) + entity = jid.JID(entity_jid_s) + disco_items = yield self.get_items(client, entity, node, use_cache) + ret = list(self.items2tuples(disco_items)) + defer.returnValue(ret)