Mercurial > libervia-backend
diff libervia/backend/plugins/plugin_xep_0198.py @ 4071:4b842c1fb686
refactoring: renamed `sat` package to `libervia.backend`
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 02 Jun 2023 11:49:51 +0200 |
parents | sat/plugins/plugin_xep_0198.py@524856bd7b19 |
children | 0d7bb4df2343 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libervia/backend/plugins/plugin_xep_0198.py Fri Jun 02 11:49:51 2023 +0200 @@ -0,0 +1,555 @@ +#!/usr/bin/env python3 + +# SàT plugin for managing Stream-Management +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from libervia.backend.core.i18n import _ +from libervia.backend.core.constants import Const as C +from libervia.backend.core import exceptions +from libervia.backend.core.log import getLogger +from twisted.words.protocols.jabber import client as jabber_client +from twisted.words.protocols.jabber import xmlstream +from twisted.words.xish import domish +from twisted.internet import defer +from twisted.internet import task, reactor +from functools import partial +from wokkel import disco, iwokkel +from zope.interface import implementer +import collections +import time + +log = getLogger(__name__) + +PLUGIN_INFO = { + C.PI_NAME: "Stream Management", + C.PI_IMPORT_NAME: "XEP-0198", + C.PI_TYPE: "XEP", + C.PI_MODES: C.PLUG_MODE_BOTH, + C.PI_PROTOCOLS: ["XEP-0198"], + C.PI_DEPENDENCIES: [], + C.PI_RECOMMENDATIONS: ["XEP-0045", "XEP-0313"], + C.PI_MAIN: "XEP_0198", + C.PI_HANDLER: "yes", + C.PI_DESCRIPTION: _("""Implementation of Stream Management"""), +} + +NS_SM = "urn:xmpp:sm:3" +SM_ENABLED = '/enabled[@xmlns="' + NS_SM + '"]' +SM_RESUMED = '/resumed[@xmlns="' + NS_SM + '"]' +SM_FAILED = '/failed[@xmlns="' + NS_SM + '"]' +SM_R_REQUEST = '/r[@xmlns="' + NS_SM + '"]' +SM_A_REQUEST = '/a[@xmlns="' + NS_SM + '"]' +SM_H_REQUEST = '/h[@xmlns="' + NS_SM + '"]' +# Max number of stanza to send before requesting ack +MAX_STANZA_ACK_R = 5 +# Max number of seconds before requesting ack +MAX_DELAY_ACK_R = 30 +MAX_COUNTER = 2**32 +RESUME_MAX = 5*60 +# if we don't have an answer to ACK REQUEST after this delay, connection is aborted +ACK_TIMEOUT = 35 + + +class ProfileSessionData(object): + out_counter = 0 + in_counter = 0 + session_id = None + location = None + session_max = None + # True when an ack answer is expected + ack_requested = False + last_ack_r = 0 + disconnected_time = None + + def __init__(self, callback, **kw): + self.buffer = collections.deque() + self.buffer_idx = 0 + self._enabled = False + self.timer = None + # time used when doing a ack request + # when it times out, connection is aborted + self.req_timer = None + self.callback_data = (callback, kw) + + @property + def enabled(self): + return self._enabled + + @enabled.setter + def enabled(self, enabled): + if enabled: + if self._enabled: + raise exceptions.InternalError( + "Stream Management can't be enabled twice") + self._enabled = True + callback, kw = self.callback_data + self.timer = task.LoopingCall(callback, **kw) + self.timer.start(MAX_DELAY_ACK_R, now=False) + else: + self._enabled = False + if self.timer is not None: + self.timer.stop() + self.timer = None + + @property + def resume_enabled(self): + return self.session_id is not None + + def reset(self): + self.enabled = False + self.buffer.clear() + self.buffer_idx = 0 + self.in_counter = self.out_counter = 0 + self.session_id = self.location = None + self.ack_requested = False + self.last_ack_r = 0 + if self.req_timer is not None: + if self.req_timer.active(): + log.error("req_timer has been called/cancelled but not reset") + else: + self.req_timer.cancel() + self.req_timer = None + + def get_buffer_copy(self): + return list(self.buffer) + + +class XEP_0198(object): + # FIXME: location is not handled yet + + def __init__(self, host): + log.info(_("Plugin Stream Management initialization")) + self.host = host + host.register_namespace('sm', NS_SM) + host.trigger.add("stream_hooks", self.add_hooks) + host.trigger.add("xml_init", self._xml_init_trigger) + host.trigger.add("disconnecting", self._disconnecting_trigger) + host.trigger.add("disconnected", self._disconnected_trigger) + try: + self._ack_timeout = int(host.memory.config_get("", "ack_timeout", ACK_TIMEOUT)) + except ValueError: + log.error(_("Invalid ack_timeout value, please check your configuration")) + self._ack_timeout = ACK_TIMEOUT + if not self._ack_timeout: + log.info(_("Ack timeout disabled")) + else: + log.info(_("Ack timeout set to {timeout}s").format( + timeout=self._ack_timeout)) + + def profile_connecting(self, client): + client._xep_0198_session = ProfileSessionData(callback=self.check_acks, + client=client) + + def get_handler(self, client): + return XEP_0198_handler(self) + + def add_hooks(self, client, receive_hooks, send_hooks): + """Add hooks to handle in/out stanzas counters""" + receive_hooks.append(partial(self.on_receive, client=client)) + send_hooks.append(partial(self.on_send, client=client)) + return True + + def _xml_init_trigger(self, client): + """Enable or resume a stream mangement""" + if not (NS_SM, 'sm') in client.xmlstream.features: + log.warning(_( + "Your server doesn't support stream management ({namespace}), this is " + "used to improve connection problems detection (like network outages). " + "Please ask your server administrator to enable this feature.".format( + namespace=NS_SM))) + return True + session = client._xep_0198_session + + # a disconnect timer from a previous disconnection may still be active + try: + disconnect_timer = session.disconnect_timer + except AttributeError: + pass + else: + if disconnect_timer.active(): + disconnect_timer.cancel() + del session.disconnect_timer + + if session.resume_enabled: + # we are resuming a session + resume_elt = domish.Element((NS_SM, 'resume')) + resume_elt['h'] = str(session.in_counter) + resume_elt['previd'] = session.session_id + client.send(resume_elt) + session.resuming = True + # session.enabled will be set on <resumed/> reception + return False + else: + # we start a new session + assert session.out_counter == 0 + enable_elt = domish.Element((NS_SM, 'enable')) + enable_elt['resume'] = 'true' + client.send(enable_elt) + session.enabled = True + return True + + def _disconnecting_trigger(self, client): + session = client._xep_0198_session + if session.enabled: + self.send_ack(client) + # This is a requested disconnection, so we can reset the session + # to disable resuming and close normally the stream + session.reset() + return True + + def _disconnected_trigger(self, client, reason): + if client.is_component: + return True + session = client._xep_0198_session + session.enabled = False + if session.resume_enabled: + session.disconnected_time = time.time() + session.disconnect_timer = reactor.callLater(session.session_max, + client.disconnect_profile, + reason) + # disconnect_profile must not be called at this point + # because session can be resumed + return False + else: + return True + + def check_acks(self, client): + """Request ack if needed""" + session = client._xep_0198_session + # log.debug("check_acks (in_counter={}, out_counter={}, buf len={}, buf idx={})" + # .format(session.in_counter, session.out_counter, len(session.buffer), + # session.buffer_idx)) + if session.ack_requested or not session.buffer: + return + if (session.out_counter - session.buffer_idx >= MAX_STANZA_ACK_R + or time.time() - session.last_ack_r >= MAX_DELAY_ACK_R): + self.request_ack(client) + session.ack_requested = True + session.last_ack_r = time.time() + + def update_buffer(self, session, server_acked): + """Update buffer and buffer_index""" + if server_acked > session.buffer_idx: + diff = server_acked - session.buffer_idx + try: + for i in range(diff): + session.buffer.pop() + except IndexError: + log.error( + "error while cleaning buffer, invalid index (buffer is empty):\n" + "diff = {diff}\n" + "server_acked = {server_acked}\n" + "buffer_idx = {buffer_id}".format( + diff=diff, server_acked=server_acked, + buffer_id=session.buffer_idx)) + session.buffer_idx += diff + + def replay_buffer(self, client, buffer_, discard_results=False): + """Resend all stanza in buffer + + @param buffer_(collection.deque, list): buffer to replay + the buffer will be cleared by this method + @param discard_results(bool): if True, don't replay IQ result stanzas + """ + while True: + try: + stanza = buffer_.pop() + except IndexError: + break + else: + if ((discard_results + and stanza.name == 'iq' + and stanza.getAttribute('type') == 'result')): + continue + client.send(stanza) + + def send_ack(self, client): + """Send an answer element with current IN counter""" + a_elt = domish.Element((NS_SM, 'a')) + a_elt['h'] = str(client._xep_0198_session.in_counter) + client.send(a_elt) + + def request_ack(self, client): + """Send a request element""" + session = client._xep_0198_session + r_elt = domish.Element((NS_SM, 'r')) + client.send(r_elt) + if session.req_timer is not None: + raise exceptions.InternalError("req_timer should not be set") + if self._ack_timeout: + session.req_timer = reactor.callLater(self._ack_timeout, self.on_ack_time_out, + client) + + def _connectionFailed(self, failure_, connector): + normal_host, normal_port = connector.normal_location + del connector.normal_location + log.warning(_( + "Connection failed using location given by server (host: {host}, port: " + "{port}), switching to normal host and port (host: {normal_host}, port: " + "{normal_port})".format(host=connector.host, port=connector.port, + normal_host=normal_host, normal_port=normal_port))) + connector.host, connector.port = normal_host, normal_port + connector.connectionFailed = connector.connectionFailed_ori + del connector.connectionFailed_ori + return connector.connectionFailed(failure_) + + def on_enabled(self, enabled_elt, client): + session = client._xep_0198_session + session.in_counter = 0 + + # we check that resuming is possible and that we have a session id + resume = C.bool(enabled_elt.getAttribute('resume')) + session_id = enabled_elt.getAttribute('id') + if not session_id: + log.warning(_('Incorrect <enabled/> element received, no "id" attribute')) + if not resume or not session_id: + log.warning(_( + "You're server doesn't support session resuming with stream management, " + "please contact your server administrator to enable it")) + return + + session.session_id = session_id + + # XXX: we disable resource binding, which must not be done + # when we resume the session. + client.factory.authenticator.res_binding = False + + # location, in case server want resuming session to be elsewhere + try: + location = enabled_elt['location'] + except KeyError: + pass + else: + # TODO: handle IPv6 here (in brackets, cf. XEP) + try: + domain, port = location.split(':', 1) + port = int(port) + except ValueError: + log.warning(_("Invalid location received: {location}") + .format(location=location)) + else: + session.location = (domain, port) + # we monkey patch connector to use the new location + connector = client.xmlstream.transport.connector + connector.normal_location = connector.host, connector.port + connector.host = domain + connector.port = port + connector.connectionFailed_ori = connector.connectionFailed + connector.connectionFailed = partial(self._connectionFailed, + connector=connector) + + # resuming time + try: + max_s = int(enabled_elt['max']) + except (ValueError, KeyError) as e: + if isinstance(e, ValueError): + log.warning(_('Invalid "max" attribute')) + max_s = RESUME_MAX + log.info(_("Using default session max value ({max_s} s).".format( + max_s=max_s))) + log.info(_("Stream Management enabled")) + else: + log.info(_( + "Stream Management enabled, with a resumption time of {res_m:.2f} min" + .format(res_m = max_s/60))) + session.session_max = max_s + + def on_resumed(self, enabled_elt, client): + session = client._xep_0198_session + assert not session.enabled + del session.resuming + server_acked = int(enabled_elt['h']) + self.update_buffer(session, server_acked) + resend_count = len(session.buffer) + # we resend all stanza which have not been received properly + self.replay_buffer(client, session.buffer) + # now we can continue the session + session.enabled = True + d_time = time.time() - session.disconnected_time + log.info(_("Stream session resumed (disconnected for {d_time} s, {count} " + "stanza(s) resent)").format(d_time=int(d_time), count=resend_count)) + + def on_failed(self, failed_elt, client): + session = client._xep_0198_session + condition_elt = failed_elt.firstChildElement() + buffer_ = session.get_buffer_copy() + session.reset() + + try: + del session.resuming + except AttributeError: + # stream management can't be started at all + msg = _("Can't use stream management") + if condition_elt is None: + log.error(msg + '.') + else: + log.error(_("{msg}: {reason}").format( + msg=msg, reason=condition_elt.name)) + else: + # only stream resumption failed, we can try full session init + # XXX: we try to start full session init from this point, with many + # variables/attributes already initialised with a potentially different + # jid. This is experimental and may not be safe. It may be more + # secured to abord the connection and restart everything with a fresh + # client. + msg = _("stream resumption not possible, restarting full session") + + if condition_elt is None: + log.warning('{msg}.'.format(msg=msg)) + else: + log.warning("{msg}: {reason}".format( + msg=msg, reason=condition_elt.name)) + # stream resumption failed, but we still can do normal stream management + # we restore attributes as if the session was new, and init stream + # we keep everything initialized, and only do binding, roster request + # and initial presence sending. + if client.conn_deferred.called: + client.conn_deferred = defer.Deferred() + else: + log.error("conn_deferred should be called at this point") + plg_0045 = self.host.plugins.get('XEP-0045') + plg_0313 = self.host.plugins.get('XEP-0313') + + # FIXME: we should call all loaded plugins with generic callbacks + # (e.g. prepareResume and resume), so a hot resuming can be done + # properly for all plugins. + + if plg_0045 is not None: + # we have to remove joined rooms + muc_join_args = plg_0045.pop_rooms(client) + # we need to recreate roster + client.handlers.remove(client.roster) + client.roster = client.roster.__class__(self.host) + client.roster.setHandlerParent(client) + # bind init is not done when resuming is possible, so we have to do it now + bind_init = jabber_client.BindInitializer(client.xmlstream) + bind_init.required = True + d = bind_init.start() + # we set the jid, which may have changed + d.addCallback(lambda __: setattr(client.factory.authenticator, "jid", client.jid)) + # we call the trigger who will send the <enable/> element + d.addCallback(lambda __: self._xml_init_trigger(client)) + # then we have to re-request the roster, as changes may have occured + d.addCallback(lambda __: client.roster.request_roster()) + # we add got_roster to be sure to have roster before sending initial presence + d.addCallback(lambda __: client.roster.got_roster) + if plg_0313 is not None: + # we retrieve one2one MAM archives + d.addCallback(lambda __: defer.ensureDeferred(plg_0313.resume(client))) + # initial presence must be sent manually + d.addCallback(lambda __: client.presence.available()) + if plg_0045 is not None: + # we re-join MUC rooms + muc_d_list = defer.DeferredList( + [defer.ensureDeferred(plg_0045.join(*args)) + for args in muc_join_args] + ) + d.addCallback(lambda __: muc_d_list) + # at the end we replay the buffer, as those stanzas have probably not + # been received + d.addCallback(lambda __: self.replay_buffer(client, buffer_, + discard_results=True)) + + def on_receive(self, element, client): + if not client.is_component: + session = client._xep_0198_session + if session.enabled and element.name.lower() in C.STANZA_NAMES: + session.in_counter += 1 % MAX_COUNTER + + def on_send(self, obj, client): + if not client.is_component: + session = client._xep_0198_session + if (session.enabled + and domish.IElement.providedBy(obj) + and obj.name.lower() in C.STANZA_NAMES): + session.out_counter += 1 % MAX_COUNTER + session.buffer.appendleft(obj) + self.check_acks(client) + + def on_ack_request(self, r_elt, client): + self.send_ack(client) + + def on_ack_answer(self, a_elt, client): + session = client._xep_0198_session + session.ack_requested = False + if self._ack_timeout: + if session.req_timer is None: + log.error("req_timer should be set") + else: + session.req_timer.cancel() + session.req_timer = None + try: + server_acked = int(a_elt['h']) + except ValueError: + log.warning(_("Server returned invalid ack element, disabling stream " + "management: {xml}").format(xml=a_elt)) + session.enabled = False + return + + if server_acked > session.out_counter: + log.error(_("Server acked more stanzas than we have sent, disabling stream " + "management.")) + session.reset() + return + + self.update_buffer(session, server_acked) + self.check_acks(client) + + def on_ack_time_out(self, client): + """Called when a requested ACK has not been received in time""" + log.info(_("Ack was not received in time, aborting connection")) + try: + xmlstream = client.xmlstream + except AttributeError: + log.warning("xmlstream has already been terminated") + else: + transport = xmlstream.transport + if transport is None: + log.warning("transport was already removed") + else: + transport.abortConnection() + client._xep_0198_session.req_timer = None + + +@implementer(iwokkel.IDisco) +class XEP_0198_handler(xmlstream.XMPPHandler): + + def __init__(self, plugin_parent): + self.plugin_parent = plugin_parent + self.host = plugin_parent.host + + def connectionInitialized(self): + self.xmlstream.addObserver( + SM_ENABLED, self.plugin_parent.on_enabled, client=self.parent + ) + self.xmlstream.addObserver( + SM_RESUMED, self.plugin_parent.on_resumed, client=self.parent + ) + self.xmlstream.addObserver( + SM_FAILED, self.plugin_parent.on_failed, client=self.parent + ) + self.xmlstream.addObserver( + SM_R_REQUEST, self.plugin_parent.on_ack_request, client=self.parent + ) + self.xmlstream.addObserver( + SM_A_REQUEST, self.plugin_parent.on_ack_answer, client=self.parent + ) + + def getDiscoInfo(self, requestor, target, nodeIdentifier=""): + return [disco.DiscoFeature(NS_SM)] + + def getDiscoItems(self, requestor, target, nodeIdentifier=""): + return []