changeset 278:8a71486c3e95

implements RSM (XEP-0059)
author souliane <souliane@mailoo.org>
date Mon, 13 Oct 2014 14:53:42 +0200
parents e749401be529
children 7c820a8e4b00
files sat_pubsub/backend.py sat_pubsub/const.py sat_pubsub/iidavoll.py sat_pubsub/pgsql_storage.py sat_pubsub/tap.py
diffstat 5 files changed, 213 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/sat_pubsub/backend.py	Mon Dec 15 13:14:53 2014 +0100
+++ b/sat_pubsub/backend.py	Mon Oct 13 14:53:42 2014 +0200
@@ -73,7 +73,7 @@
 from twisted.words.protocols.jabber.jid import JID, InvalidFormat
 from twisted.words.xish import utility
 
-from wokkel import disco, data_form
+from wokkel import disco, data_form, rsm
 from wokkel.iwokkel import IPubSubResource
 from wokkel.pubsub import PubSubResource, PubSubError, Subscription
 
@@ -500,10 +500,13 @@
 
 
     def getItems(self, nodeIdentifier, requestor, maxItems=None,
-                       itemIdentifiers=None):
+                       itemIdentifiers=None, ext_data=None):
+        if ext_data is None:
+            ext_data = {}
         d = self.storage.getNode(nodeIdentifier)
         d.addCallback(_getAffiliation, requestor)
-        d.addCallback(self._doGetItems, requestor, maxItems, itemIdentifiers)
+        d.addCallback(self._doGetItems, requestor, maxItems, itemIdentifiers,
+                      ext_data)
         return d
 
     def checkGroup(self, roster_groups, entity):
@@ -527,7 +530,8 @@
         d.addCallback(lambda groups: (roster, groups))
         return d
 
-    def _doGetItems(self, result, requestor, maxItems, itemIdentifiers):
+    def _doGetItems(self, result, requestor, maxItems, itemIdentifiers,
+                    ext_data):
         node, affiliation = result
 
         def append_item_config(items_data):
@@ -559,16 +563,24 @@
 
             roster_item = roster.get(requestor.userhostJID())
             authorized_groups = tuple(roster_item.groups) if roster_item else tuple()
+            unrestricted = affiliation == 'owner'
 
             if itemIdentifiers:
-                return node.getItemsById(authorized_groups, affiliation == 'owner', itemIdentifiers)
+                d = node.getItemsById(authorized_groups, unrestricted, itemIdentifiers)
             else:
-                if affiliation == 'owner':
-                    d = node.getItems(authorized_groups, True, maxItems)
-                    return d.addCallback(append_item_config)
-                else:
-                    return node.getItems(authorized_groups, False, maxItems)
+                d = node.getItems(authorized_groups, unrestricted, maxItems,
+                                  ext_data)
+                if unrestricted:
+                    d.addCallback(append_item_config)
 
+            for extension in ext_data:
+                if ext_data[extension] is not None:
+                    if hasattr(self, '_items_%s' % extension):
+                        method = getattr(self, '_items_%s' % extension)
+                        d.addCallback(method, node, authorized_groups,
+                                      unrestricted, maxItems, itemIdentifiers,
+                                      ext_data[extension])
+            return d
 
         if not ILeafNode.providedBy(node):
             return []
@@ -579,7 +591,7 @@
         access_model = node.getConfiguration()["pubsub#access_model"]
         d = node.getNodeOwner()
         d.addCallback(self.roster.getRoster)
-        
+
         if access_model == 'open' or affiliation == 'owner':
             d.addCallback(lambda roster: (True, roster))
             d.addCallback(access_checked)
@@ -590,6 +602,51 @@
 
         return d
 
+    def _items_rsm(self, elts, node, authorized_groups, unrestricted, maxItems,
+                   itemIdentifiers, request):
+        response = rsm.RSMResponse()
+
+        d_count = node.countItems(authorized_groups, unrestricted)
+        d_count.addCallback(lambda count: setattr(response, 'count', count))
+        d_list = [d_count]
+
+        if request.index is not None:
+            response.index = request.index
+        elif request.before is not None:
+            if request.before != '':
+                # XXX: getIndex starts with index 1, RSM starts with 0
+                d_index = node.getIndex(authorized_groups, unrestricted, request.before)
+                d_index.addCallback(lambda index: setattr(response, 'index', max(index - request.max - 1, 0)))
+                d_list.append(d_index)
+        elif request.after is not None:
+            d_index = node.getIndex(authorized_groups, unrestricted, request.after)
+            d_index.addCallback(lambda index: setattr(response, 'index', index))
+            d_list.append(d_index)
+        elif itemIdentifiers:
+            d_index = node.getIndex(authorized_groups, unrestricted, itemIdentifiers[0])
+            d_index.addCallback(lambda index: setattr(response, 'index', index - 1))
+            d_list.append(d_index)
+
+
+        def render(result):
+            items = [elt for elt in elts if elt.name == 'item']
+            if len(items) > 0:
+                if response.index is None:
+                    if request.before == '': # last page
+                        response.index = response.count - request.max
+                    else:  # first page
+                        response.index = 0
+                response.first = items[0]['id']
+                response.last = items[len(items) - 1]['id']
+                if request.before is not None:
+                    response.first, response.last = response.last, response.first
+            else:
+                response.index = None
+            elts.append(response.render())
+            return elts
+
+        return defer.DeferredList(d_list).addCallback(render)
+
     def retractItem(self, nodeIdentifier, itemIdentifiers, requestor):
         d = self.storage.getNode(nodeIdentifier)
         d.addCallback(_getAffiliation, requestor)
