diff idavoll/pgsql_storage.py @ 206:274a45d2a5ab

Implement root collection that includes all leaf nodes.
author Ralph Meijer <ralphm@ik.nu>
date Mon, 04 Aug 2008 13:47:10 +0000
parents b4bf0a5ce50d
children 8540825f85e0
line wrap: on
line diff
--- a/idavoll/pgsql_storage.py	Mon Aug 04 07:10:45 2008 +0000
+++ b/idavoll/pgsql_storage.py	Mon Aug 04 13:47:10 2008 +0000
@@ -4,8 +4,12 @@
 import copy
 
 from zope.interface import implements
+
+from twisted.enterprise import adbapi
 from twisted.words.protocols.jabber import jid
-from wokkel.generic import parseXml
+
+from wokkel.generic import parseXml, stripNamespace
+from wokkel.pubsub import Subscription
 
 from idavoll import error, iidavoll
 
@@ -13,6 +17,17 @@
 
     implements(iidavoll.IStorage)
 
+    defaultConfig = {
+            'leaf': {
+                "pubsub#persist_items": True,
+                "pubsub#deliver_payloads": True,
+                "pubsub#send_last_published_item": 'on_sub',
+            },
+            'collection': {
+                "pubsub#deliver_payloads": True,
+                "pubsub#send_last_published_item": 'on_sub',
+            }
+    }
 
     def __init__(self, dbpool):
         self.dbpool = dbpool
@@ -24,22 +39,36 @@
 
     def _getNode(self, cursor, nodeIdentifier):
         configuration = {}
-        cursor.execute("""SELECT persistent, deliver_payload,
+        cursor.execute("""SELECT node_type,
+                                 persist_items,
+                                 deliver_payloads,
                                  send_last_published_item
                           FROM nodes
                           WHERE node=%s""",
                        (nodeIdentifier,))
-        try:
-            (configuration["pubsub#persist_items"],
-             configuration["pubsub#deliver_payloads"],
-             configuration["pubsub#send_last_published_item"]) = \
-            cursor.fetchone()
-        except TypeError:
+        row = cursor.fetchone()
+
+        if not row:
             raise error.NodeNotFound()
-        else:
+
+        if row.node_type == 'leaf':
+            configuration = {
+                    'pubsub#persist_items': row.persist_items,
+                    'pubsub#deliver_payloads': row.deliver_payloads,
+                    'pubsub#send_last_published_item':
+                        row.send_last_published_item}
             node = LeafNode(nodeIdentifier, configuration)
             node.dbpool = self.dbpool
             return node
+        elif row.node_type == 'collection':
+            configuration = {
+                    'pubsub#deliver_payloads': row.deliver_payloads,
+                    'pubsub#send_last_published_item':
+                        row.send_last_published_item}
+            node = CollectionNode(nodeIdentifier, configuration)
+            node.dbpool = self.dbpool
+            return node
+
 
 
     def getNodeIds(self):
@@ -48,16 +77,27 @@
         return d
 
 
-    def createNode(self, nodeIdentifier, owner, config=None):
+    def createNode(self, nodeIdentifier, owner, config):
         return self.dbpool.runInteraction(self._createNode, nodeIdentifier,
-                                           owner)
+                                           owner, config)
 
 
-    def _createNode(self, cursor, nodeIdentifier, owner):
+    def _createNode(self, cursor, nodeIdentifier, owner, config):
+        if config['pubsub#node_type'] != 'leaf':
+            raise error.NoCollections()
+
         owner = owner.userhost()
         try:
-            cursor.execute("""INSERT INTO nodes (node) VALUES (%s)""",
-                           (nodeIdentifier))
+            cursor.execute("""INSERT INTO nodes
+                              (node, node_type, persist_items,
+                               deliver_payloads, send_last_published_item)
+                              VALUES
+                              (%s, 'leaf', %s, %s, %s)""",
+                           (nodeIdentifier,
+                            config['pubsub#persist_items'],
+                            config['pubsub#deliver_payloads'],
+                            config['pubsub#send_last_published_item'])
+                           )
         except cursor._pool.dbapi.OperationalError:
             raise error.NodeExists()
 
@@ -70,10 +110,11 @@
 
         cursor.execute("""INSERT INTO affiliations
                           (node_id, entity_id, affiliation)
-                          SELECT n.id, e.id, 'owner' FROM
-                          (SELECT id FROM nodes WHERE node=%s) AS n
+                          SELECT node_id, entity_id, 'owner' FROM
+                          (SELECT node_id FROM nodes WHERE node=%s) as n
                           CROSS JOIN
-                          (SELECT id FROM entities WHERE jid=%s) AS e""",
+                          (SELECT entity_id FROM entities
+                                            WHERE jid=%s) as e""",
                        (nodeIdentifier, owner))
 
 
