changeset 318:d13526c0eb32

RSM improvments/refactoring: - a warning message is displayed if maxItems == 0 in getItems, and an empty list is returned in this case - use the new container.ItemData instead of doing tuple (un)packing - the list of ItemData => list of domish.Element conversion is done at the end of the workflow - rsm request is checked in self._items_rsm directly - better handling of Response.index in _items_rsm - itemsIdentifiers can't be used with RSM (the later will be ignored if this happen) - don't do approximative unpacking anymore in _items_rsm - countItems and getIndex have been refactored and renamed getItemsCount and getItemsIndex, don't use query duplications anymore - cleaned query handling in getItems - /!\ mam module is temporarly broken
author Goffi <goffi@goffi.org>
date Sun, 03 Jan 2016 18:33:22 +0100
parents 34adc4a8aa64
children a51947371625
files sat_pubsub/backend.py sat_pubsub/pgsql_storage.py
diffstat 2 files changed, 183 insertions(+), 184 deletions(-) [+]
line wrap: on
line diff
--- a/sat_pubsub/backend.py	Sun Jan 03 18:33:22 2016 +0100
+++ b/sat_pubsub/backend.py	Sun Jan 03 18:33:22 2016 +0100
@@ -597,28 +597,28 @@
     def _doGetItems(self, result, requestor, maxItems, itemIdentifiers,
                     ext_data):
         node, affiliation = result
+        if maxItems == 0:
+            log.msg("WARNING: maxItems=0 on items retrieval")
+            return []
 
         def append_item_config(items_data):
-            ret = []
+            """Add item config data form to items with roster access model"""
             for item_data in items_data:
-                item, access_model, access_list = item_data.item, item_data.access_model, item_data.config
-                if access_model == const.VAL_AMODEL_OPEN:
+                if item_data.access_model == const.VAL_AMODEL_OPEN:
                     pass
-                elif access_model == const.VAL_AMODEL_ROSTER:
+                elif item_data.access_model == const.VAL_AMODEL_ROSTER:
                     form = data_form.Form('submit', formNamespace=const.NS_ITEM_CONFIG)
                     access = data_form.Field(None, const.OPT_ACCESS_MODEL, value=const.VAL_AMODEL_ROSTER)
-                    allowed = data_form.Field(None, const.OPT_ROSTER_GROUPS_ALLOWED, values=access_list[const.OPT_ROSTER_GROUPS_ALLOWED])
+                    allowed = data_form.Field(None, const.OPT_ROSTER_GROUPS_ALLOWED, values=item_data.config[const.OPT_ROSTER_GROUPS_ALLOWED])
                     form.addField(access)
                     form.addField(allowed)
-                    item.addChild(form.toElement())
+                    item_data.item.addChild(form.toElement())
                 elif access_model == const.VAL_AMODEL_JID:
                     #FIXME: manage jid
                     raise NotImplementedError
                 else:
                     raise error.BadAccessTypeError(access_model)
-
-                ret.append(item)
-            return ret
+            return items_data
 
         def access_checked(access_data):
             authorized, roster = access_data
@@ -627,24 +627,18 @@
 
             roster_item = roster.get(requestor.userhostJID())
             authorized_groups = tuple(roster_item.groups) if roster_item else tuple()
-            unrestricted = affiliation == 'owner'
+            owner = affiliation == 'owner'
 
             if itemIdentifiers:
-                d = node.getItemsById(authorized_groups, unrestricted, itemIdentifiers)
+                d = node.getItemsById(authorized_groups, owner, itemIdentifiers)
             else:
-                d = node.getItems(authorized_groups, unrestricted, maxItems, ext_data)
-                if unrestricted:
+                d = node.getItems(authorized_groups, owner, maxItems, ext_data)
+                if owner:
                     d.addCallback(append_item_config)
 
-            try:
-                rsm_data = ext_data['rsm']
-            except KeyError:
-                pass
-            else:
-                if rsm_data is not None:
-                    d.addCallback(self._items_rsm, node, authorized_groups,
-                                  unrestricted, maxItems, itemIdentifiers,
-                                  rsm_data)
+            d.addCallback(self._items_rsm, node, authorized_groups,
+                          owner, itemIdentifiers,
+                          ext_data)
             return d
 
         if not iidavoll.ILeafNode.providedBy(node):
@@ -671,63 +665,69 @@
             d.addCallback(self.checkGroup, requestor)
             d.addCallback(access_checked)
 
