view libervia/backend/core/patches.py @ 4306:94e0968987cd

plugin XEP-0033: code modernisation, improve delivery, data validation: - Code has been rewritten using Pydantic models and `async` coroutines for data validation and cleaner element parsing/generation. - Delivery has been completely rewritten. It now works even if server doesn't support multicast, and send to local multicast service first. Delivering to local multicast service first is due to bad support of XEP-0033 in server (notably Prosody which has an incomplete implementation), and the current impossibility to detect if a sub-domain service handles fully multicast or only for local domains. This is a workaround to have a good balance between backward compatilibity and use of bandwith, and to make it work with the incoming email gateway implementation (the gateway will only deliver to entities of its own domain). - disco feature checking now uses `async` corountines. `host` implementation still use Deferred return values for compatibility with legacy code. rel 450
author Goffi <goffi@goffi.org>
date Thu, 26 Sep 2024 16:12:01 +0200
parents c14e904eee13
children
line wrap: on
line source

import base64
import copy
import secrets

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from twisted.words.protocols.jabber import (
    client as tclient,
    jid,
    sasl,
    sasl_mechanisms,
    xmlstream,
)
from wokkel import client
from zope.interface import implementer

from libervia.backend.core.constants import Const as C
from libervia.backend.core.log import getLogger

log = getLogger(__name__)

"""This module applies monkey patches to Twisted and Wokkel
   First part handle certificate validation during XMPP connectionand are temporary
   (until merged upstream).
   Second part add a trigger point to send and onElement method of XmlStream
   """


# SCRAM-SHA implementation


@implementer(sasl_mechanisms.ISASLMechanism)
class ScramSha:
    """Implements the SCRAM-SHA SASL authentication mechanism.

    This mechanism is defined in RFC 5802.
    """

    ALLOWED_ALGORITHMS = ("SHA-1", "SHA-256", "SHA-512")
    backend = default_backend()

    def __init__(self, username: str, password: str, algorithm: str) -> None:
        """Initialize SCRAM-SHA mechanism with user credentials.

        @param username: The user's username.
        @param password: The user's password.
        """
        if algorithm not in self.ALLOWED_ALGORITHMS:
            raise ValueError(f"Invalid algorithm: {algorithm!r}")

        self.username = username
        self.password = password
        self.algorithm = getattr(hashes, algorithm.replace("-", "", 1))()
        self.name = f"SCRAM-{algorithm}"
        self.client_nonce = base64.b64encode(secrets.token_bytes(24)).decode()
        self.server_nonce = None
        self.salted_password = None

    def digest(self, data: bytes) -> bytes:
        hasher = hashes.Hash(self.algorithm)
        hasher.update(data)
        return hasher.finalize()

    def _hmac(self, key: bytes, msg: bytes) -> bytes:
        """Compute HMAC-SHA"""
        h = hmac.HMAC(key, self.algorithm, backend=self.backend)
        h.update(msg)
        return h.finalize()

    def _hi(self, password: str, salt: bytes, iterations: int) -> bytes:
        kdf = PBKDF2HMAC(
            algorithm=self.algorithm,
            length=self.algorithm.digest_size,
            salt=salt,
            iterations=iterations,
            backend=default_backend(),
        )
        return kdf.derive(password.encode())

    def getInitialResponse(self) -> bytes:
        """Builds the initial client response message."""
        return f"n,,n={self.username},r={self.client_nonce}".encode()

    def getResponse(self, challenge: bytes) -> bytes:
        """SCRAM-SHA authentication final step. Building proof of having the password.

        @param challenge: Challenge string from the server.
        @return: Client proof.
        """
        challenge_parts = dict(
            item.split("=", 1) for item in challenge.decode().split(",")
        )
        self.server_nonce = challenge_parts["r"]
        salt = base64.b64decode(challenge_parts["s"])
        iterations = int(challenge_parts["i"])
        self.salted_password = self._hi(self.password, salt, iterations)

        client_key = self._hmac(self.salted_password, b"Client Key")
        stored_key = self.digest(client_key)
        auth_message = (
            f"n={self.username},r={self.client_nonce},{challenge.decode()},c=biws,"
            f"r={self.server_nonce}"
        ).encode()
        client_signature = self._hmac(stored_key, auth_message)
        client_proof = bytes(a ^ b for a, b in zip(client_key, client_signature))
        client_final_message = (
            f"c=biws,r={self.server_nonce},p={base64.b64encode(client_proof).decode()}"
        )
        return client_final_message.encode()


class SASLInitiatingInitializer(sasl.SASLInitiatingInitializer):

    def setMechanism(self):
        """
        Select and setup authentication mechanism.

        Uses the authenticator's C{jid} and C{password} attribute for the
        authentication credentials. If no supported SASL mechanisms are
        advertized by the receiving party, a failing deferred is returned with
        a L{SASLNoAcceptableMechanism} exception.
        """

        jid = self.xmlstream.authenticator.jid
        password = self.xmlstream.authenticator.password

        mechanisms = sasl.get_mechanisms(self.xmlstream)
        if jid.user is not None:
            if "SCRAM-SHA-512" in mechanisms:
                self.mechanism = ScramSha(jid.user, password, algorithm="SHA-512")
            elif "SCRAM-SHA-256" in mechanisms:
                self.mechanism = ScramSha(jid.user, password, algorithm="SHA-256")
            elif "SCRAM-SHA-1" in mechanisms:
                self.mechanism = ScramSha(jid.user, password, algorithm="SHA-1")
            # FIXME: PLAIN should probably be disabled.
            elif "PLAIN" in mechanisms:
                self.mechanism = sasl_mechanisms.Plain(None, jid.user, password)
            else:
                raise sasl.SASLNoAcceptableMechanism()
        else:
            if "ANONYMOUS" in mechanisms:
                self.mechanism = sasl_mechanisms.Anonymous()
            else:
                raise sasl.SASLNoAcceptableMechanism()


