Mercurial > libervia-backend
diff sat/memory/sqla.py @ 3582:71516731d0aa
core (memory/sqla): database migration using Alembic:
Alembic database migration tool, which is the recommended one for SQLAlchemy has been
integrated. When a database is created, it will be used to stamp to current (head)
revision, otherwise, DB will be checked to see if it needs to be updated, and upgrade will
be triggered if necessary.
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 25 Jun 2021 17:55:23 +0200 |
parents | f9a5b810f14d |
children | 16ade4ad63f3 |
line wrap: on
line diff
--- a/sat/memory/sqla.py Fri Jun 25 10:17:34 2021 +0200 +++ b/sat/memory/sqla.py Fri Jun 25 17:55:23 2021 +0200 @@ -16,18 +16,22 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +import sys import time +import asyncio +from asyncio.subprocess import PIPE +from pathlib import Path from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional -from urllib.parse import quote -from pathlib import Path -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager from sqlalchemy.future import select -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, Connection from sqlalchemy import update, delete, and_, or_, event from sqlalchemy.sql.functions import coalesce, sum as sum_ from sqlalchemy.dialects.sqlite import insert +from alembic import script as al_script, config as al_config +from alembic.runtime import migration as al_migration from twisted.internet import defer from twisted.words.protocols.jabber import jid from sat.core.i18n import _ @@ -36,6 +40,8 @@ from sat.core.constants import Const as C from sat.core.core_types import SatXMPPEntity from sat.tools.utils import aio +from sat.memory import migration +from sat.memory import sqla_config from sat.memory.sqla_mapping import ( NOT_IN_EXTRA, Base, @@ -56,6 +62,7 @@ log = getLogger(__name__) +migration_path = Path(migration.__file__).parent @event.listens_for(Engine, "connect") @@ -67,32 +74,95 @@ class Storage: - def __init__(self, db_filename, sat_version): + def __init__(self): self.initialized = defer.Deferred() - self.filename = Path(db_filename) # we keep cache for the profiles (key: profile name, value: profile id) # profile id to name self.profiles: Dict[int, str] = {} # profile id to component entry point self.components: Dict[int, str] = {} + async def migrateApply(self, *args: str, log_output: bool = False) -> None: + """Do a migration command + + Commands are applied by running Alembic in a subprocess. + Arguments are alembic executables commands + + @param log_output: manage stdout and stderr: + - if False, stdout and stderr are buffered, and logged only in case of error + - if True, stdout and stderr will be logged during the command execution + @raise exceptions.DatabaseError: something went wrong while running the + process + """ + stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) + proc = await asyncio.create_subprocess_exec( + sys.executable, "-m", "alembic", *args, + stdout=stdout, stderr=stderr, cwd=migration_path + ) + log_out, log_err = await proc.communicate() + if proc.returncode != 0: + msg = _( + "Can't {operation} database (exit code {exit_code})" + ).format( + operation=args[0], + exit_code=proc.returncode + ) + if log_out or log_err: + msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" + log.error(msg) + + raise exceptions.DatabaseError(msg) + + async def createDB(self, engine: AsyncEngine, db_config: dict) -> None: + """Create a new database + + The database is generated from SQLAlchemy model, then stamped by Alembic + """ + # the dir may not exist if it's not the XDG recommended one + db_config["path"].parent.mkdir(0o700, True, True) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + log.debug("stamping the database") + await self.migrateApply("stamp", "head") + log.debug("stamping done") + + def _checkDBIsUpToDate(self, conn: Connection) -> bool: + al_ini_path = migration_path / "alembic.ini" + al_cfg = al_config.Config(al_ini_path) + directory = al_script.ScriptDirectory.from_config(al_cfg) + context = al_migration.MigrationContext.configure(conn) + return set(context.get_current_heads()) == set(directory.get_heads()) + + async def checkAndUpdateDB(self, engine: AsyncEngine, db_config: dict) -> None: + """Check that database is up-to-date, and update if necessary""" + async with engine.connect() as conn: + up_to_date = await conn.run_sync(self._checkDBIsUpToDate) + if up_to_date: + log.debug("Database is up-to-date") + else: + log.info("Database needs to be updated") + log.info("updating…") + await self.migrateApply("upgrade", "head", log_output=True) + log.info("Database is now up-to-date") + @aio - async def initialise(self): + async def initialise(self) -> None: log.info(_("Connecting database")) + db_config = sqla_config.getDbConfig() engine = create_async_engine( - f"sqlite+aiosqlite:///{quote(str(self.filename))}", + db_config["url"], future=True ) self.session = sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) - new_base = not self.filename.exists() + new_base = not db_config["path"].exists() if new_base: log.info(_("The database is new, creating the tables")) - # the dir may not exist if it's not the XDG recommended one - self.filename.parent.mkdir(0o700, True, True) - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) + await self.createDB(engine, db_config) + else: + await self.checkAndUpdateDB(engine, db_config) async with self.session() as session: result = await session.execute(select(Profile))