view sat/core/patches.py @ 2768:cb34af35af6f

core (disco): client.disco.requestInfo failure are re-raised: so far, requestInfo failure where logged as a warning, and an empty disco was returned. This prevented the requesting entity to know about the error, so now the error is logged and re-raised. CancelError are converted to TimeoutError (requestInfo is only cancelled on timeout). An empty disco is still cached in case of failure, to avoid re-waiting for the entity on next request.
author Goffi <goffi@goffi.org>
date Fri, 11 Jan 2019 19:48:20 +0100
parents 1ecceac3df96
children 00d905e1b0ef
line wrap: on
line source

from twisted.words.protocols.jabber import xmlstream, sasl, client as tclient
from twisted.internet import ssl
from wokkel import client
from sat.core.constants import Const as C
from sat.core.log import getLogger

log = getLogger(__name__)

"""This module apply monkey patches to Twisted and Wokkel
   First part handle certificate validation during XMPP connectionand are temporary
   (until merged upstream).
   Second part add a trigger point to send and onElement method of XmlStream
   """


## certificate validation patches

class TLSInitiatingInitializer(xmlstream.TLSInitiatingInitializer):
    check_certificate = True

    def onProceed(self, obj):
        self.xmlstream.removeObserver('/failure', self.onFailure)
        trustRoot = ssl.platformTrust() if self.check_certificate else None
        ctx = ssl.CertificateOptions(trustRoot=trustRoot)
        self.xmlstream.transport.startTLS(ctx)
        self.xmlstream.reset()
        self.xmlstream.sendHeader()
        self._deferred.callback(xmlstream.Reset)


class XMPPClient(client.XMPPClient):

    def __init__(self, jid, password, host=None, port=5222,
                 check_certificate=True):
        self.jid = jid
        self.domain = jid.host.encode('idna')
        self.host = host
        self.port = port

        factory = HybridClientFactory(
            jid, password, check_certificate=check_certificate)

        client.StreamManager.__init__(self, factory)


def HybridClientFactory(jid, password, check_certificate=True):
    a = HybridAuthenticator(jid, password, check_certificate)

    return xmlstream.XmlStreamFactory(a)


class HybridAuthenticator(client.HybridAuthenticator):
    res_binding = True

    def __init__(self, jid, password, check_certificate):
        xmlstream.ConnectAuthenticator.__init__(self, jid.host)
        self.jid = jid
        self.password = password
        self.check_certificate = check_certificate

    def associateWithStream(self, xs):
        xmlstream.ConnectAuthenticator.associateWithStream(self, xs)

        tlsInit = xmlstream.TLSInitiatingInitializer(xs)
        tlsInit.check_certificate = self.check_certificate
        xs.initializers = [client.client.CheckVersionInitializer(xs),
                           tlsInit,
                           CheckAuthInitializer(xs, self.res_binding)]


# XmlStream triggers


class XmlStream(xmlstream.XmlStream):
    """XmlStream which allows to add hooks"""

    def __init__(self, authenticator):
        xmlstream.XmlStream.__init__(self, authenticator)
        # hooks at this level should not modify content
        # so it's not needed to handle priority as with triggers
        self._onElementHooks = []
        self._sendHooks = []

    def addHook(self, hook_type, callback):
        """Add a send or receive hook"""
        conflict_msg = (u"Hook conflict: can't add {hook_type} hook {callback}"
            .format(hook_type=hook_type, callback=callback))
        if hook_type == C.STREAM_HOOK_RECEIVE:
            if callback not in self._onElementHooks:
                self._onElementHooks.append(callback)
            else:
                log.warning(conflict_msg)
        elif hook_type == C.STREAM_HOOK_SEND:
            if callback not in self._sendHooks:
                self._sendHooks.append(callback)
            else:
                log.warning(conflict_msg)
        else:
            raise ValueError(u"Invalid hook type: {hook_type}"
                .format(hook_type=hook_type))

    def onElement(self, element):
        for hook in self._onElementHooks:
            hook(element)
        xmlstream.XmlStream.onElement(self, element)

    def send(self, obj):
        for hook in self._sendHooks:
            hook(obj)
        xmlstream.XmlStream.send(self, obj)


# Binding activation (needed for stream management, XEP-0198)


class CheckAuthInitializer(client.CheckAuthInitializer):

    def __init__(self, xs, res_binding):
        super(CheckAuthInitializer, self).__init__(xs)
        self.res_binding = res_binding

    def initialize(self):
        # XXX: modification of client.CheckAuthInitializer which has optional
        #      resource binding, and which doesn't do deprecated
        #      SessionInitializer
        if (sasl.NS_XMPP_SASL, 'mechanisms') in self.xmlstream.features:
            inits = [(sasl.SASLInitiatingInitializer, True)]
            if self.res_binding:
                inits.append((tclient.BindInitializer, True)),

            for initClass, required in inits:
                init = initClass(self.xmlstream)
                init.required = required
                self.xmlstream.initializers.append(init)
        elif (tclient.NS_IQ_AUTH_FEATURE, 'auth') in self.xmlstream.features:
            self.xmlstream.initializers.append(
                    tclient.IQAuthInitializer(self.xmlstream))
        else:
            raise Exception("No available authentication method found")


def apply():
    # certificate validation
    xmlstream.TLSInitiatingInitializer = TLSInitiatingInitializer
    client.XMPPClient = XMPPClient
    # XmlStream triggers
    xmlstream.XmlStreamFactory.protocol = XmlStream