+        d.addCallback(lambda items_data: [item_data.item for item_data in items_data])
+
         return d
 
     def _setCount(self, value, response):
         response.count = value
 
-    def _setIndex(self, value, response):
-        response.index = value
+    def _setIndex(self, value, response, adjust):
+        """Set index in RSM response
+
+        @param value(int): value of the reference index (i.e. before or after item)
+        @param response(RSMResponse): response instance to fill
+        @param adjust(int): adjustement term (i.e. difference between reference index and first item of the result)
+        """
+        response.index = value + adjust
 
-    def _items_rsm(self, elts, node, authorized_groups, unrestricted, maxItems,
-                   itemIdentifiers, request):
-        # FIXME: move this to a separate module ?
+    def _items_rsm(self, items_data, node, authorized_groups, owner,
+                   itemIdentifiers, ext_data):
+        # FIXME: move this to a separate module
+        # TODO: Index can be optimized by keeping a cache of the last RSM request
+        #       An other optimisation would be to look for index first and use it as offset
+        try:
+            rsm_request = ext_data['rsm']
+        except KeyError:
+            # No RSM in this request, nothing to do
+            return items_data
+
+        if itemIdentifiers:
+            log.msg("WARNING, itemIdentifiers used with RSM, ignoring the RSM part")
+            return items_data
+
         response = rsm.RSMResponse()
 
-        d_count = node.countItems(authorized_groups, unrestricted)
+        d_count = node.getItemsCount(authorized_groups, owner, ext_data)
         d_count.addCallback(self._setCount, response)
         d_list = [d_count]
 
-        if request.index is not None:
-            response.index = request.index
-        elif request.before is not None:
-            if request.before != '':
-                # XXX: getIndex starts with index 1, RSM starts with 0
-                d_index = node.getIndex(authorized_groups, unrestricted, request.before)
-                d_index.addCallback(lambda index: max(index - request.max - 1, 0))
-                d_index.addCallback(self._setIndex, response)
+        if items_data:
+            response.first = items_data[0].item['id']
+            response.last = items_data[-1].item['id']
+
+            # index handling
+            if rsm_request.index is not None:
+                response.index = rsm_request.index
+            elif rsm_request.before:
+                # The last page case (before == '') is managed in render method
+                d_index = node.getItemsIndex(rsm_request.before, authorized_groups, owner, ext_data)
+                d_index.addCallback(self._setIndex, response, -len(items_data))
                 d_list.append(d_index)
-        elif request.after is not None:
-            d_index = node.getIndex(authorized_groups, unrestricted, request.after)
-            d_index.addCallback(self._setIndex, response)
-            d_list.append(d_index)
-        elif itemIdentifiers:
-            d_index = node.getIndex(authorized_groups, unrestricted, itemIdentifiers[0])
-            d_index.addCallback(lambda index: index - 1)
-            d_index.addCallback(self._setIndex, response)
-            d_list.append(d_index)
-
+            elif rsm_request.after is not None:
+                d_index = node.getItemsIndex(rsm_request.after, authorized_groups, owner, ext_data)
+                d_index.addCallback(self._setIndex, response, 1)
+                d_list.append(d_index)
+            else:
+                # the first page was requested
+                response.index = 0
 
         def render(result):
-            try:
-                items = [elt for elt in elts if elt.name == 'item']
-            except AttributeError:
-                # XXX: see sat_pubsub.pgsql_storage.LeafNode.getItemsById return value
-                items = [elt[0] for elt in elts if elt[0].name == 'item']
-            if len(items) > 0:
-                if response.index is None:
-                    if request.before == '': # last page
-                        response.index = response.count - request.max
-                    else:  # first page
-                        response.index = 0
-                response.first = items[0]['id']
-                response.last = items[len(items) - 1]['id']
-                if request.before is not None:
-                    response.first, response.last = response.last, response.first
-            else:
-                response.index = None
-            elts.append(response.toElement())
-            return elts
+            if rsm_request.before == '':
+                # the last page was requested
+                response.index = response.count - len(items_data)
+            items_data.append(container.ItemData(response.toElement()))
+            return items_data
 
         return defer.DeferredList(d_list).addCallback(render)
 
--- a/sat_pubsub/pgsql_storage.py	Sun Jan 03 18:33:22 2016 +0100
+++ b/sat_pubsub/pgsql_storage.py	Sun Jan 03 18:33:22 2016 +0100
@@ -602,13 +602,11 @@
     def storeItems(self, item_data, publisher):
         return self.dbpool.runInteraction(self._storeItems, item_data, publisher)
 
