diff sat/plugins/plugin_xep_0060.py @ 3715:b9718216a1c0 0.9

merge bookmark 0.9
author Goffi <goffi@goffi.org>
date Wed, 01 Dec 2021 16:13:31 +0100
parents 5d108ce026d7
children 1cdb9d9fad6b
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0060.py	Tue Nov 30 23:31:09 2021 +0100
+++ b/sat/plugins/plugin_xep_0060.py	Wed Dec 01 16:13:31 2021 +0100
@@ -17,7 +17,7 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 
-from typing import Optional
+from typing import Optional, List, Tuple
 from collections import namedtuple
 import urllib.request, urllib.parse, urllib.error
 from functools import reduce
@@ -35,8 +35,9 @@
 from sat.core.i18n import _
 from sat.core.constants import Const as C
 from sat.core.log import getLogger
-from sat.core.xmpp import SatXMPPEntity
+from sat.core.core_types import SatXMPPEntity
 from sat.core import exceptions
+from sat.tools import utils
 from sat.tools import sat_defer
 from sat.tools import xml_tools
 from sat.tools.common import data_format
@@ -90,6 +91,8 @@
     ID_SINGLETON = "current"
     EXTRA_PUBLISH_OPTIONS = "publish_options"
     EXTRA_ON_PRECOND_NOT_MET = "on_precondition_not_met"
+    # extra disco needed for RSM, cf. XEP-0060 § 6.5.4
+    DISCO_RSM = "http://jabber.org/protocol/pubsub#rsm"
 
     def __init__(self, host):
         log.info(_("PubSub plugin initialization"))
@@ -197,7 +200,7 @@
         host.bridge.addMethod(
             "psItemsGet",
             ".plugin",
-            in_sign="ssiassa{ss}s",
+            in_sign="ssiassss",
             out_sign="s",
             method=self._getItems,
             async_=True,
@@ -284,7 +287,7 @@
         host.bridge.addMethod(
             "psGetFromMany",
             ".plugin",
-            in_sign="a(ss)ia{ss}s",
+            in_sign="a(ss)iss",
             out_sign="s",
             method=self._getFromMany,
         )
@@ -391,6 +394,7 @@
             the method must be named after PubSub constants in lower case
             and suffixed with "_cb"
             e.g.: "items_cb" for C.PS_ITEMS, "delete_cb" for C.PS_DELETE
+            note: only C.PS_ITEMS and C.PS_DELETE are implemented so far
         """
         assert node is not None
         assert kwargs
@@ -471,9 +475,9 @@
         service = None if not service else jid.JID(service)
         payload = xml_tools.parse(payload)
         extra = data_format.deserialise(extra_ser)
-        d = self.sendItem(
+        d = defer.ensureDeferred(self.sendItem(
             client, service, nodeIdentifier, payload, item_id or None, extra
-        )
+        ))
         d.addCallback(lambda ret: ret or "")
         return d
 
@@ -487,23 +491,13 @@
             raise exceptions.DataError(_("Can't parse items: {msg}").format(
                 msg=e))
         extra = data_format.deserialise(extra_ser)
-        d = self.sendItems(
+        return defer.ensureDeferred(self.sendItems(
             client, service, nodeIdentifier, items, extra
-        )
-        return d
-
-    def _getPublishedItemId(self, published_ids, original_id):
-        """Return item of published id if found in answer
+        ))
 
-        if not found original_id is returned, which may be None
-        """
-        try:
-            return published_ids[0]
-        except IndexError:
-            return original_id
-
-    def sendItem(self, client, service, nodeIdentifier, payload, item_id=None,
-                 extra=None):
+    async def sendItem(
+        self, client, service, nodeIdentifier, payload, item_id=None, extra=None
+    ):
         """High level method to send one item
 
         @param service(jid.JID, None): service to send the item to
@@ -519,15 +513,17 @@
         if item_id is not None:
             item_elt['id'] = item_id
         item_elt.addChild(payload)