@@ -1017,13 +1074,19 @@
 
 
     def items(self, request):
+        ext_data = {}
+        if const.FLAG_ENABLE_RSM:
+            rsm_ = rsm.RSMRequest.parse(request.element.pubsub)
+            if not rsm_:
+                rsm_ = rsm.RSMRequest(const.VAL_RSM_MAX_DEFAULT)
+            ext_data['rsm'] = rsm_
         d = self.backend.getItems(request.nodeIdentifier,
                                   request.sender,
                                   request.maxItems,
-                                  request.itemIdentifiers)
+                                  request.itemIdentifiers,
+                                  ext_data)
         return d.addErrback(self._mapErrors)
 
-
     def retract(self, request):
         d = self.backend.retractItem(request.nodeIdentifier,
                                      request.itemIdentifiers,
--- a/sat_pubsub/const.py	Mon Dec 15 13:14:53 2014 +0100
+++ b/sat_pubsub/const.py	Mon Oct 13 14:53:42 2014 +0200
@@ -68,3 +68,5 @@
 VAL_PMODEL_DEFAULT = VAL_PMODEL_PUBLISHERS
 
 FLAG_RETRACT_ALLOW_PUBLISHER = True  # XXX: see the method BackendService._doRetractAllowPublisher
+FLAG_ENABLE_RSM = True
+VAL_RSM_MAX_DEFAULT = 10
--- a/sat_pubsub/iidavoll.py	Mon Dec 15 13:14:53 2014 +0100
+++ b/sat_pubsub/iidavoll.py	Mon Oct 13 14:53:42 2014 +0200
@@ -553,6 +553,24 @@
         """
 
 
+    def countItems(authorized_groups, unrestricted):
+        """ Count the accessible items.
+
+        @param authorized_groups: we want to get items that these groups can access.
+        @param unrestricted: if true, don't check permissions (i.e.: get all items).
+        @return: deferred that fires a C{int}.
+        """
+
+
+    def getIndex(authorized_groups, unrestricted, item):
+        """ Retrieve the index of the given item within the accessible window.
+
+        @param authorized_groups: we want to get items that these groups can access.
+        @param unrestricted: if true, don't check permissions (i.e.: get all items).
+        @param item: item identifier.
+        @return: deferred that fires a C{int}.
+        """
+
     def getItemsById(authorized_groups, unrestricted, itemIdentifiers):
         """
         Get items by item id.
--- a/sat_pubsub/pgsql_storage.py	Mon Dec 15 13:14:53 2014 +0100
+++ b/sat_pubsub/pgsql_storage.py	Mon Oct 13 14:53:42 2014 +0200
@@ -608,36 +608,65 @@
         return deleted
 
 
-    def getItems(self, authorized_groups, unrestricted, maxItems=None):
+    def getItems(self, authorized_groups, unrestricted, maxItems=None, ext_data=None):
         """ Get all authorised items
         @param authorized_groups: we want to get items that these groups can access
         @param unrestricted: if true, don't check permissions (i.e.: get all items)
         @param maxItems: nb of items we want to tget
+        @param rsm_data: options for RSM feature handling (XEP-0059) as a
+                         dictionnary of C{unicode} to C{unicode}.
+
         @return: list of (item, access_model, id) if unrestricted is True, else list of items
         """
-        return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems)
+        if ext_data is None:
+            ext_data = {}
+        return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems, ext_data)
 
-
-    def _getItems(self, cursor, authorized_groups, unrestricted, maxItems):
+    def _getItems(self, cursor, authorized_groups, unrestricted, maxItems, ext_data):
         self._checkNodeExists(cursor)
+
         if unrestricted:
