diff sat_pubsub/pgsql_storage.py @ 357:1167e48e5f52

handle single node on subscriptions request
author Goffi <goffi@goffi.org>
date Fri, 08 Sep 2017 08:02:05 +0200
parents 95c83899b5e9
children 8bd8be6815ab
line wrap: on
line diff
--- a/sat_pubsub/pgsql_storage.py	Fri Sep 08 08:02:05 2017 +0200
+++ b/sat_pubsub/pgsql_storage.py	Fri Sep 08 08:02:05 2017 +0200
@@ -346,22 +346,41 @@
         rows = cursor.fetchall()
         return [tuple(r) for r in rows]
 
-    def getSubscriptions(self, entity, pep, recipient=None):
+    def getSubscriptions(self, entity, nodeIdentifier=None, pep=False, recipient=None):
+        """retrieve subscriptions of an entity
+
+        @param entity(jid.JID): entity to check
+        @param nodeIdentifier(unicode, None): node identifier
+            None to retrieve all subscriptions
+        @param pep: True if we are in PEP mode
+        @param recipient: jid of the recipient
+        """
+
         def toSubscriptions(rows):
             subscriptions = []
             for row in rows:
-                subscriber = jid.internJID('%s/%s' % (row[1],
-                                                      row[2]))
-                subscription = Subscription(row[0], subscriber, row[3])
+                subscriber = jid.internJID('%s/%s' % (row.jid,
+                                                      row.resource))
+                subscription = Subscription(row.node, subscriber, row.state)
                 subscriptions.append(subscription)
             return subscriptions
 
-        d = self.dbpool.runQuery("""SELECT node, jid, resource, state
-                                     FROM entities
-                                     NATURAL JOIN subscriptions
-                                     NATURAL JOIN nodes
-                                     WHERE jid=%s AND nodes.pep=%s""",
-                                  (entity.userhost(), recipient.userhost() if pep else None))
+        query = ["""SELECT node,
+                           jid,
+                           resource,
+                           state
+                    FROM entities
+                    NATURAL JOIN subscriptions
+                    NATURAL JOIN nodes
+                    WHERE jid=%s"""]
+
+        args = [entity.userhost()]
+
+        if nodeIdentifier is not None:
+            query.append("AND node=%s")
+            args.append(nodeIdentifier)
+
+        d = self.dbpool.runQuery(*withPEP(' '.join(query), args, pep, recipient))
         d.addCallback(toSubscriptions)
         return d
 
@@ -545,16 +564,16 @@
 
         subscriptions = []
         for row in rows:
-            subscriber = jid.JID(u'%s/%s' % (row[1], row[2]))
+            subscriber = jid.JID(u'%s/%s' % (row.jid, row.resource))
 
             options = {}
-            if row[4]:
-                options['pubsub#subscription_type'] = row[4];
-            if row[5]:
-                options['pubsub#subscription_depth'] = row[5];
+            if row.subscription_type:
+                options['pubsub#subscription_type'] = row.subscription_type;
+            if row.subscription_depth:
+                options['pubsub#subscription_depth'] = row.subscription_depth;
 
-            subscriptions.append(Subscription(row[0], subscriber,
-                                              row[3], options))
+            subscriptions.append(Subscription(row.node, subscriber,
+                                              row.state, options))
 
         return subscriptions