diff sat/plugins/plugin_xep_0363.py @ 3089:e75024e41f81

plugin upload, XEP-0363: code modernisation + preparation for extension: - use of async/await syntax - fileUpload's options are now serialised, allowing non string values - (XEP-0363) Slot is now a dataclass, so it can be modified by other plugins - (XEP-0363) Moved SSL related code to the new tools.web module - (XEP-0363) added `XEP-0363_upload_size` and `XEP-0363_upload` trigger points - a Deferred is not used anymore for `progress_id`, the value is directly returned
author Goffi <goffi@goffi.org>
date Fri, 20 Dec 2019 12:28:04 +0100
parents fee60f17ebac
children 9d0df638c8b4
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0363.py	Fri Dec 20 12:28:04 2019 +0100
+++ b/sat/plugins/plugin_xep_0363.py	Fri Dec 20 12:28:04 2019 +0100
@@ -17,30 +17,26 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
-from sat.core.i18n import _
-from sat.core.constants import Const as C
-from sat.core.log import getLogger
-
-log = getLogger(__name__)
-from sat.core import exceptions
+import os.path
+import mimetypes
+from dataclasses import dataclass
 from wokkel import disco, iwokkel
 from zope.interface import implementer
 from twisted.words.protocols.jabber import jid
 from twisted.words.protocols.jabber.xmlstream import XMPPHandler
 from twisted.internet import reactor
 from twisted.internet import defer
-from twisted.internet import ssl
-from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
 from twisted.web import client as http_client
 from twisted.web import http_headers
