diff libervia/backend/plugins/plugin_xep_0300.py @ 4334:111dce64dcb5

plugins XEP-0300, XEP-0446, XEP-0447, XEP0448 and others: Refactoring to use Pydantic: Pydantic models are used more and more in Libervia, for the bridge API, and also to convert `domish.Element` to internal representation. Type hints have also been added in many places. rel 453
author Goffi <goffi@goffi.org>
date Tue, 03 Dec 2024 00:12:38 +0100
parents 0d7bb4df2343
children
line wrap: on
line diff
--- a/libervia/backend/plugins/plugin_xep_0300.py	Tue Dec 03 00:11:00 2024 +0100
+++ b/libervia/backend/plugins/plugin_xep_0300.py	Tue Dec 03 00:12:38 2024 +0100
@@ -1,8 +1,7 @@
 #!/usr/bin/env python3
 
-
-# SAT plugin for Hash functions (XEP-0300)
-# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)
+# Libervia plugin for Hash functions (XEP-0300)
+# Copyright (C) 2009-2024 Jérôme Poisson (goffi@goffi.org)
 
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Affero General Public License as published by
@@ -17,13 +16,16 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
-from typing import Tuple
 import base64
 from collections import OrderedDict
 import hashlib
+from typing import BinaryIO, Callable, Self, TYPE_CHECKING
 
+from _hashlib import HASH
+from pydantic import BaseModel, Field
 from twisted.internet import threads
 from twisted.internet import defer
+from twisted.words.protocols.jabber import jid
 from twisted.words.protocols.jabber.xmlstream import XMPPHandler
 from twisted.words.xish import domish
 from wokkel import disco, iwokkel
@@ -31,9 +33,13 @@
 
 from libervia.backend.core import exceptions
 from libervia.backend.core.constants import Const as C
+from libervia.backend.core.core_types import SatXMPPEntity
 from libervia.backend.core.i18n import _
 from libervia.backend.core.log import getLogger
 
+if TYPE_CHECKING:
+    from libervia.backend.core.main import LiberviaBackend
+
 log = getLogger(__name__)
 
 
@@ -54,9 +60,91 @@
 ALGO_DEFAULT = "sha-256"
 
 
-class XEP_0300(object):
+class Hash(BaseModel):
+    """
+    Model for hash data.
+    """
+
+    algo: str = Field(description="The algorithm used for hashing.")
+    hash_: str = Field(min_length=16, description="The base64-encoded hash value.")
+
+    @classmethod
+    def from_element(cls, hash_elt: domish.Element) -> Self:
+        """
+        Create a HashModel instance from a <hash> element.
+
+        @param hash_elt: The <hash> element.
+        @return: HashModel instance.
+        @raise exceptions.NotFound: If the <hash> element is not found.
+        """
+        if hash_elt.uri != NS_HASHES or hash_elt.name != "hash":
+            raise exceptions.NotFound("<hash> element not found")
+        algo = hash_elt.getAttribute("algo")
+        hash_value = str(hash_elt)
+        return cls(algo=algo, hash_=hash_value)
+
+    @classmethod
+    def from_parent(cls, parent_elt: domish.Element) -> list[Self]:
+        """Find and return child <hash> element in given parent.
+
+        @param parent_elt: Element which may content child <hash> elements.
+        @return: list of Hash corresponding to found elements
+        """
+        return [
+            cls.from_element(hash_elt)
+            for hash_elt in parent_elt.elements(NS_HASHES, "hash")
+        ]
+
+    def to_element(self) -> domish.Element:
+        """Build the <hash> element from this instance's data.
+
+        @return: <hash> element.
+        """
+        hash_elt = domish.Element((NS_HASHES, "hash"))
+        hash_elt["algo"] = self.algo
+        hash_elt.addContent(self.hash_)
+        return hash_elt
+
+
+class HashUsed(BaseModel):
+    """
+    Model for hash-used data.
+    """
+
+    algo: str = Field(description="The algorithm used for hashing.")
+
+    @classmethod
+    def from_element(cls, hash_used_elt: domish.Element) -> Self:
+        """Create a HashUsedModel instance from a <hash-used> element.
+
+        @param hash_used_elt: The <hash-used> element.
+        @return: HashUsedModel instance.
+        @raise exceptions.NotFound: If the <hash-used> element is not found.
+        """
+        if hash_used_elt.uri != NS_HASHES or hash_used_elt.name != "hash-used":
+            child_hash_used_elt = next(
+                hash_used_elt.elements(NS_HASHES, "hash-used"), None
+            )
+            if child_hash_used_elt is None:
+                raise exceptions.NotFound("<hash-used> element not found")
+            else:
+                hash_used_elt = child_hash_used_elt
+        algo = hash_used_elt.getAttribute("algo")
+        return cls(algo=algo)
+
+    def to_element(self) -> domish.Element:
+        """Build the <hash-used> element from this instance's data.
+
+        @return: <hash-used> element.
+        """
+        hash_used_elt = domish.Element((NS_HASHES, "hash-used"))
+        hash_used_elt["algo"] = self.algo
+        return hash_used_elt
+
+
+class XEP_0300:
     # TODO: add blake after moving to Python 3
