diff sat_pubsub/backend.py @ 478:b544109ab4c4

Privileged Entity update + Pubsub Account Management partial implementation + Public Pubsub Subscription /!\ pgsql schema needs to be updated /!\ /!\ server conf needs to be updated for privileged entity: only the new `urn:xmpp:privilege:2` namespace is handled now /!\ Privileged entity has been updated to hanlde the new namespace and IQ permission. Roster pushes are not managed yet. XEP-0376 (Pubsub Account Management) is partially implemented. The XEP is not fully specified at the moment, and my messages on standard@ haven't seen any reply. Thus for now only "Subscribing", "Unsubscribing" and "Listing Subscriptions" is implemented, "Auto Subscriptions" and "Filtering" is not. Public Pubsub Subscription (https://xmpp.org/extensions/inbox/pubsub-public-subscriptions.html) is implemented; the XEP has been accepted by council but is not yet published. It will be updated to use subscription options instead of the <public> element actually specified, I'm waiting for publication to update the XEP. unsubscribe has been updated to return the `<subscription>` element as expected by XEP-0060 (sat_tmp needs to be updated). database schema has been updated to add columns necessary to keep track of subscriptions to external nodes and to mark subscriptions as public.
author Goffi <goffi@goffi.org>
date Wed, 11 May 2022 13:39:08 +0200
parents ed9e12701e0f
children 0e801ae1869f
line wrap: on
line diff
--- a/sat_pubsub/backend.py	Mon Jan 03 16:48:22 2022 +0100
+++ b/sat_pubsub/backend.py	Wed May 11 13:39:08 2022 +0200
@@ -62,6 +62,7 @@
 
 import copy
 import uuid
+import hashlib
 from typing import Optional, List, Tuple
 
 from zope.interface import implementer
@@ -240,10 +241,14 @@
     def _getFTSLanguagesEb(self, failure_):
         log.msg(f"WARNING: can get FTS languages: {failure_}")
 
-    def isAdmin(self, entity_jid):
+    def isAdmin(self, entity_jid: jid.JID) -> bool:
         """Return True if an entity is an administrator"""
         return entity_jid.userhostJID() in self.admins
 
+    def isFromServer(self, entity_jid: jid.JID) -> bool:
+        """Return True if an entity come from our server"""
+        return entity_jid.host == self.server_jid.host
+
     def supportsPublishOptions(self):
         return True
 
@@ -595,18 +600,21 @@
     def registerPurgeNotifier(self, observerfn, *args, **kwargs):
         self.addObserver('//event/pubsub/purge', observerfn, *args, **kwargs)
 
-    def subscribe(self, nodeIdentifier, subscriber, requestor, pep, recipient):
+    async def subscribe(
+        self,
+        nodeIdentifier: str,
+        subscriber: jid.JID,
+        requestor: jid.JID,
+        options: Optional[dict],
+        pep: bool,
+        recipient: jid.JID
+    ) -> pubsub.Subscription:
         subscriberEntity = subscriber.userhostJID()
         if subscriberEntity != requestor.userhostJID():
-            return defer.fail(error.Forbidden())
+            raise error.Forbidden()
 
-        d = self.storage.getNode(nodeIdentifier, pep, recipient)
-        d.addCallback(_getAffiliation, subscriberEntity)
-        d.addCallback(self._doSubscribe, subscriber, pep, recipient)
-        return d
-
-    def _doSubscribe(self, result, subscriber, pep, recipient):
-        node, affiliation = result
+        node = await self.storage.getNode(nodeIdentifier, pep, recipient)
+        __, affiliation = await _getAffiliation(node, subscriberEntity)
 
         if affiliation == 'outcast':
             raise error.Forbidden()
@@ -614,67 +622,61 @@
         access_model = node.getAccessModel()
 
         if access_model == const.VAL_AMODEL_OPEN:
-            d = defer.succeed(None)
+            pass
         elif access_model == const.VAL_AMODEL_PRESENCE:
