diff sat/plugins/plugin_xep_0060.py @ 3584:edc79cefe968

plugin XEP-0060: `getItem(s)`, `publish` and `(un)subscribe` are now coroutines
author Goffi <goffi@goffi.org>
date Wed, 30 Jun 2021 16:19:14 +0200
parents 02eec2a5b5f9
children 5f65f4e9f8cb
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0060.py	Sun Jun 27 00:15:40 2021 +0200
+++ b/sat/plugins/plugin_xep_0060.py	Wed Jun 30 16:19:14 2021 +0200
@@ -17,7 +17,7 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 
-from typing import Optional
+from typing import Optional, List, Tuple
 from collections import namedtuple
 import urllib.request, urllib.parse, urllib.error
 from functools import reduce
@@ -35,7 +35,7 @@
 from sat.core.i18n import _
 from sat.core.constants import Const as C
 from sat.core.log import getLogger
-from sat.core.xmpp import SatXMPPEntity
+from sat.core.core_types import SatXMPPEntity
 from sat.core import exceptions
 from sat.tools import sat_defer
 from sat.tools import xml_tools
@@ -471,9 +471,9 @@
         service = None if not service else jid.JID(service)
         payload = xml_tools.parse(payload)
         extra = data_format.deserialise(extra_ser)
-        d = self.sendItem(
+        d = defer.ensureDeferred(self.sendItem(
             client, service, nodeIdentifier, payload, item_id or None, extra
-        )
+        ))
         d.addCallback(lambda ret: ret or "")
         return d
 
@@ -487,23 +487,13 @@
             raise exceptions.DataError(_("Can't parse items: {msg}").format(
                 msg=e))
         extra = data_format.deserialise(extra_ser)
