view libervia/backend/core/patches.py @ 4237:a1e7e82a8921

core: implement SCRAM-SHA auth algorithm: Twisted auth mechanism are outdated, and as a result, Libervia was not supporting the mandatory SCRAM-SHA auth mechanism. This patch implements it for SCRAM-SHA-1, SCRAM-SHA-256 and SCRAM-SHA-512 variants.
author Goffi <goffi@goffi.org>
date Mon, 08 Apr 2024 12:29:40 +0200
parents 4b842c1fb686
children
line wrap: on
line source

import base64
import copy
import secrets

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from twisted.words.protocols.jabber import (
    client as tclient,
    jid,
    sasl,
    sasl_mechanisms,
    xmlstream,
)
from wokkel import client
from zope.interface import implementer

from libervia.backend.core.constants import Const as C
from libervia.backend.core.log import getLogger

log = getLogger(__name__)

"""This module applies monkey patches to Twisted and Wokkel
   First part handle certificate validation during XMPP connectionand are temporary
   (until merged upstream).
   Second part add a trigger point to send and onElement method of XmlStream
   """


# SCRAM-SHA implementation


@implementer(sasl_mechanisms.ISASLMechanism)
class ScramSha:
    """Implements the SCRAM-SHA SASL authentication mechanism.

    This mechanism is defined in RFC 5802.
    """

    ALLOWED_ALGORITHMS = ("SHA-1", "SHA-256", "SHA-512")
    backend = default_backend()

    def __init__(self, username: str, password: str, algorithm: str) -> None:
        """Initialize SCRAM-SHA mechanism with user credentials.

        @param username: The user's username.
        @param password: The user's password.
        """
        if algorithm not in self.ALLOWED_ALGORITHMS:
            raise ValueError(f"Invalid algorithm: {algorithm!r}")

        self.username = username
        self.password = password
        self.algorithm = getattr(hashes, algorithm.replace("-", "", 1))()
        self.name = f"SCRAM-{algorithm}"
        self.client_nonce = base64.b64encode(secrets.token_bytes(24)).decode()
        self.server_nonce = None
        self.salted_password = None

    def digest(self, data: bytes) -> bytes:
        hasher = hashes.Hash(self.algorithm)
        hasher.update(data)
        return hasher.finalize()

    def _hmac(self, key: bytes, msg: bytes) -> bytes:
        """Compute HMAC-SHA"""
        h = hmac.HMAC(key, self.algorithm, backend=self.backend)
        h.update(msg)
        return h.finalize()

    def _hi(self, password: str, salt: bytes, iterations: int) -> bytes:
        kdf = PBKDF2HMAC(
            algorithm=self.algorithm,
            length=self.algorithm.digest_size,
            salt=salt,
            iterations=iterations,
            backend=default_backend(),
        )
        return kdf.derive(password.encode())

    def getInitialResponse(self) -> bytes:
        """Builds the initial client response message."""
        return f"n,,n={self.username},r={self.client_nonce}".encode()

    def getResponse(self, challenge: bytes) -> bytes:
        """SCRAM-SHA authentication final step. Building proof of having the password.

        @param challenge: Challenge string from the server.
        @return: Client proof.
        """
        challenge_parts = dict(item.split("=") for item in challenge.decode().split(","))
        self.server_nonce = challenge_parts["r"]
        salt = base64.b64decode(challenge_parts["s"])
        iterations = int(challenge_parts["i"])
        self.salted_password = self._hi(self.password, salt, iterations)

        client_key = self._hmac(self.salted_password, b"Client Key")
        stored_key = self.digest(client_key)
        auth_message = (
            f"n={self.username},r={self.client_nonce},{challenge.decode()},c=biws,"
            f"r={self.server_nonce}"
        ).encode()
        client_signature = self._hmac(stored_key, auth_message)
        client_proof = bytes(a ^ b for a, b in zip(client_key, client_signature))
        client_final_message = (
            f"c=biws,r={self.server_nonce},p={base64.b64encode(client_proof).decode()}"
        )
        return client_final_message.encode()


class SASLInitiatingInitializer(sasl.SASLInitiatingInitializer):

    def setMechanism(self):
        """
        Select and setup authentication mechanism.

        Uses the authenticator's C{jid} and C{password} attribute for the
        authentication credentials. If no supported SASL mechanisms are
        advertized by the receiving party, a failing deferred is returned with
        a L{SASLNoAcceptableMechanism} exception.
        """

        jid = self.xmlstream.authenticator.jid
        password = self.xmlstream.authenticator.password

        mechanisms = sasl.get_mechanisms(self.xmlstream)
        if jid.user is not None:
            if "SCRAM-SHA-512" in mechanisms:
                self.mechanism = ScramSha(jid.user, password, algorithm="SHA-512")
            elif "SCRAM-SHA-256" in mechanisms:
                self.mechanism = ScramSha(jid.user, password, algorithm="SHA-256")
            elif "SCRAM-SHA-1" in mechanisms:
                self.mechanism = ScramSha(jid.user, password, algorithm="SHA-1")
            # FIXME: PLAIN should probably be disabled.
            elif "PLAIN" in mechanisms:
                self.mechanism = sasl_mechanisms.Plain(None, jid.user, password)
            else:
                raise sasl.SASLNoAcceptableMechanism()
        else:
            if "ANONYMOUS" in mechanisms:
                self.mechanism = sasl_mechanisms.Anonymous()
            else:
                raise sasl.SASLNoAcceptableMechanism()


