diff sat_pubsub/pgsql_storage.py @ 478:b544109ab4c4

Privileged Entity update + Pubsub Account Management partial implementation + Public Pubsub Subscription /!\ pgsql schema needs to be updated /!\ /!\ server conf needs to be updated for privileged entity: only the new `urn:xmpp:privilege:2` namespace is handled now /!\ Privileged entity has been updated to hanlde the new namespace and IQ permission. Roster pushes are not managed yet. XEP-0376 (Pubsub Account Management) is partially implemented. The XEP is not fully specified at the moment, and my messages on standard@ haven't seen any reply. Thus for now only "Subscribing", "Unsubscribing" and "Listing Subscriptions" is implemented, "Auto Subscriptions" and "Filtering" is not. Public Pubsub Subscription (https://xmpp.org/extensions/inbox/pubsub-public-subscriptions.html) is implemented; the XEP has been accepted by council but is not yet published. It will be updated to use subscription options instead of the <public> element actually specified, I'm waiting for publication to update the XEP. unsubscribe has been updated to return the `<subscription>` element as expected by XEP-0060 (sat_tmp needs to be updated). database schema has been updated to add columns necessary to keep track of subscriptions to external nodes and to mark subscriptions as public.
author Goffi <goffi@goffi.org>
date Wed, 11 May 2022 13:39:08 +0200
parents d993e8b0fd60
children e814c98ef07a
line wrap: on
line diff
--- a/sat_pubsub/pgsql_storage.py	Mon Jan 03 16:48:22 2022 +0100
+++ b/sat_pubsub/pgsql_storage.py	Wed May 11 13:39:08 2022 +0200
@@ -49,7 +49,7 @@
 # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 
-
+from typing import Optional, List
 import copy, logging
 from datetime import datetime, timezone
 
@@ -80,7 +80,7 @@
 parseXml = lambda unicode_data: generic.parseXml(unicode_data.encode('utf-8'))
 ITEMS_SEQ_NAME = 'node_{node_id}_seq'
 PEP_COL_NAME = 'pep'
-CURRENT_VERSION = '9'
+CURRENT_VERSION = '10'
 # retrieve the maximum integer item id + 1
 NEXT_ITEM_ID_QUERY = r"SELECT COALESCE(max(item::integer)+1,1) as val from items where node_id={node_id} and item ~ E'^\\d+$'"
 
@@ -406,8 +406,15 @@
         rows = cursor.fetchall()
         return [tuple(r) for r in rows]
 
-    def getSubscriptions(self, entity, nodeIdentifier=None, pep=False, recipient=None):
-        """retrieve subscriptions of an entity
+    def getSubscriptions(
+        self,
+        entity: jid.JID,
+        nodeIdentifier: Optional[str] = None,
+        public: Optional[bool] = None,
+        pep: bool = False,
+        recipient: Optional[jid.JID]=None
+    ) -> List[Subscription]:
+        """Retrieve local subscriptions of an entity
 
         @param entity(jid.JID): entity to check
         @param nodeIdentifier(unicode, None): node identifier
@@ -425,25 +432,136 @@
                 subscriptions.append(subscription)
             return subscriptions
 
-        query = ["""SELECT node,
+        query = ["""SELECT nodes.node,
                            jid,
                            resource,
                            state
                     FROM entities
                     NATURAL JOIN subscriptions
-                    NATURAL JOIN nodes
-                    WHERE jid=%s"""]
-
+                    LEFT JOIN nodes ON nodes.node_id=subscriptions.node_id
+                    WHERE jid=%s AND subscriptions.node_id IS NOT NULL"""]
         args = [entity.userhost()]
 
+        if public is not None:
+            query.append("AND subscriptions.public=%s")
+            args.append(public)
+
         if nodeIdentifier is not None:
-            query.append("AND node=%s")
+            query.append("AND nodes.node=%s")
             args.append(nodeIdentifier)
 
         d = self.dbpool.runQuery(*withPEP(' '.join(query), args, pep, recipient))
         d.addCallback(toSubscriptions)
         return d
 
