changeset 357:1167e48e5f52

handle single node on subscriptions request
author Goffi <goffi@goffi.org>
date Fri, 08 Sep 2017 08:02:05 +0200
parents 95c83899b5e9
children 9e40fc16f4df
files sat_pubsub/backend.py sat_pubsub/pgsql_storage.py
diffstat 2 files changed, 46 insertions(+), 20 deletions(-) [+]
line wrap: on
line diff
--- a/sat_pubsub/backend.py	Fri Sep 08 08:02:05 2017 +0200
+++ b/sat_pubsub/backend.py	Fri Sep 08 08:02:05 2017 +0200
@@ -530,8 +530,14 @@
         d.addCallback(lambda node: node.removeSubscription(subscriber))
         return d
 
-    def getSubscriptions(self, requestor, pep, recipient):
-        return self.storage.getSubscriptions(requestor, pep, recipient)
+    def getSubscriptions(self, requestor, nodeIdentifier, pep, recipient):
+        """retrieve subscriptions of an entity
+
+        @param requestor(jid.JID): entity who want to check subscriptions
+        @param nodeIdentifier(unicode, None): identifier of the node
+            node to get all subscriptions of a service
+        """
+        return self.storage.getSubscriptions(requestor, nodeIdentifier, pep, recipient)
 
     def supportsAutoCreate(self):
         return True
@@ -616,7 +622,7 @@
     def setNodeSchema(self, nodeIdentifier, schema, requestor, pep, recipient):
         """set or remove Schema of a node
 
-        @param NodeIdentifier(unicode): identifier of the pubusb node
+        @param nodeIdentifier(unicode): identifier of the pubusb node
         @param schema(domish.Element, None): schema to set
             None to remove schema
         """
@@ -1556,6 +1562,7 @@
 
     def subscriptions(self, request):
         d = self.backend.getSubscriptions(request.sender,
+                                          request.nodeIdentifier,
                                           self._isPep(request),
                                           request.recipient)
         return d.addErrback(self._mapErrors)
--- a/sat_pubsub/pgsql_storage.py	Fri Sep 08 08:02:05 2017 +0200
+++ b/sat_pubsub/pgsql_storage.py	Fri Sep 08 08:02:05 2017 +0200
@@ -346,22 +346,41 @@
         rows = cursor.fetchall()
         return [tuple(r) for r in rows]
 
-    def getSubscriptions(self, entity, pep, recipient=None):
+    def getSubscriptions(self, entity, nodeIdentifier=None, pep=False, recipient=None):
+        """retrieve subscriptions of an entity
+
+        @param entity(jid.JID): entity to check
+        @param nodeIdentifier(unicode, None): node identifier
+            None to retrieve all subscriptions
+        @param pep: True if we are in PEP mode
+        @param recipient: jid of the recipient
+        """
+
         def toSubscriptions(rows):
             subscriptions = []
             for row in rows:
-                subscriber = jid.internJID('%s/%s' % (row[1],
-                                                      row[2]))
-                subscription = Subscription(row[0], subscriber, row[3])
+                subscriber = jid.internJID('%s/%s' % (row.jid,
+                                                      row.resource))
+                subscription = Subscription(row.node, subscriber, row.state)
                 subscriptions.append(subscription)
             return subscriptions
 
-        d = self.dbpool.runQuery("""SELECT node, jid, resource, state
-                                     FROM entities
-                                     NATURAL JOIN subscriptions
-                                     NATURAL JOIN nodes
-                                     WHERE jid=%s AND nodes.pep=%s""",
-                                  (entity.userhost(), recipient.userhost() if pep else None))
+        query = ["""SELECT node,
+                           jid,
+                           resource,
+                           state
+                    FROM entities
+                    NATURAL JOIN subscriptions
+                    NATURAL JOIN nodes
+                    WHERE jid=%s"""]
+
+        args = [entity.userhost()]
+
+        if nodeIdentifier is not None:
+            query.append("AND node=%s")
+            args.append(nodeIdentifier)
+
+        d = self.dbpool.runQuery(*withPEP(' '.join(query), args, pep, recipient))
         d.addCallback(toSubscriptions)
         return d
 
@@ -545,16 +564,16 @@
 
         subscriptions = []
         for row in rows:
-            subscriber = jid.JID(u'%s/%s' % (row[1], row[2]))
+            subscriber = jid.JID(u'%s/%s' % (row.jid, row.resource))
 
             options = {}
-            if row[4]:
-                options['pubsub#subscription_type'] = row[4];
-            if row[5]:
-                options['pubsub#subscription_depth'] = row[5];
+            if row.subscription_type:
+                options['pubsub#subscription_type'] = row.subscription_type;
+            if row.subscription_depth:
+                options['pubsub#subscription_depth'] = row.subscription_depth;
 
-            subscriptions.append(Subscription(row[0], subscriber,
-                                              row[3], options))
+            subscriptions.append(Subscription(row.node, subscriber,
+                                              row.state, options))
 
         return subscriptions