## certificate validation patches


class XMPPClient(client.XMPPClient):

    def __init__(
        self,
        jid,
        password,
        host=None,
        port=5222,
        tls_required=True,
        configurationForTLS=None,
    ):
        self.jid = jid
        self.domain = jid.host.encode("idna")
        self.host = host
        self.port = port

        factory = HybridClientFactory(
            jid,
            password,
            tls_required=tls_required,
            configurationForTLS=configurationForTLS,
        )

        client.StreamManager.__init__(self, factory)


def HybridClientFactory(jid, password, tls_required=True, configurationForTLS=None):
    a = HybridAuthenticator(jid, password, tls_required, configurationForTLS)

    return xmlstream.XmlStreamFactory(a)


class HybridAuthenticator(client.HybridAuthenticator):
    res_binding = True

    def __init__(self, jid, password, tls_required=True, configurationForTLS=None):
        xmlstream.ConnectAuthenticator.__init__(self, jid.host)
        self.jid = jid
        self.password = password
        self.tls_required = tls_required
        self.configurationForTLS = configurationForTLS

    def associateWithStream(self, xs):
        xmlstream.ConnectAuthenticator.associateWithStream(self, xs)

        tlsInit = xmlstream.TLSInitiatingInitializer(
            xs, required=self.tls_required, configurationForTLS=self.configurationForTLS
        )
        xs.initializers = [
            client.client.CheckVersionInitializer(xs),
            tlsInit,
            CheckAuthInitializer(xs, self.res_binding),
        ]


# XmlStream triggers


class XmlStream(xmlstream.XmlStream):
    """XmlStream which allows to add hooks"""

    def __init__(self, authenticator):
        xmlstream.XmlStream.__init__(self, authenticator)
        # hooks at this level should not modify content
        # so it's not needed to handle priority as with triggers
        self._onElementHooks = []
        self._sendHooks = []

    def add_hook(self, hook_type, callback):
        """Add a send or receive hook"""
        conflict_msg = f"Hook conflict: can't add {hook_type} hook {callback}"
        if hook_type == C.STREAM_HOOK_RECEIVE:
            if callback not in self._onElementHooks:
                self._onElementHooks.append(callback)
            else:
                log.warning(conflict_msg)
        elif hook_type == C.STREAM_HOOK_SEND:
            if callback not in self._sendHooks:
                self._sendHooks.append(callback)
            else:
                log.warning(conflict_msg)
        else:
            raise ValueError(f"Invalid hook type: {hook_type}")

    def onElement(self, element):
        for hook in self._onElementHooks:
            hook(element)
        xmlstream.XmlStream.onElement(self, element)

    def send(self, obj):
        for hook in self._sendHooks:
            hook(obj)
        xmlstream.XmlStream.send(self, obj)


# Binding activation (needed for stream management, XEP-0198)


class CheckAuthInitializer(client.CheckAuthInitializer):

    def __init__(self, xs, res_binding):
        super(CheckAuthInitializer, self).__init__(xs)
        self.res_binding = res_binding

    def initialize(self):
        # XXX: modification of client.CheckAuthInitializer which has optional
        #      resource binding, and which doesn't do deprecated
        #      SessionInitializer
        if (sasl.NS_XMPP_SASL, "mechanisms") in self.xmlstream.features:
            inits = [(SASLInitiatingInitializer, True)]
            if self.res_binding:
                inits.append((tclient.BindInitializer, True)),

            for initClass, required in inits:
                init = initClass(self.xmlstream)
                init.required = required
                self.xmlstream.initializers.append(init)
        elif (tclient.NS_IQ_AUTH_FEATURE, "auth") in self.xmlstream.features:
            self.xmlstream.initializers.append(tclient.IQAuthInitializer(self.xmlstream))
        else:
            raise Exception("No available authentication method found")


# jid fix


def internJID(jidstring):
    """
    Return interned JID.

    @rtype: L{JID}
    """
    # XXX: this interJID return a copy of the cached jid
    #      this avoid modification of cached jid as JID is mutable
    # TODO: propose this upstream

    if jidstring in jid.__internJIDs:
        return copy.copy(jid.__internJIDs[jidstring])
    else:
        j = jid.JID(jidstring)
        jid.__internJIDs[jidstring] = j
        return copy.copy(j)


def apply():
    # certificate validation
    client.XMPPClient = XMPPClient
    # XmlStream triggers
    xmlstream.XmlStreamFactory.protocol = XmlStream
    # jid fix
    jid.internJID = internJID