diff sat/plugins/plugin_misc_download.py @ 3922:0ff265725489

plugin XEP-0447: handle attachment and download: - plugin XEP-0447 can now be used in message attachments and to retrieve an attachment - plugin attach: `attachment` being processed is added to `extra` so the handler can inspect it - plugin attach: `size` is added to attachment - plugin download: a whole attachment dict is now used in `download` and `file_download`/`file_download_complete`. `download_uri` can be used as a shortcut when just a URI is used. In addition to URI scheme handler, whole attachment handlers can now be registered with `register_download_handler` - plugin XEP-0363: `file_http_upload` `XEP-0363_upload_size` triggers have been renamed to `XEP-0363_upload_pre_slot` and is now using a dict with arguments, allowing for the size but also the filename to be modified, which is necessary for encryption (filename may be hidden from URL this way). - plugin XEP-0446: fix wrong element name - plugin XEP-0447: source handler can now be registered (`url-data` is registered by default) - plugin XEP-0447: source parsing has been put in a separated `parse_sources_elt` method, as it may be useful to do it independently (notably with XEP-0448) - plugin XEP-0447: parse received message and complete attachments when suitable - plugin XEP-0447: can now be used with message attachments - plugin XEP-0447: can now be used with attachments download - renamed `options` arguments to `extra` for consistency - some style change (progressive move from legacy camelCase to PEP8 snake_case) - some typing rel 379
author Goffi <goffi@goffi.org>
date Thu, 06 Oct 2022 16:02:05 +0200
parents be6d91572633
children 524856bd7b19
line wrap: on
line diff
--- a/sat/plugins/plugin_misc_download.py	Thu Oct 06 16:02:05 2022 +0200
+++ b/sat/plugins/plugin_misc_download.py	Thu Oct 06 16:02:05 2022 +0200
@@ -16,19 +16,23 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+import hashlib
 from pathlib import Path
-from urllib.parse import urlparse, unquote
-import hashlib
+from typing import Any, Dict, Optional, Union, Tuple, Callable
+from urllib.parse import unquote, urlparse
+
 import treq
 from twisted.internet import defer
 from twisted.words.protocols.jabber import error as jabber_error
-from sat.core.i18n import _, D_
-from sat.core.constants import Const as C
-from sat.core.log import getLogger
+
 from sat.core import exceptions
+from sat.core.constants import Const as C
+from sat.core.core_types import SatXMPPEntity
+from sat.core.i18n import D_, _
+from sat.core.log import getLogger
 from sat.tools import xml_tools
+from sat.tools import stream
 from sat.tools.common import data_format
-from sat.tools import stream
 from sat.tools.web import treq_client_no_ssl
 
 log = getLogger(__name__)
@@ -38,6 +42,7 @@
     C.PI_NAME: "File Download",
     C.PI_IMPORT_NAME: "DOWNLOAD",
     C.PI_TYPE: C.PLUG_TYPE_MISC,
+    C.PI_MODES: C.PLUG_MODE_BOTH,
     C.PI_MAIN: "DownloadPlugin",
     C.PI_HANDLER: "no",
     C.PI_DESCRIPTION: _("""File download management"""),
@@ -53,7 +58,7 @@
             "fileDownload",
             ".plugin",
             in_sign="ssss",
-            out_sign="a{ss}",
+            out_sign="s",
             method=self._fileDownload,
             async_=True,
         )
@@ -66,18 +71,29 @@
             async_=True,
         )
         self._download_callbacks = {}
-        self.registerScheme('http', self.downloadHTTP)
-        self.registerScheme('https', self.downloadHTTP)
+        self._scheme_callbacks = {}
+        self.register_scheme('http', self.download_http)
+        self.register_scheme('https', self.download_http)
 
-    def _fileDownload(self, uri, dest_path, options_s, profile):
-        client = self.host.getClient(profile)
-        options = data_format.deserialise(options_s)
+    def _fileDownload(
+            self, attachment_s: str, dest_path: str, extra_s: str, profile: str
+    ) -> defer.Deferred:
+        d = defer.ensureDeferred(self.file_download(
+            self.host.getClient(profile),
+            data_format.deserialise(attachment_s),
+            Path(dest_path),
+            data_format.deserialise(extra_s)
+        ))
+        d.addCallback(lambda ret: data_format.serialise(ret))
+        return d
 
