diff sat_pubsub/pgsql_storage.py @ 367:a772f7dac930

backend, storage(pgsql): creation/update date + serial ids: /!\ this patch updates pgqsl schema /!\ Had to set 2 features in the same patch, to avoid updating 2 times the schema. 1) creation/last modification date: column keeping the date of creation of items is renamed from "date" to "created" the date of last modification of items is saved in the new "updated" column 2) serial ids: this experimental feature allows to have ids in series (i.e. 1, 2, 3, etc.) instead of UUID. This is a convenience feature and there are some drawbacks: - PostgreSQL sequences are used, so gaps can happen (see PostgreSQL documentation for more details) - if somebody create an item with a future id in the series, the series will adapt, which can have undesired effect, and may lead to item fail if several items are created at the same time. For instance if next id in series is "8", and somebody hads already created item "8" and "256", the item will be created with biggest value in items +1 (i.e. 257). if 2 people want to create item in this situation, the second will fail with a conflict error.
author Goffi <goffi@goffi.org>
date Sat, 04 Nov 2017 21:31:32 +0100
parents 81e6d4a516c3
children 618a92080812
line wrap: on
line diff
--- a/sat_pubsub/pgsql_storage.py	Sat Nov 04 21:17:12 2017 +0100
+++ b/sat_pubsub/pgsql_storage.py	Sat Nov 04 21:31:32 2017 +0100
@@ -56,6 +56,7 @@
 from zope.interface import implements
 
 from twisted.internet import reactor
+from twisted.internet import defer
 from twisted.words.protocols.jabber import jid
 from twisted.python import log
 
@@ -66,6 +67,8 @@
 from sat_pubsub import iidavoll
 from sat_pubsub import const
 from sat_pubsub import container
+from sat_pubsub import exceptions
+import uuid
 import psycopg2
 import psycopg2.extensions
 # we wants psycopg2 to return us unicode, not str
@@ -74,8 +77,11 @@
 
 # parseXml manage str, but we get unicode
 parseXml = lambda unicode_data: generic.parseXml(unicode_data.encode('utf-8'))
+ITEMS_SEQ_NAME = u'node_{node_id}_seq'
 PEP_COL_NAME = 'pep'
-CURRENT_VERSION = '3'
+CURRENT_VERSION = '4'
+# retrieve the maximum integer item id + 1
+NEXT_ITEM_ID_QUERY = r"SELECT COALESCE(max(item::integer)+1,1) as val from items where node_id={node_id} and item ~ E'^\\d+$'"
 
 
 def withPEP(query, values, pep, recipient):
@@ -107,6 +113,7 @@
                 const.OPT_SEND_LAST_PUBLISHED_ITEM: 'on_sub',
                 const.OPT_ACCESS_MODEL: const.VAL_AMODEL_DEFAULT,
                 const.OPT_PUBLISH_MODEL: const.VAL_PMODEL_DEFAULT,
+                const.OPT_SERIAL_IDS: False,
             },
             'collection': {
                 const.OPT_DELIVER_PAYLOADS: True,
@@ -146,8 +153,9 @@
                     'pubsub#send_last_published_item': row[5],
                     const.OPT_ACCESS_MODEL:row[6],
                     const.OPT_PUBLISH_MODEL:row[7],
+                    const.OPT_SERIAL_IDS:row[8],
                     }
-            schema = row[8]
+            schema = row[9]
             if schema is not None:
                 schema = parseXml(schema)
             node = LeafNode(row[0], row[1], configuration, schema)
@@ -182,6 +190,7 @@
                                  send_last_published_item,
                                  access_model,
                                  publish_model,
+                                 serial_ids,
                                  schema::text,
                                  pep
                             FROM nodes
@@ -202,6 +211,7 @@
                                           send_last_published_item,
                                           access_model,
                                           publish_model,
+                                          serial_ids,
                                           schema::text,
                                           pep
                                    FROM nodes
