diff sat_pubsub/pgsql_storage.py @ 278:8a71486c3e95

implements RSM (XEP-0059)
author souliane <souliane@mailoo.org>
date Mon, 13 Oct 2014 14:53:42 +0200
parents b757c29b20d7
children 7d54ff2eeaf2
line wrap: on
line diff
--- a/sat_pubsub/pgsql_storage.py	Mon Dec 15 13:14:53 2014 +0100
+++ b/sat_pubsub/pgsql_storage.py	Mon Oct 13 14:53:42 2014 +0200
@@ -608,36 +608,65 @@
         return deleted
 
 
-    def getItems(self, authorized_groups, unrestricted, maxItems=None):
+    def getItems(self, authorized_groups, unrestricted, maxItems=None, ext_data=None):
         """ Get all authorised items
         @param authorized_groups: we want to get items that these groups can access
         @param unrestricted: if true, don't check permissions (i.e.: get all items)
         @param maxItems: nb of items we want to tget
+        @param rsm_data: options for RSM feature handling (XEP-0059) as a
+                         dictionnary of C{unicode} to C{unicode}.
+
         @return: list of (item, access_model, id) if unrestricted is True, else list of items
         """
-        return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems)
+        if ext_data is None:
+            ext_data = {}
+        return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems, ext_data)
 
-
-    def _getItems(self, cursor, authorized_groups, unrestricted, maxItems):
+    def _getItems(self, cursor, authorized_groups, unrestricted, maxItems, ext_data):
         self._checkNodeExists(cursor)
+
         if unrestricted:
-            query = ["""SELECT data,items.access_model,item_id FROM nodes
+            query = ["SELECT data,items.access_model,item_id"]
+            source = """FROM nodes
                        INNER JOIN items USING (node_id)
-                       WHERE node=%s ORDER BY date DESC"""]
+                       WHERE node=%s"""
             args = [self.nodeIdentifier]
         else:
-            query = ["""SELECT data FROM nodes
-                       INNER  JOIN items USING (node_id)
+            query = ["SELECT data"]
+            groups = " or (items.access_model='roster' and groupname in %s)" if authorized_groups else ""
+            source = """FROM nodes
+                       INNER JOIN items USING (node_id)
                        LEFT JOIN item_groups_authorized USING (item_id)
                        WHERE node=%s AND
-                       (items.access_model='open' """ +
-                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
-                       """)
-                       ORDER BY date DESC"""]
+                       (items.access_model='open'""" + groups + ")"
+
             args = [self.nodeIdentifier]
             if authorized_groups:
                 args.append(authorized_groups)
 
+        query.append(source)
+        order = "DESC"
+
+        if 'rsm' in ext_data:
+            rsm = ext_data['rsm']
+            maxItems = rsm.max
+            if rsm.index is not None:
+                query.append("AND date<=(SELECT date " + source + " ORDER BY date DESC LIMIT 1 OFFSET %s)")
+                args.append(self.nodeIdentifier)
+                if authorized_groups:
+                    args.append(authorized_groups)
+                args.append(rsm.index)
+            elif rsm.before is not None:
+                order = "ASC"
+                if rsm.before != '':
+                    query.append("AND date>(SELECT date FROM items WHERE item=%s LIMIT 1)")
+                    args.append(rsm.before)
+            elif rsm.after:
+                query.append("AND date<(SELECT date FROM items WHERE item=%s LIMIT 1)")
+                args.append(rsm.after)
+
+        query.append("ORDER BY date %s" % order)
+
         if maxItems:
             query.append("LIMIT %s")
             args.append(maxItems)
@@ -662,6 +691,79 @@
         items = [generic.stripNamespace(parseXml(r[0])) for r in result]
         return items
 
+    def countItems(self, authorized_groups, unrestricted):
+        """ Count the accessible items.
+
+        @param authorized_groups: we want to get items that these groups can access.
+        @param unrestricted: if true, don't check permissions (i.e.: get all items).
+        @return: deferred that fires a C{int}.
+        """
+        return self.dbpool.runInteraction(self._countItems, authorized_groups, unrestricted)
+
+    def _countItems(self, cursor, authorized_groups, unrestricted):
+        self._checkNodeExists(cursor)
+
+        if unrestricted:
+            query = ["""SELECT count(item_id) FROM nodes
+                       INNER JOIN items USING (node_id)
+                       WHERE node=%s"""]
+            args = [self.nodeIdentifier]
+        else:
+            query = ["""SELECT count(item_id) FROM nodes
+                       INNER  JOIN items USING (node_id)
+                       LEFT JOIN item_groups_authorized USING (item_id)
+                       WHERE node=%s AND
+                       (items.access_model='open' """ +
+                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
+                       ")"]
+
+            args = [self.nodeIdentifier]
+            if authorized_groups:
+                args.append(authorized_groups)
+
+        cursor.execute(' '.join(query), args)
+        return cursor.fetchall()[0][0]
+
+    def getIndex(self, authorized_groups, unrestricted, item):
+        """ Retrieve the index of the given item within the accessible window.
+
+        @param authorized_groups: we want to get items that these groups can access.
+        @param unrestricted: if true, don't check permissions (i.e.: get all items).
+        @param item: item identifier.
+        @return: deferred that fires a C{int}.
+        """
+        return self.dbpool.runInteraction(self._getIndex, authorized_groups, unrestricted, item)
+
+    def _getIndex(self, cursor, authorized_groups, unrestricted, item):
+        self._checkNodeExists(cursor)
+
+        if unrestricted:
+            query = ["""SELECT row_number FROM (
+                       SELECT row_number() OVER (ORDER BY date DESC), item
+                       FROM nodes INNER JOIN items USING (node_id)
+                       WHERE node=%s
+                       ) as x
+                       WHERE item=%s LIMIT 1"""]
+            args = [self.nodeIdentifier]
+        else:
+            query = ["""SELECT row_number FROM (
+                       SELECT row_number() OVER (ORDER BY date DESC), item
+                       FROM nodes INNER JOIN items USING (node_id)
+                       LEFT JOIN item_groups_authorized USING (item_id)
+                       WHERE node=%s AND
+                       (items.access_model='open' """ +
+                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
+                       """)) as x
+                       WHERE item=%s LIMIT 1"""]
+
+            args = [self.nodeIdentifier]
+            if authorized_groups:
+                args.append(authorized_groups)
+
+        args.append(item)
+        cursor.execute(' '.join(query), args)
+
+        return cursor.fetchall()[0][0]
 
     def getItemsById(self, authorized_groups, unrestricted, itemIdentifiers):
         """ Get items which are in the given list