-
     def _storeItems(self, cursor, items_data, publisher):
         self._checkNodeExists(cursor)
         for item_data in items_data:
             self._storeItem(cursor, item_data, publisher)
 
-
     def _storeItem(self, cursor, item_data, publisher):
         item, access_model, item_config = item_data.item, item_data.access_model, item_data.config
         data = item.toXml()
@@ -665,7 +663,6 @@
     def removeItems(self, itemIdentifiers):
         return self.dbpool.runInteraction(self._removeItems, itemIdentifiers)
 
-
     def _removeItems(self, cursor, itemIdentifiers):
         self._checkNodeExists(cursor)
 
@@ -683,14 +680,12 @@
 
         return deleted
 
-
     def getItems(self, authorized_groups, unrestricted, maxItems=None, ext_data=None):
         """ Get all authorised items
         @param authorized_groups: we want to get items that these groups can access
         @param unrestricted: if true, don't check permissions (i.e.: get all items)
-        @param maxItems: nb of items we want to tget
-        @param rsm_data: options for RSM feature handling (XEP-0059) as a
-                         dictionnary of C{unicode} to C{unicode}.
+        @param maxItems: nb of items we want to get
+        @param ext_data: options for extra features like RSM and MAM
 
         @return: list of container.ItemData
             if unrestricted is False, access_model and config will be None
@@ -699,70 +694,98 @@
             ext_data = {}
         return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems, ext_data)
 
-    def _getItems(self, cursor, authorized_groups, unrestricted, maxItems, ext_data):
-        #  FIXME: simplify the query construction
-        self._checkNodeExists(cursor)
+    def _appendSourcesAndFilters(self, query, args, authorized_groups, unrestricted, ext_data):
+        """append sources and filters to sql query requesting items and return ORDER BY
 
