diff src/memory/sqlite.py @ 853:c2f6ada7858f

core (sqlite): automatic database update: - new Updater class check database consistency (by calculating a hash on the .schema), and updates base if necessary - database now has a version (1 for current, 0 will be for 0.3's database), for each change this version will be increased - creation statements and update statements are in the form of dict of dict with tuples. There is a help text at the top of the module to explain how it works - if we are on a development version, the updater try to update the database automaticaly (without deleting table or columns). The Updater.generateUpdateData method can be used to ease the creation of update data (i.e. the dictionary at the top, see the one for the key 1 for an example). - if there is an inconsistency, an exception is raised, and a message indicate the SQL statements that should fix the situation. - well... this is rather complicated, a KISS method would maybe have been better. The future will say if we need to simplify it :-/ - new DatabaseError exception
author Goffi <goffi@goffi.org>
date Sun, 23 Feb 2014 23:30:32 +0100
parents f8681a7fd834
children 34dd9287dfe5
line wrap: on
line diff
--- a/src/memory/sqlite.py	Thu Feb 20 13:28:10 2014 +0100
+++ b/src/memory/sqlite.py	Sun Feb 23 23:30:32 2014 +0100
@@ -18,18 +18,63 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from sat.core.i18n import _
+from sat.core import exceptions
 from logging import debug, info, warning, error
 from twisted.enterprise import adbapi
 from twisted.internet import defer
+from collections import OrderedDict
 from time import time
+import re
 import os.path
 import cPickle as pickle
+import hashlib
+
+CURRENT_DB_VERSION = 1
+
+# XXX: DATABASE schemas are used in the following way:
+#      - 'current' key is for the actual database schema, for a new base
+#      - x(int) is for update needed between x-1 and x. All number are needed between y and z to do an update
+#        e.g.: if CURRENT_DB_VERSION is 6, 'current' is the actuel DB, and to update from version 3, numbers 4, 5 and 6 are needed
+#      a 'current' data dict can contains the keys:
+#      - 'CREATE': it contains an Ordered dict with table to create as keys, and a len 2 tuple as value, where value[0] are the columns definitions and value[1] are the table constraints
+#      - 'INSERT': it contains an Ordered dict with table where values have to be inserted, and many tuples containing values to insert in the order of the rows (#TODO: manage named columns)
+#      an update data dict (the ones with a number) can contains the keys 'create', 'delete', 'cols create', 'cols delete', 'cols modify' or 'insert'. See Updater.generateUpdateData for more infos. This metho can be used to autogenerate update_data, to ease the work of the developers.
+
+DATABASE_SCHEMAS = {
+        "current": {'CREATE': OrderedDict((
+                              ('profiles',        (("id INTEGER PRIMARY KEY ASC", "name TEXT"),
+                                                   ("UNIQUE (name)",))),
+                              ('message_types',   (("type TEXT PRIMARY KEY",),
+                                                   tuple())),
+                              ('history',         (("id INTEGER PRIMARY KEY ASC", "profile_id INTEGER", "source TEXT", "dest TEXT", "source_res TEXT", "dest_res TEXT", "timestamp DATETIME", "message TEXT", "type TEXT", "extra BLOB"),
+                                                   ("FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE", "FOREIGN KEY(type) REFERENCES message_types(type)"))),
+                              ('param_gen',       (("category TEXT", "name TEXT", "value TEXT"),
+                                                   ("PRIMARY KEY (category,name)",))),
+                              ('param_ind',       (("category TEXT", "name TEXT", "profile_id INTEGER", "value TEXT"),
+                                                   ("PRIMARY KEY (category,name,profile_id)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE"))),
+                              ('private_gen',     (("namespace TEXT", "key TEXT", "value TEXT"),
+                                                   ("PRIMARY KEY (namespace, key)",))),
+                              ('private_ind',     (("namespace TEXT", "key TEXT", "profile_id INTEGER", "value TEXT"),
+                                                   ("PRIMARY KEY (namespace, key, profile_id)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE"))),
+                              ('private_gen_bin', (("namespace TEXT", "key TEXT", "value BLOB"),
+                                                   ("PRIMARY KEY (namespace, key)",))),
+
+                              ('private_ind_bin', (("namespace TEXT", "key TEXT", "profile_id INTEGER", "value BLOB"),
+                                                   ("PRIMARY KEY (namespace, key, profile_id)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE")))
+                              )),
+                    'INSERT': OrderedDict((
+                              ('message_types', (("'chat'",), ("'error'",), ("'groupchat'",), ("'headline'",), ("'normal'",))),
+                              )),
+                    },
+        1:         {'cols create': {'history': ('extra BLOB',)}
+                   },
+        }
 
 
 class SqliteStorage(object):
     """This class manage storage with Sqlite database"""
 
-    def __init__(self, db_filename):
+    def __init__(self, db_filename, sat_version):
         """Connect to the given database
         @param db_filename: full path to the Sqlite database"""
         self.initialized = defer.Deferred()  # triggered when memory is fully initialised and ready
@@ -38,31 +83,31 @@
         info(_("Connecting database"))
         new_base = not os.path.exists(db_filename)  # do we have to create the database ?
         self.dbpool = adbapi.ConnectionPool("sqlite3", db_filename, check_same_thread=False)
-        
+
         # init_defer is the initialisation deferred, initialisation is ok when all its callbacks have been done
         init_defer = self.dbpool.runOperation("PRAGMA foreign_keys = ON").addErrback(lambda x: error(_("Can't activate foreign keys")))
 
-        if new_base:
+        def getNewBaseSql():
             info(_("The database is new, creating the tables"))
-            database_creation = [
-                "CREATE TABLE profiles (id INTEGER PRIMARY KEY ASC, name TEXT, UNIQUE (name))",
-                "CREATE TABLE message_types (type TEXT PRIMARY KEY)",
-                "INSERT INTO message_types VALUES ('chat')",
-                "INSERT INTO message_types VALUES ('error')",
-                "INSERT INTO message_types VALUES ('groupchat')",
-                "INSERT INTO message_types VALUES ('headline')",
-                "INSERT INTO message_types VALUES ('normal')",
-                "CREATE TABLE history (id INTEGER PRIMARY KEY ASC, profile_id INTEGER, source TEXT, dest TEXT, source_res TEXT, dest_res TEXT, timestamp DATETIME, message TEXT, type TEXT, extra BLOB, FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE, FOREIGN KEY(type) REFERENCES message_types(type))",
-                "CREATE TABLE param_gen (category TEXT, name TEXT, value TEXT, PRIMARY KEY (category,name))",
-                "CREATE TABLE param_ind (category TEXT, name TEXT, profile_id INTEGER, value TEXT, PRIMARY KEY (category,name,profile_id), FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE)",
-                "CREATE TABLE private_gen (namespace TEXT, key TEXT, value TEXT, PRIMARY KEY (namespace, key))",
-                "CREATE TABLE private_ind (namespace TEXT, key TEXT, profile_id INTEGER, value TEXT, PRIMARY KEY (namespace, key, profile_id), FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE)",
-                "CREATE TABLE private_gen_bin (namespace TEXT, key TEXT, value BLOB, PRIMARY KEY (namespace, key))",
-                "CREATE TABLE private_ind_bin (namespace TEXT, key TEXT, profile_id INTEGER, value BLOB, PRIMARY KEY (namespace, key, profile_id), FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE)",
-            ]
-            for op in database_creation:
-                init_defer.addCallback(lambda ignore, sql: self.dbpool.runOperation(sql), op)
-                init_defer.addErrback(lambda ignore, sql: error(_("Error while creating tables in database [QUERY: %s]") % sql, op))
+            database_creation = ["PRAGMA user_version=%d" % CURRENT_DB_VERSION]
+            database_creation.extend(Updater.createData2Raw(DATABASE_SCHEMAS['current']['CREATE']))
+            database_creation.extend(Updater.insertData2Raw(DATABASE_SCHEMAS['current']['INSERT']))
+            return database_creation
+
+        def getUpdateSql():
+            updater = Updater(self.dbpool, sat_version)
+            return updater.checkUpdates()
+
+        def commitStatements(statements):
+
+            if statements is None:
+                return defer.succeed(None)
+            debug("===== COMMITING STATEMENTS =====\n%s\n============\n\n" % '\n'.join(statements))
+            d = self.dbpool.runInteraction(self._updateDb, tuple(statements))
+            return d
+
+        init_defer.addCallback(lambda ignore: getNewBaseSql() if new_base else getUpdateSql())
+        init_defer.addCallback(commitStatements)
 
         def fillProfileCache(ignore):
             d = self.dbpool.runQuery("SELECT name,id FROM profiles").addCallback(self._profilesCache)
@@ -70,6 +115,11 @@
 
         init_defer.addCallback(fillProfileCache)
 
+    def _updateDb(self, interaction, statements):
+        for statement in statements:
+            interaction.execute(statement)
+
+
     #Profiles
     def _profilesCache(self, profiles_result):
         """Fill the profiles cache
@@ -402,3 +452,271 @@
         """Return the first result of a database query
         Useful when we are looking for one specific value"""
         return None if not result else result[0][0]
+
+
+class Updater(object):
+    stmnt_regex = re.compile(r"(?:[\w ]+(?:\([\w, ]+\))?)+")
+    clean_regex = re.compile(r"^ +|(?<= ) +|(?<=,) +| +$")
+    CREATE_SQL = "CREATE TABLE %s (%s)"
+    INSERT_SQL = "INSERT INTO %s VALUES (%s)"
+    DROP_SQL = "DROP TABLE %s"
+    ALTER_SQL = "ALTER TABLE %s ADD COLUMN %s"
+    RENAME_TABLE_SQL = "ALTER TABLE %s RENAME TO %s"
+
+    CONSTRAINTS = ('PRIMARY', 'UNIQUE', 'CHECK', 'FOREIGN')
+    TMP_TABLE = "tmp_sat_update"
+
+    def __init__(self, dbpool, sat_version):
+        self._sat_version = sat_version
+        self.dbpool = dbpool
+
+    def getLocalVersion(self):
+        """ Get local database version
+        @return: version (int)
+
+        """
+        return self.dbpool.runQuery("PRAGMA user_version").addCallback(lambda ret: int(ret[0][0]))
+
+    def _setLocalVersion(self, version):
+        """ Set local database version
+        @param version: version (int)
+        @return: deferred
+
+        """
+        return self.dbpool.runOperation("PRAGMA user_version=%d" % version)
+
+    def getLocalSchema(self):
+        """ return raw local schema
+        @return: list of strings with CREATE sql statements for local database
+
+        """
+        d = self.dbpool.runQuery("select sql from sqlite_master where type = 'table'")
+        d.addCallback(lambda result: [row[0] for row in result])
+        return d
+
+    @defer.inlineCallbacks
+    def checkUpdates(self):
+        """ Check is database schema update is needed, according to DATABASE_SCHEMAS
+        @return: deferred which fire a list of SQL update statements, or None if no update is needed
+
+        """
+        local_version = yield self.getLocalVersion()
+        raw_local_sch = yield self.getLocalSchema()
+        local_sch = self.rawStatements2data(raw_local_sch)
+        current_sch = DATABASE_SCHEMAS['current']['CREATE']
+        local_hash = self.statementHash(local_sch)
+        current_hash = self.statementHash(current_sch)
+
+        if local_hash == current_hash:
+            if local_version != CURRENT_DB_VERSION:
+                warning(_("Your local schema is up-to-date, but database versions mismatch, fixing it..."))
+                yield self._setLocalVersion(CURRENT_DB_VERSION)
+        else:
+            # an update is needed
+
+            if local_version == CURRENT_DB_VERSION:
+                # Database mismatch and we have the latest version
+                if self._sat_version.endswith('D'):
+                    # we are in a development version
+                    update_data = self.generateUpdateData(local_sch, current_sch, False)
+                    warning(_("There is a schema mismatch, but as we are on a dev version, database will be updated"))
+                    update_raw = self.update2raw(update_data, True)
+                    defer.returnValue(update_raw)
+                else:
+                    error(_(u"schema version is up-to-date, but local schema differ from expected current schema"))
+                    update_data = self.generateUpdateData(local_sch, current_sch, True)
+                    warning(_(u"Here are the commands that should fix the situation, use at your own risk (do a backup before modifying database), you can go to SàT's MUC room at sat@chat.jabberfr.org for help\n### SQL###\n%s\n### END SQL ###\n") % u'\n'.join(("%s;" % statement for statement in self.update2raw(update_data))))
+                    raise exceptions.DatabaseError("Database mismatch")
+            else:
+                # Database is not up-to-date, we'll do the update
+                info(_("Database schema has changed, local database will be updated"))
+                update_raw = []
+                for version in xrange(local_version+1, CURRENT_DB_VERSION+1):
+                    try:
+                        update_data = DATABASE_SCHEMAS[version]
+                    except KeyError:
+                        raise exceptions.InternalError("Missing update definition (version %d)" % version)
+                    update_raw.extend(self.update2raw(update_data))
+                update_raw.append("PRAGMA user_version=%d" % CURRENT_DB_VERSION)
+                defer.returnValue(update_raw)
+
+    @staticmethod
+    def createData2Raw(data):
+        """ Generate SQL statements from statements data
+        @param data: dictionary with table as key, and statements data in tuples as value
+        @return: list of strings with raw statements
+
+        """
+        ret = []
+        for table in data:
+            defs, constraints = data[table]
+            assert isinstance(defs, tuple)
+            assert isinstance(constraints, tuple)
+            ret.append(Updater.CREATE_SQL % (table, ', '.join(defs + constraints)))
+        return ret
+
+    @staticmethod
+    def insertData2Raw(data):
+        """ Generate SQL statements from statements data
+        @param data: dictionary with table as key, and statements data in tuples as value
+        @return: list of strings with raw statements
+
+        """
+        ret = []
+        for table in data:
+            values_tuple = data[table]
+            assert isinstance(values_tuple, tuple)
+            for values in values_tuple:
+                assert isinstance(values, tuple)
+                ret.append(Updater.INSERT_SQL % (table, ', '.join(values)))
+        return ret
+
+    def statementHash(self, data):
+        """ Generate hash of template data
+        useful to compare schemas
+
+        @param data: dictionary of "CREATE" statement, with tables names as key,
+                     and tuples of (col_defs, constraints) as values
+        @return: hash as string
+        """
+        hash_ = hashlib.sha1()
+        tables = data.keys()
+        tables.sort()
+
+        def stmnts2str(stmts):
+            return ','.join([self.clean_regex.sub('',stmt) for stmt in sorted(stmts)])
+
+        for table in tables:
+            col_defs, col_constr = data[table]
+            hash_.update("%s:%s:%s" % (table, stmnts2str(col_defs), stmnts2str(col_constr)))
+        return hash_.digest()
+
+    def rawStatements2data(self, raw_statements):
+        """ separate "CREATE" statements into dictionary/tuples data
+        @param raw_statements: list of CREATE statements as strings
+        @return: dictionary with table names as key, and a (col_defs, constraints) tuple
+
+        """
+        schema_dict = {}
+        for create_statement in raw_statements:
+            if not create_statement.startswith("CREATE TABLE "):
+                warning("Unexpected statement, ignoring it")
+                continue
+            _create_statement = create_statement[13:]
+            table, raw_col_stats = _create_statement.split(' ',1)
+            if raw_col_stats[0] != '(' or raw_col_stats[-1] != ')':
+                warning("Unexpected statement structure, ignoring it")
+                continue
+            col_stats = [stmt.strip() for stmt in self.stmnt_regex.findall(raw_col_stats[1:-1])]
+            col_defs = []
+            constraints = []
+            for col_stat in col_stats:
+                name = col_stat.split(' ',1)[0]
+                if name in self.CONSTRAINTS:
+                    constraints.append(col_stat)
+                else:
+                    col_defs.append(col_stat)
+            schema_dict[table] = (tuple(col_defs), tuple(constraints))
+        return schema_dict
+
+    def generateUpdateData(self, old_data, new_data, modify=False):
+        """ Generate data for automatic update between two schema data
+        @param old_data: data of the former schema (which must be updated)
+        @param new_data: data of the current schema
+        @param modify: if True, always use "cols modify" table, else try to ALTER tables
+        @return: update data, a dictionary with:
+                 - 'create': dictionary of tables to create
+                 - 'delete': tuple of tables to delete
+                 - 'cols create': dictionary of columns to create (table as key, tuple of columns to create as value)
+                 - 'cols delete': dictionary of columns to delete (table as key, tuple of columns to delete as value)
+                 - 'cols modify': dictionary of columns to modify (table as key, tuple of old columns to transfert as value). With this table, a new table will be created, and content from the old table will be transfered to it, only cols specified in the tuple will be transfered.
+
+        """
+
+        create_tables_data = {}
+        create_cols_data = {}
+        modify_cols_data = {}
+        delete_cols_data = {}
+        old_tables = set(old_data.keys())
+        new_tables = set(new_data.keys())
+
+        def getChanges(set_olds, set_news):
+            to_create = set_news.difference(set_olds)
+            to_delete = set_olds.difference(set_news)
+            to_check = set_news.intersection(set_olds)
+            return tuple(to_create), tuple(to_delete), tuple(to_check)
+
+        tables_to_create, tables_to_delete, tables_to_check = getChanges(old_tables, new_tables)
+
+        for table in tables_to_create:
+            create_tables_data[table] = new_data[table]
+
+        for table in tables_to_check:
+            old_col_defs, old_constraints = old_data[table]
+            new_col_defs, new_constraints = new_data[table]
+            for obj in old_col_defs, old_constraints, new_col_defs, new_constraints:
+                if not isinstance(obj, tuple):
+                    raise InternalError("Columns definitions must be tuples")
+            defs_create, defs_delete, ignore = getChanges(set(old_col_defs), set(new_col_defs))
+            constraints_create, constraints_delete, ignore = getChanges(set(old_constraints), set(new_constraints))
+            created_col_names = set([name.split(' ',1)[0] for name in defs_create])
+            deleted_col_names = set([name.split(' ',1)[0] for name in defs_delete])
+            if (created_col_names.intersection(deleted_col_names or constraints_create or constraints_delete) or
+                (modify and (defs_create or constraints_create or defs_delete or constraints_delete))):
+                # we have modified columns, we need to transfer table
+                # we determinate which columns are in both schema so we can transfer them
+                old_names = set([name.split(' ',1)[0] for name in old_col_defs])
+                new_names = set([name.split(' ',1)[0] for name in new_col_defs])
+                modify_cols_data[table] = tuple(old_names.intersection(new_names));
+            else:
+                if defs_create:
+                    create_cols_data[table] = (defs_create)
+                if defs_delete or constraints_delete:
+                    delete_cols_data[table] = (defs_delete)
+
+        return {'create': create_tables_data,
+                'delete': tables_to_delete,
+                'cols create': create_cols_data,
+                'cols delete': delete_cols_data,
+                'cols modify': modify_cols_data
+                }
+
+    def update2raw(self, update, dev_version=False):
+        """ Transform update data to raw SQLite statements
+        @param upadte: update data as returned by generateUpdateData
+        @param dev_version: if True, update will be done in dev mode: no deletion will be done, instead a message will be shown. This prevent accidental lost of data while working on the code/database.
+        @return: list of string with SQL statements needed to update the base
+
+        """
+        ret = self.createData2Raw(update.get('create', {}))
+        drop = []
+        for table in update.get('delete', tuple()):
+            drop.append(self.DROP_SQL % table)
+        if dev_version:
+            if drop:
+                info("Dev version, SQL NOT EXECUTED:\n--\n%s\n--\n" % "\n".join(drop))
+        else:
+            ret.extend(drop)
+
+        cols_create = update.get('cols create', {})
+        for table in cols_create:
+            for col_def in cols_create[table]:
+                ret.append(self.ALTER_SQL % (table, col_def))
+
+        cols_delete = update.get('cols delete', {})
+        for table in cols_delete:
+            info("Following columns in table [%s] are not needed anymore, but are kept for dev version: %s" % (table, ", ".join(cols_delete[table])))
+
+        cols_modify = update.get('cols modify', {})
+        for table in cols_modify:
+            ret.append(self.RENAME_TABLE_SQL % (table, self.TMP_TABLE))
+            main, extra = DATABASE_SCHEMAS['current']['CREATE'][table]
+            ret.append(self.CREATE_SQL % (table, ', '.join(main + extra)))
+            common_cols = ', '.join(cols_modify[table])
+            ret.append("INSERT INTO %s (%s) SELECT %s FROM %s" % (table, common_cols, common_cols, self.TMP_TABLE))
+            ret.append(self.DROP_SQL % self.TMP_TABLE)
+
+        insert = update.get('insert', {})
+        ret.extend(self.insertData2Raw(insert))
+
+        return ret