# HG changeset patch # User Goffi # Date 1665064925 -7200 # Node ID 0ff2657254893f249bae27bb5b1c1740a5df4404 # Parent cc27052257781f8fd308613861ee4f0501319b9e plugin XEP-0447: handle attachment and download: - plugin XEP-0447 can now be used in message attachments and to retrieve an attachment - plugin attach: `attachment` being processed is added to `extra` so the handler can inspect it - plugin attach: `size` is added to attachment - plugin download: a whole attachment dict is now used in `download` and `file_download`/`file_download_complete`. `download_uri` can be used as a shortcut when just a URI is used. In addition to URI scheme handler, whole attachment handlers can now be registered with `register_download_handler` - plugin XEP-0363: `file_http_upload` `XEP-0363_upload_size` triggers have been renamed to `XEP-0363_upload_pre_slot` and is now using a dict with arguments, allowing for the size but also the filename to be modified, which is necessary for encryption (filename may be hidden from URL this way). - plugin XEP-0446: fix wrong element name - plugin XEP-0447: source handler can now be registered (`url-data` is registered by default) - plugin XEP-0447: source parsing has been put in a separated `parse_sources_elt` method, as it may be useful to do it independently (notably with XEP-0448) - plugin XEP-0447: parse received message and complete attachments when suitable - plugin XEP-0447: can now be used with message attachments - plugin XEP-0447: can now be used with attachments download - renamed `options` arguments to `extra` for consistency - some style change (progressive move from legacy camelCase to PEP8 snake_case) - some typing rel 379 diff -r cc2705225778 -r 0ff265725489 sat/plugins/plugin_blog_import.py --- a/sat/plugins/plugin_blog_import.py Thu Oct 06 16:02:05 2022 +0200 +++ b/sat/plugins/plugin_blog_import.py Thu Oct 06 16:02:05 2022 +0200 @@ -307,7 +307,7 @@ "%", "_" ) # FIXME: tmp workaround for a bug in prosody http upload __, download_d = yield self._u.upload( - client, tmp_file, filename, options=upload_options + client, tmp_file, filename, extra=upload_options ) download_url = yield download_d except Exception as e: diff -r cc2705225778 -r 0ff265725489 sat/plugins/plugin_misc_attach.py --- a/sat/plugins/plugin_misc_attach.py Thu Oct 06 16:02:05 2022 +0200 +++ b/sat/plugins/plugin_misc_attach.py Thu Oct 06 16:02:05 2022 +0200 @@ -16,15 +16,19 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from pathlib import Path from collections import namedtuple -from twisted.internet import defer import mimetypes -import tempfile +from pathlib import Path import shutil -from sat.core.i18n import _ +import tempfile +from typing import Callable, Optional + +from twisted.internet import defer + from sat.core import exceptions from sat.core.constants import Const as C +from sat.core.core_types import SatXMPPEntity +from sat.core.i18n import _ from sat.core.log import getLogger from sat.tools import utils from sat.tools import image @@ -37,6 +41,7 @@ C.PI_NAME: "File Attach", C.PI_IMPORT_NAME: "ATTACH", C.PI_TYPE: C.PLUG_TYPE_MISC, + C.PI_MODES: C.PLUG_MODE_BOTH, C.PI_DEPENDENCIES: ["UPLOAD"], C.PI_MAIN: "AttachPlugin", C.PI_HANDLER: "no", @@ -152,7 +157,12 @@ return data - async def uploadFiles(self, client, data, upload_cb=None): + async def upload_files( + self, + client: SatXMPPEntity, + data: dict, + upload_cb: Optional[Callable] = None + ): """Upload file, and update attachments invalid attachments will be removed @@ -160,7 +170,7 @@ @param data(dict): message data @param upload_cb(coroutine, Deferred, None): method to use for upload if None, upload method from UPLOAD plugin will be used. - Otherwise, following kwargs will be use with the cb: + Otherwise, following kwargs will be used with the cb: - client - filepath - filename @@ -179,7 +189,7 @@ for attachment in attachments: try: - # we pop path because we don't want it to be stored, as the image can be + # we pop path because we don't want it to be stored, as the file can be # only in a temporary location path = Path(attachment.pop("path")) except KeyError: @@ -198,14 +208,18 @@ except KeyError: name = attachment["name"] = path.name - options = {} + attachment["size"] = path.stat().st_size + + extra = { + "attachment": attachment + } progress_id = attachment.pop("progress_id", None) if progress_id: - options["progress_id"] = progress_id + extra["progress_id"] = progress_id check_certificate = self.host.memory.getParamA( "check_certificate", "Connection", profile_key=client.profile) if not check_certificate: - options['ignore_tls_errors'] = True + extra['ignore_tls_errors'] = True log.warning( _("certificate check disabled for upload, this is dangerous!")) @@ -213,7 +227,7 @@ client=client, filepath=path, filename=name, - options=options, + extra=extra, ) uploads_d.append(upload_d) @@ -246,9 +260,11 @@ return True async def defaultAttach(self, client, data): - await self.uploadFiles(client, data) + await self.upload_files(client, data) # TODO: handle xhtml-im - body_elt = next(data["xml"].elements(C.NS_CLIENT, "body")) + body_elt = data["xml"].body + if body_elt is None: + body_elt = data["xml"].addElement("body") attachments = data["extra"][C.MESS_KEY_ATTACHMENTS] if attachments: body_links = '\n'.join(a['url'] for a in attachments) diff -r cc2705225778 -r 0ff265725489 sat/plugins/plugin_misc_download.py --- a/sat/plugins/plugin_misc_download.py Thu Oct 06 16:02:05 2022 +0200 +++ b/sat/plugins/plugin_misc_download.py Thu Oct 06 16:02:05 2022 +0200 @@ -16,19 +16,23 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import hashlib from pathlib import Path -from urllib.parse import urlparse, unquote -import hashlib +from typing import Any, Dict, Optional, Union, Tuple, Callable +from urllib.parse import unquote, urlparse + import treq from twisted.internet import defer from twisted.words.protocols.jabber import error as jabber_error -from sat.core.i18n import _, D_ -from sat.core.constants import Const as C -from sat.core.log import getLogger + from sat.core import exceptions +from sat.core.constants import Const as C +from sat.core.core_types import SatXMPPEntity +from sat.core.i18n import D_, _ +from sat.core.log import getLogger from sat.tools import xml_tools +from sat.tools import stream from sat.tools.common import data_format -from sat.tools import stream from sat.tools.web import treq_client_no_ssl log = getLogger(__name__) @@ -38,6 +42,7 @@ C.PI_NAME: "File Download", C.PI_IMPORT_NAME: "DOWNLOAD", C.PI_TYPE: C.PLUG_TYPE_MISC, + C.PI_MODES: C.PLUG_MODE_BOTH, C.PI_MAIN: "DownloadPlugin", C.PI_HANDLER: "no", C.PI_DESCRIPTION: _("""File download management"""), @@ -53,7 +58,7 @@ "fileDownload", ".plugin", in_sign="ssss", - out_sign="a{ss}", + out_sign="s", method=self._fileDownload, async_=True, ) @@ -66,18 +71,29 @@ async_=True, ) self._download_callbacks = {} - self.registerScheme('http', self.downloadHTTP) - self.registerScheme('https', self.downloadHTTP) + self._scheme_callbacks = {} + self.register_scheme('http', self.download_http) + self.register_scheme('https', self.download_http) - def _fileDownload(self, uri, dest_path, options_s, profile): - client = self.host.getClient(profile) - options = data_format.deserialise(options_s) + def _fileDownload( + self, attachment_s: str, dest_path: str, extra_s: str, profile: str + ) -> defer.Deferred: + d = defer.ensureDeferred(self.file_download( + self.host.getClient(profile), + data_format.deserialise(attachment_s), + Path(dest_path), + data_format.deserialise(extra_s) + )) + d.addCallback(lambda ret: data_format.serialise(ret)) + return d - return defer.ensureDeferred(self.fileDownload( - client, uri, Path(dest_path), options - )) - - async def fileDownload(self, client, uri, dest_path, options=None): + async def file_download( + self, + client: SatXMPPEntity, + attachment: Dict[str, Any], + dest_path: Path, + extra: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Download a file using best available method parameters are the same as for [download] @@ -85,7 +101,7 @@ message """ try: - progress_id, __ = await self.download(client, uri, dest_path, options) + progress_id, __ = await self.download(client, attachment, dest_path, extra) except Exception as e: if (isinstance(e, jabber_error.StanzaError) and e.condition == 'not-acceptable'): @@ -102,45 +118,44 @@ else: return {"progress": progress_id} - def _fileDownloadComplete(self, uri, dest_path, options_s, profile): - client = self.host.getClient(profile) - options = data_format.deserialise(options_s) - - d = defer.ensureDeferred(self.fileDownloadComplete( - client, uri, dest_path, options + def _fileDownloadComplete( + self, attachment_s: str, dest_path: str, extra_s: str, profile: str + ) -> defer.Deferred: + d = defer.ensureDeferred(self.file_download_complete( + self.host.getClient(profile), + data_format.deserialise(attachment_s), + Path(dest_path), + data_format.deserialise(extra_s) )) d.addCallback(lambda path: str(path)) return d - async def fileDownloadComplete(self, client, uri, dest_path, options=None): + async def file_download_complete( + self, + client: SatXMPPEntity, + attachment: Dict[str, Any], + dest_path: Path, + extra: Optional[Dict[str, Any]] = None + ) -> str: """Helper method to fully download a file and return its path parameters are the same as for [download] @return (str): path to the downloaded file use empty string to store the file in cache """ - __, download_d = await self.download(client, uri, dest_path, options) + __, download_d = await self.download(client, attachment, dest_path, extra) dest_path = await download_d return dest_path - async def download(self, client, uri, dest_path, options=None): - """Send a file using best available method - - @param uri(str): URI to the file to download - @param dest_path(str, Path): where the file must be downloaded - if empty string, the file will be stored in local path - @param options(dict, None): options depending on scheme handler - Some common options: - - ignore_tls_errors(bool): True to ignore SSL/TLS certificate verification - used only if HTTPS transport is needed - @return (tuple[unicode,D(unicode)]): progress_id and a Deferred which fire - download URL when download is finished - progress_id can be empty string if the file already exist and is not - downloaded again (can happen if cache is used with empty dest_path) - """ - if options is None: - options = {} - + async def download_uri( + self, + client: SatXMPPEntity, + uri: str, + dest_path: Union[Path, str], + extra: Optional[Dict[str, Any]] = None + ) -> Tuple[str, defer.Deferred]: + if extra is None: + extra = {} uri_parsed = urlparse(uri, 'http') if dest_path: dest_path = Path(dest_path) @@ -171,18 +186,18 @@ check_certificate = self.host.memory.getParamA( "check_certificate", "Connection", profile_key=client.profile) if not check_certificate: - options['ignore_tls_errors'] = True + extra['ignore_tls_errors'] = True log.warning( _("certificate check disabled for download, this is dangerous!")) try: - callback = self._download_callbacks[uri_parsed.scheme] + callback = self._scheme_callbacks[uri_parsed.scheme] except KeyError: raise exceptions.NotFound(f"Can't find any handler for uri {uri}") else: try: progress_id, download_d = await callback( - client, uri_parsed, dest_path, options) + client, uri_parsed, dest_path, extra) except Exception as e: log.warning(_( "Can't download URI {uri}: {reason}").format( @@ -195,11 +210,96 @@ download_d.addCallback(lambda __: dest_path) return progress_id, download_d - def registerScheme(self, scheme, download_cb): + + async def download( + self, + client: SatXMPPEntity, + attachment: Dict[str, Any], + dest_path: Union[Path, str], + extra: Optional[Dict[str, Any]] = None + ) -> Tuple[str, defer.Deferred]: + """Download a file from URI using suitable method + + @param uri: URI to the file to download + @param dest_path: where the file must be downloaded + if empty string, the file will be stored in local path + @param extra: options depending on scheme handler + Some common options: + - ignore_tls_errors(bool): True to ignore SSL/TLS certificate verification + used only if HTTPS transport is needed + @return: ``progress_id`` and a Deferred which fire download URL when download is + finished. + ``progress_id`` can be empty string if the file already exist and is not + downloaded again (can happen if cache is used with empty ``dest_path``). + """ + uri = attachment.get("uri") + if uri: + return await self.download_uri(client, uri, dest_path, extra) + else: + for source in attachment.get("sources", []): + source_type = source.get("type") + if not source_type: + log.warning( + "source type is missing for source: {source}\nattachment: " + f"{attachment}" + ) + continue + try: + cb = self._download_callbacks[source_type] + except KeyError: + log.warning( + f"no source handler registered for {source_type!r}" + ) + else: + try: + return await cb(client, attachment, source, dest_path, extra) + except exceptions.CancelError as e: + # the handler can't or doesn't want to handle this source + log.debug( + f"Following source handling by {cb} has been cancelled ({e}):" + f"{source}" + ) + + log.warning( + "no source could be handled, we can't download the attachment:\n" + f"{attachment}" + ) + raise exceptions.FeatureNotFound("no handler could manage the attachment") + + def register_download_handler( + self, + source_type: str, + callback: Callable[ + [ + SatXMPPEntity, Dict[str, Any], Dict[str, Any], Union[str, Path], + Dict[str, Any] + ], + Tuple[str, defer.Deferred] + ] + ) -> None: + """Register a handler to manage a type of attachment source + + @param source_type: ``type`` of source handled + This is usually the namespace of the protocol used + @param callback: method to call to manage the source. + Call arguments are the same as for [download], with an extra ``source`` dict + which is used just after ``attachment`` to give a quick reference to the + source used. + The callabke must return a tuple with: + - progress ID + - a Deferred which fire whant the file is fully downloaded + """ + if source_type is self._download_callbacks: + raise exceptions.ConflictError( + f"The is already a callback registered for source type {source_type!r}" + ) + self._download_callbacks[source_type] = callback + + def register_scheme(self, scheme: str, download_cb: Callable) -> None: """Register an URI scheme handler - @param scheme(unicode): URI scheme this callback is handling - @param download_cb(callable): callback to download a file + @param scheme: URI scheme this callback is handling + @param download_cb: callback to download a file arguments are: - (SatXMPPClient) client - (urllib.parse.SplitResult) parsed URI @@ -208,19 +308,19 @@ must return a tuple with progress_id and a Deferred which fire when download is finished """ - if scheme in self._download_callbacks: + if scheme in self._scheme_callbacks: raise exceptions.ConflictError( f"A method with scheme {scheme!r} is already registered" ) - self._download_callbacks[scheme] = download_cb + self._scheme_callbacks[scheme] = download_cb def unregister(self, scheme): try: - del self._download_callbacks[scheme] + del self._scheme_callbacks[scheme] except KeyError: raise exceptions.NotFound(f"No callback registered for scheme {scheme!r}") - def errbackDownload(self, file_obj, download_d, resp): + def errback_download(self, file_obj, download_d, resp): """Set file_obj and download deferred appropriatly after a network error @param file_obj(SatFile): file where the download must be done @@ -231,7 +331,7 @@ file_obj.close(error=msg) download_d.errback(exceptions.NetworkError(msg)) - async def downloadHTTP(self, client, uri_parsed, dest_path, options): + async def download_http(self, client, uri_parsed, dest_path, options): url = uri_parsed.geturl() if options.get('ignore_tls_errors', False): @@ -264,5 +364,5 @@ d.addBoth(lambda _: file_obj.close()) else: d = defer.Deferred() - self.errbackDownload(file_obj, d, resp) + self.errback_download(file_obj, d, resp) return progress_id, d diff -r cc2705225778 -r 0ff265725489 sat/plugins/plugin_misc_upload.py --- a/sat/plugins/plugin_misc_upload.py Thu Oct 06 16:02:05 2022 +0200 +++ b/sat/plugins/plugin_misc_upload.py Thu Oct 06 16:02:05 2022 +0200 @@ -18,15 +18,20 @@ import os import os.path +from pathlib import Path +from typing import Optional, Tuple, Union + from twisted.internet import defer from twisted.words.protocols.jabber import jid from twisted.words.protocols.jabber import error as jabber_error -from sat.core.i18n import _, D_ + +from sat.core import exceptions from sat.core.constants import Const as C +from sat.core.core_types import SatXMPPEntity +from sat.core.i18n import D_, _ +from sat.core.log import getLogger +from sat.tools import xml_tools from sat.tools.common import data_format -from sat.core.log import getLogger -from sat.core import exceptions -from sat.tools import xml_tools log = getLogger(__name__) @@ -99,26 +104,32 @@ else: return {"progress": progress_id} - async def upload(self, client, filepath, filename=None, upload_jid=None, - options=None): + async def upload( + self, + client: SatXMPPEntity, + filepath: Union[Path, str], + filename: Optional[str] = None, + upload_jid: Optional[jid.JID] = None, + extra: Optional[dict]=None + ) -> Tuple[str, defer.Deferred]: """Send a file using best available method - @param filepath(str): absolute path to the file - @param filename(None, unicode): name to use for the upload + @param filepath: absolute path to the file + @param filename: name to use for the upload None to use basename of the path - @param upload_jid(jid.JID, None): upload capable entity jid, + @param upload_jid: upload capable entity jid, or None to use autodetected, if possible - @param options(dict): option to use for the upload, may be: + @param extra: extra data/options to use for the upload, may be: - ignore_tls_errors(bool): True to ignore SSL/TLS certificate verification used only if HTTPS transport is needed - progress_id(str): id to use for progression if not specified, one will be generated @param profile: %(doc_profile)s - @return (tuple[unicode,D(unicode)]): progress_id and a Deferred which fire - download URL when upload is finished + @return: progress_id and a Deferred which fire download URL when upload is + finished """ - if options is None: - options = {} + if extra is None: + extra = {} if not os.path.isfile(filepath): raise exceptions.DataError("The given path doesn't link to a file") for method_name, available_cb, upload_cb, priority in self._upload_callbacks: @@ -132,7 +143,7 @@ "{name} method will be used to upload the file".format(name=method_name) ) progress_id, download_d = await upload_cb( - client, filepath, filename, upload_jid, options + client, filepath, filename, upload_jid, extra ) return progress_id, download_d diff -r cc2705225778 -r 0ff265725489 sat/plugins/plugin_sec_aesgcm.py --- a/sat/plugins/plugin_sec_aesgcm.py Thu Oct 06 16:02:05 2022 +0200 +++ b/sat/plugins/plugin_sec_aesgcm.py Thu Oct 06 16:02:05 2022 +0200 @@ -63,13 +63,13 @@ log.info(_("AESGCM plugin initialization")) self._http_upload = host.plugins['XEP-0363'] self._attach = host.plugins["ATTACH"] - host.plugins["DOWNLOAD"].registerScheme( + host.plugins["DOWNLOAD"].register_scheme( "aesgcm", self.download ) self._attach.register( self.canHandleAttachment, self.attach, encrypted=True) - host.trigger.add("XEP-0363_upload_size", self._uploadSizeTrigger) - host.trigger.add("XEP-0363_upload", self._uploadTrigger) + host.trigger.add("XEP-0363_upload_pre_slot", self._upload_pre_slot) + host.trigger.add("XEP-0363_upload", self._upload_trigger) host.trigger.add("messageReceived", self._messageReceivedTrigger) async def download(self, client, uri_parsed, dest_path, options): @@ -129,7 +129,7 @@ decryptor=decryptor)) else: d = defer.Deferred() - self.host.plugins["DOWNLOAD"].errbackDownload(file_obj, d, resp) + self.host.plugins["DOWNLOAD"].errback_download(file_obj, d, resp) return progress_id, d async def canHandleAttachment(self, client, data): @@ -140,13 +140,13 @@ else: return True - async def _uploadCb(self, client, filepath, filename, options): - options['encryption'] = C.ENC_AES_GCM - return await self._http_upload.fileHTTPUpload( + async def _upload_cb(self, client, filepath, filename, extra): + extra['encryption'] = C.ENC_AES_GCM + return await self._http_upload.file_http_upload( client=client, filepath=filepath, filename=filename, - options=options + extra=extra ) async def attach(self, client, data): @@ -160,14 +160,16 @@ if not data['message'] or data['message'] == {'': ''}: extra_attachments = attachments[1:] del attachments[1:] - await self._attach.uploadFiles(client, data, upload_cb=self._uploadCb) + await self._attach.upload_files(client, data, upload_cb=self._upload_cb) else: # we have a message, we must send first attachment separately extra_attachments = attachments[:] attachments.clear() del data["extra"][C.MESS_KEY_ATTACHMENTS] - body_elt = next(data["xml"].elements(C.NS_CLIENT, "body")) + body_elt = data["xml"].body + if body_elt is None: + body_elt = data["xml"].addElement("body") for attachment in attachments: body_elt.addContent(attachment["url"]) @@ -219,11 +221,11 @@ decrypted = decryptor.update(data) file_obj.write(decrypted) - def _uploadSizeTrigger(self, client, options, file_path, size, size_adjust): - if options.get('encryption') != C.ENC_AES_GCM: + def _upload_pre_slot(self, client, extra, file_metadata): + if extra.get('encryption') != C.ENC_AES_GCM: return True # the tag is appended to the file - size_adjust.append(16) + file_metadata["size"] += 16 return True def _encrypt(self, data, encryptor): @@ -239,8 +241,8 @@ # as we have already finalized, we can now send EOF return b'' - def _uploadTrigger(self, client, options, sat_file, file_producer, slot): - if options.get('encryption') != C.ENC_AES_GCM: + def _upload_trigger(self, client, extra, sat_file, file_producer, slot): + if extra.get('encryption') != C.ENC_AES_GCM: return True log.debug("encrypting file with AES-GCM") iv = secrets.token_bytes(12) @@ -255,7 +257,7 @@ # so we need to check with final data length to avoid a warning on close() sat_file.check_size_with_read = True - # file_producer get length directly from file, and this cause trouble has + # file_producer get length directly from file, and this cause trouble as # we have to change the size because of encryption. So we adapt it here, # else the producer would stop reading prematurely file_producer.length = sat_file.size diff -r cc2705225778 -r 0ff265725489 sat/plugins/plugin_xep_0363.py --- a/sat/plugins/plugin_xep_0363.py Thu Oct 06 16:02:05 2022 +0200 +++ b/sat/plugins/plugin_xep_0363.py Thu Oct 06 16:02:05 2022 +0200 @@ -16,25 +16,29 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import os.path +from dataclasses import dataclass import mimetypes -from typing import NamedTuple, Callable, Optional -from dataclasses import dataclass +import os.path +from pathlib import Path +from typing import Callable, NamedTuple, Optional, Tuple from urllib import parse -from wokkel import disco, iwokkel -from zope.interface import implementer -from twisted.words.protocols.jabber import jid, xmlstream, error -from twisted.words.xish import domish + from twisted.internet import reactor from twisted.internet import defer from twisted.web import client as http_client from twisted.web import http_headers -from sat.core.i18n import _ -from sat.core.xmpp import SatXMPPComponent +from twisted.words.protocols.jabber import error, jid, xmlstream +from twisted.words.xish import domish +from wokkel import disco, iwokkel +from zope.interface import implementer + +from sat.core import exceptions from sat.core.constants import Const as C +from sat.core.core_types import SatXMPPEntity +from sat.core.i18n import _ from sat.core.log import getLogger -from sat.core import exceptions -from sat.tools import web as sat_web, utils +from sat.core.xmpp import SatXMPPComponent +from sat.tools import utils, web as sat_web log = getLogger(__name__) @@ -87,7 +91,7 @@ ".plugin", in_sign="sssbs", out_sign="", - method=self._fileHTTPUpload, + method=self._file_http_upload, ) host.bridge.addMethod( "fileHTTPUploadGetSlot", @@ -98,7 +102,7 @@ async_=True, ) host.plugins["UPLOAD"].register( - "HTTP Upload", self.getHTTPUploadEntity, self.fileHTTPUpload + "HTTP Upload", self.getHTTPUploadEntity, self.file_http_upload ) # list of callbacks used when a request is done to a component self.handlers = [] @@ -151,58 +155,67 @@ return entity - def _fileHTTPUpload(self, filepath, filename="", upload_jid="", + def _file_http_upload(self, filepath, filename="", upload_jid="", ignore_tls_errors=False, profile=C.PROF_KEY_NONE): assert os.path.isabs(filepath) and os.path.isfile(filepath) client = self.host.getClient(profile) - progress_id_d, __ = defer.ensureDeferred(self.fileHTTPUpload( + return defer.ensureDeferred(self.file_http_upload( client, filepath, filename or None, jid.JID(upload_jid) if upload_jid else None, {"ignore_tls_errors": ignore_tls_errors}, )) - return progress_id_d - async def fileHTTPUpload( - self, client, filepath, filename=None, upload_jid=None, options=None): + async def file_http_upload( + self, + client: SatXMPPEntity, + filepath: Path, + filename: Optional[str] = None, + upload_jid: Optional[jid.JID] = None, + extra: Optional[dict] = None + ) -> Tuple[str, defer.Deferred]: """Upload a file through HTTP - @param filepath(str): absolute path of the file - @param filename(None, unicode): name to use for the upload + @param filepath: absolute path of the file + @param filename: name to use for the upload None to use basename of the path - @param upload_jid(jid.JID, None): upload capable entity jid, + @param upload_jid: upload capable entity jid, or None to use autodetected, if possible - @param options(dict): options where key can be: + @param extra: options where key can be: - ignore_tls_errors(bool): if True, SSL certificate will not be checked + - attachment(dict): file attachment data @param profile: %(doc_profile)s - @return (D(tuple[D(unicode), D(unicode)])): progress id and Deferred which fire - download URL + @return: progress id and Deferred which fire download URL """ - if options is None: - options = {} - ignore_tls_errors = options.get("ignore_tls_errors", False) - filename = filename or os.path.basename(filepath) - size = os.path.getsize(filepath) + if extra is None: + extra = {} + ignore_tls_errors = extra.get("ignore_tls_errors", False) + file_metadata = { + "filename": filename or os.path.basename(filepath), + "filepath": filepath, + "size": os.path.getsize(filepath), + } - size_adjust = [] - #: this trigger can be used to modify the requested size, it is notably useful - #: with encryption. The size_adjust is a list which can be filled by int to add - #: to the initial size + #: this trigger can be used to modify the filename or size requested when geting + #: the slot, it is notably useful with encryption. self.host.trigger.point( - "XEP-0363_upload_size", client, options, filepath, size, size_adjust, - triggers_no_cancel=True) - if size_adjust: - size = sum([size, *size_adjust]) + "XEP-0363_upload_pre_slot", client, extra, file_metadata, + triggers_no_cancel=True + ) try: - slot = await self.getSlot(client, filename, size, upload_jid=upload_jid) + slot = await self.getSlot( + client, file_metadata["filename"], file_metadata["size"], + upload_jid=upload_jid + ) except Exception as e: log.warning(_("Can't get upload slot: {reason}").format(reason=e)) raise e else: log.debug(f"Got upload slot: {slot}") sat_file = self.host.plugins["FILE"].File( - self.host, client, filepath, uid=options.get("progress_id"), size=size, + self.host, client, filepath, uid=extra.get("progress_id"), + size=file_metadata["size"], auto_end_signals=False ) progress_id = sat_file.uid @@ -223,7 +236,7 @@ await self.host.trigger.asyncPoint( - "XEP-0363_upload", client, options, sat_file, file_producer, slot, + "XEP-0363_upload", client, extra, sat_file, file_producer, slot, triggers_no_cancel=True) download_d = agent.request( @@ -233,8 +246,8 @@ file_producer, ) download_d.addCallbacks( - self._uploadCb, - self._uploadEb, + self._upload_cb, + self._upload_eb, (sat_file, slot), None, (sat_file,), @@ -242,7 +255,7 @@ return progress_id, download_d - def _uploadCb(self, __, sat_file, slot): + def _upload_cb(self, __, sat_file, slot): """Called once file is successfully uploaded @param sat_file(SatFile): file used for the upload @@ -253,7 +266,7 @@ sat_file.progressFinished({"url": slot.get}) return slot.get - def _uploadEb(self, failure_, sat_file): + def _upload_eb(self, failure_, sat_file): """Called on unsuccessful upload @param sat_file(SatFile): file used for the upload diff -r cc2705225778 -r 0ff265725489 sat/plugins/plugin_xep_0446.py --- a/sat/plugins/plugin_xep_0446.py Thu Oct 06 16:02:05 2022 +0200 +++ b/sat/plugins/plugin_xep_0446.py Thu Oct 06 16:02:05 2022 +0200 @@ -107,19 +107,19 @@ self, file_metadata_elt: domish.Element ) -> Dict[str, Any]: - """Parse element + """Parse element - @param file_metadata_elt: element + @param file_metadata_elt: element a parent element can also be used - @return: file-metadata data. It's a dict whose keys correspond to + @return: file metadata. It's a dict whose keys correspond to [get_file_metadata_elt] parameters - @raise exceptions.NotFound: no element has been found + @raise exceptions.NotFound: no element has been found """ - if file_metadata_elt.name != "file-metadata": + if file_metadata_elt.name != "file": try: file_metadata_elt = next( - file_metadata_elt.elements(NS_FILE_METADATA, "file-metadata") + file_metadata_elt.elements(NS_FILE_METADATA, "file") ) except StopIteration: raise exceptions.NotFound @@ -158,7 +158,7 @@ from sat.tools.xml_tools import pFmtElt log.warning("invalid element:\n{pFmtElt(file_metadata_elt)}") else: - data["file_hash"] = (algo, hash_.decode()) + data["file_hash"] = (algo, hash_) # TODO: thumbnails diff -r cc2705225778 -r 0ff265725489 sat/plugins/plugin_xep_0447.py --- a/sat/plugins/plugin_xep_0447.py Thu Oct 06 16:02:05 2022 +0200 +++ b/sat/plugins/plugin_xep_0447.py Thu Oct 06 16:02:05 2022 +0200 @@ -15,14 +15,23 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Dict, List, Tuple, Union, Any +from collections import namedtuple +from functools import partial +import mimetypes +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import treq +from twisted.internet import defer from twisted.words.xish import domish +from sat.core import exceptions from sat.core.constants import Const as C +from sat.core.core_types import SatXMPPEntity from sat.core.i18n import _ from sat.core.log import getLogger -from sat.core import exceptions +from sat.tools import stream +from sat.tools.web import treq_client_no_ssl log = getLogger(__name__) @@ -33,23 +42,130 @@ C.PI_TYPE: "XEP", C.PI_MODES: C.PLUG_MODE_BOTH, C.PI_PROTOCOLS: ["XEP-0447"], - C.PI_DEPENDENCIES: ["XEP-0103", "XEP-0446"], + C.PI_DEPENDENCIES: ["XEP-0103", "XEP-0334", "XEP-0446", "ATTACH", "DOWNLOAD"], + C.PI_RECOMMENDATIONS: ["XEP-0363"], C.PI_MAIN: "XEP_0447", C.PI_HANDLER: "no", C.PI_DESCRIPTION: _("""Implementation of XEP-0447 (Stateless File Sharing)"""), } NS_SFS = "urn:xmpp:sfs:0" +SourceHandler = namedtuple("SourceHandler", ["callback", "encrypted"]) class XEP_0447: namespace = NS_SFS def __init__(self, host): + self.host = host log.info(_("XEP-0447 (Stateless File Sharing) plugin initialization")) host.registerNamespace("sfs", NS_SFS) + self._sources_handlers = {} self._u = host.plugins["XEP-0103"] + self._hints = host.plugins["XEP-0334"] self._m = host.plugins["XEP-0446"] + self._http_upload = host.plugins.get("XEP-0363") + self._attach = host.plugins["ATTACH"] + self._attach.register( + self.can_handle_attachment, self.attach, priority=1000 + ) + self.register_source_handler( + self._u.namespace, "url-data", self._u.parse_url_data_elt + ) + host.plugins["DOWNLOAD"].register_download_handler(self._u.namespace, self.download) + host.trigger.add("messageReceived", self._message_received_trigger) + + def register_source_handler( + self, namespace: str, element_name: str, + callback: Callable[[domish.Element], Dict[str, Any]], + encrypted: bool = False + ) -> None: + """Register a handler for file source + + @param namespace: namespace of the element supported + @param element_name: name of the element supported + @param callback: method to call to parse the element + get the matching element as argument, must return the parsed data + @param encrypted: if True, the source is encrypted (the transmitting channel + should then be end2end encrypted to avoir leaking decrypting data to servers). + """ + key = (namespace, element_name) + if key in self._sources_handlers: + raise exceptions.ConflictError( + f"There is already a resource handler for namespace {namespace!r} and " + f"name {element_name!r}" + ) + self._sources_handlers[key] = SourceHandler(callback, encrypted) + + async def download( + self, + client: SatXMPPEntity, + attachment: Dict[str, Any], + source: Dict[str, Any], + dest_path: Union[Path, str], + extra: Optional[Dict[str, Any]] = None + ) -> Tuple[str, defer.Deferred]: + # TODO: handle url-data headers + if extra is None: + extra = {} + try: + download_url = source["url"] + except KeyError: + raise ValueError(f"{source} has missing URL") + + if extra.get('ignore_tls_errors', False): + log.warning( + "TLS certificate check disabled, this is highly insecure" + ) + treq_client = treq_client_no_ssl + else: + treq_client = treq + + try: + file_size = int(attachment["size"]) + except (KeyError, ValueError): + head_data = await treq_client.head(download_url) + file_size = int(head_data.headers.getRawHeaders('content-length')[0]) + + file_obj = stream.SatFile( + self.host, + client, + dest_path, + mode="wb", + size = file_size, + ) + + progress_id = file_obj.uid + + resp = await treq_client.get(download_url, unbuffered=True) + if resp.code == 200: + d = treq.collect(resp, file_obj.write) + d.addCallback(lambda __: file_obj.close()) + else: + d = defer.Deferred() + self.host.plugins["DOWNLOAD"].errback_download(file_obj, d, resp) + return progress_id, d + + async def can_handle_attachment(self, client, data): + if self._http_upload is None: + return False + try: + await self._http_upload.getHTTPUploadEntity(client) + except exceptions.NotFound: + return False + else: + return True + + def get_sources_elt( + self, + children: Optional[List[domish.Element]] = None + ) -> domish.Element: + """Generate element""" + sources_elt = domish.Element((NS_SFS, "sources")) + if children: + for child in children: + sources_elt.addChild(child) + return sources_elt def get_file_sharing_elt( self, @@ -75,6 +191,8 @@ file_sharing_elt = domish.Element((NS_SFS, "file-sharing")) if disposition is not None: file_sharing_elt["disposition"] = disposition + if media_type is None and name: + media_type = mimetypes.guess_type(name, strict=False)[0] file_sharing_elt.addChild( self._m.get_file_metadata_elt( name=name, @@ -89,7 +207,8 @@ thumbnail=thumbnail, ) ) - sources_elt = file_sharing_elt.addElement("sources") + sources_elt = self.get_sources_elt() + file_sharing_elt.addChild(sources_elt) for source_data in sources: if "url" in source_data: sources_elt.addChild( @@ -102,6 +221,42 @@ return file_sharing_elt + def parse_sources_elt( + self, + sources_elt: domish.Element + ) -> List[Dict[str, Any]]: + """Parse element + + @param sources_elt: element, or a direct parent element + @return: list of found sources data + @raise: exceptions.NotFound: Can't find element + """ + if sources_elt.name != "sources" or sources_elt.uri != NS_SFS: + try: + sources_elt = next(sources_elt.elements(NS_SFS, "sources")) + except StopIteration: + raise exceptions.NotFound( + f" element is missing: {sources_elt.toXml()}") + sources = [] + for elt in sources_elt.elements(): + if not elt.uri: + log.warning("ignoring source element {elt.toXml()}") + continue + key = (elt.uri, elt.name) + try: + source_handler = self._sources_handlers[key] + except KeyError: + log.warning(f"unmanaged file sharing element: {elt.toXml}") + continue + else: + source_data = source_handler.callback(elt) + if source_handler.encrypted: + source_data[C.MESS_KEY_ENCRYPTED] = True + if "type" not in source_data: + source_data["type"] = elt.uri + sources.append(source_data) + return sources + def parse_file_sharing_elt( self, file_sharing_elt: domish.Element @@ -126,17 +281,95 @@ disposition = file_sharing_elt.getAttribute("disposition") if disposition is not None: data["disposition"] = disposition - sources = data["sources"] = [] try: - sources_elt = next(file_sharing_elt.elements(NS_SFS, "sources")) - except StopIteration: - raise ValueError(f" element is missing: {file_sharing_elt.toXml()}") - for elt in sources_elt.elements(): - if elt.name == "url-data" and elt.uri == self._u.namespace: - source_data = self._u.parse_url_data_elt(elt) - else: - log.warning(f"unmanaged file sharing element: {elt.toXml}") - continue - sources.append(source_data) + data["sources"] = self.parse_sources_elt(file_sharing_elt) + except exceptions.NotFound as e: + raise ValueError(str(e)) + + return data + + def _add_file_sharing_attachments( + self, + client: SatXMPPEntity, + message_elt: domish.Element, + data: Dict[str, Any] + ) -> Dict[str, Any]: + """Check for a shared file, and add it as an attachment""" + # XXX: XEP-0447 doesn't support several attachments in a single message, thus only + # one attachment can be added + try: + attachment = self.parse_file_sharing_elt(message_elt) + except exceptions.NotFound: + return data + + if any( + s.get(C.MESS_KEY_ENCRYPTED, False) + for s in attachment["sources"] + ) and client.encryption.isEncrypted(data): + # we don't add the encrypted flag if the message itself is not encrypted, + # because the decryption key is part of the link, so sending it over + # unencrypted channel is like having no encryption at all. + attachment[C.MESS_KEY_ENCRYPTED] = True + + attachments = data['extra'].setdefault(C.MESS_KEY_ATTACHMENTS, []) + attachments.append(attachment) return data + + async def attach(self, client, data): + # XXX: for now, XEP-0447 only allow to send one file per , thus we need + # to send each file in a separate message + attachments = data["extra"][C.MESS_KEY_ATTACHMENTS] + if not data['message'] or data['message'] == {'': ''}: + extra_attachments = attachments[1:] + del attachments[1:] + else: + # we have a message, we must send first attachment separately + extra_attachments = attachments[:] + attachments.clear() + del data["extra"][C.MESS_KEY_ATTACHMENTS] + + if attachments: + if len(attachments) > 1: + raise exceptions.InternalError( + "There should not be more that one attachment at this point" + ) + await self._attach.upload_files(client, data) + self._hints.addHintElements(data["xml"], [self._hints.HINT_STORE]) + for attachment in attachments: + try: + file_hash = (attachment["hash_algo"], attachment["hash"]) + except KeyError: + file_hash = None + file_sharing_elt = self.get_file_sharing_elt( + [{"url": attachment["url"]}], + name=attachment["name"], + size=attachment["size"], + file_hash=file_hash + ) + data["xml"].addChild(file_sharing_elt) + + for attachment in extra_attachments: + # we send all remaining attachment in a separate message + await client.sendMessage( + to_jid=data['to'], + message={'': ''}, + subject=data['subject'], + mess_type=data['type'], + extra={C.MESS_KEY_ATTACHMENTS: [attachment]}, + ) + + if ((not data['extra'] + and (not data['message'] or data['message'] == {'': ''}) + and not data['subject'])): + # nothing left to send, we can cancel the message + raise exceptions.CancelError("Cancelled by XEP_0447 attachment handling") + + def _message_received_trigger(self, client, message_elt, post_treat): + # we use a post_treat callback instead of "message_parse" trigger because we need + # to check if the "encrypted" flag is set to decide if we add the same flag to the + # attachment + post_treat.addCallback( + partial(self._add_file_sharing_attachments, client, message_elt) + ) + return True