Mercurial > libervia-backend
diff sat/plugins/plugin_misc_download.py @ 3922:0ff265725489
plugin XEP-0447: handle attachment and download:
- plugin XEP-0447 can now be used in message attachments and to retrieve an attachment
- plugin attach: `attachment` being processed is added to `extra` so the handler can inspect it
- plugin attach: `size` is added to attachment
- plugin download: a whole attachment dict is now used in `download` and
`file_download`/`file_download_complete`. `download_uri` can be used as a shortcut when
just a URI is used. In addition to URI scheme handler, whole attachment handlers can now
be registered with `register_download_handler`
- plugin XEP-0363: `file_http_upload` `XEP-0363_upload_size` triggers have been renamed to
`XEP-0363_upload_pre_slot` and is now using a dict with arguments, allowing for the size
but also the filename to be modified, which is necessary for encryption (filename may
be hidden from URL this way).
- plugin XEP-0446: fix wrong element name
- plugin XEP-0447: source handler can now be registered (`url-data` is registered by
default)
- plugin XEP-0447: source parsing has been put in a separated `parse_sources_elt` method,
as it may be useful to do it independently (notably with XEP-0448)
- plugin XEP-0447: parse received message and complete attachments when suitable
- plugin XEP-0447: can now be used with message attachments
- plugin XEP-0447: can now be used with attachments download
- renamed `options` arguments to `extra` for consistency
- some style change (progressive move from legacy camelCase to PEP8 snake_case)
- some typing
rel 379
author | Goffi <goffi@goffi.org> |
---|---|
date | Thu, 06 Oct 2022 16:02:05 +0200 |
parents | be6d91572633 |
children | 524856bd7b19 |
line wrap: on
line diff
--- a/sat/plugins/plugin_misc_download.py Thu Oct 06 16:02:05 2022 +0200 +++ b/sat/plugins/plugin_misc_download.py Thu Oct 06 16:02:05 2022 +0200 @@ -16,19 +16,23 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +import hashlib from pathlib import Path -from urllib.parse import urlparse, unquote -import hashlib +from typing import Any, Dict, Optional, Union, Tuple, Callable +from urllib.parse import unquote, urlparse + import treq from twisted.internet import defer from twisted.words.protocols.jabber import error as jabber_error -from sat.core.i18n import _, D_ -from sat.core.constants import Const as C -from sat.core.log import getLogger + from sat.core import exceptions +from sat.core.constants import Const as C +from sat.core.core_types import SatXMPPEntity +from sat.core.i18n import D_, _ +from sat.core.log import getLogger from sat.tools import xml_tools +from sat.tools import stream from sat.tools.common import data_format -from sat.tools import stream from sat.tools.web import treq_client_no_ssl log = getLogger(__name__) @@ -38,6 +42,7 @@ C.PI_NAME: "File Download", C.PI_IMPORT_NAME: "DOWNLOAD", C.PI_TYPE: C.PLUG_TYPE_MISC, + C.PI_MODES: C.PLUG_MODE_BOTH, C.PI_MAIN: "DownloadPlugin", C.PI_HANDLER: "no", C.PI_DESCRIPTION: _("""File download management"""), @@ -53,7 +58,7 @@ "fileDownload", ".plugin", in_sign="ssss", - out_sign="a{ss}", + out_sign="s", method=self._fileDownload, async_=True, ) @@ -66,18 +71,29 @@ async_=True, ) self._download_callbacks = {} - self.registerScheme('http', self.downloadHTTP) - self.registerScheme('https', self.downloadHTTP) + self._scheme_callbacks = {} + self.register_scheme('http', self.download_http) + self.register_scheme('https', self.download_http) - def _fileDownload(self, uri, dest_path, options_s, profile): - client = self.host.getClient(profile) - options = data_format.deserialise(options_s) + def _fileDownload( + self, attachment_s: str, dest_path: str, extra_s: str, profile: str + ) -> defer.Deferred: + d = defer.ensureDeferred(self.file_download( + self.host.getClient(profile), + data_format.deserialise(attachment_s), + Path(dest_path), + data_format.deserialise(extra_s) + )) + d.addCallback(lambda ret: data_format.serialise(ret)) + return d - return defer.ensureDeferred(self.fileDownload( - client, uri, Path(dest_path), options - )) - - async def fileDownload(self, client, uri, dest_path, options=None): + async def file_download( + self, + client: SatXMPPEntity, + attachment: Dict[str, Any], + dest_path: Path, + extra: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Download a file using best available method parameters are the same as for [download] @@ -85,7 +101,7 @@ message """ try: - progress_id, __ = await self.download(client, uri, dest_path, options) + progress_id, __ = await self.download(client, attachment, dest_path, extra) except Exception as e: if (isinstance(e, jabber_error.StanzaError) and e.condition == 'not-acceptable'): @@ -102,45 +118,44 @@ else: return {"progress": progress_id} - def _fileDownloadComplete(self, uri, dest_path, options_s, profile): - client = self.host.getClient(profile) - options = data_format.deserialise(options_s) - - d = defer.ensureDeferred(self.fileDownloadComplete( - client, uri, dest_path, options + def _fileDownloadComplete( + self, attachment_s: str, dest_path: str, extra_s: str, profile: str + ) -> defer.Deferred: + d = defer.ensureDeferred(self.file_download_complete( + self.host.getClient(profile), + data_format.deserialise(attachment_s), + Path(dest_path), + data_format.deserialise(extra_s) )) d.addCallback(lambda path: str(path)) return d - async def fileDownloadComplete(self, client, uri, dest_path, options=None): + async def file_download_complete( + self, + client: SatXMPPEntity, + attachment: Dict[str, Any], + dest_path: Path, + extra: Optional[Dict[str, Any]] = None + ) -> str: """Helper method to fully download a file and return its path parameters are the same as for [download] @return (str): path to the downloaded file use empty string to store the file in cache """ - __, download_d = await self.download(client, uri, dest_path, options) + __, download_d = await self.download(client, attachment, dest_path, extra) dest_path = await download_d return dest_path - async def download(self, client, uri, dest_path, options=None): - """Send a file using best available method - - @param uri(str): URI to the file to download - @param dest_path(str, Path): where the file must be downloaded - if empty string, the file will be stored in local path - @param options(dict, None): options depending on scheme handler - Some common options: - - ignore_tls_errors(bool): True to ignore SSL/TLS certificate verification - used only if HTTPS transport is needed - @return (tuple[unicode,D(unicode)]): progress_id and a Deferred which fire - download URL when download is finished - progress_id can be empty string if the file already exist and is not - downloaded again (can happen if cache is used with empty dest_path) - """ - if options is None: - options = {} - + async def download_uri( + self, + client: SatXMPPEntity, + uri: str, + dest_path: Union[Path, str], + extra: Optional[Dict[str, Any]] = None + ) -> Tuple[str, defer.Deferred]: + if extra is None: + extra = {} uri_parsed = urlparse(uri, 'http') if dest_path: dest_path = Path(dest_path) @@ -171,18 +186,18 @@ check_certificate = self.host.memory.getParamA( "check_certificate", "Connection", profile_key=client.profile) if not check_certificate: - options['ignore_tls_errors'] = True + extra['ignore_tls_errors'] = True log.warning( _("certificate check disabled for download, this is dangerous!")) try: - callback = self._download_callbacks[uri_parsed.scheme] + callback = self._scheme_callbacks[uri_parsed.scheme] except KeyError: raise exceptions.NotFound(f"Can't find any handler for uri {uri}") else: try: progress_id, download_d = await callback( - client, uri_parsed, dest_path, options) + client, uri_parsed, dest_path, extra) except Exception as e: log.warning(_( "Can't download URI {uri}: {reason}").format( @@ -195,11 +210,96 @@ download_d.addCallback(lambda __: dest_path) return progress_id, download_d - def registerScheme(self, scheme, download_cb): + + async def download( + self, + client: SatXMPPEntity, + attachment: Dict[str, Any], + dest_path: Union[Path, str], + extra: Optional[Dict[str, Any]] = None + ) -> Tuple[str, defer.Deferred]: + """Download a file from URI using suitable method + + @param uri: URI to the file to download + @param dest_path: where the file must be downloaded + if empty string, the file will be stored in local path + @param extra: options depending on scheme handler + Some common options: + - ignore_tls_errors(bool): True to ignore SSL/TLS certificate verification + used only if HTTPS transport is needed + @return: ``progress_id`` and a Deferred which fire download URL when download is + finished. + ``progress_id`` can be empty string if the file already exist and is not + downloaded again (can happen if cache is used with empty ``dest_path``). + """ + uri = attachment.get("uri") + if uri: + return await self.download_uri(client, uri, dest_path, extra) + else: + for source in attachment.get("sources", []): + source_type = source.get("type") + if not source_type: + log.warning( + "source type is missing for source: {source}\nattachment: " + f"{attachment}" + ) + continue + try: + cb = self._download_callbacks[source_type] + except KeyError: + log.warning( + f"no source handler registered for {source_type!r}" + ) + else: + try: + return await cb(client, attachment, source, dest_path, extra) + except exceptions.CancelError as e: + # the handler can't or doesn't want to handle this source + log.debug( + f"Following source handling by {cb} has been cancelled ({e}):" + f"{source}" + ) + + log.warning( + "no source could be handled, we can't download the attachment:\n" + f"{attachment}" + ) + raise exceptions.FeatureNotFound("no handler could manage the attachment") + + def register_download_handler( + self, + source_type: str, + callback: Callable[ + [ + SatXMPPEntity, Dict[str, Any], Dict[str, Any], Union[str, Path], + Dict[str, Any] + ], + Tuple[str, defer.Deferred] + ] + ) -> None: + """Register a handler to manage a type of attachment source + + @param source_type: ``type`` of source handled + This is usually the namespace of the protocol used + @param callback: method to call to manage the source. + Call arguments are the same as for [download], with an extra ``source`` dict + which is used just after ``attachment`` to give a quick reference to the + source used. + The callabke must return a tuple with: + - progress ID + - a Deferred which fire whant the file is fully downloaded + """ + if source_type is self._download_callbacks: + raise exceptions.ConflictError( + f"The is already a callback registered for source type {source_type!r}" + ) + self._download_callbacks[source_type] = callback + + def register_scheme(self, scheme: str, download_cb: Callable) -> None: """Register an URI scheme handler - @param scheme(unicode): URI scheme this callback is handling - @param download_cb(callable): callback to download a file + @param scheme: URI scheme this callback is handling + @param download_cb: callback to download a file arguments are: - (SatXMPPClient) client - (urllib.parse.SplitResult) parsed URI @@ -208,19 +308,19 @@ must return a tuple with progress_id and a Deferred which fire when download is finished """ - if scheme in self._download_callbacks: + if scheme in self._scheme_callbacks: raise exceptions.ConflictError( f"A method with scheme {scheme!r} is already registered" ) - self._download_callbacks[scheme] = download_cb + self._scheme_callbacks[scheme] = download_cb def unregister(self, scheme): try: - del self._download_callbacks[scheme] + del self._scheme_callbacks[scheme] except KeyError: raise exceptions.NotFound(f"No callback registered for scheme {scheme!r}") - def errbackDownload(self, file_obj, download_d, resp): + def errback_download(self, file_obj, download_d, resp): """Set file_obj and download deferred appropriatly after a network error @param file_obj(SatFile): file where the download must be done @@ -231,7 +331,7 @@ file_obj.close(error=msg) download_d.errback(exceptions.NetworkError(msg)) - async def downloadHTTP(self, client, uri_parsed, dest_path, options): + async def download_http(self, client, uri_parsed, dest_path, options): url = uri_parsed.geturl() if options.get('ignore_tls_errors', False): @@ -264,5 +364,5 @@ d.addBoth(lambda _: file_obj.close()) else: d = defer.Deferred() - self.errbackDownload(file_obj, d, resp) + self.errback_download(file_obj, d, resp) return progress_id, d