Mercurial > libervia-backend
diff sat/plugins/plugin_misc_download.py @ 3088:d1464548055a
plugin file download: meta plugin to handle file download:
- first code in backend to use async/await Python syntax \o/
- plugin with file upload
- URL schemes can be registered
- `http` and `https` schemes are handled by default
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 20 Dec 2019 12:28:04 +0100 |
parents | |
children | 9d0df638c8b4 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/plugins/plugin_misc_download.py Fri Dec 20 12:28:04 2019 +0100 @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 + +# SAT plugin for downloading files +# Copyright (C) 2009-2019 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from pathlib import Path +from urllib.parse import urlparse +import treq +from twisted.internet import defer +from twisted.words.protocols.jabber import error as jabber_error +from sat.core.i18n import _, D_ +from sat.core.constants import Const as C +from sat.core.log import getLogger +from sat.core import exceptions +from sat.tools import xml_tools +from sat.tools.common import data_format +from sat.tools import stream + +log = getLogger(__name__) + + +PLUGIN_INFO = { + C.PI_NAME: "File Download", + C.PI_IMPORT_NAME: "DOWNLOAD", + C.PI_TYPE: C.PLUG_TYPE_MISC, + C.PI_MAIN: "DownloadPlugin", + C.PI_HANDLER: "no", + C.PI_DESCRIPTION: _("""File download management"""), +} + + +class DownloadPlugin(object): + + def __init__(self, host): + log.info(_("plugin Download initialization")) + self.host = host + host.bridge.addMethod( + "fileDownload", + ".plugin", + in_sign="ssss", + out_sign="a{ss}", + method=self._fileDownload, + async_=True, + ) + host.bridge.addMethod( + "fileDownloadComplete", + ".plugin", + in_sign="ssss", + out_sign="s", + method=self._fileDownloadComplete, + async_=True, + ) + self._download_callbacks = {} + self.registerScheme('http', self.downloadHTTP) + self.registerScheme('https', self.downloadHTTP) + + def _fileDownload(self, uri, dest_path, options_s, profile): + client = self.host.getClient(profile) + options = data_format.deserialise(options_s) + + return defer.ensureDeferred(self.fileDownload( + client, uri, Path(dest_path), options + )) + + async def fileDownload(self, client, uri, dest_path, options=None): + """Send a file using best available method + + parameters are the same as for [download] + @return (dict): action dictionary, with progress id in case of success, else xmlui + message + """ + try: + progress_id, __ = await self.download(client, uri, dest_path, options) + except Exception as e: + if (isinstance(e, jabber_error.StanzaError) + and e.condition == 'not-acceptable'): + reason = e.text + else: + reason = str(e) + msg = D_("Can't download file: {reason}").format(reason=reason) + log.warning(msg) + return { + "xmlui": xml_tools.note( + msg, D_("Can't download file"), C.XMLUI_DATA_LVL_WARNING + ).toXml() + } + else: + return {"progress": progress_id} + + def _fileDownloadComplete(self, uri, dest_path, options_s, profile): + client = self.host.getClient(profile) + options = data_format.deserialise(options_s) + + d = defer.ensureDeferred(self.fileDownloadComplete( + client, uri, Path(dest_path), options + )) + d.addCallback(lambda path: str(path)) + return d + + async def fileDownloadComplete(self, client, uri, dest_path, options=None): + """Helper method to fully download a file and return its path + + parameters are the same as for [download] + @return (str): path to the downloaded file + """ + __, download_d = await self.download(client, uri, dest_path, options) + await download_d + return dest_path + + async def download(self, client, uri, dest_path, options=None): + """Send a file using best available method + + @param uri(str): URI to the file to download + @param dest_path(str, Path): where the file must be downloaded + @param options(dict, None): options depending on scheme handler + Some common options: + - ignore_tls_errors(bool): True to ignore SSL/TLS certificate verification + used only if HTTPS transport is needed + @return (tuple[unicode,D(unicode)]): progress_id and a Deferred which fire + download URL when download is finished + """ + if options is None: + options = {} + + dest_path = Path(dest_path) + uri_parsed = urlparse(uri, 'http') + try: + callback = self._download_callbacks[uri_parsed.scheme] + except KeyError: + raise exceptions.NotFound(f"Can't find any handler for uri {uri}") + else: + return await callback(client, uri_parsed, dest_path, options) + + def registerScheme(self, scheme, download_cb): + """Register an URI scheme handler + + @param scheme(unicode): URI scheme this callback is handling + @param download_cb(callable): callback to download a file + arguments are: + - (SatXMPPClient) client + - (urllib.parse.SplitResult) parsed URI + - (Path) destination path where the file must be downloaded + - (dict) options + must return a tuple with progress_id and a Deferred which fire when download + is finished + """ + if scheme in self._download_callbacks: + raise exceptions.ConflictError( + f"A method with scheme {scheme!r} is already registered" + ) + self._download_callbacks[scheme] = download_cb + + def unregister(self, scheme): + try: + del self._download_callbacks[scheme] + except KeyError: + raise exceptions.NotFound(f"No callback registered for scheme {scheme!r}") + + async def downloadHTTP(self, client, uri_parsed, dest_path, options): + url = uri_parsed.geturl() + + head_data = await treq.head(url) + try: + content_length = int(head_data.headers.getRawHeaders('content-length')[0]) + except (KeyError, TypeError, IndexError): + content_length = None + log.debug(f"No content lenght found at {url}") + file_obj = stream.SatFile( + self.host, + client, + dest_path, + mode="wb", + size = content_length, + ) + + progress_id = file_obj.uid + + resp = await treq.get(url, unbuffered=True) + d = treq.collect(resp, file_obj.write) + d.addBoth(lambda _: file_obj.close()) + return progress_id, d