Mercurial > libervia-pubsub
view idavoll/pgsql_backend.py @ 123:8f99b4f7aea2
Add tests for all Node methods.
author | Ralph Meijer <ralphm@ik.nu> |
---|---|
date | Tue, 12 Apr 2005 12:28:07 +0000 |
parents | 8d8946e67fcb |
children |
line wrap: on
line source
from twisted.application import service from twisted.internet import defer from twisted.words.protocols.jabber import jid from twisted.enterprise import adbapi import backend class Storage: def __init__(self, user, database): self.dbpool = adbapi.ConnectionPool('pyPgSQL.PgSQL', user=user, database=database) def _check_node_exists(self, cursor, node_id): cursor.execute("""SELECT id FROM nodes WHERE node=%s""", (node_id.encode('utf8'))) if not cursor.fetchone(): raise backend.NodeNotFound else: return def _get_node_configuration(self, cursor, node_id): configuration = {} cursor.execute("""SELECT persistent, deliver_payload FROM nodes WHERE node=%s""", (node_id,)) try: (configuration["pubsub#persist_items"], configuration["pubsub#deliver_payloads"]) = cursor.fetchone() return configuration except TypeError: raise backend.NodeNotFound def get_node_configuration(self, node_id): return self.dbpool.runInteraction(self._get_node_configuration, node_id) def _get_affiliation(self, cursor, node_id, entity): self._check_node_exists(cursor, node_id) cursor.execute("""SELECT affiliation FROM affiliations JOIN nodes ON (node_id=nodes.id) JOIN entities ON (entity_id=entities.id) WHERE node=%s AND jid=%s""", (node_id.encode('utf8'), entity.full().encode('utf8'))) try: return cursor.fetchone()[0] except TypeError: return None def get_affiliation(self, node_id, entity): return self.dbpool.runInteraction(self._get_affiliation, node_id, entity) def get_subscribers(self, node_id): d = self.dbpool.runInteraction(self._get_subscribers, node_id) d.addCallback(self._convert_to_jids) return d def _get_subscribers(self, cursor,node_id): self._check_node_exists(cursor, node_id) cursor.execute("""SELECT jid, resource FROM subscriptions JOIN nodes ON (node_id=nodes.id) JOIN entities ON (entity_id=entities.id) WHERE node=%s AND subscription='subscribed'""", (node_id.encode('utf8'),)) return cursor.fetchall() def _convert_to_jids(self, list): return [jid.JID("%s/%s" % (l[0], l[1])).full() for l in list] def store_items(self, node_id, items, publisher): return self.dbpool.runInteraction(self._store_items, node_id, items, publisher) def _store_items(self, cursor, node_id, items, publisher): self._check_node_exists(cursor, node_id) for item in items: self._store_item(cursor, node_id, item, publisher) def _store_item(self, cursor, node_id, item, publisher): data = item.toXml() cursor.execute("""UPDATE items SET date=now(), publisher=%s, data=%s FROM nodes WHERE nodes.id = items.node_id AND nodes.node = %s and items.item=%s""", (publisher.full().encode('utf8'), data.encode('utf8'), node_id.encode('utf8'), item["id"].encode('utf8'))) if cursor.rowcount == 1: return cursor.execute("""INSERT INTO items (node_id, item, publisher, data) SELECT id, %s, %s, %s FROM nodes WHERE node=%s""", (item["id"].encode('utf8'), publisher.full().encode('utf8'), data.encode('utf8'), node_id.encode('utf8'))) def add_subscription(self, node_id, subscriber, state): return self.dbpool.runInteraction(self._add_subscription, node_id, subscriber, state) def _add_subscription(self, cursor, node_id, subscriber, state): self._check_node_exists(cursor, node_id) userhost = subscriber.userhost() resource = subscriber.resource or '' try: cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""", (userhost.encode('utf8'))) except: pass try: cursor.execute("""INSERT INTO subscriptions (node_id, entity_id, resource, subscription) SELECT n.id, e.id, %s, %s FROM (SELECT id FROM nodes WHERE node=%s) AS n CROSS JOIN (SELECT id FROM entities WHERE jid=%s) AS e""", (resource.encode('utf8'), state.encode('utf8'), node_id.encode('utf8'), userhost.encode('utf8'))) except: cursor.execute("""SELECT subscription FROM subscriptions JOIN nodes ON (nodes.id=subscriptions.node_id) JOIN entities ON (entities.id=subscriptions.entity_id) WHERE node=%s AND jid=%s AND resource=%s""", (node_id.encode('utf8'), userhost.encode('utf8'), resource.encode('utf8'))) state = cursor.fetchone()[0] return {'node': node_id, 'jid': subscriber, 'subscription': state} def remove_subscription(self, node_id, subscriber): return self.dbpool.runInteraction(self._remove_subscription, node_id, subscriber) def _remove_subscription(self, cursor, node_id, subscriber): self._check_node_exists(cursor, node_id) userhost = subscriber.userhost() resource = subscriber.resource or '' cursor.execute("""DELETE FROM subscriptions WHERE node_id=(SELECT id FROM nodes WHERE node=%s) AND entity_id=(SELECT id FROM entities WHERE jid=%s) AND resource=%s""", (node_id.encode('utf8'), userhost.encode('utf8'), resource.encode('utf8'))) if cursor.rowcount != 1: raise backend.NotSubscribed return None def create_node(self, node_id, owner): return self.dbpool.runInteraction(self._create_node, node_id, owner) def _create_node(self, cursor, node_id, owner): try: cursor.execute("""INSERT INTO nodes (node) VALUES (%s)""", (node_id.encode('utf8'))) except: raise backend.NodeExists cursor.execute("""SELECT 1 from entities where jid=%s""", (owner.full().encode('utf8'))) if not cursor.fetchone(): cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""", (owner.full().encode('utf8'))) cursor.execute("""INSERT INTO affiliations (node_id, entity_id, affiliation) SELECT n.id, e.id, 'owner' FROM (SELECT id FROM nodes WHERE node=%s) AS n CROSS JOIN (SELECT id FROM entities WHERE jid=%s) AS e""", (node_id.encode('utf8'), owner.full().encode('utf8'))) return None def get_affiliations(self, entity): return self.dbpool.runQuery("""SELECT node, affiliation FROM entities JOIN affiliations ON (affiliations.entity_id=entities.id) JOIN nodes ON (nodes.id=affiliations.node_id) WHERE jid=%s""", (entity.full().encode('utf8'),)) def get_subscriptions(self, entity): d = self.dbpool.runQuery("""SELECT node, jid, resource, subscription FROM entities JOIN subscriptions ON (subscriptions.entity_id=entities.id) JOIN nodes ON (nodes.id=subscriptions.node_id) WHERE jid=%s""", (entity.full().encode('utf8'),)) d.addCallback(self._convert_subscription_jids) return d def _convert_subscription_jids(self, subscriptions): return [(node, jid.JID('%s/%s' % (subscriber, resource)), subscription) for node, subscriber, resource, subscription in subscriptions] def get_node_type(self, node_id): return self.dbpool.runInteraction(self._get_node_type, node_id) def _get_node_type(self, cursor, node_id): self._check_node_exists(cursor, node_id) return 'leaf' def get_nodes(self): d = self.dbpool.runQuery("""SELECT node from nodes""") d.addCallback(lambda results: [r[0] for r in results]) return d def is_subscribed(self, node_id, subscriber): return self.dbpool.runInteraction(self._is_subscribed, node_id, subscriber) def _is_subscribed(self, cursor, node_id, subscriber): self._check_node_exists(cursor, node_id) userhost = subscriber.userhost() resource = subscriber.resource or '' cursor.execute("""SELECT 1 FROM entities JOIN subscriptions ON (entities.id=subscriptions.entity_id) JOIN nodes ON (nodes.id=subscriptions.node_id) WHERE entities.jid=%s AND resource=%s AND node=%s""", (userhost.encode('utf8'), resource.encode('utf8'), node_id.encode('utf8'))) return cursor.fetchone() is not None def get_items_by_ids(self, node_id, item_ids): return self.dbpool.runInteraction(self._get_items_by_ids, node_id, item_ids) def _get_items_by_ids(self, cursor, node_id, item_ids): self._check_node_exists(cursor, node_id) items = [] for item_id in item_ids: cursor.execute("""SELECT data FROM nodes JOIN items ON (nodes.id=items.node_id) WHERE node=%s AND item=%s""", (node_id.encode('utf8'), item_id.encode('utf8'))) result = cursor.fetchone() if result: items.append(result[0]) return items def get_items(self, node_id, max_items=None): return self.dbpool.runInteraction(self._get_items, node_id, max_items) def _get_items(self, cursor, node_id, max_items): self._check_node_exists(cursor, node_id) query = """SELECT data FROM nodes JOIN items ON (nodes.id=items.node_id) WHERE node=%s ORDER BY date DESC""" if max_items: cursor.execute(query + " LIMIT %s", (node_id.encode('utf8'), max_items)) else: cursor.execute(query, (node_id.encode('utf8'))) result = cursor.fetchall() return [r[0] for r in result] def remove_items(self, node_id, item_ids): return self.dbpool.runInteraction(self._remove_items, node_id, item_ids) def _remove_items(self, cursor, node_id, item_ids): self._check_node_exists(cursor, node_id) deleted = [] for item_id in item_ids: cursor.execute("""DELETE FROM items WHERE node_id=(SELECT id FROM nodes WHERE node=%s) AND item=%s""", (node_id.encode('utf-8'), item_id.encode('utf-8'))) if cursor.rowcount: deleted.append(item_id) return deleted def purge_node(self, node_id): return self.dbpool.runInteraction(self._purge_node, node_id) def _purge_node(self, cursor, node_id): self._check_node_exists(cursor, node_id) cursor.execute("""DELETE FROM items WHERE node_id=(SELECT id FROM nodes WHERE node=%s)""", (node_id.encode('utf-8'),)) def delete_node(self, node_id): return self.dbpool.runInteraction(self._delete_node, node_id) def _delete_node(self, cursor, node_id): self._check_node_exists(cursor, node_id) cursor.execute("""DELETE FROM nodes WHERE node=%s""", (node_id.encode('utf-8'),)) def set_node_configuration(self, node_id, options): return self.dbpool.runInteraction(self._set_node_configuration, node_id, options) def _set_node_configuration(self, cursor, node_id, options): cursor.execute("""UPDATE nodes SET persistent=%s, deliver_payload=%s WHERE node=%s""", (options["pubsub#persist_items"].encode('utf8'), options["pubsub#deliver_payloads"].encode('utf8'), node_id.encode('utf-8'))) if cursor.rowcount != 1: raise backend.Error class BackendService(backend.BackendService): """ PostgreSQL backend Service for a JEP-0060 pubsub service """ class NodeCreationService(backend.NodeCreationService): pass class PublishService(backend.PublishService): pass class NotificationService(backend.NotificationService): pass class SubscriptionService(backend.SubscriptionService): pass class AffiliationsService(backend.AffiliationsService): pass class ItemRetrievalService(backend.ItemRetrievalService): pass class RetractionService(backend.RetractionService): pass class NodeDeletionService(backend.NodeDeletionService): pass