diff libervia/backend/plugins/plugin_xep_0448.py @ 4334:111dce64dcb5

plugins XEP-0300, XEP-0446, XEP-0447, XEP0448 and others: Refactoring to use Pydantic: Pydantic models are used more and more in Libervia, for the bridge API, and also to convert `domish.Element` to internal representation. Type hints have also been added in many places. rel 453
author Goffi <goffi@goffi.org>
date Tue, 03 Dec 2024 00:12:38 +0100
parents 0d7bb4df2343
children
line wrap: on
line diff
--- a/libervia/backend/plugins/plugin_xep_0448.py	Tue Dec 03 00:11:00 2024 +0100
+++ b/libervia/backend/plugins/plugin_xep_0448.py	Tue Dec 03 00:12:38 2024 +0100
@@ -21,13 +21,14 @@
 from pathlib import Path
 import secrets
 from textwrap import dedent
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Self, Tuple, Union, cast
 
 from cryptography.exceptions import AlreadyFinalized
 from cryptography.hazmat import backends
 from cryptography.hazmat.primitives import ciphers
 from cryptography.hazmat.primitives.ciphers import CipherContext, modes
 from cryptography.hazmat.primitives.padding import PKCS7, PaddingContext
+from pydantic import BaseModel, ValidationError
 import treq
 from twisted.internet import defer
 from twisted.words.protocols.jabber.xmlstream import XMPPHandler
@@ -40,6 +41,10 @@
 from libervia.backend.core.core_types import SatXMPPEntity
 from libervia.backend.core.i18n import _
 from libervia.backend.core.log import getLogger
+from libervia.backend.plugins.plugin_misc_download import DownloadPlugin
+from libervia.backend.plugins.plugin_xep_0103 import XEP_0103
+from libervia.backend.plugins.plugin_xep_0300 import NS_HASHES, XEP_0300, Hash
+from libervia.backend.plugins.plugin_xep_0447 import XEP_0447, Source
 from libervia.backend.tools import stream
 from libervia.backend.tools.web import treq_client_no_ssl
 
@@ -79,90 +84,121 @@
 NS_AES_256_CBC = "urn:xmpp:ciphers:aes-256-cbc-pkcs7:0"
 
 
+class EncryptedSource(Source):
+    type = "encrypted"
+    encrypted = True
+    cipher: str
+    key: str
+    iv: str
+    hashes: list[Hash]
+    sources: list[Source]
+    _hash: XEP_0300 | None = None
+    _sfs: XEP_0447 | None = None
+
+    @classmethod
+    def from_element(cls, element: domish.Element) -> Self:
+        """Parse an <encrypted> element and return corresponding EncryptedData model
+
+        @param encrypted_elt: element to parse
+        @raise exceptions.DataError: the element is invalid
+
+        """
+        assert cls._hash is not None, "_hash attribute is not set"
+        assert cls._sfs is not None, "_sfs attribute is not set"
+        try:
+            cipher = element["cipher"]
+            key = str(next(element.elements(NS_ESFS, "key")))
+            iv = str(next(element.elements(NS_ESFS, "iv")))
+        except (KeyError, StopIteration):
+            raise exceptions.DataError(
+                "invalid <encrypted/> element: {encrypted_elt.toXml()}"
+            )
+        sources = cls._sfs.parse_sources_elt(element)
+        if not sources:
+            raise exceptions.DataError(f"Sources are missing in {element.toXml()}")
+
+        if any(isinstance(source, cls) for source in sources):
+            raise exceptions.DataError(
+                f"EncryptedData is used as a source of another EncryptedData"
+            )
+
+        encrypted_data = {
+            "cipher": cipher,
+            "key": key,
+            "iv": iv,
+            "hashes": Hash.from_parent(element),
+            "sources": sources,
+        }
+
+        return cls(**encrypted_data)
+
+    def to_element(self) -> domish.Element:
+        """Convert EncryptedData model to an <encrypted> element
+
+        @return: domish.Element representing the encrypted data
+
+        """
+        assert self._hash is not None, "_hash attribute is not set"
+        encrypted_elt = domish.Element((NS_ESFS, "encrypted"))
+        encrypted_elt["cipher"] = self.cipher
+        encrypted_elt.addElement("key").addContent(self.key)
+        encrypted_elt.addElement("iv").addContent(self.iv)
+        for hash_ in self.hashes:
+            encrypted_elt.addChild(hash_.to_element())
+
+        return encrypted_elt
+
+
 class XEP_0448:
 
     def __init__(self, host):
         self.host = host
         log.info(_("XEP_0448 plugin initialization"))
         host.register_namespace("esfs", NS_ESFS)
-        self._u = host.plugins["XEP-0103"]
-        self._h = host.plugins["XEP-0300"]
+        self._u = cast(XEP_0103, host.plugins["XEP-0103"])
+        self._h = cast(XEP_0300, host.plugins["XEP-0300"])
         self._hints = host.plugins["XEP-0334"]
         self._http_upload = host.plugins["XEP-0363"]
         self._o = host.plugins["XEP-0384"]
