Mercurial > libervia-backend
diff libervia/backend/core/patches.py @ 4237:a1e7e82a8921
core: implement SCRAM-SHA auth algorithm:
Twisted auth mechanism are outdated, and as a result, Libervia was not supporting the
mandatory SCRAM-SHA auth mechanism. This patch implements it for SCRAM-SHA-1,
SCRAM-SHA-256 and SCRAM-SHA-512 variants.
author | Goffi <goffi@goffi.org> |
---|---|
date | Mon, 08 Apr 2024 12:29:40 +0200 |
parents | 4b842c1fb686 |
children | c14e904eee13 |
line wrap: on
line diff
--- a/libervia/backend/core/patches.py Sat Apr 06 15:21:00 2024 +0200 +++ b/libervia/backend/core/patches.py Mon Apr 08 12:29:40 2024 +0200 @@ -1,6 +1,20 @@ +import base64 import copy -from twisted.words.protocols.jabber import xmlstream, sasl, client as tclient, jid +import secrets + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, hmac +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from twisted.words.protocols.jabber import ( + client as tclient, + jid, + sasl, + sasl_mechanisms, + xmlstream, +) from wokkel import client +from zope.interface import implementer + from libervia.backend.core.constants import Const as C from libervia.backend.core.log import getLogger @@ -13,21 +27,147 @@ """ +# SCRAM-SHA implementation + + +@implementer(sasl_mechanisms.ISASLMechanism) +class ScramSha: + """Implements the SCRAM-SHA SASL authentication mechanism. + + This mechanism is defined in RFC 5802. + """ + + ALLOWED_ALGORITHMS = ("SHA-1", "SHA-256", "SHA-512") + backend = default_backend() + + def __init__(self, username: str, password: str, algorithm: str) -> None: + """Initialize SCRAM-SHA mechanism with user credentials. + + @param username: The user's username. + @param password: The user's password. + """ + if algorithm not in self.ALLOWED_ALGORITHMS: + raise ValueError(f"Invalid algorithm: {algorithm!r}") + + self.username = username + self.password = password + self.algorithm = getattr(hashes, algorithm.replace("-", "", 1))() + self.name = f"SCRAM-{algorithm}" + self.client_nonce = base64.b64encode(secrets.token_bytes(24)).decode() + self.server_nonce = None + self.salted_password = None + + def digest(self, data: bytes) -> bytes: + hasher = hashes.Hash(self.algorithm) + hasher.update(data) + return hasher.finalize() + + def _hmac(self, key: bytes, msg: bytes) -> bytes: + """Compute HMAC-SHA""" + h = hmac.HMAC(key, self.algorithm, backend=self.backend) + h.update(msg) + return h.finalize() + + def _hi(self, password: str, salt: bytes, iterations: int) -> bytes: + kdf = PBKDF2HMAC( + algorithm=self.algorithm, + length=self.algorithm.digest_size, + salt=salt, + iterations=iterations, + backend=default_backend(), + ) + return kdf.derive(password.encode()) + + def getInitialResponse(self) -> bytes: + """Builds the initial client response message.""" + return f"n,,n={self.username},r={self.client_nonce}".encode() + + def getResponse(self, challenge: bytes) -> bytes: + """SCRAM-SHA authentication final step. Building proof of having the password. + + @param challenge: Challenge string from the server. + @return: Client proof. + """ + challenge_parts = dict(item.split("=") for item in challenge.decode().split(",")) + self.server_nonce = challenge_parts["r"] + salt = base64.b64decode(challenge_parts["s"]) + iterations = int(challenge_parts["i"]) + self.salted_password = self._hi(self.password, salt, iterations) + + client_key = self._hmac(self.salted_password, b"Client Key") + stored_key = self.digest(client_key) + auth_message = ( + f"n={self.username},r={self.client_nonce},{challenge.decode()},c=biws," + f"r={self.server_nonce}" + ).encode() + client_signature = self._hmac(stored_key, auth_message) + client_proof = bytes(a ^ b for a, b in zip(client_key, client_signature)) + client_final_message = ( + f"c=biws,r={self.server_nonce},p={base64.b64encode(client_proof).decode()}" + ) + return client_final_message.encode() + + +class SASLInitiatingInitializer(sasl.SASLInitiatingInitializer): + + def setMechanism(self): + """ + Select and setup authentication mechanism. + + Uses the authenticator's C{jid} and C{password} attribute for the + authentication credentials. If no supported SASL mechanisms are + advertized by the receiving party, a failing deferred is returned with + a L{SASLNoAcceptableMechanism} exception. + """ + + jid = self.xmlstream.authenticator.jid + password = self.xmlstream.authenticator.password + + mechanisms = sasl.get_mechanisms(self.xmlstream) + if jid.user is not None: + if "SCRAM-SHA-512" in mechanisms: + self.mechanism = ScramSha(jid.user, password, algorithm="SHA-512") + elif "SCRAM-SHA-256" in mechanisms: + self.mechanism = ScramSha(jid.user, password, algorithm="SHA-256") + elif "SCRAM-SHA-1" in mechanisms: + self.mechanism = ScramSha(jid.user, password, algorithm="SHA-1") + # FIXME: PLAIN should probably be disabled. + elif "PLAIN" in mechanisms: + self.mechanism = sasl_mechanisms.Plain(None, jid.user, password) + else: + raise sasl.SASLNoAcceptableMechanism() + else: + if "ANONYMOUS" in mechanisms: + self.mechanism = sasl_mechanisms.Anonymous() + else: + raise sasl.SASLNoAcceptableMechanism() + + ## certificate validation patches class XMPPClient(client.XMPPClient): - def __init__(self, jid, password, host=None, port=5222, - tls_required=True, configurationForTLS=None): + def __init__( + self, + jid, + password, + host=None, + port=5222, + tls_required=True, + configurationForTLS=None, + ): self.jid = jid - self.domain = jid.host.encode('idna') + self.domain = jid.host.encode("idna") self.host = host self.port = port factory = HybridClientFactory( - jid, password, tls_required=tls_required, - configurationForTLS=configurationForTLS) + jid, + password, + tls_required=tls_required, + configurationForTLS=configurationForTLS, + ) client.StreamManager.__init__(self, factory) @@ -52,10 +192,13 @@ xmlstream.ConnectAuthenticator.associateWithStream(self, xs) tlsInit = xmlstream.TLSInitiatingInitializer( - xs, required=self.tls_required, configurationForTLS=self.configurationForTLS) - xs.initializers = [client.client.CheckVersionInitializer(xs), - tlsInit, - CheckAuthInitializer(xs, self.res_binding)] + xs, required=self.tls_required, configurationForTLS=self.configurationForTLS + ) + xs.initializers = [ + client.client.CheckVersionInitializer(xs), + tlsInit, + CheckAuthInitializer(xs, self.res_binding), + ] # XmlStream triggers @@ -111,8 +254,8 @@ # XXX: modification of client.CheckAuthInitializer which has optional # resource binding, and which doesn't do deprecated # SessionInitializer - if (sasl.NS_XMPP_SASL, 'mechanisms') in self.xmlstream.features: - inits = [(sasl.SASLInitiatingInitializer, True)] + if (sasl.NS_XMPP_SASL, "mechanisms") in self.xmlstream.features: + inits = [(SASLInitiatingInitializer, True)] if self.res_binding: inits.append((tclient.BindInitializer, True)), @@ -120,15 +263,15 @@ init = initClass(self.xmlstream) init.required = required self.xmlstream.initializers.append(init) - elif (tclient.NS_IQ_AUTH_FEATURE, 'auth') in self.xmlstream.features: - self.xmlstream.initializers.append( - tclient.IQAuthInitializer(self.xmlstream)) + elif (tclient.NS_IQ_AUTH_FEATURE, "auth") in self.xmlstream.features: + self.xmlstream.initializers.append(tclient.IQAuthInitializer(self.xmlstream)) else: raise Exception("No available authentication method found") # jid fix + def internJID(jidstring): """ Return interned JID.