diff sat/memory/sqla.py @ 3582:71516731d0aa

core (memory/sqla): database migration using Alembic: Alembic database migration tool, which is the recommended one for SQLAlchemy has been integrated. When a database is created, it will be used to stamp to current (head) revision, otherwise, DB will be checked to see if it needs to be updated, and upgrade will be triggered if necessary.
author Goffi <goffi@goffi.org>
date Fri, 25 Jun 2021 17:55:23 +0200
parents f9a5b810f14d
children 16ade4ad63f3
line wrap: on
line diff
--- a/sat/memory/sqla.py	Fri Jun 25 10:17:34 2021 +0200
+++ b/sat/memory/sqla.py	Fri Jun 25 17:55:23 2021 +0200
@@ -16,18 +16,22 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+import sys
 import time
+import asyncio
+from asyncio.subprocess import PIPE
+from pathlib import Path
 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional
-from urllib.parse import quote
-from pathlib import Path
-from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
+from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine
 from sqlalchemy.exc import IntegrityError, NoResultFound
 from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager
 from sqlalchemy.future import select
-from sqlalchemy.engine import Engine
+from sqlalchemy.engine import Engine, Connection
 from sqlalchemy import update, delete, and_, or_, event
 from sqlalchemy.sql.functions import coalesce, sum as sum_
 from sqlalchemy.dialects.sqlite import insert
+from alembic import script as al_script, config as al_config
+from alembic.runtime import migration as al_migration
 from twisted.internet import defer
 from twisted.words.protocols.jabber import jid
 from sat.core.i18n import _
@@ -36,6 +40,8 @@
 from sat.core.constants import Const as C
 from sat.core.core_types import SatXMPPEntity
 from sat.tools.utils import aio
+from sat.memory import migration
+from sat.memory import sqla_config
 from sat.memory.sqla_mapping import (
     NOT_IN_EXTRA,
     Base,
@@ -56,6 +62,7 @@
 
 
 log = getLogger(__name__)
+migration_path = Path(migration.__file__).parent
 
 
 @event.listens_for(Engine, "connect")
@@ -67,32 +74,95 @@
 
 class Storage:
 
-    def __init__(self, db_filename, sat_version):
+    def __init__(self):
         self.initialized = defer.Deferred()
-        self.filename = Path(db_filename)
         # we keep cache for the profiles (key: profile name, value: profile id)
         # profile id to name
         self.profiles: Dict[int, str] = {}
         # profile id to component entry point
         self.components: Dict[int, str] = {}
 
+    async def migrateApply(self, *args: str, log_output: bool = False) -> None:
+        """Do a migration command
+
+        Commands are applied by running Alembic in a subprocess.
+        Arguments are alembic executables commands
+
+        @param log_output: manage stdout and stderr:
+            - if False, stdout and stderr are buffered, and logged only in case of error
+            - if True, stdout and stderr will be logged during the command execution
+        @raise exceptions.DatabaseError: something went wrong while running the
+            process
+        """
+        stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,)
+        proc = await asyncio.create_subprocess_exec(
+            sys.executable, "-m", "alembic", *args,
+            stdout=stdout, stderr=stderr, cwd=migration_path
+        )
+        log_out, log_err = await proc.communicate()
+        if proc.returncode != 0:
+            msg = _(
+                "Can't {operation} database (exit code {exit_code})"
+            ).format(
+                operation=args[0],
+                exit_code=proc.returncode
+            )
+            if log_out or log_err:
+                msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}"
+            log.error(msg)
+
+            raise exceptions.DatabaseError(msg)
+
+    async def createDB(self, engine: AsyncEngine, db_config: dict) -> None:
+        """Create a new database
+
+        The database is generated from SQLAlchemy model, then stamped by Alembic
+        """
+        # the dir may not exist if it's not the XDG recommended one
+        db_config["path"].parent.mkdir(0o700, True, True)
+        async with engine.begin() as conn:
+            await conn.run_sync(Base.metadata.create_all)
+
+        log.debug("stamping the database")
+        await self.migrateApply("stamp", "head")
+        log.debug("stamping done")
+
+    def _checkDBIsUpToDate(self, conn: Connection) -> bool:
+        al_ini_path = migration_path / "alembic.ini"
+        al_cfg = al_config.Config(al_ini_path)
+        directory = al_script.ScriptDirectory.from_config(al_cfg)
+        context = al_migration.MigrationContext.configure(conn)
+        return set(context.get_current_heads()) == set(directory.get_heads())
+
+    async def checkAndUpdateDB(self, engine: AsyncEngine, db_config: dict) -> None:
+        """Check that database is up-to-date, and update if necessary"""
+        async with engine.connect() as conn:
+            up_to_date = await conn.run_sync(self._checkDBIsUpToDate)
+        if up_to_date:
+            log.debug("Database is up-to-date")
+        else:
+            log.info("Database needs to be updated")
+            log.info("updating…")
+            await self.migrateApply("upgrade", "head", log_output=True)
+            log.info("Database is now up-to-date")
+
     @aio
-    async def initialise(self):
+    async def initialise(self) -> None:
         log.info(_("Connecting database"))
+        db_config = sqla_config.getDbConfig()
         engine = create_async_engine(
-            f"sqlite+aiosqlite:///{quote(str(self.filename))}",
+            db_config["url"],
             future=True
         )
         self.session = sessionmaker(
             engine, expire_on_commit=False, class_=AsyncSession
         )
-        new_base = not self.filename.exists()
+        new_base = not db_config["path"].exists()
         if new_base:
             log.info(_("The database is new, creating the tables"))
-            # the dir may not exist if it's not the XDG recommended one
-            self.filename.parent.mkdir(0o700, True, True)
-            async with engine.begin() as conn:
-                await conn.run_sync(Base.metadata.create_all)
+            await self.createDB(engine, db_config)
+        else:
+            await self.checkAndUpdateDB(engine, db_config)
 
         async with self.session() as session:
             result = await session.execute(select(Profile))