-    ALGOS = OrderedDict(
+    ALGOS: OrderedDict[str, Callable] = OrderedDict(
         (
             ("md5", hashlib.md5),
             ("sha-1", hashlib.sha1),
@@ -66,38 +154,38 @@
     )
     ALGO_DEFAULT = ALGO_DEFAULT
 
-    def __init__(self, host):
+    def __init__(self, host: "LiberviaBackend"):
         log.info(_("plugin Hashes initialization"))
         host.register_namespace("hashes", NS_HASHES)
+        self.host = host
 
-    def get_handler(self, client):
+    def get_handler(self, client: SatXMPPEntity) -> XMPPHandler:
         return XEP_0300_handler()
 
-    def get_hasher(self, algo=ALGO_DEFAULT):
+    def get_hasher(self, algo: str = ALGO_DEFAULT) -> Callable:
         """Return hasher instance
 
-        @param algo(unicode): one of the XEP_300.ALGOS keys
-        @return (hash object): same object s in hashlib.
-           update method need to be called for each chunh
-           diget or hexdigest can be used at the end
+        @param algo: one of the XEP_300.ALGOS keys
+        @return: same object s in hashlib.
+           update method need to be called for each chunk
+           digest or hexdigest can be used at the end
         """
         return self.ALGOS[algo]()
 
-    def get_default_algo(self):
+    def get_default_algo(self) -> str:
         return ALGO_DEFAULT
 
-    @defer.inlineCallbacks
-    def get_best_peer_algo(self, to_jid, profile):
-        """Return the best available hashing algorith of other peer
+    async def get_best_peer_algo(self, to_jid: jid.JID, profile: str) -> str | None:
+        """Return the best available hashing algorithm of other peer
 
-        @param to_jid(jid.JID): peer jid
-        @parm profile: %(doc_profile)s
-        @return (D(unicode, None)): best available algorithm,
+        @param to_jid: peer jid
+        @param profile: %(doc_profile)s
+        @return: best available algorithm,
            or None if hashing is not possible
         """
         client = self.host.get_client(profile)
         for algo in reversed(XEP_0300.ALGOS):
-            has_feature = yield self.host.hasFeature(
+            has_feature = await self.host.hasFeature(
                 client, NS_HASHES_FUNCTIONS.format(algo), to_jid
             )
             if has_feature:
@@ -106,15 +194,15 @@
                         jid=to_jid.full(), algo=algo
                     )
                 )
-                defer.returnValue(algo)
+                return algo
 
-    def _calculate_hash_blocking(self, file_obj, hasher):
+    def _calculate_hash_blocking(self, file_obj: BinaryIO, hasher: HASH) -> str:
         """Calculate hash in a blocking way
 
         /!\\ blocking method, please use calculate_hash instead
-        @param file_obj(file): a file-like object
-        @param hasher(hash object): the method to call to initialise hash object
-        @return (str): the hex digest of the hash
+        @param file_obj: a file-like object
+        @param hasher: the method to call to initialise hash object
+        @return: the hex digest of the hash
         """
         while True:
             buf = file_obj.read(BUFFER_SIZE)
@@ -123,64 +211,49 @@
             hasher.update(buf)
         return hasher.hexdigest()
 
-    def calculate_hash(self, file_obj, hasher):
+    def calculate_hash(self, file_obj: BinaryIO, hasher: HASH) -> defer.Deferred[str]:
         return threads.deferToThread(self._calculate_hash_blocking, file_obj, hasher)
 
-    def calculate_hash_elt(self, file_obj=None, algo=ALGO_DEFAULT):
+    async def calculate_hash_elt(
+        self, file_obj: BinaryIO, algo: str = ALGO_DEFAULT
+    ) -> domish.Element:
         """Compute hash and build hash element
 
-        @param file_obj(file, None): file-like object to use to calculate the hash
-        @param algo(unicode): algorithme to use, must be a key of XEP_0300.ALGOS
-        @return (D(domish.Element)): hash element
+        @param file_obj: file-like object to use to calculate the hash
+        @param algo: algorithm to use, must be a key of XEP_0300.ALGOS
+        @return: hash element
         """
-
-        def hash_calculated(hash_):
-            return self.build_hash_elt(hash_, algo)
-
         hasher = self.get_hasher(algo)
-        hash_d = self.calculate_hash(file_obj, hasher)
-        hash_d.addCallback(hash_calculated)
-        return hash_d
+        hash_ = await self.calculate_hash(file_obj, hasher)
+        return self.build_hash_elt(hash_, algo)
 
-    def build_hash_used_elt(self, algo=ALGO_DEFAULT):
-        hash_used_elt = domish.Element((NS_HASHES, "hash-used"))
-        hash_used_elt["algo"] = algo
-        return hash_used_elt
+    def build_hash_used_elt(self, algo: str = ALGO_DEFAULT) -> domish.Element:
+        hash_used_model = HashUsed(algo=algo)
+        return hash_used_model.to_element()
 
-    def parse_hash_used_elt(self, parent):
+    def parse_hash_used_elt(self, parent_elt: domish.Element) -> str:
         """Find and parse a hash-used element
 
-        @param (domish.Element): parent of <hash/> element
-        @return (unicode): hash algorithm used
+        @param parent: parent of <hash-used/> element
+        @return: hash algorithm used
         @raise exceptions.NotFound: the element is not present
         @raise exceptions.DataError: the element is invalid
         """
-        try:
-            hash_used_elt = next(parent.elements(NS_HASHES, "hash-used"))
-        except StopIteration:
-            raise exceptions.NotFound
-        algo = hash_used_elt["algo"]
-        if not algo:
-            raise exceptions.DataError
-        return algo
+        hash_used_model = HashUsed.from_element(parent_elt)
+        return hash_used_model.algo
 
-    def build_hash_elt(self, hash_, algo=ALGO_DEFAULT):
+    def build_hash_elt(self, hash_hex: str, algo: str = ALGO_DEFAULT) -> domish.Element:
         """Compute hash and build hash element
 
-        @param hash_(str): hash to use
-        @param algo(unicode): algorithme to use, must be a key of XEP_0300.ALGOS
-        @return (domish.Element): computed hash
+        @param hash_: Hexadecimal representation of hash to use.
+        @param algo: Algorithm to use, must be a key of XEP_0300.ALGOS.
+        @return: <hash> element
         """
-        assert hash_
-        assert algo
-        hash_elt = domish.Element((NS_HASHES, "hash"))
-        if hash_ is not None:
-            b64_hash = base64.b64encode(hash_.encode("utf-8")).decode("utf-8")
-            hash_elt.addContent(b64_hash)
-        hash_elt["algo"] = algo
-        return hash_elt
+        b64_hash = base64.b64encode(hash_hex.encode()).decode()
+        hash_model = Hash(algo=algo, hash_=b64_hash)
+        return hash_model.to_element()
 
-    def parse_hash_elt(self, parent: domish.Element) -> Tuple[str, bytes]:
+    def parse_hash_elt(self, parent: domish.Element) -> tuple[str, str]:
         """Find and parse a hash element
 
         if multiple elements are found, the strongest managed one is returned
@@ -195,7 +268,8 @@
         best_algo = None
         best_value = None
         for hash_elt in parent.elements(NS_HASHES, "hash"):
-            algo = hash_elt.getAttribute("algo")
+            hash_model = Hash.from_element(hash_elt)
+            algo = hash_model.algo
             try:
                 idx = algos.index(algo)
             except ValueError:
@@ -205,7 +279,7 @@
 
             if best_algo is None or algos.index(best_algo) < idx:
                 best_algo = algo
-                best_value = base64.b64decode(str(hash_elt)).decode("utf-8")
+                best_value = base64.b64decode(hash_model.hash_).decode()
 
         if not hash_elt:
             raise exceptions.NotFound
@@ -217,12 +291,16 @@
 @implementer(iwokkel.IDisco)
 class XEP_0300_handler(XMPPHandler):
 
-    def getDiscoInfo(self, requestor, target, nodeIdentifier=""):
+    def getDiscoInfo(
+        self, requestor: jid.JID, target: jid.JID, nodeIdentifier: str = ""
+    ) -> list[disco.DiscoFeature]:
         hash_functions_names = [
             disco.DiscoFeature(NS_HASHES_FUNCTIONS.format(algo))
             for algo in XEP_0300.ALGOS
         ]
         return [disco.DiscoFeature(NS_HASHES)] + hash_functions_names
 
-    def getDiscoItems(self, requestor, target, nodeIdentifier=""):
+    def getDiscoItems(
+        self, requestor: jid.JID, target: jid.JID, nodeIdentifier: str = ""
+    ) -> list[disco.DiscoItem]:
         return []