-        d = defer.ensureDeferred(self.sendItems(
+        published_ids = await self.sendItems(
             client,
             service,
             nodeIdentifier,
             [item_elt],
             extra
-        ))
-        d.addCallback(self._getPublishedItemId, item_id)
-        return d
+        )
+        try:
+            return published_ids[0]
+        except IndexError:
+            return item_id
 
     async def sendItems(self, client, service, nodeIdentifier, items, extra=None):
         """High level method to send several items at once
@@ -593,12 +589,25 @@
         except AttributeError:
             return []
 
-    def publish(self, client, service, nodeIdentifier, items=None, options=None):
-        return client.pubsub_client.publish(
+    async def publish(
+        self,
+        client: SatXMPPEntity,
+        service: jid.JID,
+        nodeIdentifier: str,
+        items: Optional[List[domish.Element]] = None,
+        options: Optional[dict] = None
+    ) -> List[str]:
+        published_ids = await client.pubsub_client.publish(
             service, nodeIdentifier, items, client.pubsub_client.parent.jid,
             options=options
         )
 
+        await self.host.trigger.asyncPoint(
+            "XEP-0060_publish", client, service, nodeIdentifier, items, options,
+            published_ids
+        )
+        return published_ids
+
     def _unwrapMAMMessage(self, message_elt):
         try:
             item_elt = reduce(
@@ -621,7 +630,7 @@
         return data_format.serialise(metadata)
 
     def _getItems(self, service="", node="", max_items=10, item_ids=None, sub_id=None,
-                  extra_dict=None, profile_key=C.PROF_KEY_NONE):
+                  extra="", profile_key=C.PROF_KEY_NONE):
         """Get items from pubsub node
 
         @param max_items(int): maximum number of item to get, C.NO_LIMIT for no limit
@@ -629,23 +638,32 @@
         client = self.host.getClient(profile_key)
         service = jid.JID(service) if service else None
         max_items = None if max_items == C.NO_LIMIT else max_items
-        extra = self.parseExtra(extra_dict)
-        d = self.getItems(
+        extra = self.parseExtra(data_format.deserialise(extra))
+        d = defer.ensureDeferred(self.getItems(
             client,
             service,
             node or None,
-            max_items or None,
+            max_items,
             item_ids,
             sub_id or None,
             extra.rsm_request,
             extra.extra,
-        )
+        ))
         d.addCallback(self.transItemsData)
         d.addCallback(self.serialiseItems)
         return d
 
