changeset 341:b49f75a26156

backend, pgsql: implemented subscriptionsGet and subscriptionsSet
author Goffi <goffi@goffi.org>
date Sun, 20 Aug 2017 11:56:04 +0200
parents 567e486bce24
children 28c9579901d3
files sat_pubsub/backend.py sat_pubsub/pgsql_storage.py
diffstat 2 files changed, 91 insertions(+), 38 deletions(-) [+]
line wrap: on
line diff
--- a/sat_pubsub/backend.py	Sun Aug 20 11:16:51 2017 +0200
+++ b/sat_pubsub/backend.py	Sun Aug 20 11:56:04 2017 +0200
@@ -611,6 +611,57 @@
         return d
 
 
+    def getSubscriptionsOwner(self, nodeIdentifier, requestor, pep, recipient):
+        d = self.storage.getNode(nodeIdentifier, pep, recipient)
+        d.addCallback(_getAffiliation, requestor)
+        d.addCallback(self._doGetSubscriptionsOwner)
+        return d
+
+
+    def _doGetSubscriptionsOwner(self, result):
+        node, affiliation = result
+
+        if affiliation != 'owner':
+            raise error.Forbidden()
+        return node.getSubscriptions()
+
+
+    def setSubscriptionsOwner(self, nodeIdentifier, requestor, subscriptions, pep, recipient):
+        d = self.storage.getNode(nodeIdentifier, pep, recipient)
+        d.addCallback(_getAffiliation, requestor)
+        d.addCallback(self._doSetSubscriptionsOwner, requestor, subscriptions)
+        return d
+
+    def unwrapFirstError(self, failure):
+        failure.trap(defer.FirstError)
+        return failure.value.subFailure
+
+    def _doSetSubscriptionsOwner(self, result, requestor, subscriptions):
+        # Check that requestor is allowed to set subscriptions, and delete entities
+        # with "none" subscription
+
+        # TODO: return error with failed subscriptions in case of failure
+        node, requestor_affiliation = result
+
+        if requestor_affiliation != 'owner':
+            raise error.Forbidden()
+
+        d_list = []
+
+        for subscription in subscriptions.copy():
+            if subscription.state == 'none':
+                subscriptions.remove(subscription)
+                d_list.append(node.removeSubscription(subscription.subscriber))
+
+        if subscriptions:
+            d_list.append(node.setSubscriptions(subscriptions))
+
+        d = defer.gatherResults(d_list, consumeErrors=True)
+        d.addCallback(lambda _: None)
+        d.addErrback(self.unwrapFirstError)
+        return d
+
+
     def getItems(self, nodeIdentifier, requestor, recipient, maxItems=None,
                        itemIdentifiers=None, ext_data=None):
         d = self.getItemsData(nodeIdentifier, requestor, recipient, maxItems, itemIdentifiers, ext_data)
@@ -1351,6 +1402,27 @@
         return d.addErrback(self._mapErrors)
 
 
+    def subscriptionsGet(self, request):
+        """Retrieve subscriptions for owner (cf. XEP-0060 ยง8.8.1)
+
+        retrieve all affiliations for a node
+        """
+        d = self.backend.getSubscriptionsOwner(request.nodeIdentifier,
+                                               request.sender,
+                                               self._isPep(request),
+                                               request.recipient)
+        return d.addErrback(self._mapErrors)
+
+
+    def subscriptionsSet(self, request):
+        d = self.backend.setSubscriptionsOwner(request.nodeIdentifier,
+                                              request.sender,
+                                              request.subscriptions,
+                                              self._isPep(request),
+                                              request.recipient)
+        return d.addErrback(self._mapErrors)
+
+
     def items(self, request):
         ext_data = {}
         if const.FLAG_ENABLE_RSM and request.rsm is not None:
--- a/sat_pubsub/pgsql_storage.py	Sun Aug 20 11:16:51 2017 +0200
+++ b/sat_pubsub/pgsql_storage.py	Sun Aug 20 11:56:04 2017 +0200
@@ -155,7 +155,6 @@
         """
         return self.dbpool.runInteraction(self._getNodeById, nodeDbId)
 
-
     def _getNodeById(self, cursor, nodeDbId):
         cursor.execute("""SELECT node_id,
                                  node,
@@ -175,7 +174,6 @@
     def getNode(self, nodeIdentifier, pep, recipient=None):
         return self.dbpool.runInteraction(self._getNode, nodeIdentifier, pep, recipient)
 
