diff idavoll/pgsql_storage.py @ 204:b4bf0a5ce50d

Implement storage facilities for the HTTP gateway. Author: ralphm. Fixes #12. One of the storage facilities is PostgreSQL based, providing persistence.
author Ralph Meijer <ralphm@ik.nu>
date Wed, 16 Jul 2008 06:38:32 +0000
parents 77c61e2b8c75
children 274a45d2a5ab
line wrap: on
line diff
--- a/idavoll/pgsql_storage.py	Mon Jul 14 09:16:16 2008 +0000
+++ b/idavoll/pgsql_storage.py	Wed Jul 16 06:38:32 2008 +0000
@@ -4,10 +4,7 @@
 import copy
 
 from zope.interface import implements
-
-from twisted.enterprise import adbapi
 from twisted.words.protocols.jabber import jid
-
 from wokkel.generic import parseXml
 
 from idavoll import error, iidavoll
@@ -16,20 +13,13 @@
 
     implements(iidavoll.IStorage)
 
-    def __init__(self, user, database, password=None, host=None, port=None):
-        self._dbpool = adbapi.ConnectionPool('pyPgSQL.PgSQL',
-                                             user=user,
-                                             password=password,
-                                             database=database,
-                                             host=host,
-                                             port=port,
-                                             cp_reconnect=True,
-                                             client_encoding='utf-8'
-                                             )
+
+    def __init__(self, dbpool):
+        self.dbpool = dbpool
 
 
     def getNode(self, nodeIdentifier):
-        return self._dbpool.runInteraction(self._getNode, nodeIdentifier)
+        return self.dbpool.runInteraction(self._getNode, nodeIdentifier)
 
 
     def _getNode(self, cursor, nodeIdentifier):
@@ -48,18 +38,18 @@
             raise error.NodeNotFound()
         else:
             node = LeafNode(nodeIdentifier, configuration)
-            node._dbpool = self._dbpool
+            node.dbpool = self.dbpool
             return node
 
 
     def getNodeIds(self):
-        d = self._dbpool.runQuery("""SELECT node from nodes""")
+        d = self.dbpool.runQuery("""SELECT node from nodes""")
         d.addCallback(lambda results: [r[0] for r in results])
         return d
 
 
     def createNode(self, nodeIdentifier, owner, config=None):