## certificate validation patches


class XMPPClient(client.XMPPClient):

    def __init__(
        self,
        jid,
        password,
        host=None,
        port=5222,
        tls_required=True,
        configurationForTLS=None,
    ):
        self.jid = jid
        self.domain = jid.host.encode("idna")
        self.host = host
        self.port = port

        factory = HybridClientFactory(
            jid,
            password,
            tls_required=tls_required,
            configurationForTLS=configurationForTLS,
        )

        client.StreamManager.__init__(self, factory)


def HybridClientFactory(jid, password, tls_required=True, configurationForTLS=None):
    a = HybridAuthenticator(jid, password, tls_required, configurationForTLS)

    return xmlstream.XmlStreamFactory(a)


class HybridAuthenticator(client.HybridAuthenticator):
    res_binding = True

    def __init__(self, jid, password, tls_required=True, configurationForTLS=None):
        xmlstream.ConnectAuthenticator.__init__(self, jid.host)
        self.jid = jid
        self.password = password
        self.tls_required = tls_required
        self.configurationForTLS = configurationForTLS

    def associateWithStream(self, xs):
        xmlstream.ConnectAuthenticator.associateWithStream(self, xs)

        tlsInit = xmlstream.TLSInitiatingInitializer(
            xs, required=self.tls_required, configurationForTLS=self.configurationForTLS
        )
        xs.initializers = [
            client.client.CheckVersionInitializer(xs),
            tlsInit,
            CheckAuthInitializer(xs, self.res_binding),
        ]


# XmlStream triggers


class XmlStream(xmlstream.XmlStream):
    """XmlStream which allows to add hooks"""

    def __init__(self, authenticator):
        xmlstream.XmlStream.__init__(self, authenticator)
        # hooks at this level should not modify content
        # so it's not needed to handle priority as with triggers
        self._onElementHooks = []
        self._sendHooks = []

    def add_hook(self, hook_type, callback):
        """Add a send or receive hook"""
        conflict_msg = f"Hook conflict: can't add {hook_type} hook {callback}"
        if hook_type == C.STREAM_HOOK_RECEIVE:
            if callback not in self._onElementHooks:
                self._onElementHooks.append(callback)
            else:
                log.warning(conflict_msg)
        elif hook_type == C.STREAM_HOOK_SEND:
            if callback not in self._sendHooks:
                self._sendHooks.append(callback)
            else:
                log.warning(conflict_msg)
        else:
            raise ValueError(f"Invalid hook type: {hook_type}")

    def onElement(self, element):
        for hook in self._onElementHooks:
            hook(element)
        xmlstream.XmlStream.onElement(self, element)

    def send(self, obj):
        for hook in self._sendHooks:
            hook(obj)
        xmlstream.XmlStream.send(self, obj)


# Binding activation (needed for stream management, XEP-0198)


class CheckAuthInitializer(client.CheckAuthInitializer):

    def __init__(self, xs, res_binding):
        super(CheckAuthInitializer, self).__init__(xs)
        self.res_binding = res_binding

    def initialize(self):
        # XXX: modification of client.CheckAuthInitializer which has optional
        #      resource binding, and which doesn't do deprecated
        #      SessionInitializer
        if (sasl.NS_XMPP_SASL, "mechanisms") in self.xmlstream.features:
            inits = [(SASLInitiatingInitializer, True)]
            if self.res_binding:
                inits.append((tclient.BindInitializer, True)),

            for initClass, required in inits:
                init = initClass(self.xmlstream)
                init.required = required
                self.xmlstream.initializers.append(init)
        elif (tclient.NS_IQ_AUTH_FEATURE, "auth") in self.xmlstream.features:
            self.xmlstream.initializers.append(tclient.IQAuthInitializer(self.xmlstream))
        else:
            raise Exception("No available authentication method found")


# jid fix


def internJID(jidstring):
    """
    Return interned JID.

    @rtype: L{JID}
    """
    # XXX: this interJID return a copy of the cached jid
    #      this avoid modification of cached jid as JID is mutable
    # TODO: propose this upstream

    if jidstring in jid.__internJIDs:
        return copy.copy(jid.__internJIDs[jidstring])
    else:
        j = jid.JID(jidstring)
        jid.__internJIDs[jidstring] = j
        return copy.copy(j)


def apply():
    # certificate validation
    client.XMPPClient = XMPPClient
    # XmlStream triggers
    xmlstream.XmlStreamFactory.protocol = XmlStream
    # jid fix
    jid.internJID = internJID