diff libervia/backend/core/patches.py @ 4237:a1e7e82a8921

core: implement SCRAM-SHA auth algorithm: Twisted auth mechanism are outdated, and as a result, Libervia was not supporting the mandatory SCRAM-SHA auth mechanism. This patch implements it for SCRAM-SHA-1, SCRAM-SHA-256 and SCRAM-SHA-512 variants.
author Goffi <goffi@goffi.org>
date Mon, 08 Apr 2024 12:29:40 +0200
parents 4b842c1fb686
children c14e904eee13
line wrap: on
line diff
--- a/libervia/backend/core/patches.py	Sat Apr 06 15:21:00 2024 +0200
+++ b/libervia/backend/core/patches.py	Mon Apr 08 12:29:40 2024 +0200
@@ -1,6 +1,20 @@
+import base64
 import copy
-from twisted.words.protocols.jabber import xmlstream, sasl, client as tclient, jid
+import secrets
+
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes, hmac
+from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
+from twisted.words.protocols.jabber import (
+    client as tclient,
+    jid,
+    sasl,
+    sasl_mechanisms,
+    xmlstream,
+)
 from wokkel import client
+from zope.interface import implementer
+
 from libervia.backend.core.constants import Const as C
 from libervia.backend.core.log import getLogger
 
@@ -13,21 +27,147 @@
    """
 
 
+# SCRAM-SHA implementation
+
+
+@implementer(sasl_mechanisms.ISASLMechanism)
+class ScramSha:
+    """Implements the SCRAM-SHA SASL authentication mechanism.
+
+    This mechanism is defined in RFC 5802.
+    """
+
+    ALLOWED_ALGORITHMS = ("SHA-1", "SHA-256", "SHA-512")
+    backend = default_backend()
+
+    def __init__(self, username: str, password: str, algorithm: str) -> None:
+        """Initialize SCRAM-SHA mechanism with user credentials.
+
+        @param username: The user's username.
+        @param password: The user's password.
+        """
+        if algorithm not in self.ALLOWED_ALGORITHMS:
+            raise ValueError(f"Invalid algorithm: {algorithm!r}")
+
+        self.username = username
+        self.password = password
+        self.algorithm = getattr(hashes, algorithm.replace("-", "", 1))()
+        self.name = f"SCRAM-{algorithm}"
+        self.client_nonce = base64.b64encode(secrets.token_bytes(24)).decode()
+        self.server_nonce = None
+        self.salted_password = None
+
+    def digest(self, data: bytes) -> bytes:
+        hasher = hashes.Hash(self.algorithm)
+        hasher.update(data)
+        return hasher.finalize()
+
+    def _hmac(self, key: bytes, msg: bytes) -> bytes:
+        """Compute HMAC-SHA"""
+        h = hmac.HMAC(key, self.algorithm, backend=self.backend)
+        h.update(msg)
+        return h.finalize()
+
+    def _hi(self, password: str, salt: bytes, iterations: int) -> bytes:
+        kdf = PBKDF2HMAC(
+            algorithm=self.algorithm,
+            length=self.algorithm.digest_size,
+            salt=salt,
+            iterations=iterations,
+            backend=default_backend(),
+        )
+        return kdf.derive(password.encode())
+
+    def getInitialResponse(self) -> bytes:
+        """Builds the initial client response message."""
+        return f"n,,n={self.username},r={self.client_nonce}".encode()
+
+    def getResponse(self, challenge: bytes) -> bytes:
+        """SCRAM-SHA authentication final step. Building proof of having the password.
+
+        @param challenge: Challenge string from the server.
+        @return: Client proof.
+        """
+        challenge_parts = dict(item.split("=") for item in challenge.decode().split(","))
+        self.server_nonce = challenge_parts["r"]
+        salt = base64.b64decode(challenge_parts["s"])
+        iterations = int(challenge_parts["i"])
+        self.salted_password = self._hi(self.password, salt, iterations)
+
+        client_key = self._hmac(self.salted_password, b"Client Key")
+        stored_key = self.digest(client_key)
+        auth_message = (
+            f"n={self.username},r={self.client_nonce},{challenge.decode()},c=biws,"
+            f"r={self.server_nonce}"
+        ).encode()
+        client_signature = self._hmac(stored_key, auth_message)
+        client_proof = bytes(a ^ b for a, b in zip(client_key, client_signature))
+        client_final_message = (
+            f"c=biws,r={self.server_nonce},p={base64.b64encode(client_proof).decode()}"
+        )
+        return client_final_message.encode()
+
+
+class SASLInitiatingInitializer(sasl.SASLInitiatingInitializer):
+
+    def setMechanism(self):
+        """
+        Select and setup authentication mechanism.
+
+        Uses the authenticator's C{jid} and C{password} attribute for the
+        authentication credentials. If no supported SASL mechanisms are
+        advertized by the receiving party, a failing deferred is returned with
+        a L{SASLNoAcceptableMechanism} exception.
+        """
+
+        jid = self.xmlstream.authenticator.jid
+        password = self.xmlstream.authenticator.password
+
+        mechanisms = sasl.get_mechanisms(self.xmlstream)
+        if jid.user is not None:
+            if "SCRAM-SHA-512" in mechanisms:
+                self.mechanism = ScramSha(jid.user, password, algorithm="SHA-512")
+            elif "SCRAM-SHA-256" in mechanisms:
+                self.mechanism = ScramSha(jid.user, password, algorithm="SHA-256")
+            elif "SCRAM-SHA-1" in mechanisms:
+                self.mechanism = ScramSha(jid.user, password, algorithm="SHA-1")
+            # FIXME: PLAIN should probably be disabled.
+            elif "PLAIN" in mechanisms:
+                self.mechanism = sasl_mechanisms.Plain(None, jid.user, password)
+            else:
+                raise sasl.SASLNoAcceptableMechanism()
+        else:
+            if "ANONYMOUS" in mechanisms:
+                self.mechanism = sasl_mechanisms.Anonymous()
+            else:
+                raise sasl.SASLNoAcceptableMechanism()
+
+
 ## certificate validation patches
 
 
 class XMPPClient(client.XMPPClient):
 
-    def __init__(self, jid, password, host=None, port=5222,
-                 tls_required=True, configurationForTLS=None):
+    def __init__(
+        self,
+        jid,
+        password,
+        host=None,
+        port=5222,
+        tls_required=True,
+        configurationForTLS=None,
+    ):
         self.jid = jid
-        self.domain = jid.host.encode('idna')
+        self.domain = jid.host.encode("idna")
         self.host = host
         self.port = port
 
         factory = HybridClientFactory(
-            jid, password, tls_required=tls_required,
-            configurationForTLS=configurationForTLS)
+            jid,
+            password,
+            tls_required=tls_required,
+            configurationForTLS=configurationForTLS,
+        )
 
         client.StreamManager.__init__(self, factory)
 
@@ -52,10 +192,13 @@
         xmlstream.ConnectAuthenticator.associateWithStream(self, xs)
 
         tlsInit = xmlstream.TLSInitiatingInitializer(
-            xs, required=self.tls_required, configurationForTLS=self.configurationForTLS)
-        xs.initializers = [client.client.CheckVersionInitializer(xs),
-                           tlsInit,
-                           CheckAuthInitializer(xs, self.res_binding)]
+            xs, required=self.tls_required, configurationForTLS=self.configurationForTLS
+        )
+        xs.initializers = [
+            client.client.CheckVersionInitializer(xs),
+            tlsInit,
+            CheckAuthInitializer(xs, self.res_binding),
+        ]
 
 
 # XmlStream triggers
@@ -111,8 +254,8 @@
         # XXX: modification of client.CheckAuthInitializer which has optional
         #      resource binding, and which doesn't do deprecated
         #      SessionInitializer
-        if (sasl.NS_XMPP_SASL, 'mechanisms') in self.xmlstream.features:
-            inits = [(sasl.SASLInitiatingInitializer, True)]
+        if (sasl.NS_XMPP_SASL, "mechanisms") in self.xmlstream.features:
+            inits = [(SASLInitiatingInitializer, True)]
             if self.res_binding:
                 inits.append((tclient.BindInitializer, True)),
 
@@ -120,15 +263,15 @@
                 init = initClass(self.xmlstream)
                 init.required = required
                 self.xmlstream.initializers.append(init)
-        elif (tclient.NS_IQ_AUTH_FEATURE, 'auth') in self.xmlstream.features:
-            self.xmlstream.initializers.append(
-                    tclient.IQAuthInitializer(self.xmlstream))
+        elif (tclient.NS_IQ_AUTH_FEATURE, "auth") in self.xmlstream.features:
+            self.xmlstream.initializers.append(tclient.IQAuthInitializer(self.xmlstream))
         else:
             raise Exception("No available authentication method found")
 
 
 # jid fix
 
+
 def internJID(jidstring):
     """
     Return interned JID.