changeset 331:e93a9fd329d9

affiliations XMPP handling: /!\ PostGreSQL minimal version raised to 9.5 (for upsert support) - affiliations can now be handler after node creation using XMPP: affiliationsGet, affiliationsSet (both for node owner) and affiliations (for an entity to know with which nodes it is affiliated on the service) are implemented - pgsql: getOrCreateEntities helping method has been implemented, it get entity_id and create missing entities when needed.
author Goffi <goffi@goffi.org>
date Sun, 26 Mar 2017 20:58:48 +0200
parents 82d1259b3e36
children 31cbd8b9fa7f
files INSTALL sat_pubsub/backend.py sat_pubsub/pgsql_storage.py
diffstat 3 files changed, 171 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/INSTALL	Sun Mar 26 20:52:32 2017 +0200
+++ b/INSTALL	Sun Mar 26 20:58:48 2017 +0200
@@ -14,7 +14,7 @@
 
 For the PostgreSQL backend, the following is also required:
 
-- PostgreSQL (including development files for psycopg2)
+- PostgreSQL >= 9.5 (including development files for psycopg2)
 - psycopg2
 
 
--- a/sat_pubsub/backend.py	Sun Mar 26 20:52:32 2017 +0200
+++ b/sat_pubsub/backend.py	Sun Mar 26 20:58:48 2017 +0200
@@ -555,8 +555,61 @@
         return node.setConfiguration(options)
 
 