-
     def _getNode(self, cursor, nodeIdentifier, pep, recipient):
         cursor.execute(*withPEP("""SELECT node_id,
                                           node,
@@ -198,12 +196,10 @@
         d.addCallback(lambda results: [r[0] for r in results])
         return d
 
-
     def createNode(self, nodeIdentifier, owner, config, pep, recipient=None):
         return self.dbpool.runInteraction(self._createNode, nodeIdentifier,
                                            owner, config, pep, recipient)
 
-
     def _createNode(self, cursor, nodeIdentifier, owner, config, pep, recipient):
         if config['pubsub#node_type'] != 'leaf':
             raise error.NoCollections()
@@ -287,7 +283,6 @@
     def deleteNode(self, nodeIdentifier, pep, recipient=None):
         return self.dbpool.runInteraction(self._deleteNode, nodeIdentifier, pep, recipient)
 
-
     def _deleteNode(self, cursor, nodeIdentifier, pep, recipient):
         cursor.execute(*withPEP("""DELETE FROM nodes WHERE node=%s""",
                                 (nodeIdentifier,), pep, recipient))
@@ -313,7 +308,6 @@
         rows = cursor.fetchall()
         return [tuple(r) for r in rows]
 
-
     def getSubscriptions(self, entity, pep, recipient=None):
         def toSubscriptions(rows):
             subscriptions = []
@@ -333,11 +327,9 @@
         d.addCallback(toSubscriptions)
         return d
 
-
     def getDefaultConfiguration(self, nodeType):
         return self.defaultConfig[nodeType]
 
-
     def formatLastItems(self, result):
         last_items = []
         for pep_jid_s, node, data, item_access_model in result:
@@ -346,7 +338,6 @@
             last_items.append((pep_jid, node, item, item_access_model))
         return last_items
 
-
     def getLastItems(self, entities, nodes, node_accesses, item_accesses, pep):
         """get last item for several nodes and entities in a single request"""
         if not entities or not nodes or not node_accesses or not item_accesses:
@@ -380,14 +371,12 @@
         self.nodeIdentifier = nodeIdentifier
         self._config = config
 
-
     def _checkNodeExists(self, cursor):
         cursor.execute("""SELECT 1 as exist FROM nodes WHERE node_id=%s""",
                        (self.nodeDbId,))
         if not cursor.fetchone():
             raise error.NodeNotFound()
 
-
     def getType(self):
         return self.nodeType
 
@@ -396,11 +385,9 @@
         d.addCallback(lambda rows: [jid.JID(r[0]) for r in rows])
         return d
 
-
     def getConfiguration(self):
         return self._config
 
-
     def setConfiguration(self, options):
         config = copy.copy(self._config)
 
@@ -412,7 +399,6 @@
         d.addCallback(self._setCachedConfiguration, config)
         return d
 
-
     def _setConfiguration(self, cursor, config):
         self._checkNodeExists(cursor)
         cursor.execute("""UPDATE nodes SET persist_items=%s,
@@ -428,21 +414,17 @@
                         config[const.OPT_PUBLISH_MODEL],
                         self.nodeDbId))
 
-
     def _setCachedConfiguration(self, void, config):
         self._config = config
 
-
     def getMetaData(self):
         config = copy.copy(self._config)
         config["pubsub#node_type"] = self.nodeType
         return config
 
-
     def getAffiliation(self, entity):
         return self.dbpool.runInteraction(self._getAffiliation, entity)
 