+    async def getAllSubscriptions(
+        self,
+        entity: jid.JID,
+        public: Optional[bool] = None
+    ):
+        query = """SELECT  subscription_id::text as id,
+                           node,
+                           pep,
+                           ext_service,
+                           ext_node,
+                           state
+                    FROM entities
+                    NATURAL JOIN subscriptions
+                    LEFT JOIN nodes ON nodes.node_id=subscriptions.node_id
+                    WHERE jid=%s"""
+        args = [entity.userhost()]
+        if public is not None:
+            query += "AND public=%s"
+            args.append(public)
+        rows = await self.dbpool.runQuery(query, args)
+        return [r._asdict() for r in rows]
+
+    def addExternalSubscription(
+        self,
+        entity: jid.JID,
+        service: jid.JID,
+        node: str,
+        state: str,
+        public: bool = False
+    ) -> defer.Deferred:
+        """Store a subscription to an external node
+
+        @param entity: entity being subscribed
+        @param service: pubsub service hosting the node
+        @param node: pubsub node being subscribed to
+        @param state: state of the subscription
+        @param public: True if the subscription is publicly visible
+        """
+        return self.dbpool.runInteraction(
+            self._addExternalSubscription,
+            entity, service, node, state, public
+        )
+
+    def _addExternalSubscription(
+        self,
+        cursor,
+        entity: jid.JID,
+        service: jid.JID,
+        node: str,
+        state: str,
+        public: bool
+    ) -> None:
+
+        try:
+            cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""",
+                           (entity.userhost(),))
+        except cursor._pool.dbapi.IntegrityError:
+            cursor.connection.rollback()
+
+        cursor.execute("""INSERT INTO subscriptions
+                          (ext_service, ext_node, entity_id, state, public)
+                          SELECT %s, %s, entity_id, %s, %s FROM
+                          (SELECT entity_id FROM entities
+                                            WHERE jid=%s) AS ent_id
+                       ON CONFLICT(entity_id, ext_service, ext_node) DO UPDATE SET public=EXCLUDED.public""",
+                       (service.full(),
+                        node,
+                        state,
+                        public,
+                        entity.userhost()
+                        ))
+
+    def removeExternalSubscription(
+        self,
+        entity: jid.JID,
+        service: jid.JID,
+        node: str,
+    ) -> defer.Deferred:
+        """Remove a subscription from an external node
+
+        @param entity: entity being unsubscribed
+        @param service: pubsub service hosting the node
+        @param node: pubsub node being unsubscribed to
+        """
+        return self.dbpool.runInteraction(
+            self._removeExternalSubscription,
+            entity, service, node
+        )
+
+    def _removeExternalSubscription(
+        self,
+        cursor,
+        entity: jid.JID,
+        service: jid.JID,
+        node: str,
+    ) -> None:
+        cursor.execute("""DELETE FROM subscriptions WHERE
+                          ext_service=%s AND
+                          ext_node=%s AND
+                          entity_id=(SELECT entity_id FROM entities
+                                                      WHERE jid=%s)
+                          """,
+                       (service.full(),
+                        node,
+                        entity.userhost()))
+        if cursor.rowcount != 1:
+            raise error.NotSubscribed()
+
     def getDefaultConfiguration(self, nodeType):
         return self.defaultConfig[nodeType].copy()
 
@@ -683,9 +801,9 @@
         resource = subscriber.resource or ''
 
         cursor.execute("""SELECT state FROM subscriptions
-                          NATURAL JOIN nodes
+                          LEFT JOIN nodes ON nodes.node_id=subscriptions.node_id
                           NATURAL JOIN entities
-                          WHERE node_id=%s AND jid=%s AND resource=%s""",
+                          WHERE subscriptions.node_id=%s AND jid=%s AND resource=%s""",
                        (self.nodeDbId,
                         userhost,
                         resource))
@@ -696,25 +814,38 @@
         else:
             return Subscription(self.nodeIdentifier, subscriber, row[0])
 