-        return self._dbpool.runInteraction(self._createNode, nodeIdentifier,
+        return self.dbpool.runInteraction(self._createNode, nodeIdentifier,
                                            owner)
 
 
@@ -88,7 +78,7 @@
 
 
     def deleteNode(self, nodeIdentifier):
-        return self._dbpool.runInteraction(self._deleteNode, nodeIdentifier)
+        return self.dbpool.runInteraction(self._deleteNode, nodeIdentifier)
 
 
     def _deleteNode(self, cursor, nodeIdentifier):
@@ -100,7 +90,7 @@
 
 
     def getAffiliations(self, entity):
-        d = self._dbpool.runQuery("""SELECT node, affiliation FROM entities
+        d = self.dbpool.runQuery("""SELECT node, affiliation FROM entities
                                         JOIN affiliations ON
                                         (affiliations.entity_id=entities.id)
                                         JOIN nodes ON
@@ -112,7 +102,7 @@
 
 
     def getSubscriptions(self, entity):
-        d = self._dbpool.runQuery("""SELECT node, jid, resource, subscription
+        d = self.dbpool.runQuery("""SELECT node, jid, resource, subscription
                                      FROM entities JOIN subscriptions ON
                                      (subscriptions.entity_id=entities.id)
                                      JOIN nodes ON
@@ -162,7 +152,7 @@
             if option in config:
                 config[option] = options[option]
 
-        d = self._dbpool.runInteraction(self._setConfiguration, config)
+        d = self.dbpool.runInteraction(self._setConfiguration, config)
         d.addCallback(self._setCachedConfiguration, config)
         return d
 
@@ -189,7 +179,7 @@
 
 
     def getAffiliation(self, entity):
-        return self._dbpool.runInteraction(self._getAffiliation, entity)
+        return self.dbpool.runInteraction(self._getAffiliation, entity)
 
 
     def _getAffiliation(self, cursor, entity):
@@ -208,7 +198,7 @@
 
 
     def getSubscription(self, subscriber):
-        return self._dbpool.runInteraction(self._getSubscription, subscriber)
+        return self.dbpool.runInteraction(self._getSubscription, subscriber)
 
 
     def _getSubscription(self, cursor, subscriber):
@@ -232,7 +222,7 @@
 
 
     def addSubscription(self, subscriber, state):
-        return self._dbpool.runInteraction(self._addSubscription, subscriber,
+        return self.dbpool.runInteraction(self._addSubscription, subscriber,
                                           state)
 
 
@@ -264,7 +254,7 @@
 
 
     def removeSubscription(self, subscriber):
-        return self._dbpool.runInteraction(self._removeSubscription,
+        return self.dbpool.runInteraction(self._removeSubscription,
                                            subscriber)
 
 
@@ -288,7 +278,7 @@
 
 
     def getSubscribers(self):
-        d = self._dbpool.runInteraction(self._getSubscribers)
+        d = self.dbpool.runInteraction(self._getSubscribers)
         d.addCallback(self._convertToJIDs)
         return d
 
@@ -309,7 +299,7 @@
 
 
     def isSubscribed(self, entity):
-        return self._dbpool.runInteraction(self._isSubscribed, entity)
+        return self.dbpool.runInteraction(self._isSubscribed, entity)
 
 
     def _isSubscribed(self, cursor, entity):
@@ -329,7 +319,7 @@
 
 
     def getAffiliations(self):
-        return self._dbpool.runInteraction(self._getAffiliations)
+        return self.dbpool.runInteraction(self._getAffiliations)
 
 
     def _getAffiliations(self, cursor):
@@ -353,7 +343,7 @@
     nodeType = 'leaf'
 
     def storeItems(self, items, publisher):
-        return self._dbpool.runInteraction(self._storeItems, items, publisher)
+        return self.dbpool.runInteraction(self._storeItems, items, publisher)
 
 
     def _storeItems(self, cursor, items, publisher):
@@ -384,7 +374,7 @@
 
 
     def removeItems(self, itemIdentifiers):
-        return self._dbpool.runInteraction(self._removeItems, itemIdentifiers)
+        return self.dbpool.runInteraction(self._removeItems, itemIdentifiers)
 
 
     def _removeItems(self, cursor, itemIdentifiers):
@@ -406,7 +396,7 @@
 
 
     def getItems(self, maxItems=None):
-        return self._dbpool.runInteraction(self._getItems, maxItems)
+        return self.dbpool.runInteraction(self._getItems, maxItems)
 
 
     def _getItems(self, cursor, maxItems):
@@ -426,7 +416,7 @@
 
 
     def getItemsById(self, itemIdentifiers):
-        return self._dbpool.runInteraction(self._getItemsById, itemIdentifiers)
+        return self.dbpool.runInteraction(self._getItemsById, itemIdentifiers)
 
 
     def _getItemsById(self, cursor, itemIdentifiers):
@@ -445,7 +435,7 @@
 
 
     def purge(self):
-        return self._dbpool.runInteraction(self._purge)
+        return self.dbpool.runInteraction(self._purge)
 
 
     def _purge(self, cursor):
@@ -460,3 +450,84 @@
 class LeafNode(Node, LeafNodeMixin):
 
     implements(iidavoll.ILeafNode)
+
+
+
+class GatewayStorage(object):
+    """
+    Memory based storage facility for the XMPP-HTTP gateway.
+    """
+
+    def __init__(self, dbpool):
+        self.dbpool = dbpool
+
+
+    def _countCallbacks(self, cursor, service, nodeIdentifier):
+        """
+        Count number of callbacks registered for a node.
+        """
+        cursor.execute("""SELECT count(*) FROM callbacks
+                          WHERE service=%s and node=%s""",
+                       service.full(),
+                       nodeIdentifier)
+        results = cursor.fetchall()
+        return results[0][0]
+
+
+    def addCallback(self, service, nodeIdentifier, callback):
+        def interaction(cursor):
+            cursor.execute("""SELECT 1 FROM callbacks
+                              WHERE service=%s and node=%s and uri=%s""",
+                           service.full(),
+                           nodeIdentifier,
+                           callback)
+            if cursor.fetchall():
+                raise error.SubscriptionExists()
+
+            cursor.execute("""INSERT INTO callbacks
+                              (service, node, uri) VALUES
+                              (%s, %s, %s)""",
+                           service.full(),
+                           nodeIdentifier,
+                           callback)
+
+        return self.dbpool.runInteraction(interaction)
+
+
+    def removeCallback(self, service, nodeIdentifier, callback):
+        def interaction(cursor):
+            cursor.execute("""DELETE FROM callbacks
+                              WHERE service=%s and node=%s and uri=%s""",
+                           service.full(),
+                           nodeIdentifier,
+                           callback)
+
+            if cursor.rowcount != 1:
+                raise error.NotSubscribed()
+
+            last = not self._countCallbacks(cursor, service, nodeIdentifier)
+            return last
+
+        return self.dbpool.runInteraction(interaction)
+
+    def getCallbacks(self, service, nodeIdentifier):
+        def interaction(cursor):
+            cursor.execute("""SELECT uri FROM callbacks
+                              WHERE service=%s and node=%s""",
+                           service.full(),
+                           nodeIdentifier)
+            results = cursor.fetchall()
+
+            if not results:
+                raise error.NoCallbacks()
+
+            return [result[0] for result in results]
+
+        return self.dbpool.runInteraction(interaction)
+
+
+    def hasCallbacks(self, service, nodeIdentifier):
+        def interaction(cursor):
+            return bool(self._countCallbacks(cursor, service, nodeIdentifier))
+
+        return self.dbpool.runInteraction(interaction)