-        d = self.sendItems(
+        return defer.ensureDeferred(self.sendItems(
             client, service, nodeIdentifier, items, extra
-        )
-        return d
-
-    def _getPublishedItemId(self, published_ids, original_id):
-        """Return item of published id if found in answer
+        ))
 
-        if not found original_id is returned, which may be None
-        """
-        try:
-            return published_ids[0]
-        except IndexError:
-            return original_id
-
-    def sendItem(self, client, service, nodeIdentifier, payload, item_id=None,
-                 extra=None):
+    async def sendItem(
+        self, client, service, nodeIdentifier, payload, item_id=None, extra=None
+    ):
         """High level method to send one item
 
         @param service(jid.JID, None): service to send the item to
@@ -519,15 +509,17 @@
         if item_id is not None:
             item_elt['id'] = item_id
         item_elt.addChild(payload)
-        d = defer.ensureDeferred(self.sendItems(
+        published_ids = await self.sendItems(
             client,
             service,
             nodeIdentifier,
             [item_elt],
             extra
-        ))
-        d.addCallback(self._getPublishedItemId, item_id)
-        return d
+        )
+        try:
+            return published_ids[0]
+        except IndexError:
+            return item_id
 
     async def sendItems(self, client, service, nodeIdentifier, items, extra=None):
         """High level method to send several items at once
@@ -593,8 +585,8 @@
         except AttributeError:
             return []
 
-    def publish(self, client, service, nodeIdentifier, items=None, options=None):
-        return client.pubsub_client.publish(
+    async def publish(self, client, service, nodeIdentifier, items=None, options=None):
+        return await client.pubsub_client.publish(
             service, nodeIdentifier, items, client.pubsub_client.parent.jid,
             options=options
         )
@@ -630,7 +622,7 @@
         service = jid.JID(service) if service else None
         max_items = None if max_items == C.NO_LIMIT else max_items
         extra = self.parseExtra(extra_dict)
-        d = self.getItems(
+        d = defer.ensureDeferred(self.getItems(
             client,
             service,
             node or None,
@@ -639,13 +631,22 @@
             sub_id or None,
             extra.rsm_request,
             extra.extra,
-        )
+        ))
         d.addCallback(self.transItemsData)
         d.addCallback(self.serialiseItems)
         return d
 
-    def getItems(self, client, service, node, max_items=None, item_ids=None, sub_id=None,
-                 rsm_request=None, extra=None):
+    async def getItems(
+        self,
+        client: SatXMPPEntity,
+        service: Optional[jid.JID],
+        node: str,
+        max_items: Optional[int] = None,
+        item_ids: Optional[List[str]] = None,
+        sub_id: Optional[str] = None,
+        rsm_request: Optional[rsm.RSMRequest] = None,
+        extra: Optional[dict] = None
+    ) -> Tuple[List[dict], dict]:
         """Retrieve pubsub items from a node.
 
         @param service (JID, None): pubsub service.
@@ -682,9 +683,10 @@
                 rsm_request = rsm_request
             )
             # we have no MAM data here, so we add None
-            d.addCallback(lambda data: data + (None,))
             d.addErrback(sat_defer.stanza2NotFound)
             d.addTimeout(TIMEOUT, reactor)
+            items, rsm_response = await d
+            mam_response = None
         else:
             # if mam is requested, we have to do a totally different query
             if self._mam is None:
@@ -706,61 +708,49 @@
                     raise exceptions.DataError(
                         "Conflict between RSM request and MAM's RSM request"
                     )
-            d = self._mam.getArchives(client, mam_query, service, self._unwrapMAMMessage)
+            items, rsm_response, mam_response = await self._mam.getArchives(
+                client, mam_query, service, self._unwrapMAMMessage
+            )
 
         try:
             subscribe = C.bool(extra["subscribe"])
         except KeyError:
             subscribe = False
 
-        def subscribeEb(failure, service, node):
-            failure.trap(error.StanzaError)
-            log.warning(
-                "Could not subscribe to node {} on service {}: {}".format(
-                    node, str(service), str(failure.value)
+        if subscribe:
+            try:
+                await self.subscribe(client, service, node)
+            except error.StanzaError as e:
+                log.warning(
+                    f"Could not subscribe to node {node} on service {service}: {e}"
                 )
-            )
-
-        def doSubscribe(data):
-            self.subscribe(client, service, node).addErrback(
-                subscribeEb, service, node
-            )
-            return data
-
-        if subscribe:
-            d.addCallback(doSubscribe)
 
-        def addMetadata(result):
-            # TODO: handle the third argument (mam_response)
-            items, rsm_response, mam_response = result
-            service_jid = service if service else client.jid.userhostJID()
-            metadata = {
-                "service": service_jid,
-                "node": node,
-                "uri": self.getNodeURI(service_jid, node),
-            }
-            if mam_response is not None:
-                # mam_response is a dict with "complete" and "stable" keys
-                # we can put them directly in metadata
-                metadata.update(mam_response)
-            if rsm_request is not None and rsm_response is not None:
-                metadata['rsm'] = rsm_response.toDict()
-                if mam_response is None:
-                    index = rsm_response.index
-                    count = rsm_response.count
-                    if index is None or count is None:
-                        # we don't have enough information to know if the data is complete
-                        # or not
-                        metadata["complete"] = None
-                    else:
-                        # normally we have a strict equality here but XEP-0059 states
-                        # that index MAY be approximative, so just in case…
-                        metadata["complete"] = index + len(items) >= count
+        # TODO: handle mam_response
+        service_jid = service if service else client.jid.userhostJID()
+        metadata = {
+            "service": service_jid,
+            "node": node,
+            "uri": self.getNodeURI(service_jid, node),
+        }
+        if mam_response is not None:
+            # mam_response is a dict with "complete" and "stable" keys
+            # we can put them directly in metadata
+            metadata.update(mam_response)
+        if rsm_request is not None and rsm_response is not None:
+            metadata['rsm'] = rsm_response.toDict()
+            if mam_response is None:
+                index = rsm_response.index
+                count = rsm_response.count
+                if index is None or count is None:
+                    # we don't have enough information to know if the data is complete
+                    # or not
+                    metadata["complete"] = None
+                else:
+                    # normally we have a strict equality here but XEP-0059 states
+                    # that index MAY be approximative, so just in case…
+                    metadata["complete"] = index + len(items) >= count
 
-            return (items, metadata)
-
-        d.addCallback(addMetadata)
-        return d
+        return (items, metadata)
 
     # @defer.inlineCallbacks
     # def getItemsFromMany(self, service, data, max_items=None, sub_id=None, rsm=None, profile_key=C.PROF_KEY_NONE):
@@ -1100,22 +1090,24 @@
     def _subscribe(self, service, nodeIdentifier, options, profile_key=C.PROF_KEY_NONE):
         client = self.host.getClient(profile_key)
         service = None if not service else jid.JID(service)
-        d = self.subscribe(client, service, nodeIdentifier, options=options or None)
+        d = defer.ensureDeferred(
+            self.subscribe(client, service, nodeIdentifier, options=options or None)
+        )
         d.addCallback(lambda subscription: subscription.subscriptionIdentifier or "")
         return d
 
-    def subscribe(self, client, service, nodeIdentifier, sub_jid=None, options=None):
+    async def subscribe(self, client, service, nodeIdentifier, sub_jid=None, options=None):
         # TODO: reimplement a subscribtion cache, checking that we have not subscription before trying to subscribe
-        return client.pubsub_client.subscribe(
+        return await client.pubsub_client.subscribe(
             service, nodeIdentifier, sub_jid or client.jid.userhostJID(), options=options
         )
 
     def _unsubscribe(self, service, nodeIdentifier, profile_key=C.PROF_KEY_NONE):
         client = self.host.getClient(profile_key)
         service = None if not service else jid.JID(service)
-        return self.unsubscribe(client, service, nodeIdentifier)
+        return defer.ensureDeferred(self.unsubscribe(client, service, nodeIdentifier))
 
-    def unsubscribe(
+    async def unsubscribe(
         self,
         client,
         service,
@@ -1124,7 +1116,7 @@
         subscriptionIdentifier=None,
         sender=None,
     ):
-        return client.pubsub_client.unsubscribe(
+        return await client.pubsub_client.unsubscribe(
             service,
             nodeIdentifier,
             sub_jid or client.jid.userhostJID(),
@@ -1394,8 +1386,10 @@
         client = self.host.getClient(profile_key)
         deferreds = {}
         for service, node in node_data:
-            deferreds[(service, node)] = client.pubsub_client.subscribe(
-                service, node, subscriber, options=options
+            deferreds[(service, node)] = defer.ensureDeferred(
+                client.pubsub_client.subscribe(
+                    service, node, subscriber, options=options
+                )
             )
         return self.rt_sessions.newSession(deferreds, client.profile)
         # found_nodes = yield self.listNodes(service, profile=client.profile)
@@ -1475,9 +1469,9 @@
         client = self.host.getClient(profile_key)
         deferreds = {}
         for service, node in node_data:
-            deferreds[(service, node)] = self.getItems(
+            deferreds[(service, node)] = defer.ensureDeferred(self.getItems(
                 client, service, node, max_item, rsm_request=rsm_request, extra=extra
-            )
+            ))
         return self.rt_sessions.newSession(deferreds, client.profile)