diff idavoll/pgsql_storage.py @ 121:4f0113adb7ed

Add Node._check_node_exists() calls to all Node methods, because nodes could have been deleted in between calls. Add Node.get_subscription(). Only fire deferred (with None) on success of Node.add_subscription(). Fix Node.set_configuration() to actually work and only update the Node objects configuration when the SQL query has succeeded. Implement Node.remove_subscription(). Implement Node.is_subscribed(). Implement LeafNode methods (unchecked!).
author Ralph Meijer <ralphm@ik.nu>
date Tue, 12 Apr 2005 12:26:05 +0000
parents dfef919aaf1b
children c4ee16bc48e5
line wrap: on
line diff
--- a/idavoll/pgsql_storage.py	Tue Apr 12 12:20:01 2005 +0000
+++ b/idavoll/pgsql_storage.py	Tue Apr 12 12:26:05 2005 +0000
@@ -107,6 +107,12 @@
         self.id = node_id
         self._config = config
 
+    def _check_node_exists(self, cursor):
+        cursor.execute("""SELECT id FROM nodes WHERE node=%s""",
+                       (self.id.encode('utf8')))
+        if not cursor.fetchone():
+            raise backend.NodeNotFound
+
     def get_type(self):
         return self.type
 
@@ -114,20 +120,26 @@
         return self._config
 
     def set_configuration(self, options):