-    def getSubscriptions(self, state=None):
-        return self.dbpool.runInteraction(self._getSubscriptions, state)
+    def getSubscriptions(
+        self,
+        state: Optional[str]=None,
+        public: Optional[bool] = None
+    ) -> List[Subscription]:
+        return self.dbpool.runInteraction(self._getSubscriptions, state, public)
 
-    def _getSubscriptions(self, cursor, state):
+    def _getSubscriptions(
+        self,
+        cursor,
+        state: Optional[str],
+        public: Optional[bool] = None,
+    ) -> List[Subscription]:
         self._checkNodeExists(cursor)
 
-        query = """SELECT node, jid, resource, state,
+        query = ["""SELECT subscription_id::text, nodes.node, jid, resource, state,
                           subscription_type, subscription_depth
-                   FROM subscriptions
-                   NATURAL JOIN nodes
-                   NATURAL JOIN entities
-                   WHERE node_id=%s"""
+                    FROM subscriptions
+                    LEFT JOIN nodes ON nodes.node_id=subscriptions.node_id
+                    NATURAL JOIN entities
+                    WHERE subscriptions.node_id=%s"""]
         values = [self.nodeDbId]
 
         if state:
-            query += " AND state=%s"
+            query.append("AND state=%s")
             values.append(state)
 
-        cursor.execute(query, values)
+        if public is not None:
+            query.append("AND public=%s")
+            values.append(public)
+
+        cursor.execute(" ".join(query), values)
         rows = cursor.fetchall()
 
         subscriptions = []
@@ -727,8 +858,9 @@
             if row.subscription_depth:
                 options['pubsub#subscription_depth'] = row.subscription_depth;
 
-            subscriptions.append(Subscription(row.node, subscriber,
-                                              row.state, options))
+            subscription = Subscription(row.node, subscriber, row.state, options)
+            subscription.id = row.subscription_id
+            subscriptions.append(subscription)
 
         return subscriptions
 
@@ -744,6 +876,7 @@
 
         subscription_type = config.get('pubsub#subscription_type')
         subscription_depth = config.get('pubsub#subscription_depth')
+        public = config.get("public", False)
 
         try:
             cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""",
@@ -751,20 +884,31 @@
         except cursor._pool.dbapi.IntegrityError:
             cursor.connection.rollback()
 
-        try:
-            cursor.execute("""INSERT INTO subscriptions
-                              (node_id, entity_id, resource, state,
-                               subscription_type, subscription_depth)
-                              SELECT %s, entity_id, %s, %s, %s, %s FROM
-                              (SELECT entity_id FROM entities
-                                                WHERE jid=%s) AS ent_id""",
-                           (self.nodeDbId,
-                            resource,
-                            state,
-                            subscription_type,
-                            subscription_depth,
-                            userhost))
-        except cursor._pool.dbapi.IntegrityError:
+        # the RETURNING trick to detect INSERT vs UPDATE comes from
+        # https://stackoverflow.com/a/47001830/4188764 thanks!
+        cursor.execute("""INSERT INTO subscriptions
+                          (node_id, entity_id, resource, state,
+                           subscription_type, subscription_depth, public)
+                          SELECT %s, entity_id, %s, %s, %s, %s, %s FROM
+                          (SELECT entity_id FROM entities
+                                            WHERE jid=%s) AS ent_id
+                       ON CONFLICT (entity_id, node_id, resource) DO UPDATE SET public=EXCLUDED.public
+                       RETURNING (xmax = 0) AS inserted""",
+                       (self.nodeDbId,
+                        resource,
+                        state,
+                        subscription_type,
+                        subscription_depth,
+                        public,
+                        userhost))
+
+        rows = cursor.fetchone()
+        if not rows.inserted:
+            # this was an update, the subscription was already existing
+
+            # we have to explicitly commit, otherwise the exception raised rollbacks the
+            # transation
+            cursor.connection.commit()
             raise error.SubscriptionExists()
 
     def removeSubscription(self, subscriber):