diff sat_pubsub/pgsql_storage.py @ 318:d13526c0eb32

RSM improvments/refactoring: - a warning message is displayed if maxItems == 0 in getItems, and an empty list is returned in this case - use the new container.ItemData instead of doing tuple (un)packing - the list of ItemData => list of domish.Element conversion is done at the end of the workflow - rsm request is checked in self._items_rsm directly - better handling of Response.index in _items_rsm - itemsIdentifiers can't be used with RSM (the later will be ignored if this happen) - don't do approximative unpacking anymore in _items_rsm - countItems and getIndex have been refactored and renamed getItemsCount and getItemsIndex, don't use query duplications anymore - cleaned query handling in getItems - /!\ mam module is temporarly broken
author Goffi <goffi@goffi.org>
date Sun, 03 Jan 2016 18:33:22 +0100
parents 34adc4a8aa64
children a51947371625
line wrap: on
line diff
--- a/sat_pubsub/pgsql_storage.py	Sun Jan 03 18:33:22 2016 +0100
+++ b/sat_pubsub/pgsql_storage.py	Sun Jan 03 18:33:22 2016 +0100
@@ -602,13 +602,11 @@
     def storeItems(self, item_data, publisher):
         return self.dbpool.runInteraction(self._storeItems, item_data, publisher)
 
-
     def _storeItems(self, cursor, items_data, publisher):
         self._checkNodeExists(cursor)
         for item_data in items_data:
             self._storeItem(cursor, item_data, publisher)
 
-
     def _storeItem(self, cursor, item_data, publisher):
         item, access_model, item_config = item_data.item, item_data.access_model, item_data.config
         data = item.toXml()
@@ -665,7 +663,6 @@
     def removeItems(self, itemIdentifiers):
         return self.dbpool.runInteraction(self._removeItems, itemIdentifiers)
 
-
     def _removeItems(self, cursor, itemIdentifiers):
         self._checkNodeExists(cursor)
 
@@ -683,14 +680,12 @@
 
         return deleted
 
-
     def getItems(self, authorized_groups, unrestricted, maxItems=None, ext_data=None):
         """ Get all authorised items
         @param authorized_groups: we want to get items that these groups can access
         @param unrestricted: if true, don't check permissions (i.e.: get all items)
-        @param maxItems: nb of items we want to tget
-        @param rsm_data: options for RSM feature handling (XEP-0059) as a
-                         dictionnary of C{unicode} to C{unicode}.
+        @param maxItems: nb of items we want to get
+        @param ext_data: options for extra features like RSM and MAM
 
         @return: list of container.ItemData
             if unrestricted is False, access_model and config will be None
@@ -699,70 +694,98 @@
             ext_data = {}
         return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems, ext_data)
 
-    def _getItems(self, cursor, authorized_groups, unrestricted, maxItems, ext_data):
-        #  FIXME: simplify the query construction
-        self._checkNodeExists(cursor)
+    def _appendSourcesAndFilters(self, query, args, authorized_groups, unrestricted, ext_data):
+        """append sources and filters to sql query requesting items and return ORDER BY
 