-
     def _getAffiliation(self, cursor, entity):
         self._checkNodeExists(cursor)
         cursor.execute("""SELECT affiliation FROM affiliations
@@ -457,15 +439,12 @@
         except TypeError:
             return None
 
-
     def getAccessModel(self):
         return self._config[const.OPT_ACCESS_MODEL]
 
-
     def getSubscription(self, subscriber):
         return self.dbpool.runInteraction(self._getSubscription, subscriber)
 
-
     def _getSubscription(self, cursor, subscriber):
         self._checkNodeExists(cursor)
 
@@ -486,11 +465,9 @@
         else:
             return Subscription(self.nodeIdentifier, subscriber, row[0])
 
-
     def getSubscriptions(self, state=None):
         return self.dbpool.runInteraction(self._getSubscriptions, state)
 
-
     def _getSubscriptions(self, cursor, state):
         self._checkNodeExists(cursor)
 
@@ -524,12 +501,10 @@
 
         return subscriptions
 
-
     def addSubscription(self, subscriber, state, config):
         return self.dbpool.runInteraction(self._addSubscription, subscriber,
                                           state, config)
 
-
     def _addSubscription(self, cursor, subscriber, state, config):
         self._checkNodeExists(cursor)
 
@@ -561,12 +536,10 @@
         except cursor._pool.dbapi.IntegrityError:
             raise error.SubscriptionExists()
 
-
     def removeSubscription(self, subscriber):
         return self.dbpool.runInteraction(self._removeSubscription,
                                            subscriber)
 
-
     def _removeSubscription(self, cursor, subscriber):
         self._checkNodeExists(cursor)
 
@@ -586,11 +559,28 @@
 
         return None
 
+    def setSubscriptions(self, subscriptions):
+        return self.dbpool.runInteraction(self._setSubscriptions, subscriptions)
+
+    def _setSubscriptions(self, cursor, subscriptions):
+        self._checkNodeExists(cursor)
+
+        entities = self.getOrCreateEntities(cursor, [s.subscriber for s in subscriptions])
+        entities_map = {jid.JID(e.jid): e for e in entities}
+
+        # then we construct values for subscriptions update according to entity_id we just got
+        placeholders = ','.join(len(subscriptions) * ["%s"])
+        values = []
+        for subscription in subscriptions:
+            entity_id = entities_map[subscription.subscriber].entity_id
+            resource = subscription.subscriber.resource or u''
+            values.append((self.nodeDbId, entity_id, resource, subscription.state, None, None))
+        # we use upsert so new values are inserted and existing one updated. This feature is only available for PostgreSQL >= 9.5
+        cursor.execute("INSERT INTO subscriptions(node_id, entity_id, resource, state, subscription_type, subscription_depth) VALUES " + placeholders + " ON CONFLICT (entity_id, resource, node_id) DO UPDATE SET state=EXCLUDED.state", [v for v in values])
 
     def isSubscribed(self, entity):
         return self.dbpool.runInteraction(self._isSubscribed, entity)
 
-
     def _isSubscribed(self, cursor, entity):
         self._checkNodeExists(cursor)
 
@@ -604,11 +594,9 @@
 
         return cursor.fetchone() is not None
 
-
     def getAffiliations(self):
         return self.dbpool.runInteraction(self._getAffiliations)
 
-
     def _getAffiliations(self, cursor):
         self._checkNodeExists(cursor)
 
@@ -626,7 +614,7 @@
 
         Entities will be inserted it they don't exist
         @param entities_jid(list[jid.JID]): entities to get or create
-        @return list[record(entity_jid,jid)]]: list of entity_id and jid (as plain string)
+        @return list[record(entity_id,jid)]]: list of entity_id and jid (as plain string)
             both existing and inserted entities are returned
         """
         # cf. http://stackoverflow.com/a/35265559
@@ -655,7 +643,6 @@
     def setAffiliations(self, affiliations):
         return self.dbpool.runInteraction(self._setAffiliations, affiliations)
 
-
     def _setAffiliations(self, cursor, affiliations):
         self._checkNodeExists(cursor)
 
@@ -669,11 +656,9 @@
         # we use upsert so new values are inserted and existing one updated. This feature is only available for PostgreSQL >= 9.5
         cursor.execute("INSERT INTO affiliations(entity_id,affiliation,node_id) VALUES " + placeholders + " ON CONFLICT  (entity_id,node_id) DO UPDATE SET affiliation=EXCLUDED.affiliation", values)
 
-
     def deleteAffiliations(self, entities):
         return self.dbpool.runInteraction(self._deleteAffiliations, entities)
 
-
     def _deleteAffiliations(self, cursor, entities):
         """delete affiliations and subscriptions for this entity"""
         self._checkNodeExists(cursor)
@@ -1078,7 +1063,6 @@
     def __init__(self, dbpool):
         self.dbpool = dbpool
 
-
     def _countCallbacks(self, cursor, service, nodeIdentifier):
         """
         Count number of callbacks registered for a node.
@@ -1090,7 +1074,6 @@
         results = cursor.fetchall()
         return results[0][0]
 
-
     def addCallback(self, service, nodeIdentifier, callback):
         def interaction(cursor):
             cursor.execute("""SELECT 1 as bool FROM callbacks
@@ -1110,7 +1093,6 @@
 
         return self.dbpool.runInteraction(interaction)
 
-
     def removeCallback(self, service, nodeIdentifier, callback):
         def interaction(cursor):
             cursor.execute("""DELETE FROM callbacks
@@ -1142,7 +1124,6 @@
 
         return self.dbpool.runInteraction(interaction)
 
-
     def hasCallbacks(self, service, nodeIdentifier):
         def interaction(cursor):
             return bool(self._countCallbacks(cursor, service, nodeIdentifier))