-            query = ["""SELECT data,items.access_model,item_id FROM nodes
+            query = ["SELECT data,items.access_model,item_id"]
+            source = """FROM nodes
                        INNER JOIN items USING (node_id)
-                       WHERE node=%s ORDER BY date DESC"""]
+                       WHERE node=%s"""
             args = [self.nodeIdentifier]
         else:
-            query = ["""SELECT data FROM nodes
-                       INNER  JOIN items USING (node_id)
+            query = ["SELECT data"]
+            groups = " or (items.access_model='roster' and groupname in %s)" if authorized_groups else ""
+            source = """FROM nodes
+                       INNER JOIN items USING (node_id)
                        LEFT JOIN item_groups_authorized USING (item_id)
                        WHERE node=%s AND
-                       (items.access_model='open' """ +
-                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
-                       """)
-                       ORDER BY date DESC"""]
+                       (items.access_model='open'""" + groups + ")"
+
             args = [self.nodeIdentifier]
             if authorized_groups:
                 args.append(authorized_groups)
 
+        query.append(source)
+        order = "DESC"
+
+        if 'rsm' in ext_data:
+            rsm = ext_data['rsm']
+            maxItems = rsm.max
+            if rsm.index is not None:
+                query.append("AND date<=(SELECT date " + source + " ORDER BY date DESC LIMIT 1 OFFSET %s)")
+                args.append(self.nodeIdentifier)
+                if authorized_groups:
+                    args.append(authorized_groups)
+                args.append(rsm.index)
+            elif rsm.before is not None:
+                order = "ASC"
+                if rsm.before != '':
+                    query.append("AND date>(SELECT date FROM items WHERE item=%s LIMIT 1)")
+                    args.append(rsm.before)
+            elif rsm.after:
+                query.append("AND date<(SELECT date FROM items WHERE item=%s LIMIT 1)")
+                args.append(rsm.after)
+
+        query.append("ORDER BY date %s" % order)
+
         if maxItems:
             query.append("LIMIT %s")
             args.append(maxItems)
@@ -662,6 +691,79 @@
         items = [generic.stripNamespace(parseXml(r[0])) for r in result]
         return items
 
+    def countItems(self, authorized_groups, unrestricted):
+        """ Count the accessible items.
+
+        @param authorized_groups: we want to get items that these groups can access.
+        @param unrestricted: if true, don't check permissions (i.e.: get all items).
+        @return: deferred that fires a C{int}.
+        """
+        return self.dbpool.runInteraction(self._countItems, authorized_groups, unrestricted)
+
+    def _countItems(self, cursor, authorized_groups, unrestricted):
+        self._checkNodeExists(cursor)
+
+        if unrestricted:
+            query = ["""SELECT count(item_id) FROM nodes
+                       INNER JOIN items USING (node_id)
+                       WHERE node=%s"""]
+            args = [self.nodeIdentifier]
+        else:
+            query = ["""SELECT count(item_id) FROM nodes
+                       INNER  JOIN items USING (node_id)
+                       LEFT JOIN item_groups_authorized USING (item_id)
+                       WHERE node=%s AND
+                       (items.access_model='open' """ +
+                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
+                       ")"]
+
+            args = [self.nodeIdentifier]
+            if authorized_groups:
+                args.append(authorized_groups)
+
+        cursor.execute(' '.join(query), args)
+        return cursor.fetchall()[0][0]
+
+    def getIndex(self, authorized_groups, unrestricted, item):
+        """ Retrieve the index of the given item within the accessible window.
+
+        @param authorized_groups: we want to get items that these groups can access.
+        @param unrestricted: if true, don't check permissions (i.e.: get all items).
+        @param item: item identifier.
+        @return: deferred that fires a C{int}.
+        """
+        return self.dbpool.runInteraction(self._getIndex, authorized_groups, unrestricted, item)
+
+    def _getIndex(self, cursor, authorized_groups, unrestricted, item):
+        self._checkNodeExists(cursor)
+
+        if unrestricted:
+            query = ["""SELECT row_number FROM (
+                       SELECT row_number() OVER (ORDER BY date DESC), item
+                       FROM nodes INNER JOIN items USING (node_id)
+                       WHERE node=%s
+                       ) as x
+                       WHERE item=%s LIMIT 1"""]
+            args = [self.nodeIdentifier]
+        else:
+            query = ["""SELECT row_number FROM (
+                       SELECT row_number() OVER (ORDER BY date DESC), item
+                       FROM nodes INNER JOIN items USING (node_id)
+                       LEFT JOIN item_groups_authorized USING (item_id)
+                       WHERE node=%s AND
+                       (items.access_model='open' """ +
+                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
+                       """)) as x
+                       WHERE item=%s LIMIT 1"""]
+
+            args = [self.nodeIdentifier]
+            if authorized_groups:
+                args.append(authorized_groups)
+
+        args.append(item)
+        cursor.execute(' '.join(query), args)
+
+        return cursor.fetchall()[0][0]
 
     def getItemsById(self, authorized_groups, unrestricted, itemIdentifiers):
         """ Get items which are in the given list
--- a/sat_pubsub/tap.py	Mon Dec 15 13:14:53 2014 +0100
+++ b/sat_pubsub/tap.py	Mon Oct 13 14:53:42 2014 +0200
@@ -60,9 +60,9 @@
 from wokkel.disco import DiscoHandler
 from wokkel.generic import FallbackHandler, VersionHandler
 from wokkel.iwokkel import IPubSubResource
-from wokkel.pubsub import PubSubService
+from wokkel import pubsub, rsm
 
-from sat_pubsub import __version__
+from sat_pubsub import __version__, const
 from sat_pubsub.backend import BackendService
 from sat_pubsub.remote_roster import RosterClient
 
@@ -141,7 +141,7 @@
     resource.hideNodes = config["hide-nodes"]
     resource.serviceJID = config["jid"]
 
-    ps = PubSubService(resource)
+    ps = (rsm if const.FLAG_ENABLE_RSM else pubsub).PubSubService(resource)
     ps.setHandlerParent(cs)
     resource.pubsubService = ps