-from twisted.web import iweb
-from twisted.python import failure
-from collections import namedtuple
-from OpenSSL import SSL
-import os.path
-import mimetypes
+from sat.core.i18n import _
+from sat.core.constants import Const as C
+from sat.core.log import getLogger
+from sat.core import exceptions
+from sat.tools import web as sat_web
 
 
+log = getLogger(__name__)
+
 PLUGIN_INFO = {
     C.PI_NAME: "HTTP File Upload",
     C.PI_IMPORT_NAME: "XEP-0363",
@@ -56,37 +52,12 @@
 ALLOWED_HEADERS = ('authorization', 'cookie', 'expires')
 
 
-Slot = namedtuple("Slot", ["put", "get", "headers"])
-
-
-@implementer(IOpenSSLClientConnectionCreator)
-class NoCheckConnectionCreator(object):
-    def __init__(self, hostname, ctx):
-        self._ctx = ctx
-
-    def clientConnectionForTLS(self, tlsProtocol):
-        context = self._ctx
-        connection = SSL.Connection(context, None)
-        connection.set_app_data(tlsProtocol)
-        return connection
-
-
-@implementer(iweb.IPolicyForHTTPS)
-class NoCheckContextFactory(ssl.ClientContextFactory):
-    """Context factory which doesn't do TLS certificate check
-
-    /!\\ it's obvisously a security flaw to use this class,
-    and it should be used only with explicite agreement from the end used
-    """
-
-    def creatorForNetloc(self, hostname, port):
-        log.warning(
-            "TLS check disabled for {host} on port {port}".format(
-                host=hostname, port=port
-            )
-        )
-        certificateOptions = ssl.CertificateOptions(trustRoot=None)
-        return NoCheckConnectionCreator(hostname, certificateOptions.getContext())
+@dataclass
+class Slot:
+    """Upload slot"""
+    put: str
+    get: str
+    headers: list
 
 
 class XEP_0363(object):
@@ -115,8 +86,7 @@
     def getHandler(self, client):
         return XEP_0363_handler()
 
-    @defer.inlineCallbacks
-    def getHTTPUploadEntity(self, upload_jid=None, profile=C.PROF_KEY_NONE):
+    async def getHTTPUploadEntity(self, client, upload_jid=None):
         """Get HTTP upload capable entity
 
          upload_jid is checked, then its components
@@ -124,35 +94,35 @@
          @return(D(jid.JID)): first HTTP upload capable entity
          @raise exceptions.NotFound: no entity found
          """
-        client = self.host.getClient(profile)
         try:
             entity = client.http_upload_service
         except AttributeError:
-            found_entities = yield self.host.findFeaturesSet(client, (NS_HTTP_UPLOAD,))
+            found_entities = await self.host.findFeaturesSet(client, (NS_HTTP_UPLOAD,))
             try:
                 entity = client.http_upload_service = next(iter(found_entities))
             except StopIteration:
                 entity = client.http_upload_service = None
 
         if entity is None:
-            raise failure.Failure(exceptions.NotFound("No HTTP upload entity found"))
+            raise exceptions.NotFound("No HTTP upload entity found")
 
-        defer.returnValue(entity)
+        return entity
 
     def _fileHTTPUpload(self, filepath, filename="", upload_jid="",
                         ignore_tls_errors=False, profile=C.PROF_KEY_NONE):
         assert os.path.isabs(filepath) and os.path.isfile(filepath)
-        progress_id_d, __ = self.fileHTTPUpload(
+        client = self.host.getClient(profile)
+        progress_id_d, __ = defer.ensureDeferred(self.fileHTTPUpload(
+            client,
             filepath,
             filename or None,
             jid.JID(upload_jid) if upload_jid else None,
             {"ignore_tls_errors": ignore_tls_errors},
-            profile,
-        )
+        ))
         return progress_id_d
 
-    def fileHTTPUpload(self, filepath, filename=None, upload_jid=None, options=None,
-                       profile=C.PROF_KEY_NONE):
+    async def fileHTTPUpload(
+        self, client, filepath, filename=None, upload_jid=None, options=None):
         """Upload a file through HTTP
 
         @param filepath(str): absolute path of the file
@@ -169,137 +139,96 @@
         if options is None:
             options = {}
         ignore_tls_errors = options.get("ignore_tls_errors", False)
-        client = self.host.getClient(profile)
         filename = filename or os.path.basename(filepath)
         size = os.path.getsize(filepath)
-        progress_id_d = defer.Deferred()
-        download_d = defer.Deferred()
-        d = self.getSlot(client, filename, size, upload_jid=upload_jid)
-        d.addCallbacks(
-            self._getSlotCb,
-            self._getSlotEb,
-            (client, progress_id_d, download_d, filepath, size, ignore_tls_errors),
-            None,
-            (client, progress_id_d, download_d),
-        )
-        return progress_id_d, download_d
+
 
-    def _getSlotEb(self, fail, client, progress_id_d, download_d):
-        """an error happened while trying to get slot"""
-        log.warning("Can't get upload slot: {reason}".format(reason=fail.value))
-        progress_id_d.errback(fail)
-        download_d.errback(fail)
+        size_adjust = []
+        #: this trigger can be used to modify the requested size, it is notably useful
+        #: with encryption. The size_adjust is a list which can be filled by int to add
+        #: to the initial size
+        self.host.trigger.point(
+            "XEP-0363_upload_size", client, options, filepath, size, size_adjust,
+            triggers_no_cancel=True)
+        if size_adjust:
+            size = sum([size, *size_adjust])
+        try:
+            slot = await self.getSlot(client, filename, size, upload_jid=upload_jid)
+        except Exception as e:
+            log.warning(_("Can't get upload slot: {reason}").format(reason=e))
+            raise e
+        else:
+            log.debug(f"Got upload slot: {slot}")
+            sat_file = self.host.plugins["FILE"].File(
+                self.host, client, filepath, size=size, auto_end_signals=False
+            )
+            progress_id = sat_file.uid
 
-    def _getSlotCb(self, slot, client, progress_id_d, download_d, path, size,
-                   ignore_tls_errors=False):
-        """Called when slot is received, try to do the upload
+            file_producer = http_client.FileBodyProducer(sat_file)
+
+            if ignore_tls_errors:
+                agent = http_client.Agent(reactor, sat_web.NoCheckContextFactory())
+            else:
+                agent = http_client.Agent(reactor)
 
-        @param slot(Slot): slot instance with the get and put urls
-        @param progress_id_d(defer.Deferred): Deferred to call when progress_id is known
-        @param progress_id_d(defer.Deferred): Deferred to call with URL when upload is
-            done
-        @param path(str): path to the file to upload
-        @param size(int): size of the file to upload
-        @param ignore_tls_errors(bool): ignore TLS certificate is True
-        @return (tuple
-        """
-        log.debug(f"Got upload slot: {slot}")
-        sat_file = self.host.plugins["FILE"].File(
-            self.host, client, path, size=size, auto_end_signals=False
-        )
-        progress_id_d.callback(sat_file.uid)
-        file_producer = http_client.FileBodyProducer(sat_file)
-        if ignore_tls_errors:
-            agent = http_client.Agent(reactor, NoCheckContextFactory())
-        else:
-            agent = http_client.Agent(reactor)
+            headers = {"User-Agent": [C.APP_NAME.encode("utf-8")]}
+
+            for name, value in slot.headers:
+                name = name.encode('utf-8')
+                value = value.encode('utf-8')
+                headers[name] = value
+
+
+            await self.host.trigger.asyncPoint(
+                "XEP-0363_upload", client, options, sat_file, file_producer, slot,
+                triggers_no_cancel=True)
 
-        headers = {"User-Agent": [C.APP_NAME.encode("utf-8")]}
-        for name, value in slot.headers:
-            name = name.encode('utf-8')
-            value = value.encode('utf-8')
-            headers[name] = value
+            download_d = agent.request(
+                b"PUT",
+                slot.put.encode("utf-8"),
+                http_headers.Headers(headers),
+                file_producer,
+            )
+            download_d.addCallbacks(
+                self._uploadCb,
+                self._uploadEb,
+                (sat_file, slot),
+                None,
+                (sat_file),
+            )
 
-        d = agent.request(
-            b"PUT",
-            slot.put.encode("utf-8"),
-            http_headers.Headers(headers),
-            file_producer,
-        )
-        d.addCallbacks(
-            self._uploadCb,
-            self._uploadEb,
-            (sat_file, slot, download_d),
-            None,
-            (sat_file, download_d),
-        )
-        return d
+            return progress_id, download_d
 
-    def _uploadCb(self, __, sat_file, slot, download_d):
+    def _uploadCb(self, __, sat_file, slot):
         """Called once file is successfully uploaded
 
         @param sat_file(SatFile): file used for the upload
-            should be closed, be is needed to send the progressFinished signal
+            should be closed, but it is needed to send the progressFinished signal
         @param slot(Slot): put/get urls
         """
         log.info("HTTP upload finished")
         sat_file.progressFinished({"url": slot.get})
-        download_d.callback(slot.get)
+        return slot.get
 
-    def _uploadEb(self, fail, sat_file, download_d):
+    def _uploadEb(self, failure_, sat_file):
         """Called on unsuccessful upload
 
         @param sat_file(SatFile): file used for the upload
             should be closed, be is needed to send the progressError signal
         """
-        download_d.errback(fail)
         try:
-            wrapped_fail = fail.value.reasons[0]
+            wrapped_fail = failure_.value.reasons[0]
         except (AttributeError, IndexError) as e:
             log.warning(_("upload failed: {reason}").format(reason=e))
-            sat_file.progressError(str(fail))
-            raise fail
+            sat_file.progressError(str(failure_))
         else:
-            if wrapped_fail.check(SSL.Error):
+            if wrapped_fail.check(sat_web.SSLError):
                 msg = "TLS validation error, can't connect to HTTPS server"
             else:
                 msg = "can't upload file"
             log.warning(msg + ": " + str(wrapped_fail.value))
             sat_file.progressError(msg)
-
-    def _gotSlot(self, iq_elt, client):
-        """Slot have been received
-
-        This method convert the iq_elt result to a Slot instance
-        @param iq_elt(domish.Element): <IQ/> result as specified in XEP-0363
-        """
-        try:
-            slot_elt = next(iq_elt.elements(NS_HTTP_UPLOAD, "slot"))
-            put_elt = next(slot_elt.elements(NS_HTTP_UPLOAD, "put"))
-            put_url = put_elt['url']
-            get_elt = next(slot_elt.elements(NS_HTTP_UPLOAD, "get"))
-            get_url = get_elt['url']
-        except (StopIteration, KeyError):
-            raise exceptions.DataError("Incorrect stanza received from server")
-        headers = []
-        for header_elt in put_elt.elements(NS_HTTP_UPLOAD, "header"):
-            try:
-                name = header_elt["name"]
-                value = str(header_elt)
-            except KeyError:
-                log.warning(_("Invalid header element: {xml}").format(
-                    iq_elt.toXml()))
-                continue
-            name = name.replace('\n', '')
-            value = value.replace('\n', '')
-            if name.lower() not in ALLOWED_HEADERS:
-                log.warning(_('Ignoring unauthorised header "{name}": {xml}')
-                    .format(name=name, xml = iq_elt.toXml()))
-                continue
-            headers.append((name, value))
-
-        slot = Slot(put=put_url, get=get_url, headers=tuple(headers))
-        return slot
+        raise failure_
 
     def _getSlot(self, filename, size, content_type, upload_jid,
                  profile_key=C.PROF_KEY_NONE):
@@ -312,13 +241,13 @@
         @param content_type(unicode, None): MIME type of the content
             empty string or None to guess automatically
         """
+        client = self.host.getClient(profile_key)
         filename = filename.replace("/", "_")
-        client = self.host.getClient(profile_key)
-        return self.getSlot(
+        return defer.ensureDeferred(self.getSlot(
             client, filename, size, content_type or None, upload_jid or None
-        )
+        ))
 
-    def getSlot(self, client, filename, size, content_type=None, upload_jid=None):
+    async def getSlot(self, client, filename, size, content_type=None, upload_jid=None):
         """Get a slot (i.e. download/upload links)
 
         @param filename(unicode): name to use for the upload
@@ -340,18 +269,12 @@
             try:
                 upload_jid = client.http_upload_service
             except AttributeError:
-                d = self.getHTTPUploadEntity(profile=client.profile)
-                d.addCallback(
-                    lambda found_entity: self.getSlot(
-                        client, filename, size, content_type, found_entity
-                    )
-                )
-                return d
+                found_entity = await self.getHTTPUploadEntity(profile=client.profile)
+                return await self.getSlot(
+                    client, filename, size, content_type, found_entity)
             else:
                 if upload_jid is None:
-                    raise failure.Failure(
-                        exceptions.NotFound("No HTTP upload entity found")
-                    )
+                    raise exceptions.NotFound("No HTTP upload entity found")
 
         iq_elt = client.IQ("get")
         iq_elt["to"] = upload_jid.full()
@@ -361,10 +284,35 @@
         if content_type is not None:
             request_elt["content-type"] = content_type
 
-        d = iq_elt.send()
-        d.addCallback(self._gotSlot, client)
+        iq_result_elt = await iq_elt.send()
+
+        try:
+            slot_elt = next(iq_result_elt.elements(NS_HTTP_UPLOAD, "slot"))
+            put_elt = next(slot_elt.elements(NS_HTTP_UPLOAD, "put"))
+            put_url = put_elt['url']
+            get_elt = next(slot_elt.elements(NS_HTTP_UPLOAD, "get"))
+            get_url = get_elt['url']
+        except (StopIteration, KeyError):
+            raise exceptions.DataError("Incorrect stanza received from server")
 
-        return d
+        headers = []
+        for header_elt in put_elt.elements(NS_HTTP_UPLOAD, "header"):
+            try:
+                name = header_elt["name"]
+                value = str(header_elt)
+            except KeyError:
+                log.warning(_("Invalid header element: {xml}").format(
+                    iq_result_elt.toXml()))
+                continue
+            name = name.replace('\n', '')
+            value = value.replace('\n', '')
+            if name.lower() not in ALLOWED_HEADERS:
+                log.warning(_('Ignoring unauthorised header "{name}": {xml}')
+                    .format(name=name, xml = iq_result_elt.toXml()))
+                continue
+            headers.append((name, value))
+
+        return Slot(put=put_url, get=get_url, headers=headers)
 
 
 @implementer(iwokkel.IDisco)