# HG changeset patch # User Goffi # Date 1503222964 -7200 # Node ID b49f75a26156390fc8e2efe64abf7e8872140284 # Parent 567e486bce24e861afd4c3dd69f962cee3216bc9 backend, pgsql: implemented subscriptionsGet and subscriptionsSet diff -r 567e486bce24 -r b49f75a26156 sat_pubsub/backend.py --- a/sat_pubsub/backend.py Sun Aug 20 11:16:51 2017 +0200 +++ b/sat_pubsub/backend.py Sun Aug 20 11:56:04 2017 +0200 @@ -611,6 +611,57 @@ return d + def getSubscriptionsOwner(self, nodeIdentifier, requestor, pep, recipient): + d = self.storage.getNode(nodeIdentifier, pep, recipient) + d.addCallback(_getAffiliation, requestor) + d.addCallback(self._doGetSubscriptionsOwner) + return d + + + def _doGetSubscriptionsOwner(self, result): + node, affiliation = result + + if affiliation != 'owner': + raise error.Forbidden() + return node.getSubscriptions() + + + def setSubscriptionsOwner(self, nodeIdentifier, requestor, subscriptions, pep, recipient): + d = self.storage.getNode(nodeIdentifier, pep, recipient) + d.addCallback(_getAffiliation, requestor) + d.addCallback(self._doSetSubscriptionsOwner, requestor, subscriptions) + return d + + def unwrapFirstError(self, failure): + failure.trap(defer.FirstError) + return failure.value.subFailure + + def _doSetSubscriptionsOwner(self, result, requestor, subscriptions): + # Check that requestor is allowed to set subscriptions, and delete entities + # with "none" subscription + + # TODO: return error with failed subscriptions in case of failure + node, requestor_affiliation = result + + if requestor_affiliation != 'owner': + raise error.Forbidden() + + d_list = [] + + for subscription in subscriptions.copy(): + if subscription.state == 'none': + subscriptions.remove(subscription) + d_list.append(node.removeSubscription(subscription.subscriber)) + + if subscriptions: + d_list.append(node.setSubscriptions(subscriptions)) + + d = defer.gatherResults(d_list, consumeErrors=True) + d.addCallback(lambda _: None) + d.addErrback(self.unwrapFirstError) + return d + + def getItems(self, nodeIdentifier, requestor, recipient, maxItems=None, itemIdentifiers=None, ext_data=None): d = self.getItemsData(nodeIdentifier, requestor, recipient, maxItems, itemIdentifiers, ext_data) @@ -1351,6 +1402,27 @@ return d.addErrback(self._mapErrors) + def subscriptionsGet(self, request): + """Retrieve subscriptions for owner (cf. XEP-0060 ยง8.8.1) + + retrieve all affiliations for a node + """ + d = self.backend.getSubscriptionsOwner(request.nodeIdentifier, + request.sender, + self._isPep(request), + request.recipient) + return d.addErrback(self._mapErrors) + + + def subscriptionsSet(self, request): + d = self.backend.setSubscriptionsOwner(request.nodeIdentifier, + request.sender, + request.subscriptions, + self._isPep(request), + request.recipient) + return d.addErrback(self._mapErrors) + + def items(self, request): ext_data = {} if const.FLAG_ENABLE_RSM and request.rsm is not None: diff -r 567e486bce24 -r b49f75a26156 sat_pubsub/pgsql_storage.py --- a/sat_pubsub/pgsql_storage.py Sun Aug 20 11:16:51 2017 +0200 +++ b/sat_pubsub/pgsql_storage.py Sun Aug 20 11:56:04 2017 +0200 @@ -155,7 +155,6 @@ """ return self.dbpool.runInteraction(self._getNodeById, nodeDbId) - def _getNodeById(self, cursor, nodeDbId): cursor.execute("""SELECT node_id, node, @@ -175,7 +174,6 @@ def getNode(self, nodeIdentifier, pep, recipient=None): return self.dbpool.runInteraction(self._getNode, nodeIdentifier, pep, recipient) - def _getNode(self, cursor, nodeIdentifier, pep, recipient): cursor.execute(*withPEP("""SELECT node_id, node, @@ -198,12 +196,10 @@ d.addCallback(lambda results: [r[0] for r in results]) return d - def createNode(self, nodeIdentifier, owner, config, pep, recipient=None): return self.dbpool.runInteraction(self._createNode, nodeIdentifier, owner, config, pep, recipient) - def _createNode(self, cursor, nodeIdentifier, owner, config, pep, recipient): if config['pubsub#node_type'] != 'leaf': raise error.NoCollections() @@ -287,7 +283,6 @@ def deleteNode(self, nodeIdentifier, pep, recipient=None): return self.dbpool.runInteraction(self._deleteNode, nodeIdentifier, pep, recipient) - def _deleteNode(self, cursor, nodeIdentifier, pep, recipient): cursor.execute(*withPEP("""DELETE FROM nodes WHERE node=%s""", (nodeIdentifier,), pep, recipient)) @@ -313,7 +308,6 @@ rows = cursor.fetchall() return [tuple(r) for r in rows] - def getSubscriptions(self, entity, pep, recipient=None): def toSubscriptions(rows): subscriptions = [] @@ -333,11 +327,9 @@ d.addCallback(toSubscriptions) return d - def getDefaultConfiguration(self, nodeType): return self.defaultConfig[nodeType] - def formatLastItems(self, result): last_items = [] for pep_jid_s, node, data, item_access_model in result: @@ -346,7 +338,6 @@ last_items.append((pep_jid, node, item, item_access_model)) return last_items - def getLastItems(self, entities, nodes, node_accesses, item_accesses, pep): """get last item for several nodes and entities in a single request""" if not entities or not nodes or not node_accesses or not item_accesses: @@ -380,14 +371,12 @@ self.nodeIdentifier = nodeIdentifier self._config = config - def _checkNodeExists(self, cursor): cursor.execute("""SELECT 1 as exist FROM nodes WHERE node_id=%s""", (self.nodeDbId,)) if not cursor.fetchone(): raise error.NodeNotFound() - def getType(self): return self.nodeType @@ -396,11 +385,9 @@ d.addCallback(lambda rows: [jid.JID(r[0]) for r in rows]) return d - def getConfiguration(self): return self._config - def setConfiguration(self, options): config = copy.copy(self._config) @@ -412,7 +399,6 @@ d.addCallback(self._setCachedConfiguration, config) return d - def _setConfiguration(self, cursor, config): self._checkNodeExists(cursor) cursor.execute("""UPDATE nodes SET persist_items=%s, @@ -428,21 +414,17 @@ config[const.OPT_PUBLISH_MODEL], self.nodeDbId)) - def _setCachedConfiguration(self, void, config): self._config = config - def getMetaData(self): config = copy.copy(self._config) config["pubsub#node_type"] = self.nodeType return config - def getAffiliation(self, entity): return self.dbpool.runInteraction(self._getAffiliation, entity) - def _getAffiliation(self, cursor, entity): self._checkNodeExists(cursor) cursor.execute("""SELECT affiliation FROM affiliations @@ -457,15 +439,12 @@ except TypeError: return None - def getAccessModel(self): return self._config[const.OPT_ACCESS_MODEL] - def getSubscription(self, subscriber): return self.dbpool.runInteraction(self._getSubscription, subscriber) - def _getSubscription(self, cursor, subscriber): self._checkNodeExists(cursor) @@ -486,11 +465,9 @@ else: return Subscription(self.nodeIdentifier, subscriber, row[0]) - def getSubscriptions(self, state=None): return self.dbpool.runInteraction(self._getSubscriptions, state) - def _getSubscriptions(self, cursor, state): self._checkNodeExists(cursor) @@ -524,12 +501,10 @@ return subscriptions - def addSubscription(self, subscriber, state, config): return self.dbpool.runInteraction(self._addSubscription, subscriber, state, config) - def _addSubscription(self, cursor, subscriber, state, config): self._checkNodeExists(cursor) @@ -561,12 +536,10 @@ except cursor._pool.dbapi.IntegrityError: raise error.SubscriptionExists() - def removeSubscription(self, subscriber): return self.dbpool.runInteraction(self._removeSubscription, subscriber) - def _removeSubscription(self, cursor, subscriber): self._checkNodeExists(cursor) @@ -586,11 +559,28 @@ return None + def setSubscriptions(self, subscriptions): + return self.dbpool.runInteraction(self._setSubscriptions, subscriptions) + + def _setSubscriptions(self, cursor, subscriptions): + self._checkNodeExists(cursor) + + entities = self.getOrCreateEntities(cursor, [s.subscriber for s in subscriptions]) + entities_map = {jid.JID(e.jid): e for e in entities} + + # then we construct values for subscriptions update according to entity_id we just got + placeholders = ','.join(len(subscriptions) * ["%s"]) + values = [] + for subscription in subscriptions: + entity_id = entities_map[subscription.subscriber].entity_id + resource = subscription.subscriber.resource or u'' + values.append((self.nodeDbId, entity_id, resource, subscription.state, None, None)) + # we use upsert so new values are inserted and existing one updated. This feature is only available for PostgreSQL >= 9.5 + cursor.execute("INSERT INTO subscriptions(node_id, entity_id, resource, state, subscription_type, subscription_depth) VALUES " + placeholders + " ON CONFLICT (entity_id, resource, node_id) DO UPDATE SET state=EXCLUDED.state", [v for v in values]) def isSubscribed(self, entity): return self.dbpool.runInteraction(self._isSubscribed, entity) - def _isSubscribed(self, cursor, entity): self._checkNodeExists(cursor) @@ -604,11 +594,9 @@ return cursor.fetchone() is not None - def getAffiliations(self): return self.dbpool.runInteraction(self._getAffiliations) - def _getAffiliations(self, cursor): self._checkNodeExists(cursor) @@ -626,7 +614,7 @@ Entities will be inserted it they don't exist @param entities_jid(list[jid.JID]): entities to get or create - @return list[record(entity_jid,jid)]]: list of entity_id and jid (as plain string) + @return list[record(entity_id,jid)]]: list of entity_id and jid (as plain string) both existing and inserted entities are returned """ # cf. http://stackoverflow.com/a/35265559 @@ -655,7 +643,6 @@ def setAffiliations(self, affiliations): return self.dbpool.runInteraction(self._setAffiliations, affiliations) - def _setAffiliations(self, cursor, affiliations): self._checkNodeExists(cursor) @@ -669,11 +656,9 @@ # we use upsert so new values are inserted and existing one updated. This feature is only available for PostgreSQL >= 9.5 cursor.execute("INSERT INTO affiliations(entity_id,affiliation,node_id) VALUES " + placeholders + " ON CONFLICT (entity_id,node_id) DO UPDATE SET affiliation=EXCLUDED.affiliation", values) - def deleteAffiliations(self, entities): return self.dbpool.runInteraction(self._deleteAffiliations, entities) - def _deleteAffiliations(self, cursor, entities): """delete affiliations and subscriptions for this entity""" self._checkNodeExists(cursor) @@ -1078,7 +1063,6 @@ def __init__(self, dbpool): self.dbpool = dbpool - def _countCallbacks(self, cursor, service, nodeIdentifier): """ Count number of callbacks registered for a node. @@ -1090,7 +1074,6 @@ results = cursor.fetchall() return results[0][0] - def addCallback(self, service, nodeIdentifier, callback): def interaction(cursor): cursor.execute("""SELECT 1 as bool FROM callbacks @@ -1110,7 +1093,6 @@ return self.dbpool.runInteraction(interaction) - def removeCallback(self, service, nodeIdentifier, callback): def interaction(cursor): cursor.execute("""DELETE FROM callbacks @@ -1142,7 +1124,6 @@ return self.dbpool.runInteraction(interaction) - def hasCallbacks(self, service, nodeIdentifier): def interaction(cursor): return bool(self._countCallbacks(cursor, service, nodeIdentifier))