-            d = self.checkPresenceSubscription(node, subscriber)
+            await self.checkPresenceSubscription(node, subscriber)
         elif access_model == const.VAL_AMODEL_PUBLISHER_ROSTER:
-            d = self.checkRosterGroups(node, subscriber)
+            await self.checkRosterGroups(node, subscriber)
         elif access_model == const.VAL_AMODEL_WHITELIST:
-            d = self.checkNodeAffiliations(node, subscriber)
+            await self.checkNodeAffiliations(node, subscriber)
         else:
             raise NotImplementedError
 
-        def trapExists(failure):
-            failure.trap(error.SubscriptionExists)
-            return False
+        config = {}
+        if options and options.get(f"{{{const.NS_PPS}}}public"):
+            config["public"] = True
+        try:
+            await node.addSubscription(subscriber, 'subscribed', config)
+        except error.SubscriptionExists:
+            send_last = False
+        else:
+            send_last = True
 
-        def cb(sendLast):
-            d = node.getSubscription(subscriber)
-            if sendLast:
-                d.addCallback(self._sendLastPublished, node, pep, recipient)
-            return d
-
-        d.addCallback(lambda _: node.addSubscription(subscriber, 'subscribed', {}))
-        d.addCallbacks(lambda _: True, trapExists)
-        d.addCallback(cb)
-
-        return d
-
-    def _sendLastPublished(self, subscription, node, pep, recipient):
+        subscription = await node.getSubscription(subscriber)
 
-        def notifyItem(items_data):
-            if items_data:
-                reactor.callLater(0, self.dispatch,
-                                     {'items_data': items_data,
-                                      'node': node,
-                                      'pep': pep,
-                                      'recipient': recipient,
-                                      'subscription': subscription,
-                                     },
-                                     '//event/pubsub/notify')
-
-        config = node.getConfiguration()
-        sendLastPublished = config.get('pubsub#send_last_published_item',
-                                       'never')
-        if sendLastPublished == 'on_sub' and node.nodeType == 'leaf':
-            entity = subscription.subscriber.userhostJID()
-            d = defer.ensureDeferred(
-                self.getItemsData(
-                    node.nodeIdentifier, entity, recipient, maxItems=1, ext_data={'pep': pep}
+        if send_last:
+            config = node.getConfiguration()
+            sendLastPublished = config.get(
+                'pubsub#send_last_published_item', 'never'
+            )
+            if sendLastPublished == 'on_sub' and node.nodeType == 'leaf':
+                entity = subscription.subscriber.userhostJID()
+                items_data, __ = await self.getItemsData(
+                    node.nodeIdentifier, entity, recipient, maxItems=1,
+                    ext_data={'pep': pep}
                 )
-            )
-            d.addCallback(notifyItem)
-            d.addErrback(log.err)
+                if items_data:
+                    reactor.callLater(
+                        0,
+                        self.dispatch,
+                        {'items_data': items_data,
+                         'node': node,
+                         'pep': pep,
+                         'recipient': recipient,
+                         'subscription': subscription,
+                         },
+                        '//event/pubsub/notify'
+                    )
 
         return subscription
 
-    def unsubscribe(self, nodeIdentifier, subscriber, requestor, pep, recipient):
+    async def unsubscribe(self, nodeIdentifier, subscriber, requestor, pep, recipient):
         if subscriber.userhostJID() != requestor.userhostJID():
-            return defer.fail(error.Forbidden())
+            raise error.Forbidden()
 
-        d = self.storage.getNode(nodeIdentifier, pep, recipient)
-        d.addCallback(lambda node: node.removeSubscription(subscriber))
-        return d
+        node = await self.storage.getNode(nodeIdentifier, pep, recipient)
+        await node.removeSubscription(subscriber)
+        return pubsub.Subscription(nodeIdentifier, subscriber, "none")
 
     def getSubscriptions(self, requestor, nodeIdentifier, pep, recipient):
         """retrieve subscriptions of an entity
@@ -685,7 +687,9 @@
         @param pep(bool): True if it's a PEP request
         @param recipient(jid.JID, None): recipient of the PEP request
         """
-        return self.storage.getSubscriptions(requestor, nodeIdentifier, pep, recipient)
+        return self.storage.getSubscriptions(
+            requestor, nodeIdentifier, None, pep, recipient
+        )
 
     def supportsAutoCreate(self):
         return True
@@ -749,6 +753,10 @@
         if not nodeIdentifier:
             return defer.fail(error.NoRootNode())
 
+        if ((nodeIdentifier == const.NS_PPS_SUBSCRIPTIONS
+             or nodeIdentifier.startswith(const.PPS_SUBSCRIBERS_PREFIX))):
+            return defer.succeed({const.OPT_ACCESS_MODEL: const.VAL_AMODEL_OPEN})
+
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
         d.addCallback(lambda node: node.getConfiguration())
 
@@ -1067,15 +1075,19 @@
         )
         return ids
 