+        arguments query, args, authorized_groups, unrestricted and ext_data are the same as for
+        _getItems
+        """
+        # SOURCES
         if unrestricted:
-            query = ["SELECT data,items.access_model,item_id"]
-            source = """FROM nodes
-                       INNER JOIN items USING (node_id)
-                       WHERE node_id=%s"""
-            args = [self.nodeDbId]
+            query.append("""FROM nodes
+                INNER JOIN items USING (node_id)
+                WHERE node_id=%s""")
+            args.append(self.nodeDbId)
         else:
-            query = ["SELECT data"]
-            groups = " or (items.access_model='roster' and groupname in %s)" if authorized_groups else ""
-            source = """FROM nodes
-                       INNER JOIN items USING (node_id)
-                       LEFT JOIN item_groups_authorized USING (item_id)
-                       WHERE node_id=%s AND
-                       (items.access_model='open'""" + groups + ")"
+            args.append(self.nodeDbId)
+            if authorized_groups:
+                get_groups = " or (items.access_model='roster' and groupname in %s)"
+                args.append(authorized_groups)
+            else:
+                get_groups = ""
 
-            args = [self.nodeDbId]
-            if authorized_groups:
-                args.append(authorized_groups)
+            query.append("""FROM nodes
+                INNER JOIN items USING (node_id)
+                LEFT JOIN item_groups_authorized USING (item_id)
+                WHERE node_id=%s AND
+                (items.access_model='open'""" + get_groups + ")")
 
+        # FILTERS
         if 'filters' in ext_data:  # MAM filters
             for filter_ in ext_data['filters']:
                 if filter_.var == 'start':
-                    source += " AND date>=%s"
+                    query.append("AND date>=%s")
                     args.append(filter_.value)
                 if filter_.var == 'end':
-                    source += " AND date<=%s"
+                    query.append("AND date<=%s")
                     args.append(filter_.value)
                 if filter_.var == 'with':
                     jid_s = filter_.value
                     if '/' in jid_s:
-                        source += " AND publisher=%s"
+                        query.append("AND publisher=%s")
                         args.append(filter_.value)
                     else:
-                        source += " AND publisher LIKE %s"
+                        query.append("AND publisher LIKE %s")
                         args.append(u"{}%".format(filter_.value))
 
-        query.append(source)
-        order = "DESC"
+        return "ORDER BY date DESC"
+
+    def _getItems(self, cursor, authorized_groups, unrestricted, maxItems, ext_data):
+        self._checkNodeExists(cursor)
+
+        if maxItems == 0:
+            return []
+
+        args = []
+
+        # SELECT
+        query = ["SELECT data,items.access_model,item_id,date"]
+
+        query_order = self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
 
         if 'rsm' in ext_data:
             rsm = ext_data['rsm']
             maxItems = rsm.max
             if rsm.index is not None:
-                query.append("AND date<=(SELECT date " + source + " ORDER BY date DESC LIMIT 1 OFFSET %s)")
-                # FIXME: change the request so source is not used 2 times
-                # there is already a placeholder in source with node_id=%s, so we need to add self.noDbId in args
-                args.append(self.nodeDbId)
-                if authorized_groups:
-                    args.append(authorized_groups)
-                args.append(rsm.index)
+                # We need to know the date of corresponding to the index (offset) of the current query
+                # so we execute the query to look for the date
+                tmp_query = query[:]
+                tmp_args = args[:]
+                tmp_query[0] = "SELECT date"
+                tmp_query.append("{} LIMIT 1 OFFSET %s".format(query_order))
+                tmp_args.append(rsm.index)
+                cursor.execute(' '.join(query), args)
+                # FIXME: bad index is not managed yet
+                date = cursor.fetchall()[0][0]
+
+                # now that we have the date, we can use it
+                query.append("AND date<=%s")
+                args.append(date)
             elif rsm.before is not None:
-                order = "ASC"
                 if rsm.before != '':
                     query.append("AND date>(SELECT date FROM items WHERE item=%s LIMIT 1)")
                     args.append(rsm.before)
+                if maxItems is not None:
+                    # if we have maxItems (i.e. a limit), we need to reverse order
+                    # in a first query to get the right items
+                    query.insert(0,"SELECT * from (")
+                    query.append("ORDER BY date ASC LIMIT %s) as x")
+                    args.append(maxItems)
             elif rsm.after:
                 query.append("AND date<(SELECT date FROM items WHERE item=%s LIMIT 1)")
                 args.append(rsm.after)
 
-        query.append("ORDER BY date %s" % order)
+        query.append(query_order)
 
         if maxItems is not None:
             query.append("LIMIT %s")
@@ -772,6 +795,7 @@
 
         result = cursor.fetchall()
         if unrestricted:
+            # with unrestricted query, we need to fill the access_list for a roster access items
             ret = []
             for data in result:
                 item = generic.stripNamespace(parseXml(data[0]))
@@ -784,84 +808,10 @@
 
                 ret.append(container.ItemData(item, access_model, access_list))
             return ret
+
         items_data = [container.ItemData(generic.stripNamespace(parseXml(r[0])), None, None) for r in result]
         return items_data
 
-    def countItems(self, authorized_groups, unrestricted):
-        """ Count the accessible items.
-
-        @param authorized_groups: we want to get items that these groups can access.
-        @param unrestricted: if true, don't check permissions (i.e.: get all items).
-        @return: deferred that fires a C{int}.
-        """
-        return self.dbpool.runInteraction(self._countItems, authorized_groups, unrestricted)
-
-    def _countItems(self, cursor, authorized_groups, unrestricted):
-        # FIXME: should not be a separate method, but should be an option of getItems instead
-        self._checkNodeExists(cursor)
-
-        if unrestricted:
-            query = ["""SELECT count(item_id) FROM nodes
-                       INNER JOIN items USING (node_id)
-                       WHERE node_id=%s"""]
-            args = [self.nodeDbId]
-        else:
-            query = ["""SELECT count(item_id) FROM nodes
-                       INNER  JOIN items USING (node_id)
-                       LEFT JOIN item_groups_authorized USING (item_id)
-                       WHERE node_id=%s AND
-                       (items.access_model='open' """ +
-                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
-                       ")"]
-
-            args = [self.nodeDbId]
-            if authorized_groups:
-                args.append(authorized_groups)
-
-        cursor.execute(' '.join(query), args)
-        return cursor.fetchall()[0][0]
-
-    def getIndex(self, authorized_groups, unrestricted, item):
-        """ Retrieve the index of the given item within the accessible window.
-
-        @param authorized_groups: we want to get items that these groups can access.
-        @param unrestricted: if true, don't check permissions (i.e.: get all items).
-        @param item: item identifier.
-        @return: deferred that fires a C{int}.
-        """
-        return self.dbpool.runInteraction(self._getIndex, authorized_groups, unrestricted, item)
-
-    def _getIndex(self, cursor, authorized_groups, unrestricted, item):
-        self._checkNodeExists(cursor)
-
-        if unrestricted:
-            query = ["""SELECT row_number FROM (
-                       SELECT row_number() OVER (ORDER BY date DESC), item
-                       FROM nodes INNER JOIN items USING (node_id)
-                       WHERE node_id=%s
-                       ) as x
-                       WHERE item=%s LIMIT 1"""]
-            args = [self.nodeDbId]
-        else:
-            query = ["""SELECT row_number FROM (
-                       SELECT row_number() OVER (ORDER BY date DESC), item
-                       FROM nodes INNER JOIN items USING (node_id)
-                       LEFT JOIN item_groups_authorized USING (item_id)
-                       WHERE node_id=%s AND
-                       (items.access_model='open' """ +
-                       ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') +
-                       """)) as x
-                       WHERE item=%s LIMIT 1"""]
-
-            args = [self.nodeDbId]
-            if authorized_groups:
-                args.append(authorized_groups)
-
-        args.append(item)
-        cursor.execute(' '.join(query), args)
-
-        return cursor.fetchall()[0][0]
-
     def getItemsById(self, authorized_groups, unrestricted, itemIdentifiers):
         """ Get items which are in the given list
         @param authorized_groups: we want to get items that these groups can access
@@ -873,7 +823,6 @@
         """
         return self.dbpool.runInteraction(self._getItemsById, authorized_groups, unrestricted, itemIdentifiers)
 
-
     def _getItemsById(self, cursor, authorized_groups, unrestricted, itemIdentifiers):
         self._checkNodeExists(cursor)
         ret = []
@@ -916,6 +865,59 @@
 
         return ret
 
+    def getItemsCount(self, authorized_groups, unrestricted, ext_data=None):
+        """Count expected number of items in a getItems query
+
+        @param authorized_groups: we want to get items that these groups can access
+        @param unrestricted: if true, don't check permissions (i.e.: get all items)
+        @param ext_data: options for extra features like RSM and MAM
+        """
+        if ext_data is None:
+            ext_data = {}
+        return self.dbpool.runInteraction(self._getItemsCount, authorized_groups, unrestricted, ext_data)
+
+    def _getItemsCount(self, cursor, authorized_groups, unrestricted, ext_data):
+        self._checkNodeExists(cursor)
+        args = []
+
+        # SELECT
+        query = ["SELECT count(1)"]
+
+        self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
+
+        cursor.execute(' '.join(query), args)
+        return cursor.fetchall()[0][0]
+
+    def getItemsIndex(self, item_id, authorized_groups, unrestricted, ext_data=None):
+        """Get expected index of first item in the window of a getItems query
+
+        @param item_id: id of the item
+        @param authorized_groups: we want to get items that these groups can access
+        @param unrestricted: if true, don't check permissions (i.e.: get all items)
+        @param ext_data: options for extra features like RSM and MAM
+        """
+        if ext_data is None:
+            ext_data = {}
+        return self.dbpool.runInteraction(self._getItemsIndex, item_id, authorized_groups, unrestricted, ext_data)
+
+    def _getItemsIndex(self, cursor, item_id, authorized_groups, unrestricted, ext_data):
+        self._checkNodeExists(cursor)
+        args = []
+
+        # SELECT
+        query = []
+
+        query_order = self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
+
+        query_select = "SELECT row_number from (SELECT row_number() OVER ({}), item".format(query_order)
+        query.insert(0, query_select)
+        query.append(") as x WHERE item=%s")
+        args.append(item_id)
+
+
+        cursor.execute(' '.join(query), args)
+        # XXX: row_number start at 1, but we want that index start at 0
+        return cursor.fetchall()[0][0] - 1
 
     def getItemsPublishers(self, itemIdentifiers):
         """Get the publishers for all given identifiers
@@ -925,7 +927,6 @@
         """
         return self.dbpool.runInteraction(self._getItemsPublishers, itemIdentifiers)
 
-
     def _getItemsPublishers(self, cursor, itemIdentifiers):
         self._checkNodeExists(cursor)
         ret = {}
@@ -938,11 +939,9 @@
                 ret[itemIdentifier] = jid.JID(result[0])
         return ret
 
-
     def purge(self):
         return self.dbpool.runInteraction(self._purge)
 
-
     def _purge(self, cursor):
         self._checkNodeExists(cursor)