@@ -91,10 +132,8 @@
 
     def getAffiliations(self, entity):
         d = self.dbpool.runQuery("""SELECT node, affiliation FROM entities
-                                        JOIN affiliations ON
-                                        (affiliations.entity_id=entities.id)
-                                        JOIN nodes ON
-                                        (nodes.id=affiliations.node_id)
+                                        NATURAL JOIN affiliations
+                                        NATURAL JOIN nodes
                                         WHERE jid=%s""",
                                      (entity.userhost(),))
         d.addCallback(lambda results: [tuple(r) for r in results])
@@ -102,22 +141,27 @@
 
 
     def getSubscriptions(self, entity):
-        d = self.dbpool.runQuery("""SELECT node, jid, resource, subscription
-                                     FROM entities JOIN subscriptions ON
-                                     (subscriptions.entity_id=entities.id)
-                                     JOIN nodes ON
-                                     (nodes.id=subscriptions.node_id)
+        def toSubscriptions(rows):
+            subscriptions = []
+            for row in rows:
+                subscriber = jid.internJID('%s/%s' % (row.jid,
+                                                      row.resource))
+                subscription = Subscription(row.node, subscriber, row.state)
+                subscriptions.append(subscription)
+            return subscriptions
+
+        d = self.dbpool.runQuery("""SELECT node, jid, resource, state
+                                     FROM entities
+                                     NATURAL JOIN subscriptions
+                                     NATURAL JOIN nodes
                                      WHERE jid=%s""",
                                   (entity.userhost(),))
-        d.addCallback(self._convertSubscriptionJIDs)
+        d.addCallback(toSubscriptions)
         return d
 
 
-    def _convertSubscriptionJIDs(self, subscriptions):
-        return [(node,
-                 jid.internJID('%s/%s' % (subscriber, resource)),
-                 subscription)
-                for node, subscriber, resource, subscription in subscriptions]
+    def getDefaultConfiguration(self, nodeType):
+        return self.defaultConfig[nodeType]
 
 
 
@@ -131,7 +175,7 @@
 
 
     def _checkNodeExists(self, cursor):
-        cursor.execute("""SELECT id FROM nodes WHERE node=%s""",
+        cursor.execute("""SELECT node_id FROM nodes WHERE node=%s""",
                        (self.nodeIdentifier))
         if not cursor.fetchone():
             raise error.NodeNotFound()
@@ -159,7 +203,8 @@
 
     def _setConfiguration(self, cursor, config):
         self._checkNodeExists(cursor)
-        cursor.execute("""UPDATE nodes SET persistent=%s, deliver_payload=%s,
+        cursor.execute("""UPDATE nodes SET persist_items=%s,
+                                           deliver_payloads=%s,
                                            send_last_published_item=%s
                           WHERE node=%s""",
                        (config["pubsub#persist_items"],
@@ -185,8 +230,8 @@
     def _getAffiliation(self, cursor, entity):
         self._checkNodeExists(cursor)
         cursor.execute("""SELECT affiliation FROM affiliations
-                          JOIN nodes ON (node_id=nodes.id)
-                          JOIN entities ON (entity_id=entities.id)
+                          NATURAL JOIN nodes
+                          NATURAL JOIN entities
                           WHERE node=%s AND jid=%s""",
                        (self.nodeIdentifier,
                         entity.userhost()))
@@ -207,31 +252,72 @@
         userhost = subscriber.userhost()
         resource = subscriber.resource or ''
 
-        cursor.execute("""SELECT subscription FROM subscriptions
-                          JOIN nodes ON (nodes.id=subscriptions.node_id)
-                          JOIN entities ON
-                               (entities.id=subscriptions.entity_id)
+        cursor.execute("""SELECT state FROM subscriptions
+                          NATURAL JOIN nodes
+                          NATURAL JOIN entities
                           WHERE node=%s AND jid=%s AND resource=%s""",
                        (self.nodeIdentifier,
                         userhost,
                         resource))
-        try:
-            return cursor.fetchone()[0]
-        except TypeError:
+        row = cursor.fetchone()
+        if not row:
             return None
+        else:
+            return Subscription(self.nodeIdentifier, subscriber, row.state)
+
+
+    def getSubscriptions(self, state=None):
+        return self.dbpool.runInteraction(self._getSubscriptions, state)
 
 
-    def addSubscription(self, subscriber, state):
-        return self.dbpool.runInteraction(self._addSubscription, subscriber,
-                                          state)
+    def _getSubscriptions(self, cursor, state):
+        self._checkNodeExists(cursor)
+
+        query = """SELECT jid, resource, state,
+                          subscription_type, subscription_depth
+                   FROM subscriptions
+                   NATURAL JOIN nodes
+                   NATURAL JOIN entities
+                   WHERE node=%s""";
+        values = [self.nodeIdentifier]
+
+        if state:
+            query += " AND state=%s"
+            values.append(state)
+
+        cursor.execute(query, values);
+        rows = cursor.fetchall()
+
+        subscriptions = []
+        for row in rows:
+            subscriber = jid.JID('%s/%s' % (row.jid, row.resource))
+
+            options = {}
+            if row.subscription_type:
+                options['pubsub#subscription_type'] = row.subscription_type;
+            if row.subscription_depth:
+                options['pubsub#subscription_depth'] = row.subscription_depth;
+
+            subscriptions.append(Subscription(self.nodeIdentifier, subscriber,
+                                              row.state, options))
+
+        return subscriptions
 
 
-    def _addSubscription(self, cursor, subscriber, state):
+    def addSubscription(self, subscriber, state, config):
+        return self.dbpool.runInteraction(self._addSubscription, subscriber,
+                                          state, config)
+
+
+    def _addSubscription(self, cursor, subscriber, state, config):
         self._checkNodeExists(cursor)
 
         userhost = subscriber.userhost()
         resource = subscriber.resource or ''
 
+        subscription_type = config.get('pubsub#subscription_type')
+        subscription_depth = config.get('pubsub#subscription_depth')
+
         try:
             cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""",
                            (userhost))
@@ -240,13 +326,18 @@
 
         try:
             cursor.execute("""INSERT INTO subscriptions
-                              (node_id, entity_id, resource, subscription)
-                              SELECT n.id, e.id, %s, %s FROM
-                              (SELECT id FROM nodes WHERE node=%s) AS n
+                              (node_id, entity_id, resource, state,
+                               subscription_type, subscription_depth)
+                              SELECT node_id, entity_id, %s, %s, %s, %s FROM
+                              (SELECT node_id FROM nodes
+                                              WHERE node=%s) as n
                               CROSS JOIN
-                              (SELECT id FROM entities WHERE jid=%s) AS e""",
+                              (SELECT entity_id FROM entities
+                                                WHERE jid=%s) as e""",
                            (resource,
                             state,
+                            subscription_type,
+                            subscription_depth,
                             self.nodeIdentifier,
                             userhost))
         except cursor._pool.dbapi.OperationalError:
@@ -265,9 +356,11 @@
         resource = subscriber.resource or ''
 
         cursor.execute("""DELETE FROM subscriptions WHERE
-                          node_id=(SELECT id FROM nodes WHERE node=%s) AND
-                          entity_id=(SELECT id FROM entities WHERE jid=%s)
-                          AND resource=%s""",
+                          node_id=(SELECT node_id FROM nodes
+                                                  WHERE node=%s) AND
+                          entity_id=(SELECT entity_id FROM entities
+                                                      WHERE jid=%s) AND
+                          resource=%s""",
                        (self.nodeIdentifier,
                         userhost,
                         resource))
@@ -277,27 +370,6 @@
         return None
 
 
-    def getSubscribers(self):
-        d = self.dbpool.runInteraction(self._getSubscribers)
-        d.addCallback(self._convertToJIDs)
-        return d
-
-
-    def _getSubscribers(self, cursor):
-        self._checkNodeExists(cursor)
-        cursor.execute("""SELECT jid, resource FROM subscriptions
-                          JOIN nodes ON (node_id=nodes.id)
-                          JOIN entities ON (entity_id=entities.id)
-                          WHERE node=%s AND
-                          subscription='subscribed'""",
-                       (self.nodeIdentifier,))
-        return cursor.fetchall()
-
-
-    def _convertToJIDs(self, list):
-        return [jid.internJID("%s/%s" % (l[0], l[1])) for l in list]
-
-
     def isSubscribed(self, entity):
         return self.dbpool.runInteraction(self._isSubscribed, entity)
 
@@ -306,12 +378,10 @@
         self._checkNodeExists(cursor)
 
         cursor.execute("""SELECT 1 FROM entities
-                          JOIN subscriptions ON
-                          (entities.id=subscriptions.entity_id)
-                          JOIN nodes ON
-                          (nodes.id=subscriptions.node_id)
+                          NATURAL JOIN subscriptions
+                          NATURAL JOIN nodes
                           WHERE entities.jid=%s
-                          AND node=%s AND subscription='subscribed'""",
+                          AND node=%s AND state='subscribed'""",
                        (entity.userhost(),
                        self.nodeIdentifier))
 
@@ -326,10 +396,8 @@
         self._checkNodeExists(cursor)
 
         cursor.execute("""SELECT jid, affiliation FROM nodes
-                          JOIN affiliations ON
-                            (nodes.id = affiliations.node_id)
-                          JOIN entities ON
-                            (affiliations.entity_id = entities.id)
+                          NATURAL JOIN affiliations
+                          NATURAL JOIN entities
                           WHERE node=%s""",
                        self.nodeIdentifier)
         result = cursor.fetchall()
@@ -338,7 +406,9 @@
 
 
 
-class LeafNodeMixin:
+class LeafNode(Node):
+
+    implements(iidavoll.ILeafNode)
 
     nodeType = 'leaf'
 
@@ -356,7 +426,7 @@
         data = item.toXml()
         cursor.execute("""UPDATE items SET date=now(), publisher=%s, data=%s
                           FROM nodes
-                          WHERE nodes.id = items.node_id AND
+                          WHERE nodes.node_id = items.node_id AND
                                 nodes.node = %s and items.item=%s""",
                        (publisher.full(),
                         data,
@@ -366,7 +436,8 @@
             return
 
         cursor.execute("""INSERT INTO items (node_id, item, publisher, data)
-                          SELECT id, %s, %s, %s FROM nodes WHERE node=%s""",
+                          SELECT node_id, %s, %s, %s FROM nodes
+                                                     WHERE node=%s""",
                        (item["id"],
                         publisher.full(),
                         data,
@@ -384,7 +455,8 @@
 
         for itemIdentifier in itemIdentifiers:
             cursor.execute("""DELETE FROM items WHERE
-                              node_id=(SELECT id FROM nodes WHERE node=%s) AND
+                              node_id=(SELECT node_id FROM nodes
+                                                      WHERE node=%s) AND
                               item=%s""",
                            (self.nodeIdentifier,
                             itemIdentifier))
@@ -401,8 +473,8 @@
 
     def _getItems(self, cursor, maxItems):
         self._checkNodeExists(cursor)
-        query = """SELECT data FROM nodes JOIN items ON
-                   (nodes.id=items.node_id)
+        query = """SELECT data FROM nodes
+                   NATURAL JOIN items
                    WHERE node=%s ORDER BY date DESC"""
         if maxItems:
             cursor.execute(query + " LIMIT %s",
@@ -412,7 +484,8 @@
             cursor.execute(query, (self.nodeIdentifier))
 
         result = cursor.fetchall()
-        return [parseXml(r[0]) for r in result]
+        items = [stripNamespace(parseXml(r[0])) for r in result]
+        return items
 
 
     def getItemsById(self, itemIdentifiers):
@@ -423,8 +496,8 @@
         self._checkNodeExists(cursor)
         items = []
         for itemIdentifier in itemIdentifiers:
-            cursor.execute("""SELECT data FROM nodes JOIN items ON
-                              (nodes.id=items.node_id)
+            cursor.execute("""SELECT data FROM nodes
+                              NATURAL JOIN items
                               WHERE node=%s AND item=%s""",
                            (self.nodeIdentifier,
                             itemIdentifier))
@@ -442,14 +515,13 @@
         self._checkNodeExists(cursor)
 
         cursor.execute("""DELETE FROM items WHERE
-                          node_id=(SELECT id FROM nodes WHERE node=%s)""",
+                          node_id=(SELECT node_id FROM nodes WHERE node=%s)""",
                        (self.nodeIdentifier,))
 
 
+class CollectionNode(Node):
 
-class LeafNode(Node, LeafNodeMixin):
-
-    implements(iidavoll.ILeafNode)
+    nodeType = 'collection'
 
 
 
@@ -482,7 +554,7 @@
                            nodeIdentifier,
                            callback)
             if cursor.fetchall():
-                raise error.SubscriptionExists()
+                return
 
             cursor.execute("""INSERT INTO callbacks
                               (service, node, uri) VALUES