-    def getAffiliations(self, entity):
-        return self.storage.getAffiliations(entity)
+    def getAffiliations(self, entity, nodeIdentifier, pep, recipient):
+        return self.storage.getAffiliations(entity, nodeIdentifier, pep, recipient)
+
+
+    def getAffiliationsOwner(self, nodeIdentifier, requestor, pep, recipient):
+        d = self.storage.getNode(nodeIdentifier, pep, recipient)
+        d.addCallback(_getAffiliation, requestor)
+        d.addCallback(self._doGetAffiliationsOwner)
+        return d
+
+
+    def _doGetAffiliationsOwner(self, result):
+        node, affiliation = result
+
+        if affiliation != 'owner':
+            raise error.Forbidden()
+        return node.getAffiliations()
+
+
+    def setAffiliationsOwner(self, nodeIdentifier, requestor, affiliations, pep, recipient):
+        d = self.storage.getNode(nodeIdentifier, pep, recipient)
+        d.addCallback(_getAffiliation, requestor)
+        d.addCallback(self._doSetAffiliationsOwner, requestor, affiliations)
+        return d
+
+
+    def _doSetAffiliationsOwner(self, result, requestor, affiliations):
+        # Check that requestor is allowed to set affiliations, and delete entities
+        # with "none" affiliation
+
+        # TODO: return error with failed affiliations in case of failure
+        node, requestor_affiliation = result
+
+        if requestor_affiliation != 'owner':
+            raise error.Forbidden()
+
+        # we don't allow requestor to change its own affiliation
+        requestor_bare = requestor.userhostJID()
+        if requestor_bare in affiliations and affiliations[requestor_bare] != 'owner':
+            # FIXME: it may be interesting to allow the owner to ask for ownership removal
+            #        if at least one other entity is owner for this node
+            raise error.Forbidden("You can't change your own affiliation")
+
+        to_delete = [jid_ for jid_, affiliation in affiliations.iteritems() if affiliation == 'none']
+        for jid_ in to_delete:
+            del affiliations[jid_]
+
+        if to_delete:
+            d = node.deleteAffiliations(to_delete)
+            if affiliations:
+                d.addCallback(lambda dummy: node.setAffiliations(affiliations))
+        else:
+            d = node.setAffiliations(affiliations)
+
+        return d
 
 
     def getItems(self, nodeIdentifier, requestor, recipient, maxItems=None,
@@ -1240,8 +1293,14 @@
 
 
     def affiliations(self, request):
-        d = self.backend.getAffiliations(self._isPep(request),
-                                         request.sender)
+        """Retrieve affiliation for normal entity (cf. XEP-0060 §5.7)
+
+        retrieve all node where this jid is affiliated
+        """
+        d = self.backend.getAffiliations(request.sender,
+                                         request.nodeIdentifier,
+                                         self._isPep(request),
+                                         request.recipient)
         return d.addErrback(self._mapErrors)
 
 
@@ -1276,6 +1335,26 @@
         return d.addErrback(self._mapErrors)
 
 
+    def affiliationsGet(self, request):
+        """Retrieve affiliations for owner (cf. XEP-0060 §8.9.1)
+
+        retrieve all affiliations for a node
+        """
+        d = self.backend.getAffiliationsOwner(request.nodeIdentifier,
+                                              request.sender,
+                                              self._isPep(request),
+                                              request.recipient)
+        return d.addErrback(self._mapErrors)
+
+    def affiliationsSet(self, request):
+        d = self.backend.setAffiliationsOwner(request.nodeIdentifier,
+                                              request.sender,
+                                              request.affiliations,
+                                              self._isPep(request),
+                                              request.recipient)
+        return d.addErrback(self._mapErrors)
+
+
     def items(self, request):
         ext_data = {}
         if const.FLAG_ENABLE_RSM and request.rsm is not None:
--- a/sat_pubsub/pgsql_storage.py	Sun Mar 26 20:52:32 2017 +0200
+++ b/sat_pubsub/pgsql_storage.py	Sun Mar 26 20:58:48 2017 +0200
@@ -291,16 +291,23 @@
         if cursor.rowcount != 1:
             raise error.NodeNotFound()
 
-
+    def getAffiliations(self, entity, nodeIdentifier, pep, recipient=None):
+        return self.dbpool.runInteraction(self._getAffiliations, entity, nodeIdentifier, pep, recipient)
 
-    def getAffiliations(self, entity, pep, recipient=None):
-        d = self.dbpool.runQuery(*withPEP("""SELECT node, affiliation FROM entities
-                                        NATURAL JOIN affiliations
-                                        NATURAL JOIN nodes
-                                        WHERE jid=%s""",
-                                     (entity.userhost(),), pep, recipient, 'nodes'))
-        d.addCallback(lambda results: [tuple(r) for r in results])
-        return d
+    def _getAffiliations(self, cursor, entity, nodeIdentifier, pep, recipient=None):
+        query = ["""SELECT node, affiliation FROM entities
+                    NATURAL JOIN affiliations
+                    NATURAL JOIN nodes
+                    WHERE jid=%s"""]
+        args = [entity.userhost()]
+
+        if nodeIdentifier is not None:
+            query.append("AND node=%s")
+            args.append(nodeIdentifier)
+
+        cursor.execute(*withPEP(' '.join(query), args, pep, recipient))
+        rows = cursor.fetchall()
+        return [tuple(r) for r in rows]
 
 
     def getSubscriptions(self, entity, pep, recipient=None):
@@ -578,6 +585,77 @@
 
         return {jid.internJID(r[0]): r[1] for r in result}
 
+    def getOrCreateEntities(self, cursor, entities_jids):
+        """Get entity_id from entities in entities table
+
+        Entities will be inserted it they don't exist
+        @param entities_jid(list[jid.JID]): entities to get or create
+        @return list[record(entity_jid,jid)]]: list of entity_id and jid (as plain string)
+            both existing and inserted entities are returned
+        """
+        # cf. http://stackoverflow.com/a/35265559
+        placeholders = ','.join(len(entities_jids) * ["(%s)"])
+        query = (
+        """
+        WITH
+        jid_values (jid) AS (
+               VALUES {placeholders}
+        ),
+        inserted (entity_id, jid) AS (
+            INSERT INTO entities (jid)
+            SELECT jid
+            FROM jid_values
+            ON CONFLICT DO NOTHING
+            RETURNING entity_id, jid
+        )
+        SELECT e.entity_id, e.jid
+        FROM entities e JOIN jid_values jv ON jv.jid = e.jid
+        UNION ALL
+        SELECT entity_id, jid
+        FROM inserted""".format(placeholders=placeholders))
+        cursor.execute(query, [j.userhost() for j in entities_jids])
+        return cursor.fetchall()
+
+    def setAffiliations(self, affiliations):
+        return self.dbpool.runInteraction(self._setAffiliations, affiliations)
+
+
+    def _setAffiliations(self, cursor, affiliations):
+        self._checkNodeExists(cursor)
+
+        entities = self.getOrCreateEntities(cursor, affiliations)
+
+        # then we construct values for affiliations update according to entity_id we just got
+        placeholders = ','.join(len(affiliations) * ["(%s,%s,%s)"])
+        values = []
+        map(values.extend, ((e.entity_id, affiliations[jid.JID(e.jid)], self.nodeDbId) for e in entities))
+
+        # we use upsert so new values are inserted and existing one updated. This feature is only available for PostgreSQL >= 9.5
+        cursor.execute("INSERT INTO affiliations(entity_id,affiliation,node_id) VALUES " + placeholders + " ON CONFLICT  (entity_id,node_id) DO UPDATE SET affiliation=EXCLUDED.affiliation", values)
+
+
+    def deleteAffiliations(self, entities):
+        return self.dbpool.runInteraction(self._deleteAffiliations, entities)
+
+
+    def _deleteAffiliations(self, cursor, entities):
+        """delete affiliations and subscriptions for this entity"""
+        self._checkNodeExists(cursor)
+        placeholders = ','.join(len(entities) * ["%s"])
+        cursor.execute("DELETE FROM affiliations WHERE node_id=%s AND entity_id in (SELECT entity_id FROM entities WHERE jid IN (" + placeholders + ")) RETURNING entity_id", [self.nodeDbId] + [e.userhost() for e in entities])
+
+        rows = cursor.fetchall()
+        placeholders = ','.join(len(rows) * ["%s"])
+        cursor.execute("DELETE FROM subscriptions WHERE node_id=%s AND entity_id in (" + placeholders + ")", [self.nodeDbId] + [r[0] for r in rows])
+
+    def getAuthorizedGroups(self):
+        return self.dbpool.runInteraction(self._getNodeGroups)
+
+    def _getAuthorizedGroups(self, cursor):
+        cursor.execute("SELECT groupname FROM node_groups_authorized NATURAL JOIN nodes WHERE node=%s",
+                                (self.nodeDbId,))
+        rows = cursor.fetchall()
+        return [row[0] for row in rows]
 
 
 class LeafNode(Node):