-    def getItems(self, client, service, node, max_items=None, item_ids=None, sub_id=None,
-                 rsm_request=None, extra=None):
+    async def getItems(
+        self,
+        client: SatXMPPEntity,
+        service: Optional[jid.JID],
+        node: str,
+        max_items: Optional[int] = None,
+        item_ids: Optional[List[str]] = None,
+        sub_id: Optional[str] = None,
+        rsm_request: Optional[rsm.RSMRequest] = None,
+        extra: Optional[dict] = None
+    ) -> Tuple[List[dict], dict]:
         """Retrieve pubsub items from a node.
 
         @param service (JID, None): pubsub service.
@@ -668,6 +686,12 @@
             raise ValueError("items_id can't be used with rsm")
         if extra is None:
             extra = {}
+        cont, ret = await self.host.trigger.asyncReturnPoint(
+            "XEP-0060_getItems", client, service, node, max_items, item_ids, sub_id,
+            rsm_request, extra
+        )
+        if not cont:
+            return ret
         try:
             mam_query = extra["mam"]
         except KeyError:
@@ -682,9 +706,10 @@
                 rsm_request = rsm_request
             )
             # we have no MAM data here, so we add None
-            d.addCallback(lambda data: data + (None,))
             d.addErrback(sat_defer.stanza2NotFound)
             d.addTimeout(TIMEOUT, reactor)
+            items, rsm_response = await d
+            mam_response = None
         else:
             # if mam is requested, we have to do a totally different query
             if self._mam is None:
@@ -706,61 +731,49 @@
                     raise exceptions.DataError(
                         "Conflict between RSM request and MAM's RSM request"
                     )
-            d = self._mam.getArchives(client, mam_query, service, self._unwrapMAMMessage)
+            items, rsm_response, mam_response = await self._mam.getArchives(
+                client, mam_query, service, self._unwrapMAMMessage
+            )
 
         try:
             subscribe = C.bool(extra["subscribe"])
         except KeyError:
             subscribe = False
 
-        def subscribeEb(failure, service, node):
-            failure.trap(error.StanzaError)
-            log.warning(
-                "Could not subscribe to node {} on service {}: {}".format(
-                    node, str(service), str(failure.value)
+        if subscribe:
+            try:
+                await self.subscribe(client, service, node)
+            except error.StanzaError as e:
+                log.warning(
+                    f"Could not subscribe to node {node} on service {service}: {e}"
                 )
-            )
-
-        def doSubscribe(data):
-            self.subscribe(client, service, node).addErrback(
-                subscribeEb, service, node
-            )
-            return data
-
-        if subscribe:
-            d.addCallback(doSubscribe)
 
-        def addMetadata(result):
-            # TODO: handle the third argument (mam_response)
-            items, rsm_response, mam_response = result
-            service_jid = service if service else client.jid.userhostJID()
-            metadata = {
-                "service": service_jid,
-                "node": node,
-                "uri": self.getNodeURI(service_jid, node),
-            }
-            if mam_response is not None:
-                # mam_response is a dict with "complete" and "stable" keys
-                # we can put them directly in metadata
-                metadata.update(mam_response)
-            if rsm_request is not None and rsm_response is not None:
-                metadata['rsm'] = rsm_response.toDict()
-                if mam_response is None:
-                    index = rsm_response.index
-                    count = rsm_response.count
-                    if index is None or count is None:
-                        # we don't have enough information to know if the data is complete
-                        # or not
-                        metadata["complete"] = None
-                    else:
-                        # normally we have a strict equality here but XEP-0059 states
-                        # that index MAY be approximative, so just in case…
-                        metadata["complete"] = index + len(items) >= count
+        # TODO: handle mam_response
+        service_jid = service if service else client.jid.userhostJID()
+        metadata = {
+            "service": service_jid,
+            "node": node,
+            "uri": self.getNodeURI(service_jid, node),
+        }
+        if mam_response is not None:
+            # mam_response is a dict with "complete" and "stable" keys
+            # we can put them directly in metadata
+            metadata.update(mam_response)
+        if rsm_request is not None and rsm_response is not None:
+            metadata['rsm'] = rsm_response.toDict()
+            if mam_response is None:
+                index = rsm_response.index
+                count = rsm_response.count
+                if index is None or count is None:
+                    # we don't have enough information to know if the data is complete
+                    # or not
+                    metadata["complete"] = None
+                else:
+                    # normally we have a strict equality here but XEP-0059 states
+                    # that index MAY be approximative, so just in case…
+                    metadata["complete"] = index + len(items) >= count
 
-            return (items, metadata)
-
-        d.addCallback(addMetadata)
-        return d
+        return (items, metadata)
 
     # @defer.inlineCallbacks
     # def getItemsFromMany(self, service, data, max_items=None, sub_id=None, rsm=None, profile_key=C.PROF_KEY_NONE):
@@ -1059,7 +1072,7 @@
         notify=True,
     ):
         return client.pubsub_client.retractItems(
-            service, nodeIdentifier, itemIdentifiers, notify=True
+            service, nodeIdentifier, itemIdentifiers, notify=notify
         )
 
     def _renameItem(
@@ -1100,37 +1113,55 @@
     def _subscribe(self, service, nodeIdentifier, options, profile_key=C.PROF_KEY_NONE):
         client = self.host.getClient(profile_key)
         service = None if not service else jid.JID(service)
-        d = self.subscribe(client, service, nodeIdentifier, options=options or None)
+        d = defer.ensureDeferred(
+            self.subscribe(client, service, nodeIdentifier, options=options or None)
+        )
         d.addCallback(lambda subscription: subscription.subscriptionIdentifier or "")
         return d
 
-    def subscribe(self, client, service, nodeIdentifier, sub_jid=None, options=None):
+    async def subscribe(
+        self,
+        client: SatXMPPEntity,
+        service: jid.JID,
+        nodeIdentifier: str,
+        sub_jid: Optional[jid.JID] = None,
+        options: Optional[dict] = None
+    ) -> pubsub.Subscription:
         # TODO: reimplement a subscribtion cache, checking that we have not subscription before trying to subscribe
-        return client.pubsub_client.subscribe(
+        subscription = await client.pubsub_client.subscribe(
             service, nodeIdentifier, sub_jid or client.jid.userhostJID(), options=options
         )
+        await self.host.trigger.asyncPoint(
+            "XEP-0060_subscribe", client, service, nodeIdentifier, sub_jid, options,
+            subscription
+        )
+        return subscription
 
     def _unsubscribe(self, service, nodeIdentifier, profile_key=C.PROF_KEY_NONE):
         client = self.host.getClient(profile_key)
         service = None if not service else jid.JID(service)
-        return self.unsubscribe(client, service, nodeIdentifier)
+        return defer.ensureDeferred(self.unsubscribe(client, service, nodeIdentifier))
 
-    def unsubscribe(
+    async def unsubscribe(
         self,
-        client,
-        service,
-        nodeIdentifier,
+        client: SatXMPPEntity,
+        service: jid.JID,
+        nodeIdentifier: str,
         sub_jid=None,
         subscriptionIdentifier=None,
         sender=None,
     ):
-        return client.pubsub_client.unsubscribe(
+        await client.pubsub_client.unsubscribe(
             service,
             nodeIdentifier,
             sub_jid or client.jid.userhostJID(),
             subscriptionIdentifier,
             sender,
         )
+        await self.host.trigger.asyncPoint(
+            "XEP-0060_unsubscribe", client, service, nodeIdentifier, sub_jid,
+            subscriptionIdentifier, sender
+        )
 
     def _subscriptions(self, service, nodeIdentifier="", profile_key=C.PROF_KEY_NONE):
         client = self.host.getClient(profile_key)
@@ -1394,8 +1425,10 @@
         client = self.host.getClient(profile_key)
         deferreds = {}
         for service, node in node_data:
-            deferreds[(service, node)] = client.pubsub_client.subscribe(
-                service, node, subscriber, options=options
+            deferreds[(service, node)] = defer.ensureDeferred(
+                client.pubsub_client.subscribe(
+                    service, node, subscriber, options=options
+                )
             )
         return self.rt_sessions.newSession(deferreds, client.profile)
         # found_nodes = yield self.listNodes(service, profile=client.profile)
@@ -1445,13 +1478,13 @@
         return d
 
     def _getFromMany(
-        self, node_data, max_item=10, extra_dict=None, profile_key=C.PROF_KEY_NONE
+        self, node_data, max_item=10, extra="", profile_key=C.PROF_KEY_NONE
     ):
         """
         @param max_item(int): maximum number of item to get, C.NO_LIMIT for no limit
         """
         max_item = None if max_item == C.NO_LIMIT else max_item
-        extra = self.parseExtra(extra_dict)
+        extra = self.parseExtra(data_format.deserialise(extra))
         return self.getFromMany(
             [(jid.JID(service), str(node)) for service, node in node_data],
             max_item,
@@ -1475,9 +1508,9 @@
         client = self.host.getClient(profile_key)
         deferreds = {}
         for service, node in node_data:
-            deferreds[(service, node)] = self.getItems(
+            deferreds[(service, node)] = defer.ensureDeferred(self.getItems(
                 client, service, node, max_item, rsm_request=rsm_request, extra=extra
-            )
+            ))
         return self.rt_sessions.newSession(deferreds, client.profile)
 
 
@@ -1513,7 +1546,10 @@
     def itemsReceived(self, event):
         log.debug("Pubsub items received")
         for callback in self._getNodeCallbacks(event.nodeIdentifier, C.PS_ITEMS):
-            callback(self.parent, event)
+            d = utils.asDeferred(callback, self.parent, event)
+            d.addErrback(lambda f: log.error(
+                f"Error while running items event callback {callback}: {f}"
+            ))
         client = self.parent
         if (event.sender, event.nodeIdentifier) in client.pubsub_watching:
             raw_items = [i.toXml() for i in event.items]
@@ -1528,13 +1564,29 @@
     def deleteReceived(self, event):
         log.debug(("Publish node deleted"))
         for callback in self._getNodeCallbacks(event.nodeIdentifier, C.PS_DELETE):
-            callback(self.parent, event)
+            d = utils.asDeferred(callback, self.parent, event)
+            d.addErrback(lambda f: log.error(
+                f"Error while running delete event callback {callback}: {f}"
+            ))
         client = self.parent
         if (event.sender, event.nodeIdentifier) in client.pubsub_watching:
             self.host.bridge.psEventRaw(
                 event.sender.full(), event.nodeIdentifier, C.PS_DELETE, [], client.profile
             )
 
+    def purgeReceived(self, event):
+        log.debug(("Publish node purged"))
+        for callback in self._getNodeCallbacks(event.nodeIdentifier, C.PS_PURGE):
+            d = utils.asDeferred(callback, self.parent, event)
+            d.addErrback(lambda f: log.error(
+                f"Error while running purge event callback {callback}: {f}"
+            ))
+        client = self.parent
+        if (event.sender, event.nodeIdentifier) in client.pubsub_watching:
+            self.host.bridge.psEventRaw(
+                event.sender.full(), event.nodeIdentifier, C.PS_PURGE, [], client.profile
+            )
+
     def subscriptions(self, service, nodeIdentifier, sender=None):
         """Return the list of subscriptions to the given service and node.