@@ -247,16 +257,25 @@
 
         try:
             cursor.execute("""INSERT INTO nodes
-                              (node, node_type, persist_items,
-                               deliver_payloads, send_last_published_item, access_model, publish_model, schema, pep)
+                              (node,
+                               node_type,
+                               persist_items,
+                               deliver_payloads,
+                               send_last_published_item,
+                               access_model,
+                               publish_model,
+                               serial_ids,
+                               schema,
+                               pep)
                               VALUES
-                              (%s, 'leaf', %s, %s, %s, %s, %s, %s, %s)""",
+                              (%s, 'leaf', %s, %s, %s, %s, %s, %s, %s, %s)""",
                            (nodeIdentifier,
                             config['pubsub#persist_items'],
                             config['pubsub#deliver_payloads'],
                             config['pubsub#send_last_published_item'],
                             config[const.OPT_ACCESS_MODEL],
                             config[const.OPT_PUBLISH_MODEL],
+                            config[const.OPT_SERIAL_IDS],
                             schema,
                             recipient.userhost() if pep else None
                             )
@@ -282,12 +301,12 @@
             # "WHERE NOT EXISTS" but none of them worked, so the following solution
             # looks like the sole - unless you have auto-commit on. More info
             # about this issue: http://cssmay.com/question/tag/tag-psycopg2
-            cursor._connection.commit()
+            cursor.connection.commit()
             try:
                 cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""",
                                (owner,))
             except psycopg2.IntegrityError as e:
-                cursor._connection.rollback()
+                cursor.connection.rollback()
                 logging.warning("during node creation: %s" % e.message)
 
         cursor.execute("""INSERT INTO affiliations
@@ -309,6 +328,12 @@
         # XXX: affiliations can't be set on during node creation (at least not with XEP-0060 alone)
         #      so whitelist affiliations need to be done afterward
 
+        # no we may have to do extra things according to config options
+        default_conf = self.defaultConfig['leaf']
+        # XXX: trigger works on node creation because OPT_SERIAL_IDS is False in defaultConfig
+        #      if this value is changed, the _configurationTriggers method should be adapted.
+        Node._configurationTriggers(cursor, node_id, default_conf, config)
+
     def deleteNodeByDbId(self, db_id):
         """Delete a node using directly its database id"""
         return self.dbpool.runInteraction(self._deleteNodeByDbId, db_id)
@@ -448,6 +473,49 @@
     def getConfiguration(self):
         return self._config
 
+    def getNextId(self):
+        """return XMPP item id usable for next item to publish
+
+        the return value will be next int if serila_ids is set,
+        else an UUID will be returned
+        """
+        if self._config[const.OPT_SERIAL_IDS]:
+            d = self.dbpool.runQuery("SELECT nextval('{seq_name}')".format(
+                seq_name = ITEMS_SEQ_NAME.format(node_id=self.nodeDbId)))
+            d.addCallback(lambda rows: unicode(rows[0][0]))
+            return d
+        else:
+            return defer.succeed(unicode(uuid.uuid4()))
+
+    @staticmethod
+    def _configurationTriggers(cursor, node_id, old_config, new_config):
+        """trigger database relative actions needed when a config is changed
+
+        @param cursor(): current db cursor
+        @param node_id(unicode): database ID of the node
+        @param old_config(dict): config of the node before the change
+        @param new_config(dict): new options that will be changed
+        """
+        serial_ids = new_config[const.OPT_SERIAL_IDS]
+        if serial_ids != old_config[const.OPT_SERIAL_IDS]:
+            # serial_ids option has been modified,
+            # we need to handle corresponding sequence
+
+            # XXX: we use .format in following queries because values
+            #      are generated by ourself
+            seq_name = ITEMS_SEQ_NAME.format(node_id=node_id)
+            if serial_ids:
+                # the next query get the max value +1 of all XMPP items ids
+                # which are integers, and default to 1
+                cursor.execute(NEXT_ITEM_ID_QUERY.format(node_id=node_id))
+                next_val = cursor.fetchone()[0]
+                cursor.execute("DROP SEQUENCE IF EXISTS {seq_name}".format(seq_name = seq_name))
+                cursor.execute("CREATE SEQUENCE {seq_name} START {next_val} OWNED BY nodes.node_id".format(
+                    seq_name = seq_name,
+                    next_val = next_val))
+            else:
+                cursor.execute("DROP SEQUENCE IF EXISTS {seq_name}".format(seq_name = seq_name))
+
     def setConfiguration(self, options):
         config = copy.copy(self._config)
 
@@ -461,17 +529,20 @@
 
     def _setConfiguration(self, cursor, config):
         self._checkNodeExists(cursor)
+        self._configurationTriggers(cursor, self.nodeDbId, self._config, config)
         cursor.execute("""UPDATE nodes SET persist_items=%s,
                                            deliver_payloads=%s,
                                            send_last_published_item=%s,
                                            access_model=%s,
-                                           publish_model=%s
+                                           publish_model=%s,
+                                           serial_ids=%s
                           WHERE node_id=%s""",
                        (config[const.OPT_PERSIST_ITEMS],
                         config[const.OPT_DELIVER_PAYLOADS],
                         config[const.OPT_SEND_LAST_PUBLISHED_ITEM],
                         config[const.OPT_ACCESS_MODEL],
                         config[const.OPT_PUBLISH_MODEL],
+                        config[const.OPT_SERIAL_IDS],
                         self.nodeDbId))
 
     def _setCachedConfiguration(self, void, config):
@@ -596,7 +667,7 @@
             cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""",
                            (userhost,))
         except cursor._pool.dbapi.IntegrityError:
