Mercurial > libervia-backend
view libervia/backend/core/patches.py @ 4265:2417ad1d0f23
core (xmpp): fix message workflow interruption from trigger.
author | Goffi <goffi@goffi.org> |
---|---|
date | Wed, 12 Jun 2024 22:37:04 +0200 |
parents | c14e904eee13 |
children |
line wrap: on
line source
import base64 import copy import secrets from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, hmac from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from twisted.words.protocols.jabber import ( client as tclient, jid, sasl, sasl_mechanisms, xmlstream, ) from wokkel import client from zope.interface import implementer from libervia.backend.core.constants import Const as C from libervia.backend.core.log import getLogger log = getLogger(__name__) """This module applies monkey patches to Twisted and Wokkel First part handle certificate validation during XMPP connectionand are temporary (until merged upstream). Second part add a trigger point to send and onElement method of XmlStream """ # SCRAM-SHA implementation @implementer(sasl_mechanisms.ISASLMechanism) class ScramSha: """Implements the SCRAM-SHA SASL authentication mechanism. This mechanism is defined in RFC 5802. """ ALLOWED_ALGORITHMS = ("SHA-1", "SHA-256", "SHA-512") backend = default_backend() def __init__(self, username: str, password: str, algorithm: str) -> None: """Initialize SCRAM-SHA mechanism with user credentials. @param username: The user's username. @param password: The user's password. """ if algorithm not in self.ALLOWED_ALGORITHMS: raise ValueError(f"Invalid algorithm: {algorithm!r}") self.username = username self.password = password self.algorithm = getattr(hashes, algorithm.replace("-", "", 1))() self.name = f"SCRAM-{algorithm}" self.client_nonce = base64.b64encode(secrets.token_bytes(24)).decode() self.server_nonce = None self.salted_password = None def digest(self, data: bytes) -> bytes: hasher = hashes.Hash(self.algorithm) hasher.update(data) return hasher.finalize() def _hmac(self, key: bytes, msg: bytes) -> bytes: """Compute HMAC-SHA""" h = hmac.HMAC(key, self.algorithm, backend=self.backend) h.update(msg) return h.finalize() def _hi(self, password: str, salt: bytes, iterations: int) -> bytes: kdf = PBKDF2HMAC( algorithm=self.algorithm, length=self.algorithm.digest_size, salt=salt, iterations=iterations, backend=default_backend(), ) return kdf.derive(password.encode()) def getInitialResponse(self) -> bytes: """Builds the initial client response message.""" return f"n,,n={self.username},r={self.client_nonce}".encode() def getResponse(self, challenge: bytes) -> bytes: """SCRAM-SHA authentication final step. Building proof of having the password. @param challenge: Challenge string from the server. @return: Client proof. """ challenge_parts = dict( item.split("=", 1) for item in challenge.decode().split(",") ) self.server_nonce = challenge_parts["r"] salt = base64.b64decode(challenge_parts["s"]) iterations = int(challenge_parts["i"]) self.salted_password = self._hi(self.password, salt, iterations) client_key = self._hmac(self.salted_password, b"Client Key") stored_key = self.digest(client_key) auth_message = ( f"n={self.username},r={self.client_nonce},{challenge.decode()},c=biws," f"r={self.server_nonce}" ).encode() client_signature = self._hmac(stored_key, auth_message) client_proof = bytes(a ^ b for a, b in zip(client_key, client_signature)) client_final_message = ( f"c=biws,r={self.server_nonce},p={base64.b64encode(client_proof).decode()}" ) return client_final_message.encode() class SASLInitiatingInitializer(sasl.SASLInitiatingInitializer): def setMechanism(self): """ Select and setup authentication mechanism. Uses the authenticator's C{jid} and C{password} attribute for the authentication credentials. If no supported SASL mechanisms are advertized by the receiving party, a failing deferred is returned with a L{SASLNoAcceptableMechanism} exception. """ jid = self.xmlstream.authenticator.jid password = self.xmlstream.authenticator.password mechanisms = sasl.get_mechanisms(self.xmlstream) if jid.user is not None: if "SCRAM-SHA-512" in mechanisms: self.mechanism = ScramSha(jid.user, password, algorithm="SHA-512") elif "SCRAM-SHA-256" in mechanisms: self.mechanism = ScramSha(jid.user, password, algorithm="SHA-256") elif "SCRAM-SHA-1" in mechanisms: self.mechanism = ScramSha(jid.user, password, algorithm="SHA-1") # FIXME: PLAIN should probably be disabled. elif "PLAIN" in mechanisms: self.mechanism = sasl_mechanisms.Plain(None, jid.user, password) else: raise sasl.SASLNoAcceptableMechanism() else: if "ANONYMOUS" in mechanisms: self.mechanism = sasl_mechanisms.Anonymous() else: raise sasl.SASLNoAcceptableMechanism() ## certificate validation patches class XMPPClient(client.XMPPClient): def __init__( self, jid, password, host=None, port=5222, tls_required=True, configurationForTLS=None, ): self.jid = jid self.domain = jid.host.encode("idna") self.host = host self.port = port factory = HybridClientFactory( jid, password, tls_required=tls_required, configurationForTLS=configurationForTLS, ) client.StreamManager.__init__(self, factory) def HybridClientFactory(jid, password, tls_required=True, configurationForTLS=None): a = HybridAuthenticator(jid, password, tls_required, configurationForTLS) return xmlstream.XmlStreamFactory(a) class HybridAuthenticator(client.HybridAuthenticator): res_binding = True def __init__(self, jid, password, tls_required=True, configurationForTLS=None): xmlstream.ConnectAuthenticator.__init__(self, jid.host) self.jid = jid self.password = password self.tls_required = tls_required self.configurationForTLS = configurationForTLS def associateWithStream(self, xs): xmlstream.ConnectAuthenticator.associateWithStream(self, xs) tlsInit = xmlstream.TLSInitiatingInitializer( xs, required=self.tls_required, configurationForTLS=self.configurationForTLS ) xs.initializers = [ client.client.CheckVersionInitializer(xs), tlsInit, CheckAuthInitializer(xs, self.res_binding), ] # XmlStream triggers class XmlStream(xmlstream.XmlStream): """XmlStream which allows to add hooks""" def __init__(self, authenticator): xmlstream.XmlStream.__init__(self, authenticator) # hooks at this level should not modify content # so it's not needed to handle priority as with triggers self._onElementHooks = [] self._sendHooks = [] def add_hook(self, hook_type, callback): """Add a send or receive hook""" conflict_msg = f"Hook conflict: can't add {hook_type} hook {callback}" if hook_type == C.STREAM_HOOK_RECEIVE: if callback not in self._onElementHooks: self._onElementHooks.append(callback) else: log.warning(conflict_msg) elif hook_type == C.STREAM_HOOK_SEND: if callback not in self._sendHooks: self._sendHooks.append(callback) else: log.warning(conflict_msg) else: raise ValueError(f"Invalid hook type: {hook_type}") def onElement(self, element): for hook in self._onElementHooks: hook(element) xmlstream.XmlStream.onElement(self, element) def send(self, obj): for hook in self._sendHooks: hook(obj) xmlstream.XmlStream.send(self, obj) # Binding activation (needed for stream management, XEP-0198) class CheckAuthInitializer(client.CheckAuthInitializer): def __init__(self, xs, res_binding): super(CheckAuthInitializer, self).__init__(xs) self.res_binding = res_binding def initialize(self): # XXX: modification of client.CheckAuthInitializer which has optional # resource binding, and which doesn't do deprecated # SessionInitializer if (sasl.NS_XMPP_SASL, "mechanisms") in self.xmlstream.features: inits = [(SASLInitiatingInitializer, True)] if self.res_binding: inits.append((tclient.BindInitializer, True)), for initClass, required in inits: init = initClass(self.xmlstream) init.required = required self.xmlstream.initializers.append(init) elif (tclient.NS_IQ_AUTH_FEATURE, "auth") in self.xmlstream.features: self.xmlstream.initializers.append(tclient.IQAuthInitializer(self.xmlstream)) else: raise Exception("No available authentication method found") # jid fix def internJID(jidstring): """ Return interned JID. @rtype: L{JID} """ # XXX: this interJID return a copy of the cached jid # this avoid modification of cached jid as JID is mutable # TODO: propose this upstream if jidstring in jid.__internJIDs: return copy.copy(jid.__internJIDs[jidstring]) else: j = jid.JID(jidstring) jid.__internJIDs[jidstring] = j return copy.copy(j) def apply(): # certificate validation client.XMPPClient = XMPPClient # XmlStream triggers xmlstream.XmlStreamFactory.protocol = XmlStream # jid fix jid.internJID = internJID