-        return defer.ensureDeferred(self.fileDownload(
-            client, uri, Path(dest_path), options
-        ))
-
-    async def fileDownload(self, client, uri, dest_path, options=None):
+    async def file_download(
+        self,
+        client: SatXMPPEntity,
+        attachment: Dict[str, Any],
+        dest_path: Path,
+        extra: Optional[Dict[str, Any]] = None
+    ) -> Dict[str, Any]:
         """Download a file using best available method
 
         parameters are the same as for [download]
@@ -85,7 +101,7 @@
             message
         """
         try:
-            progress_id, __ = await self.download(client, uri, dest_path, options)
+            progress_id, __ = await self.download(client, attachment, dest_path, extra)
         except Exception as e:
             if (isinstance(e, jabber_error.StanzaError)
                 and e.condition == 'not-acceptable'):
@@ -102,45 +118,44 @@
         else:
             return {"progress": progress_id}
 
-    def _fileDownloadComplete(self, uri, dest_path, options_s, profile):
-        client = self.host.getClient(profile)
-        options = data_format.deserialise(options_s)
-
-        d = defer.ensureDeferred(self.fileDownloadComplete(
-            client, uri, dest_path, options
+    def _fileDownloadComplete(
+            self, attachment_s: str, dest_path: str, extra_s: str, profile: str
+    ) -> defer.Deferred:
+        d = defer.ensureDeferred(self.file_download_complete(
+            self.host.getClient(profile),
+            data_format.deserialise(attachment_s),
+            Path(dest_path),
+            data_format.deserialise(extra_s)
         ))
         d.addCallback(lambda path: str(path))
         return d
 
-    async def fileDownloadComplete(self, client, uri, dest_path, options=None):
+    async def file_download_complete(
+        self,
+        client: SatXMPPEntity,
+        attachment: Dict[str, Any],
+        dest_path: Path,
+        extra: Optional[Dict[str, Any]] = None
+    ) -> str:
         """Helper method to fully download a file and return its path
 
         parameters are the same as for [download]
         @return (str): path to the downloaded file
             use empty string to store the file in cache
         """
-        __, download_d = await self.download(client, uri, dest_path, options)
+        __, download_d = await self.download(client, attachment, dest_path, extra)
         dest_path = await download_d
         return dest_path
 
-    async def download(self, client, uri, dest_path, options=None):
-        """Send a file using best available method
-
-        @param uri(str): URI to the file to download
-        @param dest_path(str, Path): where the file must be downloaded
-            if empty string, the file will be stored in local path
-        @param options(dict, None): options depending on scheme handler
-            Some common options:
-                - ignore_tls_errors(bool): True to ignore SSL/TLS certificate verification
-                  used only if HTTPS transport is needed
-        @return (tuple[unicode,D(unicode)]): progress_id and a Deferred which fire
-            download URL when download is finished
-            progress_id can be empty string if the file already exist and is not
-            downloaded again (can happen if cache is used with empty dest_path)
-        """
-        if options is None:
-            options = {}
-
+    async def download_uri(
+        self,
+        client: SatXMPPEntity,
+        uri: str,
+        dest_path: Union[Path, str],
+        extra: Optional[Dict[str, Any]] = None
+    ) -> Tuple[str, defer.Deferred]:
+        if extra is None:
+            extra = {}
         uri_parsed = urlparse(uri, 'http')
         if dest_path:
             dest_path = Path(dest_path)
@@ -171,18 +186,18 @@
         check_certificate = self.host.memory.getParamA(
             "check_certificate", "Connection", profile_key=client.profile)
         if not check_certificate:
-            options['ignore_tls_errors'] = True
+            extra['ignore_tls_errors'] = True
             log.warning(
                 _("certificate check disabled for download, this is dangerous!"))
 
         try:
-            callback = self._download_callbacks[uri_parsed.scheme]
+            callback = self._scheme_callbacks[uri_parsed.scheme]
         except KeyError:
             raise exceptions.NotFound(f"Can't find any handler for uri {uri}")
         else:
             try:
                 progress_id, download_d = await callback(
-                    client, uri_parsed, dest_path, options)
+                    client, uri_parsed, dest_path, extra)
             except Exception as e:
                 log.warning(_(
                     "Can't download URI {uri}: {reason}").format(
@@ -195,11 +210,96 @@
             download_d.addCallback(lambda __: dest_path)
             return progress_id, download_d
 
-    def registerScheme(self, scheme, download_cb):
+
+    async def download(
+        self,
+        client: SatXMPPEntity,
+        attachment: Dict[str, Any],
+        dest_path: Union[Path, str],
+        extra: Optional[Dict[str, Any]] = None
+    ) -> Tuple[str, defer.Deferred]:
+        """Download a file from URI using suitable method
+
+        @param uri: URI to the file to download
+        @param dest_path: where the file must be downloaded
+            if empty string, the file will be stored in local path
+        @param extra: options depending on scheme handler
+            Some common options:
+                - ignore_tls_errors(bool): True to ignore SSL/TLS certificate verification
+                  used only if HTTPS transport is needed
+        @return: ``progress_id`` and a Deferred which fire download URL when download is
+            finished.
+            ``progress_id`` can be empty string if the file already exist and is not
+            downloaded again (can happen if cache is used with empty ``dest_path``).
+        """
+        uri = attachment.get("uri")
+        if uri:
+            return await self.download_uri(client, uri, dest_path, extra)
+        else:
+            for source in attachment.get("sources", []):
+                source_type = source.get("type")
+                if not source_type:
+                    log.warning(
+                        "source type is missing for source: {source}\nattachment: "
+                        f"{attachment}"
+                    )
+                    continue
+                try:
+                    cb = self._download_callbacks[source_type]
+                except KeyError:
+                    log.warning(
+                        f"no source handler registered for {source_type!r}"
+                    )
+                else:
+                    try:
+                        return await cb(client, attachment, source, dest_path, extra)
+                    except exceptions.CancelError as e:
+                        # the handler can't or doesn't want to handle this source
+                        log.debug(
+                            f"Following source handling by {cb} has been cancelled ({e}):"
+                            f"{source}"
+                        )
+
+        log.warning(
+            "no source could be handled, we can't download the attachment:\n"
+            f"{attachment}"
+        )
+        raise exceptions.FeatureNotFound("no handler could manage the attachment")
+
+    def register_download_handler(
+        self,
+        source_type: str,
+        callback: Callable[
+            [
+                SatXMPPEntity, Dict[str, Any], Dict[str, Any], Union[str, Path],
+                Dict[str, Any]
+            ],
+            Tuple[str, defer.Deferred]
+        ]
+    ) -> None:
+        """Register a handler to manage a type of attachment source
+
+        @param source_type: ``type`` of source handled
+            This is usually the namespace of the protocol used
+        @param callback: method to call to manage the source.
+            Call arguments are the same as for [download], with an extra ``source`` dict
+            which is used just after ``attachment`` to give a quick reference to the
+            source used.
+            The callabke must return a tuple with:
+                - progress ID
+                - a Deferred which fire whant the file is fully downloaded
+        """
+        if source_type is self._download_callbacks:
+            raise exceptions.ConflictError(
+                f"The is already a callback registered for source type {source_type!r}"
+            )
+        self._download_callbacks[source_type] = callback
+
+    def register_scheme(self, scheme: str, download_cb: Callable) -> None:
         """Register an URI scheme handler
 
-        @param scheme(unicode): URI scheme this callback is handling
-        @param download_cb(callable): callback to download a file
+        @param scheme: URI scheme this callback is handling
+        @param download_cb: callback to download a file
             arguments are:
                 - (SatXMPPClient) client
                 - (urllib.parse.SplitResult) parsed URI
@@ -208,19 +308,19 @@
             must return a tuple with progress_id and a Deferred which fire when download
             is finished
         """
-        if scheme in self._download_callbacks:
+        if scheme in self._scheme_callbacks:
             raise exceptions.ConflictError(
                 f"A method with scheme {scheme!r} is already registered"
             )
-        self._download_callbacks[scheme] = download_cb
+        self._scheme_callbacks[scheme] = download_cb
 
     def unregister(self, scheme):
         try:
-            del self._download_callbacks[scheme]
+            del self._scheme_callbacks[scheme]
         except KeyError:
             raise exceptions.NotFound(f"No callback registered for scheme {scheme!r}")
 
-    def errbackDownload(self, file_obj, download_d, resp):
+    def errback_download(self, file_obj, download_d, resp):
         """Set file_obj and download deferred appropriatly after a network error
 
         @param file_obj(SatFile): file where the download must be done
@@ -231,7 +331,7 @@
         file_obj.close(error=msg)
         download_d.errback(exceptions.NetworkError(msg))
 
-    async def downloadHTTP(self, client, uri_parsed, dest_path, options):
+    async def download_http(self, client, uri_parsed, dest_path, options):
         url = uri_parsed.geturl()
 
         if options.get('ignore_tls_errors', False):
@@ -264,5 +364,5 @@
             d.addBoth(lambda _: file_obj.close())
         else:
             d = defer.Deferred()
-            self.errbackDownload(file_obj, d, resp)
+            self.errback_download(file_obj, d, resp)
         return progress_id, d