-        self._sfs = host.plugins["XEP-0447"]
-        self._sfs.register_source_handler(
-            NS_ESFS, "encrypted", self.parse_encrypted_elt, encrypted=True
-        )
+        self._sfs = cast(XEP_0447, host.plugins["XEP-0447"])
+        self._sfs.register_source(NS_ESFS, "encrypted", EncryptedSource)
         self._attach = host.plugins["ATTACH"]
         self._attach.register(
             self.can_handle_attachment, self.attach, encrypted=True, priority=1000
         )
-        host.plugins["DOWNLOAD"].register_download_handler(NS_ESFS, self.download)
+        EncryptedSource._hash = self._h
+        EncryptedSource._sfs = self._sfs
+        download = cast(DownloadPlugin, host.plugins["DOWNLOAD"])
+        download.register_download_handler(NS_ESFS, self.download)
         host.trigger.add("XEP-0363_upload_pre_slot", self._upload_pre_slot)
         host.trigger.add("XEP-0363_upload", self._upload_trigger)
 
     def get_handler(self, client):
         return XEP0448Handler()
 
-    def parse_encrypted_elt(self, encrypted_elt: domish.Element) -> Dict[str, Any]:
-        """Parse an <encrypted> element and return corresponding source data
-
-        @param encrypted_elt: element to parse
-        @raise exceptions.DataError: the element is invalid
-
-        """
-        sources = self._sfs.parse_sources_elt(encrypted_elt)
-        if not sources:
-            raise exceptions.NotFound("sources are missing in {encrypted_elt.toXml()}")
-        if len(sources) > 1:
-            log.debug(
-                "more that one sources has been found, this is not expected, only the "
-                "first one will be used"
-            )
-        source = sources[0]
-        source["type"] = NS_ESFS
-        try:
-            encrypted_data = source["encrypted_data"] = {
-                "cipher": encrypted_elt["cipher"],
-                "key": str(next(encrypted_elt.elements(NS_ESFS, "key"))),
-                "iv": str(next(encrypted_elt.elements(NS_ESFS, "iv"))),
-            }
-        except (KeyError, StopIteration):
-            raise exceptions.DataError(
-                "invalid <encrypted/> element: {encrypted_elt.toXml()}"
-            )
-        try:
-            hash_algo, hash_value = self._h.parse_hash_elt(encrypted_elt)
-        except exceptions.NotFound:
-            pass
-        else:
-            encrypted_data["hash_algo"] = hash_algo
-            encrypted_data["hash"] = base64.b64encode(hash_value.encode()).decode()
-        return source
-
     async def download(
         self,
         client: SatXMPPEntity,
-        attachment: Dict[str, Any],
-        source: Dict[str, Any],
+        attachment: dict[str, Any],
+        source: dict[str, Any],
         dest_path: Union[Path, str],
-        extra: Optional[Dict[str, Any]] = None,
-    ) -> Tuple[str, defer.Deferred]:
+        extra: dict[str, Any] | None = None,
+    ) -> tuple[str, defer.Deferred]:
         # TODO: check hash
         if extra is None:
             extra = {}
+        assert source["type"] == "encrypted"
         try:
-            encrypted_data = source["encrypted_data"]
-            cipher = encrypted_data["cipher"]
-            iv = base64.b64decode(encrypted_data["iv"])
-            key = base64.b64decode(encrypted_data["key"])
+            cipher = source["cipher"]
+            iv = base64.b64decode(source["iv"])
+            key = base64.b64decode(source["key"])
         except KeyError as e:
-            raise ValueError(f"{source} has incomplete encryption data: {e}")
+            raise ValueError(f"{source} has incomplete encryption data: {e}") from e
+
         try:
-            download_url = source["url"]
-        except KeyError:
-            raise ValueError(f"{source} has missing URL")
+            download_url = source["sources"][0]["url"]
+        except (IndexError, KeyError) as e:
+            raise ValueError(f"{source} has missing URL") from e
 
         if extra.get("ignore_tls_errors", False):
             log.warning("TLS certificate check disabled, this is highly insecure")
@@ -294,9 +330,9 @@
                     size=attachment["size"],
                     file_hash=file_hash,
                 )
-                encrypted_elt = file_sharing_elt.sources.addElement(
-                    (NS_ESFS, "encrypted")
-                )
+                sources_elt = file_sharing_elt.sources
+                assert sources_elt is not None
+                encrypted_elt = sources_elt.addElement((NS_ESFS, "encrypted"))
                 encrypted_elt["cipher"] = NS_AES_256_GCM
                 encrypted_elt.addElement(
                     "key", content=base64.b64encode(encryption_data["key"]).decode()
@@ -311,7 +347,7 @@
                 )
                 encrypted_elt.addChild(
                     self._sfs.get_sources_elt(
-                        [self._u.get_url_data_elt(attachment["url"])]
+                        [self._u.generate_url_data(attachment["url"]).to_element()]
                     )
                 )
                 data["xml"].addChild(file_sharing_elt)