-    def getItems(self, nodeIdentifier, requestor, recipient, maxItems=None,
-                       itemIdentifiers=None, ext_data=None):
-        d = defer.ensureDeferred(
-            self.getItemsData(
-                nodeIdentifier, requestor, recipient, maxItems, itemIdentifiers, ext_data
-            )
+    async def getItems(
+        self,
+        nodeIdentifier: str,
+        requestor: jid.JID,
+        recipient: jid.JID,
+        maxItems: Optional[int] = None,
+        itemIdentifiers: Optional[List[str]] = None,
+        ext_data: Optional[dict] = None
+    ) -> Tuple[List[domish.Element], Optional[rsm.RSMResponse]]:
+        items_data, rsm_response = await self.getItemsData(
+            nodeIdentifier, requestor, recipient, maxItems, itemIdentifiers, ext_data
         )
-        d.addCallback(lambda items_data: [item_data.item for item_data in items_data])
-        return d
+        return [item_data.item for item_data in items_data], rsm_response
 
     async def getOwnerRoster(self, node, owners=None):
         # FIXME: roster of publisher, not owner, must be used
@@ -1099,20 +1111,40 @@
             return
         return roster
 
-    async def getItemsData(self, nodeIdentifier, requestor, recipient, maxItems=None,
-                       itemIdentifiers=None, ext_data=None):
+    async def getItemsData(
+        self,
+        nodeIdentifier: str,
+        requestor: jid.JID,
+        recipient: jid.JID,
+        maxItems: Optional[int] = None,
+        itemIdentifiers: Optional[List[str]] = None,
+        ext_data: Optional[dict] = None
+    ) -> Tuple[List[container.ItemData], Optional[rsm.RSMResponse]]:
         """like getItems but return the whole ItemData"""
         if maxItems == 0:
             log.msg("WARNING: maxItems=0 on items retrieval")
-            return []
+            return [], None
 
         if ext_data is None:
             ext_data = {}
+
+        if nodeIdentifier == const.NS_PPS_SUBSCRIPTIONS:
+            return await self.getPublicSubscriptions(
+                requestor, maxItems, itemIdentifiers, ext_data,
+                ext_data.pop("pep", False), recipient
+            )
+        elif nodeIdentifier.startswith(f"{const.NS_PPS_SUBSCRIBERS}/"):
+            target_node = nodeIdentifier[len(const.NS_PPS_SUBSCRIBERS)+1:]
+            return await self.getPublicNodeSubscriptions(
+                target_node, requestor, maxItems, itemIdentifiers, ext_data,
+                ext_data.pop("pep", False), recipient
+            )
+
         node = await self.storage.getNode(nodeIdentifier, ext_data.get('pep', False), recipient)
         try:
             affiliation, owner, roster, access_model = await self.checkNodeAccess(node, requestor)
         except error.NotLeafNodeError:
-            return []
+            return [], None
 
         # at this point node access is checked
 
@@ -1155,9 +1187,9 @@
         if schema is not None:
             self.filterItemsWithSchema(items_data, schema, owner)
 
-        await self._items_rsm(
-            items_data, node, requestor_groups, owner, itemIdentifiers, ext_data)
-        return items_data
+        return await self._items_rsm(
+            items_data, node, requestor_groups, owner, itemIdentifiers, ext_data
+        )
 
     def _setCount(self, value, response):
         response.count = value
@@ -1180,48 +1212,135 @@
             rsm_request = ext_data['rsm']
         except KeyError:
             # No RSM in this request, nothing to do
-            return items_data
+            return items_data, None
 
         if itemIdentifiers:
             log.msg("WARNING, itemIdentifiers used with RSM, ignoring the RSM part")
-            return items_data
+            return items_data, None
 
-        response = rsm.RSMResponse()
+        rsm_response = rsm.RSMResponse()
 
         d_count = node.getItemsCount(authorized_groups, owner, ext_data)
-        d_count.addCallback(self._setCount, response)
+        d_count.addCallback(self._setCount, rsm_response)
         d_list = [d_count]
 
         if items_data:
-            response.first = items_data[0].item['id']
-            response.last = items_data[-1].item['id']
+            rsm_response.first = items_data[0].item['id']
+            rsm_response.last = items_data[-1].item['id']
 
             # index handling
             if rsm_request.index is not None:
-                response.index = rsm_request.index
+                rsm_response.index = rsm_request.index
             elif rsm_request.before:
                 # The last page case (before == '') is managed in render method
                 d_index = node.getItemsIndex(rsm_request.before, authorized_groups, owner, ext_data)
-                d_index.addCallback(self._setIndex, response, -len(items_data))
+                d_index.addCallback(self._setIndex, rsm_response, -len(items_data))
                 d_list.append(d_index)
             elif rsm_request.after is not None:
                 d_index = node.getItemsIndex(rsm_request.after, authorized_groups, owner, ext_data)
-                d_index.addCallback(self._setIndex, response, 1)
+                d_index.addCallback(self._setIndex, rsm_response, 1)
                 d_list.append(d_index)
             else:
                 # the first page was requested
-                response.index = 0
+                rsm_response.index = 0
 
 
         await defer.DeferredList(d_list)
 
         if rsm_request.before == '':
             # the last page was requested
-            response.index = response.count - len(items_data)
+            rsm_response.index = rsm_response.count - len(items_data)
+
+        return items_data, rsm_response
+
+    def addEltFromSubDict(
+        self,
+        parent_elt: domish.Element,
+        from_jid: Optional[jid.JID],
+        sub_dict: dict[str, str],
+        namespace: Optional[str] = None,
+    ) -> None:
+        """Generate <subscription> element from storage.getAllSubscriptions's dict
+
+        @param parent_elt: element where the new subscription element must be added
+        @param sub_dict: subscription data as returned by storage.getAllSubscriptions
+        @param namespace: if not None, namespace to use for <subscription> element
+        @param service_attribute: name of the attribute to use for the subscribed service
+        """
+        subscription_elt = parent_elt.addElement(
+            "subscription" if namespace is None else (namespace, "subscription")
+        )
+        if from_jid is not None:
+            subscription_elt["jid"] = from_jid.userhost()
+        if sub_dict["node"] is not None:
+            if sub_dict["pep"] is not None:
+                subscription_elt["service"] = sub_dict["pep"]
+            else:
+                subscription_elt["service"] = self.jid.full()
+            subscription_elt["node"] = sub_dict["node"]
+        else:
+            subscription_elt["service"] = sub_dict["ext_service"]
+            subscription_elt["node"] = sub_dict["ext_node"]
+        subscription_elt["subscription"] = sub_dict["state"]
+
+    async def getPublicSubscriptions(
+        self,
+        requestor: jid.JID,
+        maxItems: Optional[int],
+        itemIdentifiers: Optional[List[str]],
+        ext_data: dict,
+        pep: bool,
+        recipient: jid.JID
+    ) -> Tuple[List[container.ItemData], Optional[rsm.RSMResponse]]:
 
-        items_data.append(container.ItemData(response.toElement()))
+        if itemIdentifiers or ext_data.get("rsm") or ext_data.get("mam"):
+            raise NotImplementedError(
+                "item identifiers, RSM and MAM are not implemented yet"
+            )
+
+        if not pep:
+            return [], None
+
+        subs = await self.storage.getAllSubscriptions(recipient, True)
+        items_data = []
+        for sub in subs:
+            if sub["state"] != "subscribed":
+                continue
+            item = domish.Element((pubsub.NS_PUBSUB, "item"))
+            item["id"] = sub["id"]
+            self.addEltFromSubDict(item, None, sub, const.NS_PPS)
+            items_data.append(container.ItemData(item))
+
+        return items_data, None
 
-        return items_data
+    async def getPublicNodeSubscriptions(
+        self,
+        nodeIdentifier: str,
+        requestor: jid.JID,
+        maxItems: Optional[int],
+        itemIdentifiers: Optional[List[str]],
+        ext_data: dict,
+        pep: bool,
+        recipient: jid.JID
+    ) -> Tuple[List[container.ItemData], Optional[rsm.RSMResponse]]:
+
+        if itemIdentifiers or ext_data.get("rsm") or ext_data.get("mam"):
+            raise NotImplementedError(
+                "item identifiers, RSM and MAM are not implemented yet"
+            )
+
+        node = await self.storage.getNode(nodeIdentifier, pep, recipient)
+
+        subs = await node.getSubscriptions(public=True)
+        items_data = []
+        for sub in subs:
+            item = domish.Element((pubsub.NS_PUBSUB, "item"))
+            item["id"] = sub.id
+            subscriber_elt = item.addElement((const.NS_PPS, "subscriber"))
+            subscriber_elt["jid"] = sub.subscriber.full()
+            items_data.append(container.ItemData(item))
+
+        return items_data, None
 
     async def retractItem(self, nodeIdentifier, itemIdentifiers, requestor, notify, pep, recipient):
         node = await self.storage.getNode(nodeIdentifier, pep, recipient)
@@ -1824,19 +1943,26 @@
         return d.addErrback(self._mapErrors)
 
     def subscribe(self, request):
-        d = self.backend.subscribe(request.nodeIdentifier,
-                                   request.subscriber,
-                                   request.sender,
-                                   self._isPep(request),
-                                   request.recipient)
+        d = defer.ensureDeferred(
+            self.backend.subscribe(
+                request.nodeIdentifier,
+                request.subscriber,
+                request.sender,
+                request.options,
+                self._isPep(request),
+                request.recipient
+            )
+        )
         return d.addErrback(self._mapErrors)
 
     def unsubscribe(self, request):
-        d = self.backend.unsubscribe(request.nodeIdentifier,
+        d = defer.ensureDeferred(
+            self.backend.unsubscribe(request.nodeIdentifier,
                                      request.subscriber,
                                      request.sender,
                                      self._isPep(request),
                                      request.recipient)
+        )
         return d.addErrback(self._mapErrors)
 
     def subscriptions(self, request):
@@ -1933,12 +2059,16 @@
         except AttributeError:
             pass
         ext_data['order_by'] = request.orderBy or []
-        d = self.backend.getItems(request.nodeIdentifier,
-                                  request.sender,
-                                  request.recipient,
-                                  request.maxItems,
-                                  request.itemIdentifiers,
-                                  ext_data)
+        d = defer.ensureDeferred(
+            self.backend.getItems(
+                request.nodeIdentifier,
+                request.sender,
+                request.recipient,
+                request.maxItems,
+                request.itemIdentifiers,
+                ext_data
+            )
+        )
         return d.addErrback(self._mapErrors)
 
     def retract(self, request):
@@ -1987,7 +2117,8 @@
             # cf. https://xmpp.org/extensions/xep-0060.html#subscriber-retrieve-returnsome
             disco.DiscoFeature(const.NS_PUBSUB_RSM),
             disco.DiscoFeature(pubsub.NS_ORDER_BY),
-            disco.DiscoFeature(const.NS_FDP)
+            disco.DiscoFeature(const.NS_FDP),
+            disco.DiscoFeature(const.NS_PPS)
         ]
 
     def getDiscoItems(self, requestor, service, nodeIdentifier=''):