comparison sat/memory/sqla.py @ 3582:71516731d0aa

core (memory/sqla): database migration using Alembic: Alembic database migration tool, which is the recommended one for SQLAlchemy has been integrated. When a database is created, it will be used to stamp to current (head) revision, otherwise, DB will be checked to see if it needs to be updated, and upgrade will be triggered if necessary.
author Goffi <goffi@goffi.org>
date Fri, 25 Jun 2021 17:55:23 +0200
parents f9a5b810f14d
children 16ade4ad63f3
comparison
equal deleted inserted replaced
3581:84ea57a8d6b3 3582:71516731d0aa
14 # GNU Affero General Public License for more details. 14 # GNU Affero General Public License for more details.
15 15
16 # You should have received a copy of the GNU Affero General Public License 16 # You should have received a copy of the GNU Affero General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>. 17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 18
19 import sys
19 import time 20 import time
21 import asyncio
22 from asyncio.subprocess import PIPE
23 from pathlib import Path
20 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional 24 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional
21 from urllib.parse import quote 25 from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine
22 from pathlib import Path
23 from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
24 from sqlalchemy.exc import IntegrityError, NoResultFound 26 from sqlalchemy.exc import IntegrityError, NoResultFound
25 from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager 27 from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager
26 from sqlalchemy.future import select 28 from sqlalchemy.future import select
27 from sqlalchemy.engine import Engine 29 from sqlalchemy.engine import Engine, Connection
28 from sqlalchemy import update, delete, and_, or_, event 30 from sqlalchemy import update, delete, and_, or_, event
29 from sqlalchemy.sql.functions import coalesce, sum as sum_ 31 from sqlalchemy.sql.functions import coalesce, sum as sum_
30 from sqlalchemy.dialects.sqlite import insert 32 from sqlalchemy.dialects.sqlite import insert
33 from alembic import script as al_script, config as al_config
34 from alembic.runtime import migration as al_migration
31 from twisted.internet import defer 35 from twisted.internet import defer
32 from twisted.words.protocols.jabber import jid 36 from twisted.words.protocols.jabber import jid
33 from sat.core.i18n import _ 37 from sat.core.i18n import _
34 from sat.core import exceptions 38 from sat.core import exceptions
35 from sat.core.log import getLogger 39 from sat.core.log import getLogger
36 from sat.core.constants import Const as C 40 from sat.core.constants import Const as C
37 from sat.core.core_types import SatXMPPEntity 41 from sat.core.core_types import SatXMPPEntity
38 from sat.tools.utils import aio 42 from sat.tools.utils import aio
43 from sat.memory import migration
44 from sat.memory import sqla_config
39 from sat.memory.sqla_mapping import ( 45 from sat.memory.sqla_mapping import (
40 NOT_IN_EXTRA, 46 NOT_IN_EXTRA,
41 Base, 47 Base,
42 Profile, 48 Profile,
43 Component, 49 Component,
54 File 60 File
55 ) 61 )
56 62
57 63
58 log = getLogger(__name__) 64 log = getLogger(__name__)
65 migration_path = Path(migration.__file__).parent
59 66
60 67
61 @event.listens_for(Engine, "connect") 68 @event.listens_for(Engine, "connect")
62 def set_sqlite_pragma(dbapi_connection, connection_record): 69 def set_sqlite_pragma(dbapi_connection, connection_record):
63 cursor = dbapi_connection.cursor() 70 cursor = dbapi_connection.cursor()
65 cursor.close() 72 cursor.close()
66 73
67 74
68 class Storage: 75 class Storage:
69 76
70 def __init__(self, db_filename, sat_version): 77 def __init__(self):
71 self.initialized = defer.Deferred() 78 self.initialized = defer.Deferred()
72 self.filename = Path(db_filename)
73 # we keep cache for the profiles (key: profile name, value: profile id) 79 # we keep cache for the profiles (key: profile name, value: profile id)
74 # profile id to name 80 # profile id to name
75 self.profiles: Dict[int, str] = {} 81 self.profiles: Dict[int, str] = {}
76 # profile id to component entry point 82 # profile id to component entry point
77 self.components: Dict[int, str] = {} 83 self.components: Dict[int, str] = {}
78 84
79 @aio 85 async def migrateApply(self, *args: str, log_output: bool = False) -> None:
80 async def initialise(self): 86 """Do a migration command
87
88 Commands are applied by running Alembic in a subprocess.
89 Arguments are alembic executables commands
90
91 @param log_output: manage stdout and stderr:
92 - if False, stdout and stderr are buffered, and logged only in case of error
93 - if True, stdout and stderr will be logged during the command execution
94 @raise exceptions.DatabaseError: something went wrong while running the
95 process
96 """
97 stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,)
98 proc = await asyncio.create_subprocess_exec(
99 sys.executable, "-m", "alembic", *args,
100 stdout=stdout, stderr=stderr, cwd=migration_path
101 )
102 log_out, log_err = await proc.communicate()
103 if proc.returncode != 0:
104 msg = _(
105 "Can't {operation} database (exit code {exit_code})"
106 ).format(
107 operation=args[0],
108 exit_code=proc.returncode
109 )
110 if log_out or log_err:
111 msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}"
112 log.error(msg)
113
114 raise exceptions.DatabaseError(msg)
115
116 async def createDB(self, engine: AsyncEngine, db_config: dict) -> None:
117 """Create a new database
118
119 The database is generated from SQLAlchemy model, then stamped by Alembic
120 """
121 # the dir may not exist if it's not the XDG recommended one
122 db_config["path"].parent.mkdir(0o700, True, True)
123 async with engine.begin() as conn:
124 await conn.run_sync(Base.metadata.create_all)
125
126 log.debug("stamping the database")
127 await self.migrateApply("stamp", "head")
128 log.debug("stamping done")
129
130 def _checkDBIsUpToDate(self, conn: Connection) -> bool:
131 al_ini_path = migration_path / "alembic.ini"
132 al_cfg = al_config.Config(al_ini_path)
133 directory = al_script.ScriptDirectory.from_config(al_cfg)
134 context = al_migration.MigrationContext.configure(conn)
135 return set(context.get_current_heads()) == set(directory.get_heads())
136
137 async def checkAndUpdateDB(self, engine: AsyncEngine, db_config: dict) -> None:
138 """Check that database is up-to-date, and update if necessary"""
139 async with engine.connect() as conn:
140 up_to_date = await conn.run_sync(self._checkDBIsUpToDate)
141 if up_to_date:
142 log.debug("Database is up-to-date")
143 else:
144 log.info("Database needs to be updated")
145 log.info("updating…")
146 await self.migrateApply("upgrade", "head", log_output=True)
147 log.info("Database is now up-to-date")
148
149 @aio
150 async def initialise(self) -> None:
81 log.info(_("Connecting database")) 151 log.info(_("Connecting database"))
152 db_config = sqla_config.getDbConfig()
82 engine = create_async_engine( 153 engine = create_async_engine(
83 f"sqlite+aiosqlite:///{quote(str(self.filename))}", 154 db_config["url"],
84 future=True 155 future=True
85 ) 156 )
86 self.session = sessionmaker( 157 self.session = sessionmaker(
87 engine, expire_on_commit=False, class_=AsyncSession 158 engine, expire_on_commit=False, class_=AsyncSession
88 ) 159 )
89 new_base = not self.filename.exists() 160 new_base = not db_config["path"].exists()
90 if new_base: 161 if new_base:
91 log.info(_("The database is new, creating the tables")) 162 log.info(_("The database is new, creating the tables"))
92 # the dir may not exist if it's not the XDG recommended one 163 await self.createDB(engine, db_config)
93 self.filename.parent.mkdir(0o700, True, True) 164 else:
94 async with engine.begin() as conn: 165 await self.checkAndUpdateDB(engine, db_config)
95 await conn.run_sync(Base.metadata.create_all)
96 166
97 async with self.session() as session: 167 async with self.session() as session:
98 result = await session.execute(select(Profile)) 168 result = await session.execute(select(Profile))
99 for p in result.scalars(): 169 for p in result.scalars():
100 self.profiles[p.name] = p.id 170 self.profiles[p.name] = p.id