+        arguments query, args, authorized_groups, unrestricted and ext_data are the same as for
+        _getItems
+        """
+        # SOURCES
         if unrestricted:
-            query = ["SELECT data,items.access_model,item_id"]
-            source = """FROM nodes
-                       INNER JOIN items USING (node_id)
-                       WHERE node_id=%s"""
-            args = [self.nodeDbId]
+            query.append("""FROM nodes
+                INNER JOIN items USING (node_id)
+                WHERE node_id=%s""")
+            args.append(self.nodeDbId)
         else:
-            query = ["SELECT data"]
-            groups = " or (items.access_model='roster' and groupname in %s)" if authorized_groups else ""
-            source = """FROM nodes
-                       INNER JOIN items USING (node_id)
-                       LEFT JOIN item_groups_authorized USING (item_id)
-                       WHERE node_id=%s AND
-                       (items.access_model='open'""" + groups + ")"
+            args.append(self.nodeDbId)
+            if authorized_groups:
+                get_groups = " or (items.access_model='roster' and groupname in %s)"
+                args.append(authorized_groups)
+            else:
+                get_groups = ""
 
-            args = [self.nodeDbId]
-            if authorized_groups:
-                args.append(authorized_groups)
+            query.append("""FROM nodes
+                INNER JOIN items USING (node_id)
+                LEFT JOIN item_groups_authorized USING (item_id)
+                WHERE node_id=%s AND
+                (items.access_model='open'""" + get_groups + ")")
 
+        # FILTERS
         if 'filters' in ext_data:  # MAM filters
             for filter_ in ext_data['filters']:
                 if filter_.var == 'start':
-                    source += " AND date>=%s"
+                    query.append("AND date>=%s")
                     args.append(filter_.value)
                 if filter_.var == 'end':
-                    source += " AND date<=%s"
+                    query.append("AND date<=%s")
                     args.append(filter_.value)
                 if filter_.var == 'with':
                     jid_s = filter_.value
                     if '/' in jid_s:
-                        source += " AND publisher=%s"
+                        query.append("AND publisher=%s")
                         args.append(filter_.value)
                     else:
-                        source += " AND publisher LIKE %s"
+                        query.append("AND publisher LIKE %s")
                         args.append(u"{}%".format(filter_.value))
 
-        query.append(source)
-        order = "DESC"
+        return "ORDER BY date DESC"
+
+    def _getItems(self, cursor, authorized_groups, unrestricted, maxItems, ext_data):
+        self._checkNodeExists(cursor)
+
+        if maxItems == 0:
+            return []
+
+        args = []
+
+        # SELECT
+        query = ["SELECT data,items.access_model,item_id,date"]
+
+        query_order = self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
 
         if 'rsm' in ext_data:
             rsm = ext_data['rsm']
             maxItems = rsm.max
             if rsm.index is not None:
-                query.append("AND date<=(SELECT date " + source + " ORDER BY date DESC LIMIT 1 OFFSET %s)")
-                # FIXME: change the request so source is not used 2 times
-                # there is already a placeholder in source with node_id=%s, so we need to add self.noDbId in args
-                args.append(self.nodeDbId)
-                if authorized_groups:
-                    args.append(authorized_groups)
-                args.append(rsm.index)
+                # We need to know the date of corresponding to the index (offset) of the current query
+                # so we execute the query to look for the date
+                tmp_query = query[:]
+                tmp_args = args[:]
+                tmp_query[0] = "SELECT date"
+                tmp_query.append("{} LIMIT 1 OFFSET %s".format(query_order))
+                tmp_args.append(rsm.index)
+                cursor.execute(' '.join(query), args)
+                # FIXME: bad index is not managed yet
+                date = cursor.fetchall()[0][0]
+
+                # now that we have the date, we can use it
+                query.append("AND date<=%s")
+                args.append(date)
             elif rsm.before is not None:
-                order = "ASC"
                 if rsm.before != '':
                     query.append("AND date>(SELECT date FROM items WHERE item=%s LIMIT 1)")
                     args.append(rsm.before)
+                if maxItems is not None:
+                    # if we have maxItems (i.e. a limit), we need to reverse order
+                    # in a first query to get the right items
+                    query.insert(0,"SELECT * from (")
+                    query.append("ORDER BY date ASC LIMIT %s) as x")
+                    args.append(maxItems)
             elif rsm.after:
                 query.append("AND date<(SELECT date FROM items WHERE item=%s LIMIT 1)")
                 args.append(rsm.after)
 
-        query.append("ORDER BY date %s" % order)
+        query.append(query_order)
 
         if maxItems is not None:
             query.append("LIMIT %s")
@@ -772,6 +795,7 @@
 
         result = cursor.fetchall()
         if unrestricted:
+            # with unrestricted query, we need to fill the access_list for a roster access items
             ret = []
             for data in result:
                 item = generic.stripNamespace(parseXml(data[0]))
@@ -784,84 +808,10 @@
 
                 ret.append(container.ItemData(item, access_model, access_list))
             return ret
+
         items_data = [container.ItemData(generic.stripNamespace(parseXml(r[0])), None, None) for r in result]
         return items_data
 
-    def countItems(self, authorized_groups, unrestricted):
-        """ Count the accessible items.
-
-        @param authorized_groups: we want to get items that these groups can access.
-        @param unrestricted: if true, don't check permissions (i.e.: get all items).
-        @return: deferred that fires a C{int}.
-        """
-        return self.dbpool.runInteraction(self._countItems, authorized_groups, unrestricted)
-
-    def _countItems(self, cursor, authorized_groups, unrestricted):
-        # FIXME: should not be a separate method, but should be an option of getItems instead
-        self._checkNodeExists(cursor)
-
-        if unrestricted:
-            query = ["""SELECT count(item_id) FROM nodes
-                       INNER JOIN items USING (node_id)
-                       WHERE node_id=%s"""]
-            args = [self.nodeDbId]
-        else:
-            query = ["""SELECT count(item_id) FROM nodes
-                       INNER  JOIN items USING (node_id)
-                       LEFT JOIN item_groups_authorized USING (item_id)
-                       WHERE node_id=%s AND
-                       (items.access_model='open' """ +
-                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
-                       ")"]
-
-            args = [self.nodeDbId]
-            if authorized_groups:
-                args.append(authorized_groups)
-
-        cursor.execute(' '.join(query), args)
-        return cursor.fetchall()[0][0]
-
-    def getIndex(self, authorized_groups, unrestricted, item):
-        """ Retrieve the index of the given item within the accessible window.
-
-        @param authorized_groups: we want to get items that these groups can access.
-        @param unrestricted: if true, don't check permissions (i.e.: get all items).
-        @param item: item identifier.
-        @return: deferred that fires a C{int}.
-        """
-        return self.dbpool.runInteraction(self._getIndex, authorized_groups, unrestricted, item)
-
-    def _getIndex(self, cursor, authorized_groups, unrestricted, item):
-        self._checkNodeExists(cursor)
-
-        if unrestricted:
-            query = ["""SELECT row_number FROM (
-                       SELECT row_number() OVER (ORDER BY date DESC), item
-                       FROM nodes INNER JOIN items USING (node_id)
-                       WHERE node_id=%s
-                       ) as x
-                       WHERE item=%s LIMIT 1"""]
-            args = [self.nodeDbId]
-        else:
-            query = ["""SELECT row_number FROM (
-                       SELECT row_number() OVER (ORDER BY date DESC), item
-                       FROM nodes INNER JOIN items USING (node_id)
-                       LEFT JOIN item_groups_authorized USING (item_id)
-                       WHERE node_id=%s AND
-                       (items.access_model='open' """ +
-                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
-                       """)) as x
-                       WHERE item=%s LIMIT 1"""]
-
-            args = [self.nodeDbId]
-            if authorized_groups:
-                args.append(authorized_groups)
-
-        args.append(item)
-        cursor.execute(' '.join(query), args)
-
-        return cursor.fetchall()[0][0]
-
     def getItemsById(self, authorized_groups, unrestricted, itemIdentifiers):
         """ Get items which are in the given list
         @param authorized_groups: we want to get items that these groups can access
@@ -873,7 +823,6 @@
         """
         return self.dbpool.runInteraction(self._getItemsById, authorized_groups, unrestricted, itemIdentifiers)
 
-
     def _getItemsById(self, cursor, authorized_groups, unrestricted, itemIdentifiers):
         self._checkNodeExists(cursor)
         ret = []
@@ -916,6 +865,59 @@
 
         return ret
 
+    def getItemsCount(self, authorized_groups, unrestricted, ext_data=None):
+        """Count expected number of items in a getItems query
+
+        @param authorized_groups: we want to get items that these groups can access
+        @param unrestricted: if true, don't check permissions (i.e.: get all items)
+        @param ext_data: options for extra features like RSM and MAM
+        """
+        if ext_data is None:
+            ext_data = {}
+        return self.dbpool.runInteraction(self._getItemsCount, authorized_groups, unrestricted, ext_data)
+
+    def _getItemsCount(self, cursor, authorized_groups, unrestricted, ext_data):
+        self._checkNodeExists(cursor)
+        args = []
+
+        # SELECT
+        query = ["SELECT count(1)"]
+
+        self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
+
+        cursor.execute(' '.join(query), args)
+        return cursor.fetchall()[0][0]
+
+    def getItemsIndex(self, item_id, authorized_groups, unrestricted, ext_data=None):
+        """Get expected index of first item in the window of a getItems query
+
+        @param item_id: id of the item
+        @param authorized_groups: we want to get items that these groups can access
+        @param unrestricted: if true, don't check permissions (i.e.: get all items)
+        @param ext_data: options for extra features like RSM and MAM
+        """
+        if ext_data is None:
+            ext_data = {}
+        return self.dbpool.runInteraction(self._getItemsIndex, item_id, authorized_groups, unrestricted, ext_data)
+
+    def _getItemsIndex(self, cursor, item_id, authorized_groups, unrestricted, ext_data):
+        self._checkNodeExists(cursor)
+        args = []
+
+        # SELECT
+        query = []
+
+        query_order = self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
+
+        query_select = "SELECT row_number from (SELECT row_number() OVER ({}), item".format(query_order)
+        query.insert(0, query_select)
+        query.append(") as x WHERE item=%s")
+        args.append(item_id)
+
+
+        cursor.execute(' '.join(query), args)
+        # XXX: row_number start at 1, but we want that index start at 0
+        return cursor.fetchall()[0][0] - 1
 
     def getItemsPublishers(self, itemIdentifiers):
         """Get the publishers for all given identifiers
@@ -925,7 +927,6 @@
         """
         return self.dbpool.runInteraction(self._getItemsPublishers, itemIdentifiers)
 
-
     def _getItemsPublishers(self, cursor, itemIdentifiers):
         self._checkNodeExists(cursor)
         ret = {}
@@ -938,11 +939,9 @@
                 ret[itemIdentifier] = jid.JID(result[0])
         return ret
 
-
     def purge(self):
         return self.dbpool.runInteraction(self._purge)
 
-
     def _purge(self, cursor):
         self._checkNodeExists(cursor)