-        return self._dbpool.runInteraction(self._set_node_configuration,
+        return self._dbpool.runInteraction(self._set_configuration,
                                            options)
 
     def _set_configuration(self, cursor, options):
+        self._check_node_exists(cursor)
+
+        config = copy.copy(self._config)
+
         for option in options:
-            if option in self._config:
-                self._config[option] = options[option]
+            if option in config:
+                config[option] = options[option]
         
         cursor.execute("""UPDATE nodes SET persistent=%s, deliver_payload=%s
                           WHERE node=%s""",
-                       (self._config["pubsub#persist_items"].encode('utf8'),
-                        self._config["pubsub#deliver_payloads"].encode('utf8'),
+                       (config["pubsub#persist_items"],
+                        config["pubsub#deliver_payloads"],
                         self.id.encode('utf-8')))
 
+        self._config = config
+
     def get_meta_data(self):
         config = copy.copy(self._config)
         config["pubsub#node_type"] = self.type
@@ -137,6 +149,7 @@
         return self._dbpool.runInteraction(self._get_affiliation, entity)
 
     def _get_affiliation(self, cursor, entity):
+        self._check_node_exists(cursor)
         cursor.execute("""SELECT affiliation FROM affiliations
                           JOIN nodes ON (node_id=nodes.id)
                           JOIN entities ON (entity_id=entities.id)
@@ -149,11 +162,35 @@
         except TypeError:
             return None
 
+    def get_subscription(self, subscriber):
+        return self._dbpool.runInteraction(self._get_subscription, subscriber)
+
+    def _get_subscription(self, cursor, subscriber):
+        self._check_node_exists(cursor)
+
+        userhost = subscriber.userhost()
+        resource = subscriber.resource or ''
+
+        cursor.execute("""SELECT subscription FROM subscriptions
+                          JOIN nodes ON (nodes.id=subscriptions.node_id)
+                          JOIN entities ON
+                               (entities.id=subscriptions.entity_id)
+                          WHERE node=%s AND jid=%s AND resource=%s""",
+                       (self.id.encode('utf8'),
+                        userhost.encode('utf8'),
+                        resource.encode('utf8')))
+        try:
+            return cursor.fetchone()[0]
+        except TypeError:
+            return None
+
     def add_subscription(self, subscriber, state):
         return self._dbpool.runInteraction(self._add_subscription, subscriber,
                                           state)
 
     def _add_subscription(self, cursor, subscriber, state):
+        self._check_node_exists(cursor)
+
         userhost = subscriber.userhost()
         resource = subscriber.resource or ''
 
@@ -175,38 +212,69 @@
                             self.id.encode('utf8'),
                             userhost.encode('utf8')))
         except cursor._pool.dbapi.OperationalError:
-            cursor.execute("""SELECT subscription FROM subscriptions
-                              JOIN nodes ON (nodes.id=subscriptions.node_id)
-                              JOIN entities ON
-                                   (entities.id=subscriptions.entity_id)
-                              WHERE node=%s AND jid=%s AND resource=%s""",
-                           (self.id.encode('utf8'),
-                            userhost.encode('utf8'),
-                            resource.encode('utf8')))
-            state = cursor.fetchone()[0]
+            raise storage.SubscriptionExists
+
+    def remove_subscription(self, subscriber):
+        return self._dbpool.runInteraction(self._remove_subscription,
+                                           subscriber)
+
+    def _remove_subscription(self, cursor, subscriber):
+        self._check_node_exists(cursor)
+
+        userhost = subscriber.userhost()
+        resource = subscriber.resource or ''
 
-        return {'node': self.id,
-                'jid': subscriber,
-                'subscription': state}
+        cursor.execute("""DELETE FROM subscriptions WHERE
+                          node_id=(SELECT id FROM nodes WHERE node=%s) AND
+                          entity_id=(SELECT id FROM entities WHERE jid=%s)
+                          AND resource=%s""",
+                       (self.id.encode('utf8'),
+                        userhost.encode('utf8'),
+                        resource.encode('utf8')))
+        if cursor.rowcount != 1:
+            raise storage.SubscriptionNotFound
 
-    def remove_subscription(self, subscriber, state):
-        pass
+        return None
 
     def get_subscribers(self):
-        d = self._dbpool.runQuery("""SELECT jid, resource FROM subscriptions
-                                     JOIN nodes ON (node_id=nodes.id)
-                                     JOIN entities ON (entity_id=entities.id)
-                                     WHERE node=%s AND
-                                     subscription='subscribed'""",
-                                  (self.id.encode('utf8'),))
+        d = self._dbpool.runInteraction(self._get_subscribers)
         d.addCallback(self._convert_to_jids)
         return d
 
+    def _get_subscribers(self, cursor):
+        self._check_node_exists(cursor)
+        cursor.execute("""SELECT jid, resource FROM subscriptions
+                          JOIN nodes ON (node_id=nodes.id)
+                          JOIN entities ON (entity_id=entities.id)
+                          WHERE node=%s AND
+                          subscription='subscribed'""",
+                       (self.id.encode('utf8'),))
+        return cursor.fetchall()
+
     def _convert_to_jids(self, list):
         return [jid.JID("%s/%s" % (l[0], l[1])) for l in list]
 
     def is_subscribed(self, subscriber):
-        pass
+        return self._dbpool.runInteraction(self._is_subscribed, subscriber)
+
+    def _is_subscribed(self, cursor, subscriber):
+        self._check_node_exists(cursor)
+
+        userhost = subscriber.userhost()
+        resource = subscriber.resource or ''
+
+        cursor.execute("""SELECT 1 FROM entities
+                          JOIN subscriptions ON
+                          (entities.id=subscriptions.entity_id)
+                          JOIN nodes ON
+                          (nodes.id=subscriptions.node_id)
+                          WHERE entities.jid=%s AND resource=%s
+                          AND node=%s AND subscription='subscribed'""",
+                       (userhost.encode('utf8'),
+                       resource.encode('utf8'),
+                       self.id.encode('utf8')))
+
+        return cursor.fetchone() is not None
 
 class LeafNode(Node):
 
@@ -215,16 +283,95 @@
     type = 'leaf'
 
     def store_items(self, items, publisher):
-        return defer.succeed(None)
+        return self._dbpool.runInteraction(self._store_items, items, publisher)
+
+    def _store_items(self, cursor, items, publisher):
+        self._check_node_exists(cursor)
+        for item in items:
+            self._store_item(cursor, item, publisher)
+
+    def _store_item(self, cursor, item, publisher):
+        data = item.toXml()
+        cursor.execute("""UPDATE items SET date=now(), publisher=%s, data=%s
+                          FROM nodes
+                          WHERE nodes.id = items.node_id AND
+                                nodes.node = %s and items.item=%s""",
+                       (publisher.full().encode('utf8'),
+                        data.encode('utf8'),
+                        self.id.encode('utf8'),
+                        item["id"].encode('utf8')))
+        if cursor.rowcount == 1:
+            return
+
+        cursor.execute("""INSERT INTO items (node_id, item, publisher, data)
+                          SELECT id, %s, %s, %s FROM nodes WHERE node=%s""",
+                       (item["id"].encode('utf8'),
+                        publisher.full().encode('utf8'),
+                        data.encode('utf8'),
+                        self.id.encode('utf8')))
 
     def remove_items(self, item_ids):
-        pass
+        return self._dbpool.runInteraction(self._remove_items, item_ids)
+
+    def _remove_items(self, cursor, item_ids):
+        self._check_node_exists(cursor)
+        
+        deleted = []
+
+        for item_id in item_ids:
+            cursor.execute("""DELETE FROM items WHERE
+                              node_id=(SELECT id FROM nodes WHERE node=%s) AND
+                              item=%s""",
+                           (self.id.encode('utf-8'),
+                            item_id.encode('utf-8')))
+
+            if cursor.rowcount:
+                deleted.append(item_id)
+
+        return deleted
 
     def get_items(self, max_items=None):
-        pass
+        return self._dbpool.runInteraction(self._get_items, max_items)
+
+    def _get_items(self, cursor, max_items):
+        self._check_node_exists(cursor)
+        query = """SELECT data FROM nodes JOIN items ON
+                   (nodes.id=items.node_id)
+                   WHERE node=%s ORDER BY date DESC"""
+        if max_items:
+            cursor.execute(query + " LIMIT %s",
+                           (self.id.encode('utf8'),
+                            max_items))
+        else:
+            cursor.execute(query, (self.id.encode('utf8')))
+
+        result = cursor.fetchall()
+        return [r[0] for r in result]
 
-    def get_items_by_id(self, item_ids):
-        pass
+    def get_items_by_ids(self, item_ids):
+        return self._dbpool.runInteraction(self._get_items_by_ids, item_ids)
+
+    def _get_items_by_ids(self, cursor, item_ids):
+        self._check_node_exists(cursor)
+        items = []
+        for item_id in item_ids:
+            cursor.execute("""SELECT data FROM nodes JOIN items ON
+                              (nodes.id=items.node_id)
+                              WHERE node=%s AND item=%s""",
+                           (self.id.encode('utf8'),
+                            item_id.encode('utf8')))
+            result = cursor.fetchone()
+            if result:
+                items.append(result[0])
+        return items
 
     def purge(self):
-        pass
+        return self._dbpool.runInteraction(self._purge)
+
+    def _purge_node(self, cursor):
+        self._check_node_exists(cursor)
+
+        cursor.execute("""DELETE FROM items WHERE
+                          node_id=(SELECT id FROM nodes WHERE node=%s)""",
+                       (self.id.encode('utf-8'),))
+