Mercurial > libervia-pubsub
diff sat_pubsub/pgsql_storage.py @ 232:923281d4c5bc
renamed idavoll directory to sat_pubsub
author | Goffi <goffi@goffi.org> |
---|---|
date | Thu, 17 May 2012 12:48:14 +0200 |
parents | idavoll/pgsql_storage.py@8540825f85e0 |
children | 564ae55219e1 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat_pubsub/pgsql_storage.py Thu May 17 12:48:14 2012 +0200 @@ -0,0 +1,605 @@ +# Copyright (c) 2003-2008 Ralph Meijer +# See LICENSE for details. + +import copy + +from zope.interface import implements + +from twisted.enterprise import adbapi +from twisted.words.protocols.jabber import jid + +from wokkel.generic import parseXml, stripNamespace +from wokkel.pubsub import Subscription + +from idavoll import error, iidavoll + +class Storage: + + implements(iidavoll.IStorage) + + defaultConfig = { + 'leaf': { + "pubsub#persist_items": True, + "pubsub#deliver_payloads": True, + "pubsub#send_last_published_item": 'on_sub', + }, + 'collection': { + "pubsub#deliver_payloads": True, + "pubsub#send_last_published_item": 'on_sub', + } + } + + def __init__(self, dbpool): + self.dbpool = dbpool + + + def getNode(self, nodeIdentifier): + return self.dbpool.runInteraction(self._getNode, nodeIdentifier) + + + def _getNode(self, cursor, nodeIdentifier): + configuration = {} + cursor.execute("""SELECT node_type, + persist_items, + deliver_payloads, + send_last_published_item + FROM nodes + WHERE node=%s""", + (nodeIdentifier,)) + row = cursor.fetchone() + + if not row: + raise error.NodeNotFound() + + if row[0] == 'leaf': + configuration = { + 'pubsub#persist_items': row[1], + 'pubsub#deliver_payloads': row[2], + 'pubsub#send_last_published_item': + row[3]} + node = LeafNode(nodeIdentifier, configuration) + node.dbpool = self.dbpool + return node + elif row[0] == 'collection': + configuration = { + 'pubsub#deliver_payloads': row[2], + 'pubsub#send_last_published_item': + row[3]} + node = CollectionNode(nodeIdentifier, configuration) + node.dbpool = self.dbpool + return node + + + + def getNodeIds(self): + d = self.dbpool.runQuery("""SELECT node from nodes""") + d.addCallback(lambda results: [r[0] for r in results]) + return d + + + def createNode(self, nodeIdentifier, owner, config): + return self.dbpool.runInteraction(self._createNode, nodeIdentifier, + owner, config) + + + def _createNode(self, cursor, nodeIdentifier, owner, config): + if config['pubsub#node_type'] != 'leaf': + raise error.NoCollections() + + owner = owner.userhost() + try: + cursor.execute("""INSERT INTO nodes + (node, node_type, persist_items, + deliver_payloads, send_last_published_item) + VALUES + (%s, 'leaf', %s, %s, %s)""", + (nodeIdentifier, + config['pubsub#persist_items'], + config['pubsub#deliver_payloads'], + config['pubsub#send_last_published_item']) + ) + except cursor._pool.dbapi.IntegrityError: + raise error.NodeExists() + + cursor.execute("""SELECT 1 from entities where jid=%s""", + (owner,)) + + if not cursor.fetchone(): + cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""", + (owner,)) + + cursor.execute("""INSERT INTO affiliations + (node_id, entity_id, affiliation) + SELECT node_id, entity_id, 'owner' FROM + (SELECT node_id FROM nodes WHERE node=%s) as n + CROSS JOIN + (SELECT entity_id FROM entities + WHERE jid=%s) as e""", + (nodeIdentifier, owner)) + + + def deleteNode(self, nodeIdentifier): + return self.dbpool.runInteraction(self._deleteNode, nodeIdentifier) + + + def _deleteNode(self, cursor, nodeIdentifier): + cursor.execute("""DELETE FROM nodes WHERE node=%s""", + (nodeIdentifier,)) + + if cursor.rowcount != 1: + raise error.NodeNotFound() + + + def getAffiliations(self, entity): + d = self.dbpool.runQuery("""SELECT node, affiliation FROM entities + NATURAL JOIN affiliations + NATURAL JOIN nodes + WHERE jid=%s""", + (entity.userhost(),)) + d.addCallback(lambda results: [tuple(r) for r in results]) + return d + + + def getSubscriptions(self, entity): + def toSubscriptions(rows): + subscriptions = [] + for row in rows: + subscriber = jid.internJID('%s/%s' % (row[1], + row[2])) + subscription = Subscription(row[0], subscriber, row[3]) + subscriptions.append(subscription) + return subscriptions + + d = self.dbpool.runQuery("""SELECT node, jid, resource, state + FROM entities + NATURAL JOIN subscriptions + NATURAL JOIN nodes + WHERE jid=%s""", + (entity.userhost(),)) + d.addCallback(toSubscriptions) + return d + + + def getDefaultConfiguration(self, nodeType): + return self.defaultConfig[nodeType] + + + +class Node: + + implements(iidavoll.INode) + + def __init__(self, nodeIdentifier, config): + self.nodeIdentifier = nodeIdentifier + self._config = config + + + def _checkNodeExists(self, cursor): + cursor.execute("""SELECT node_id FROM nodes WHERE node=%s""", + (self.nodeIdentifier,)) + if not cursor.fetchone(): + raise error.NodeNotFound() + + + def getType(self): + return self.nodeType + + + def getConfiguration(self): + return self._config + + + def setConfiguration(self, options): + config = copy.copy(self._config) + + for option in options: + if option in config: + config[option] = options[option] + + d = self.dbpool.runInteraction(self._setConfiguration, config) + d.addCallback(self._setCachedConfiguration, config) + return d + + + def _setConfiguration(self, cursor, config): + self._checkNodeExists(cursor) + cursor.execute("""UPDATE nodes SET persist_items=%s, + deliver_payloads=%s, + send_last_published_item=%s + WHERE node=%s""", + (config["pubsub#persist_items"], + config["pubsub#deliver_payloads"], + config["pubsub#send_last_published_item"], + self.nodeIdentifier)) + + + def _setCachedConfiguration(self, void, config): + self._config = config + + + def getMetaData(self): + config = copy.copy(self._config) + config["pubsub#node_type"] = self.nodeType + return config + + + def getAffiliation(self, entity): + return self.dbpool.runInteraction(self._getAffiliation, entity) + + + def _getAffiliation(self, cursor, entity): + self._checkNodeExists(cursor) + cursor.execute("""SELECT affiliation FROM affiliations + NATURAL JOIN nodes + NATURAL JOIN entities + WHERE node=%s AND jid=%s""", + (self.nodeIdentifier, + entity.userhost())) + + try: + return cursor.fetchone()[0] + except TypeError: + return None + + + def getSubscription(self, subscriber): + return self.dbpool.runInteraction(self._getSubscription, subscriber) + + + def _getSubscription(self, cursor, subscriber): + self._checkNodeExists(cursor) + + userhost = subscriber.userhost() + resource = subscriber.resource or '' + + cursor.execute("""SELECT state FROM subscriptions + NATURAL JOIN nodes + NATURAL JOIN entities + WHERE node=%s AND jid=%s AND resource=%s""", + (self.nodeIdentifier, + userhost, + resource)) + row = cursor.fetchone() + if not row: + return None + else: + return Subscription(self.nodeIdentifier, subscriber, row[0]) + + + def getSubscriptions(self, state=None): + return self.dbpool.runInteraction(self._getSubscriptions, state) + + + def _getSubscriptions(self, cursor, state): + self._checkNodeExists(cursor) + + query = """SELECT jid, resource, state, + subscription_type, subscription_depth + FROM subscriptions + NATURAL JOIN nodes + NATURAL JOIN entities + WHERE node=%s"""; + values = [self.nodeIdentifier] + + if state: + query += " AND state=%s" + values.append(state) + + cursor.execute(query, values); + rows = cursor.fetchall() + + subscriptions = [] + for row in rows: + subscriber = jid.JID('%s/%s' % (row[0], row[1])) + + options = {} + if row[3]: + options['pubsub#subscription_type'] = row[3]; + if row[4]: + options['pubsub#subscription_depth'] = row[4]; + + subscriptions.append(Subscription(self.nodeIdentifier, subscriber, + row[2], options)) + + return subscriptions + + + def addSubscription(self, subscriber, state, config): + return self.dbpool.runInteraction(self._addSubscription, subscriber, + state, config) + + + def _addSubscription(self, cursor, subscriber, state, config): + self._checkNodeExists(cursor) + + userhost = subscriber.userhost() + resource = subscriber.resource or '' + + subscription_type = config.get('pubsub#subscription_type') + subscription_depth = config.get('pubsub#subscription_depth') + + try: + cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""", + (userhost,)) + except cursor._pool.dbapi.IntegrityError: + cursor._connection.rollback() + + try: + cursor.execute("""INSERT INTO subscriptions + (node_id, entity_id, resource, state, + subscription_type, subscription_depth) + SELECT node_id, entity_id, %s, %s, %s, %s FROM + (SELECT node_id FROM nodes + WHERE node=%s) as n + CROSS JOIN + (SELECT entity_id FROM entities + WHERE jid=%s) as e""", + (resource, + state, + subscription_type, + subscription_depth, + self.nodeIdentifier, + userhost)) + except cursor._pool.dbapi.IntegrityError: + raise error.SubscriptionExists() + + + def removeSubscription(self, subscriber): + return self.dbpool.runInteraction(self._removeSubscription, + subscriber) + + + def _removeSubscription(self, cursor, subscriber): + self._checkNodeExists(cursor) + + userhost = subscriber.userhost() + resource = subscriber.resource or '' + + cursor.execute("""DELETE FROM subscriptions WHERE + node_id=(SELECT node_id FROM nodes + WHERE node=%s) AND + entity_id=(SELECT entity_id FROM entities + WHERE jid=%s) AND + resource=%s""", + (self.nodeIdentifier, + userhost, + resource)) + if cursor.rowcount != 1: + raise error.NotSubscribed() + + return None + + + def isSubscribed(self, entity): + return self.dbpool.runInteraction(self._isSubscribed, entity) + + + def _isSubscribed(self, cursor, entity): + self._checkNodeExists(cursor) + + cursor.execute("""SELECT 1 FROM entities + NATURAL JOIN subscriptions + NATURAL JOIN nodes + WHERE entities.jid=%s + AND node=%s AND state='subscribed'""", + (entity.userhost(), + self.nodeIdentifier)) + + return cursor.fetchone() is not None + + + def getAffiliations(self): + return self.dbpool.runInteraction(self._getAffiliations) + + + def _getAffiliations(self, cursor): + self._checkNodeExists(cursor) + + cursor.execute("""SELECT jid, affiliation FROM nodes + NATURAL JOIN affiliations + NATURAL JOIN entities + WHERE node=%s""", + (self.nodeIdentifier,)) + result = cursor.fetchall() + + return [(jid.internJID(r[0]), r[1]) for r in result] + + + +class LeafNode(Node): + + implements(iidavoll.ILeafNode) + + nodeType = 'leaf' + + def storeItems(self, items, publisher): + return self.dbpool.runInteraction(self._storeItems, items, publisher) + + + def _storeItems(self, cursor, items, publisher): + self._checkNodeExists(cursor) + for item in items: + self._storeItem(cursor, item, publisher) + + + def _storeItem(self, cursor, item, publisher): + data = item.toXml() + cursor.execute("""UPDATE items SET date=now(), publisher=%s, data=%s + FROM nodes + WHERE nodes.node_id = items.node_id AND + nodes.node = %s and items.item=%s""", + (publisher.full(), + data, + self.nodeIdentifier, + item["id"])) + if cursor.rowcount == 1: + return + + cursor.execute("""INSERT INTO items (node_id, item, publisher, data) + SELECT node_id, %s, %s, %s FROM nodes + WHERE node=%s""", + (item["id"], + publisher.full(), + data, + self.nodeIdentifier)) + + + def removeItems(self, itemIdentifiers): + return self.dbpool.runInteraction(self._removeItems, itemIdentifiers) + + + def _removeItems(self, cursor, itemIdentifiers): + self._checkNodeExists(cursor) + + deleted = [] + + for itemIdentifier in itemIdentifiers: + cursor.execute("""DELETE FROM items WHERE + node_id=(SELECT node_id FROM nodes + WHERE node=%s) AND + item=%s""", + (self.nodeIdentifier, + itemIdentifier)) + + if cursor.rowcount: + deleted.append(itemIdentifier) + + return deleted + + + def getItems(self, maxItems=None): + return self.dbpool.runInteraction(self._getItems, maxItems) + + + def _getItems(self, cursor, maxItems): + self._checkNodeExists(cursor) + query = """SELECT data FROM nodes + NATURAL JOIN items + WHERE node=%s ORDER BY date DESC""" + if maxItems: + cursor.execute(query + " LIMIT %s", + (self.nodeIdentifier, + maxItems)) + else: + cursor.execute(query, (self.nodeIdentifier,)) + + result = cursor.fetchall() + items = [stripNamespace(parseXml(r[0])) for r in result] + return items + + + def getItemsById(self, itemIdentifiers): + return self.dbpool.runInteraction(self._getItemsById, itemIdentifiers) + + + def _getItemsById(self, cursor, itemIdentifiers): + self._checkNodeExists(cursor) + items = [] + for itemIdentifier in itemIdentifiers: + cursor.execute("""SELECT data FROM nodes + NATURAL JOIN items + WHERE node=%s AND item=%s""", + (self.nodeIdentifier, + itemIdentifier)) + result = cursor.fetchone() + if result: + items.append(parseXml(result[0])) + return items + + + def purge(self): + return self.dbpool.runInteraction(self._purge) + + + def _purge(self, cursor): + self._checkNodeExists(cursor) + + cursor.execute("""DELETE FROM items WHERE + node_id=(SELECT node_id FROM nodes WHERE node=%s)""", + (self.nodeIdentifier,)) + + +class CollectionNode(Node): + + nodeType = 'collection' + + + +class GatewayStorage(object): + """ + Memory based storage facility for the XMPP-HTTP gateway. + """ + + def __init__(self, dbpool): + self.dbpool = dbpool + + + def _countCallbacks(self, cursor, service, nodeIdentifier): + """ + Count number of callbacks registered for a node. + """ + cursor.execute("""SELECT count(*) FROM callbacks + WHERE service=%s and node=%s""", + service.full(), + nodeIdentifier) + results = cursor.fetchall() + return results[0][0] + + + def addCallback(self, service, nodeIdentifier, callback): + def interaction(cursor): + cursor.execute("""SELECT 1 FROM callbacks + WHERE service=%s and node=%s and uri=%s""", + service.full(), + nodeIdentifier, + callback) + if cursor.fetchall(): + return + + cursor.execute("""INSERT INTO callbacks + (service, node, uri) VALUES + (%s, %s, %s)""", + service.full(), + nodeIdentifier, + callback) + + return self.dbpool.runInteraction(interaction) + + + def removeCallback(self, service, nodeIdentifier, callback): + def interaction(cursor): + cursor.execute("""DELETE FROM callbacks + WHERE service=%s and node=%s and uri=%s""", + service.full(), + nodeIdentifier, + callback) + + if cursor.rowcount != 1: + raise error.NotSubscribed() + + last = not self._countCallbacks(cursor, service, nodeIdentifier) + return last + + return self.dbpool.runInteraction(interaction) + + def getCallbacks(self, service, nodeIdentifier): + def interaction(cursor): + cursor.execute("""SELECT uri FROM callbacks + WHERE service=%s and node=%s""", + service.full(), + nodeIdentifier) + results = cursor.fetchall() + + if not results: + raise error.NoCallbacks() + + return [result[0] for result in results] + + return self.dbpool.runInteraction(interaction) + + + def hasCallbacks(self, service, nodeIdentifier): + def interaction(cursor): + return bool(self._countCallbacks(cursor, service, nodeIdentifier)) + + return self.dbpool.runInteraction(interaction)