Mercurial > libervia-backend
diff sat/plugins/plugin_xep_0363.py @ 3089:e75024e41f81
plugin upload, XEP-0363: code modernisation + preparation for extension:
- use of async/await syntax
- fileUpload's options are now serialised, allowing non string values
- (XEP-0363) Slot is now a dataclass, so it can be modified by other plugins
- (XEP-0363) Moved SSL related code to the new tools.web module
- (XEP-0363) added `XEP-0363_upload_size` and `XEP-0363_upload` trigger points
- a Deferred is not used anymore for `progress_id`, the value is directly returned
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 20 Dec 2019 12:28:04 +0100 |
parents | fee60f17ebac |
children | 9d0df638c8b4 |
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0363.py Fri Dec 20 12:28:04 2019 +0100 +++ b/sat/plugins/plugin_xep_0363.py Fri Dec 20 12:28:04 2019 +0100 @@ -17,30 +17,26 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. -from sat.core.i18n import _ -from sat.core.constants import Const as C -from sat.core.log import getLogger - -log = getLogger(__name__) -from sat.core import exceptions +import os.path +import mimetypes +from dataclasses import dataclass from wokkel import disco, iwokkel from zope.interface import implementer from twisted.words.protocols.jabber import jid from twisted.words.protocols.jabber.xmlstream import XMPPHandler from twisted.internet import reactor from twisted.internet import defer -from twisted.internet import ssl -from twisted.internet.interfaces import IOpenSSLClientConnectionCreator from twisted.web import client as http_client from twisted.web import http_headers -from twisted.web import iweb -from twisted.python import failure -from collections import namedtuple -from OpenSSL import SSL -import os.path -import mimetypes +from sat.core.i18n import _ +from sat.core.constants import Const as C +from sat.core.log import getLogger +from sat.core import exceptions +from sat.tools import web as sat_web +log = getLogger(__name__) + PLUGIN_INFO = { C.PI_NAME: "HTTP File Upload", C.PI_IMPORT_NAME: "XEP-0363", @@ -56,37 +52,12 @@ ALLOWED_HEADERS = ('authorization', 'cookie', 'expires') -Slot = namedtuple("Slot", ["put", "get", "headers"]) - - -@implementer(IOpenSSLClientConnectionCreator) -class NoCheckConnectionCreator(object): - def __init__(self, hostname, ctx): - self._ctx = ctx - - def clientConnectionForTLS(self, tlsProtocol): - context = self._ctx - connection = SSL.Connection(context, None) - connection.set_app_data(tlsProtocol) - return connection - - -@implementer(iweb.IPolicyForHTTPS) -class NoCheckContextFactory(ssl.ClientContextFactory): - """Context factory which doesn't do TLS certificate check - - /!\\ it's obvisously a security flaw to use this class, - and it should be used only with explicite agreement from the end used - """ - - def creatorForNetloc(self, hostname, port): - log.warning( - "TLS check disabled for {host} on port {port}".format( - host=hostname, port=port - ) - ) - certificateOptions = ssl.CertificateOptions(trustRoot=None) - return NoCheckConnectionCreator(hostname, certificateOptions.getContext()) +@dataclass +class Slot: + """Upload slot""" + put: str + get: str + headers: list class XEP_0363(object): @@ -115,8 +86,7 @@ def getHandler(self, client): return XEP_0363_handler() - @defer.inlineCallbacks - def getHTTPUploadEntity(self, upload_jid=None, profile=C.PROF_KEY_NONE): + async def getHTTPUploadEntity(self, client, upload_jid=None): """Get HTTP upload capable entity upload_jid is checked, then its components @@ -124,35 +94,35 @@ @return(D(jid.JID)): first HTTP upload capable entity @raise exceptions.NotFound: no entity found """ - client = self.host.getClient(profile) try: entity = client.http_upload_service except AttributeError: - found_entities = yield self.host.findFeaturesSet(client, (NS_HTTP_UPLOAD,)) + found_entities = await self.host.findFeaturesSet(client, (NS_HTTP_UPLOAD,)) try: entity = client.http_upload_service = next(iter(found_entities)) except StopIteration: entity = client.http_upload_service = None if entity is None: - raise failure.Failure(exceptions.NotFound("No HTTP upload entity found")) + raise exceptions.NotFound("No HTTP upload entity found") - defer.returnValue(entity) + return entity def _fileHTTPUpload(self, filepath, filename="", upload_jid="", ignore_tls_errors=False, profile=C.PROF_KEY_NONE): assert os.path.isabs(filepath) and os.path.isfile(filepath) - progress_id_d, __ = self.fileHTTPUpload( + client = self.host.getClient(profile) + progress_id_d, __ = defer.ensureDeferred(self.fileHTTPUpload( + client, filepath, filename or None, jid.JID(upload_jid) if upload_jid else None, {"ignore_tls_errors": ignore_tls_errors}, - profile, - ) + )) return progress_id_d - def fileHTTPUpload(self, filepath, filename=None, upload_jid=None, options=None, - profile=C.PROF_KEY_NONE): + async def fileHTTPUpload( + self, client, filepath, filename=None, upload_jid=None, options=None): """Upload a file through HTTP @param filepath(str): absolute path of the file @@ -169,137 +139,96 @@ if options is None: options = {} ignore_tls_errors = options.get("ignore_tls_errors", False) - client = self.host.getClient(profile) filename = filename or os.path.basename(filepath) size = os.path.getsize(filepath) - progress_id_d = defer.Deferred() - download_d = defer.Deferred() - d = self.getSlot(client, filename, size, upload_jid=upload_jid) - d.addCallbacks( - self._getSlotCb, - self._getSlotEb, - (client, progress_id_d, download_d, filepath, size, ignore_tls_errors), - None, - (client, progress_id_d, download_d), - ) - return progress_id_d, download_d + - def _getSlotEb(self, fail, client, progress_id_d, download_d): - """an error happened while trying to get slot""" - log.warning("Can't get upload slot: {reason}".format(reason=fail.value)) - progress_id_d.errback(fail) - download_d.errback(fail) + size_adjust = [] + #: this trigger can be used to modify the requested size, it is notably useful + #: with encryption. The size_adjust is a list which can be filled by int to add + #: to the initial size + self.host.trigger.point( + "XEP-0363_upload_size", client, options, filepath, size, size_adjust, + triggers_no_cancel=True) + if size_adjust: + size = sum([size, *size_adjust]) + try: + slot = await self.getSlot(client, filename, size, upload_jid=upload_jid) + except Exception as e: + log.warning(_("Can't get upload slot: {reason}").format(reason=e)) + raise e + else: + log.debug(f"Got upload slot: {slot}") + sat_file = self.host.plugins["FILE"].File( + self.host, client, filepath, size=size, auto_end_signals=False + ) + progress_id = sat_file.uid - def _getSlotCb(self, slot, client, progress_id_d, download_d, path, size, - ignore_tls_errors=False): - """Called when slot is received, try to do the upload + file_producer = http_client.FileBodyProducer(sat_file) + + if ignore_tls_errors: + agent = http_client.Agent(reactor, sat_web.NoCheckContextFactory()) + else: + agent = http_client.Agent(reactor) - @param slot(Slot): slot instance with the get and put urls - @param progress_id_d(defer.Deferred): Deferred to call when progress_id is known - @param progress_id_d(defer.Deferred): Deferred to call with URL when upload is - done - @param path(str): path to the file to upload - @param size(int): size of the file to upload - @param ignore_tls_errors(bool): ignore TLS certificate is True - @return (tuple - """ - log.debug(f"Got upload slot: {slot}") - sat_file = self.host.plugins["FILE"].File( - self.host, client, path, size=size, auto_end_signals=False - ) - progress_id_d.callback(sat_file.uid) - file_producer = http_client.FileBodyProducer(sat_file) - if ignore_tls_errors: - agent = http_client.Agent(reactor, NoCheckContextFactory()) - else: - agent = http_client.Agent(reactor) + headers = {"User-Agent": [C.APP_NAME.encode("utf-8")]} + + for name, value in slot.headers: + name = name.encode('utf-8') + value = value.encode('utf-8') + headers[name] = value + + + await self.host.trigger.asyncPoint( + "XEP-0363_upload", client, options, sat_file, file_producer, slot, + triggers_no_cancel=True) - headers = {"User-Agent": [C.APP_NAME.encode("utf-8")]} - for name, value in slot.headers: - name = name.encode('utf-8') - value = value.encode('utf-8') - headers[name] = value + download_d = agent.request( + b"PUT", + slot.put.encode("utf-8"), + http_headers.Headers(headers), + file_producer, + ) + download_d.addCallbacks( + self._uploadCb, + self._uploadEb, + (sat_file, slot), + None, + (sat_file), + ) - d = agent.request( - b"PUT", - slot.put.encode("utf-8"), - http_headers.Headers(headers), - file_producer, - ) - d.addCallbacks( - self._uploadCb, - self._uploadEb, - (sat_file, slot, download_d), - None, - (sat_file, download_d), - ) - return d + return progress_id, download_d - def _uploadCb(self, __, sat_file, slot, download_d): + def _uploadCb(self, __, sat_file, slot): """Called once file is successfully uploaded @param sat_file(SatFile): file used for the upload - should be closed, be is needed to send the progressFinished signal + should be closed, but it is needed to send the progressFinished signal @param slot(Slot): put/get urls """ log.info("HTTP upload finished") sat_file.progressFinished({"url": slot.get}) - download_d.callback(slot.get) + return slot.get - def _uploadEb(self, fail, sat_file, download_d): + def _uploadEb(self, failure_, sat_file): """Called on unsuccessful upload @param sat_file(SatFile): file used for the upload should be closed, be is needed to send the progressError signal """ - download_d.errback(fail) try: - wrapped_fail = fail.value.reasons[0] + wrapped_fail = failure_.value.reasons[0] except (AttributeError, IndexError) as e: log.warning(_("upload failed: {reason}").format(reason=e)) - sat_file.progressError(str(fail)) - raise fail + sat_file.progressError(str(failure_)) else: - if wrapped_fail.check(SSL.Error): + if wrapped_fail.check(sat_web.SSLError): msg = "TLS validation error, can't connect to HTTPS server" else: msg = "can't upload file" log.warning(msg + ": " + str(wrapped_fail.value)) sat_file.progressError(msg) - - def _gotSlot(self, iq_elt, client): - """Slot have been received - - This method convert the iq_elt result to a Slot instance - @param iq_elt(domish.Element): <IQ/> result as specified in XEP-0363 - """ - try: - slot_elt = next(iq_elt.elements(NS_HTTP_UPLOAD, "slot")) - put_elt = next(slot_elt.elements(NS_HTTP_UPLOAD, "put")) - put_url = put_elt['url'] - get_elt = next(slot_elt.elements(NS_HTTP_UPLOAD, "get")) - get_url = get_elt['url'] - except (StopIteration, KeyError): - raise exceptions.DataError("Incorrect stanza received from server") - headers = [] - for header_elt in put_elt.elements(NS_HTTP_UPLOAD, "header"): - try: - name = header_elt["name"] - value = str(header_elt) - except KeyError: - log.warning(_("Invalid header element: {xml}").format( - iq_elt.toXml())) - continue - name = name.replace('\n', '') - value = value.replace('\n', '') - if name.lower() not in ALLOWED_HEADERS: - log.warning(_('Ignoring unauthorised header "{name}": {xml}') - .format(name=name, xml = iq_elt.toXml())) - continue - headers.append((name, value)) - - slot = Slot(put=put_url, get=get_url, headers=tuple(headers)) - return slot + raise failure_ def _getSlot(self, filename, size, content_type, upload_jid, profile_key=C.PROF_KEY_NONE): @@ -312,13 +241,13 @@ @param content_type(unicode, None): MIME type of the content empty string or None to guess automatically """ + client = self.host.getClient(profile_key) filename = filename.replace("/", "_") - client = self.host.getClient(profile_key) - return self.getSlot( + return defer.ensureDeferred(self.getSlot( client, filename, size, content_type or None, upload_jid or None - ) + )) - def getSlot(self, client, filename, size, content_type=None, upload_jid=None): + async def getSlot(self, client, filename, size, content_type=None, upload_jid=None): """Get a slot (i.e. download/upload links) @param filename(unicode): name to use for the upload @@ -340,18 +269,12 @@ try: upload_jid = client.http_upload_service except AttributeError: - d = self.getHTTPUploadEntity(profile=client.profile) - d.addCallback( - lambda found_entity: self.getSlot( - client, filename, size, content_type, found_entity - ) - ) - return d + found_entity = await self.getHTTPUploadEntity(profile=client.profile) + return await self.getSlot( + client, filename, size, content_type, found_entity) else: if upload_jid is None: - raise failure.Failure( - exceptions.NotFound("No HTTP upload entity found") - ) + raise exceptions.NotFound("No HTTP upload entity found") iq_elt = client.IQ("get") iq_elt["to"] = upload_jid.full() @@ -361,10 +284,35 @@ if content_type is not None: request_elt["content-type"] = content_type - d = iq_elt.send() - d.addCallback(self._gotSlot, client) + iq_result_elt = await iq_elt.send() + + try: + slot_elt = next(iq_result_elt.elements(NS_HTTP_UPLOAD, "slot")) + put_elt = next(slot_elt.elements(NS_HTTP_UPLOAD, "put")) + put_url = put_elt['url'] + get_elt = next(slot_elt.elements(NS_HTTP_UPLOAD, "get")) + get_url = get_elt['url'] + except (StopIteration, KeyError): + raise exceptions.DataError("Incorrect stanza received from server") - return d + headers = [] + for header_elt in put_elt.elements(NS_HTTP_UPLOAD, "header"): + try: + name = header_elt["name"] + value = str(header_elt) + except KeyError: + log.warning(_("Invalid header element: {xml}").format( + iq_result_elt.toXml())) + continue + name = name.replace('\n', '') + value = value.replace('\n', '') + if name.lower() not in ALLOWED_HEADERS: + log.warning(_('Ignoring unauthorised header "{name}": {xml}') + .format(name=name, xml = iq_result_elt.toXml())) + continue + headers.append((name, value)) + + return Slot(put=put_url, get=get_url, headers=headers) @implementer(iwokkel.IDisco)