-            cursor._connection.rollback()
+            cursor.connection.rollback()
 
         try:
             cursor.execute("""INSERT INTO subscriptions
@@ -772,33 +843,70 @@
             self._storeItem(cursor, item_data, publisher)
 
     def _storeItem(self, cursor, item_data, publisher):
+        # first try to insert the item
+        # - if it fails (conflict), and the item is new and we have serial_ids options,
+        #   current id will be recomputed using next item id query (note that is not perfect, as
+        #   table is not locked and this can fail if two items are added at the same time
+        #   but this can only happen with serial_ids and if future ids have been set by a client,
+        #   this case should be rare enough to consider this situation acceptable)
+        # - if item insertion fail and the item is not new, we do an update
+        # - in other cases, exception is raised
         item, access_model, item_config = item_data.item, item_data.access_model, item_data.config
         data = item.toXml()
 
-        cursor.execute("""UPDATE items SET date=now(), publisher=%s, data=%s
-                          FROM nodes
-                          WHERE nodes.node_id = items.node_id AND
-                                nodes.node_id = %s and items.item=%s
-                          RETURNING item_id""",
-                       (publisher.full(),
-                        data,
-                        self.nodeDbId,
-                        item["id"]))
-        if cursor.rowcount == 1:
-            item_id = cursor.fetchone()[0];
-            self._storeCategories(cursor, item_id, item_data.categories, update=True)
-            return
+        insert_query = """INSERT INTO items (node_id, item, publisher, data, access_model)
+                                             SELECT %s, %s, %s, %s, %s FROM nodes
+                                                                        WHERE node_id=%s
+                                                                        RETURNING item_id"""
+        insert_data = [self.nodeDbId,
+                       item["id"],
+                       publisher.full(),
+                       data,
+                       access_model,
+                       self.nodeDbId]
 
-        cursor.execute("""INSERT INTO items (node_id, item, publisher, data, access_model)
-                          SELECT %s, %s, %s, %s, %s FROM nodes
-                                                     WHERE node_id=%s
-                                                     RETURNING item_id""",
-                       (self.nodeDbId,
-                        item["id"],
-                        publisher.full(),
-                        data,
-                        access_model,
-                        self.nodeDbId))
+        try:
+            cursor.execute(insert_query, insert_data)
+        except cursor._pool.dbapi.IntegrityError as e:
+            if e.pgcode != "23505":
+                # we only handle unique_violation, every other exception must be raised
+                raise e
+            cursor.connection.rollback()
+            # the item already exist
+            if item_data.new:
+                # the item is new
+                if self._config[const.OPT_SERIAL_IDS]:
+                    # this can happen with serial_ids, if a item has been stored
+                    # with a future id (generated by XMPP client)
+                    cursor.execute(NEXT_ITEM_ID_QUERY.format(node_id=self.nodeDbId))
+                    next_id = cursor.fetchone()[0]
+                    # we update the sequence, so we can skip conflicting ids
+                    cursor.execute(u"SELECT setval('{seq_name}', %s)".format(
+                        seq_name = ITEMS_SEQ_NAME.format(node_id=self.nodeDbId)), [next_id])
+                    # and now we can retry the query with the new id
+                    item['id'] = insert_data[1] = unicode(next_id)
+                    # item saved in DB must also be updated with the new id
+                    insert_data[3] = item.toXml()
+                    cursor.execute(insert_query, insert_data)
+                else:
+                    # but if we have not serial_ids, we have a real problem
+                    raise e
+            else:
+                # this is an update
+                cursor.execute("""UPDATE items SET updated=now(), publisher=%s, data=%s
+                                  FROM nodes
+                                  WHERE nodes.node_id = items.node_id AND
+                                        nodes.node_id = %s and items.item=%s
+                                  RETURNING item_id""",
+                               (publisher.full(),
+                                data,
+                                self.nodeDbId,
+                                item["id"]))
+                if cursor.rowcount != 1:
+                    raise exceptions.InternalError("item has not been updated correctly")
+                item_id = cursor.fetchone()[0];
+                self._storeCategories(cursor, item_id, item_data.categories, update=True)
+                return
 
         item_id = cursor.fetchone()[0];
         self._storeCategories(cursor, item_id, item_data.categories)
@@ -901,10 +1009,10 @@
         if 'filters' in ext_data:  # MAM filters
             for filter_ in ext_data['filters']:
                 if filter_.var == 'start':
-                    query_filters.append("AND date>=%s")
+                    query_filters.append("AND created>=%s")
                     args.append(filter_.value)
                 elif filter_.var == 'end':
-                    query_filters.append("AND date<=%s")
+                    query_filters.append("AND created<=%s")
                     args.append(filter_.value)
                 elif filter_.var == 'with':
                     jid_s = filter_.value
@@ -937,7 +1045,7 @@
         if ids_only:
             query = ["SELECT item"]
         else:
-            query = ["SELECT data::text,items.access_model,item_id,date"]
+            query = ["SELECT data::text,items.access_model,item_id,created,updated"]
 
         query_order = self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
 
@@ -989,20 +1097,21 @@
                 item = generic.stripNamespace(parseXml(item_data.data))
                 access_model = item_data.access_model
                 item_id = item_data.item_id
-                date = item_data.date
+                created = item_data.created
+                updated = item_data.updated
                 access_list = {}
                 if access_model == const.VAL_AMODEL_PUBLISHER_ROSTER:
                     cursor.execute('SELECT groupname FROM item_groups_authorized WHERE item_id=%s', (item_id,))
                     access_list[const.OPT_ROSTER_GROUPS_ALLOWED] = [r.groupname for r in cursor.fetchall()]
 
-                ret.append(container.ItemData(item, access_model, access_list, date=date))
+                ret.append(container.ItemData(item, access_model, access_list, created=created, updated=updated))
                 # TODO: whitelist item access model
             return ret
 
         if ids_only:
             return [r.item for r in result]
         else:
-            items_data = [container.ItemData(generic.stripNamespace(parseXml(r.data)), r.access_model, date=r.date) for r in result]
+            items_data = [container.ItemData(generic.stripNamespace(parseXml(r.data)), r.access_model, created=r.created, updated=r.updated) for r in result]
         return items_data
 
     def getItemsById(self, authorized_groups, unrestricted, itemIdentifiers):
@@ -1022,7 +1131,7 @@
         ret = []
         if unrestricted: #we get everything without checking permissions
             for itemIdentifier in itemIdentifiers:
-                cursor.execute("""SELECT data::text,items.access_model,item_id,date FROM nodes
+                cursor.execute("""SELECT data::text,items.access_model,item_id,created,updated FROM nodes
                                   INNER JOIN items USING (node_id)
                                   WHERE node_id=%s AND item=%s""",
                                (self.nodeDbId,
@@ -1034,20 +1143,21 @@
                 item = generic.stripNamespace(parseXml(result[0]))
                 access_model = result[1]
                 item_id = result[2]
-                date= result[3]
+                created= result[3]
+                updated= result[4]
                 access_list = {}
                 if access_model == const.VAL_AMODEL_PUBLISHER_ROSTER:
                     cursor.execute('SELECT groupname FROM item_groups_authorized WHERE item_id=%s', (item_id,))
                     access_list[const.OPT_ROSTER_GROUPS_ALLOWED] = [r[0] for r in cursor.fetchall()]
                  #TODO: WHITELIST access_model
 
-                ret.append(container.ItemData(item, access_model, access_list, date=date))
+                ret.append(container.ItemData(item, access_model, access_list, created=created, updated=updated))
         else: #we check permission before returning items
             for itemIdentifier in itemIdentifiers:
                 args = [self.nodeDbId, itemIdentifier]
                 if authorized_groups:
                     args.append(authorized_groups)
-                cursor.execute("""SELECT data::text, date FROM nodes
+                cursor.execute("""SELECT data::text, created, updated FROM nodes
                            INNER  JOIN items USING (node_id)
                            LEFT JOIN item_groups_authorized USING (item_id)
                            WHERE node_id=%s AND item=%s AND
@@ -1057,7 +1167,7 @@
 
                 result = cursor.fetchone()
                 if result:
-                    ret.append(container.ItemData(generic.stripNamespace(parseXml(result[0])), date=result[1]))
+                    ret.append(container.ItemData(generic.stripNamespace(parseXml(result[0])), created=result[1], updated=result[2]))
 
         return ret