Mercurial > libervia-backend
comparison libervia/backend/memory/sqla.py @ 4071:4b842c1fb686
refactoring: renamed `sat` package to `libervia.backend`
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 02 Jun 2023 11:49:51 +0200 |
parents | sat/memory/sqla.py@524856bd7b19 |
children | 74c66c0d93f3 |
comparison
equal
deleted
inserted
replaced
4070:d10748475025 | 4071:4b842c1fb686 |
---|---|
1 #!/usr/bin/env python3 | |
2 | |
3 # Libervia: an XMPP client | |
4 # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) | |
5 | |
6 # This program is free software: you can redistribute it and/or modify | |
7 # it under the terms of the GNU Affero General Public License as published by | |
8 # the Free Software Foundation, either version 3 of the License, or | |
9 # (at your option) any later version. | |
10 | |
11 # This program is distributed in the hope that it will be useful, | |
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
14 # GNU Affero General Public License for more details. | |
15 | |
16 # You should have received a copy of the GNU Affero General Public License | |
17 # along with this program. If not, see <http://www.gnu.org/licenses/>. | |
18 | |
19 import asyncio | |
20 from asyncio.subprocess import PIPE | |
21 import copy | |
22 from datetime import datetime | |
23 from pathlib import Path | |
24 import sys | |
25 import time | |
26 from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union | |
27 | |
28 from alembic import config as al_config, script as al_script | |
29 from alembic.runtime import migration as al_migration | |
30 from sqlalchemy import and_, delete, event, func, or_, update | |
31 from sqlalchemy import Integer, literal_column, text | |
32 from sqlalchemy.dialects.sqlite import insert | |
33 from sqlalchemy.engine import Connection, Engine | |
34 from sqlalchemy.exc import IntegrityError, NoResultFound | |
35 from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine | |
36 from sqlalchemy.future import select | |
37 from sqlalchemy.orm import ( | |
38 contains_eager, | |
39 joinedload, | |
40 selectinload, | |
41 sessionmaker, | |
42 subqueryload, | |
43 ) | |
44 from sqlalchemy.orm.attributes import Mapped | |
45 from sqlalchemy.orm.decl_api import DeclarativeMeta | |
46 from sqlalchemy.sql.functions import coalesce, count, now, sum as sum_ | |
47 from twisted.internet import defer | |
48 from twisted.words.protocols.jabber import jid | |
49 from twisted.words.xish import domish | |
50 | |
51 from libervia.backend.core import exceptions | |
52 from libervia.backend.core.constants import Const as C | |
53 from libervia.backend.core.core_types import SatXMPPEntity | |
54 from libervia.backend.core.i18n import _ | |
55 from libervia.backend.core.log import getLogger | |
56 from libervia.backend.memory import migration | |
57 from libervia.backend.memory import sqla_config | |
58 from libervia.backend.memory.sqla_mapping import ( | |
59 Base, | |
60 Component, | |
61 File, | |
62 History, | |
63 Message, | |
64 NOT_IN_EXTRA, | |
65 ParamGen, | |
66 ParamInd, | |
67 PrivateGen, | |
68 PrivateGenBin, | |
69 PrivateInd, | |
70 PrivateIndBin, | |
71 Profile, | |
72 PubsubItem, | |
73 PubsubNode, | |
74 Subject, | |
75 SyncState, | |
76 Thread, | |
77 ) | |
78 from libervia.backend.tools.common import uri | |
79 from libervia.backend.tools.utils import aio, as_future | |
80 | |
81 | |
82 log = getLogger(__name__) | |
83 migration_path = Path(migration.__file__).parent | |
84 #: mapping of Libervia search query operators to SQLAlchemy method name | |
85 OP_MAP = { | |
86 "==": "__eq__", | |
87 "eq": "__eq__", | |
88 "!=": "__ne__", | |
89 "ne": "__ne__", | |
90 ">": "__gt__", | |
91 "gt": "__gt__", | |
92 "<": "__le__", | |
93 "le": "__le__", | |
94 "between": "between", | |
95 "in": "in_", | |
96 "not_in": "not_in", | |
97 "overlap": "in_", | |
98 "ioverlap": "in_", | |
99 "disjoint": "in_", | |
100 "idisjoint": "in_", | |
101 "like": "like", | |
102 "ilike": "ilike", | |
103 "not_like": "notlike", | |
104 "not_ilike": "notilike", | |
105 } | |
106 | |
107 | |
108 @event.listens_for(Engine, "connect") | |
109 def set_sqlite_pragma(dbapi_connection, connection_record): | |
110 cursor = dbapi_connection.cursor() | |
111 cursor.execute("PRAGMA foreign_keys=ON") | |
112 cursor.close() | |
113 | |
114 | |
115 class Storage: | |
116 | |
117 def __init__(self): | |
118 self.initialized = defer.Deferred() | |
119 # we keep cache for the profiles (key: profile name, value: profile id) | |
120 # profile id to name | |
121 self.profiles: Dict[str, int] = {} | |
122 # profile id to component entry point | |
123 self.components: Dict[int, str] = {} | |
124 | |
125 def get_profile_by_id(self, profile_id): | |
126 return self.profiles.get(profile_id) | |
127 | |
128 async def migrate_apply(self, *args: str, log_output: bool = False) -> None: | |
129 """Do a migration command | |
130 | |
131 Commands are applied by running Alembic in a subprocess. | |
132 Arguments are alembic executables commands | |
133 | |
134 @param log_output: manage stdout and stderr: | |
135 - if False, stdout and stderr are buffered, and logged only in case of error | |
136 - if True, stdout and stderr will be logged during the command execution | |
137 @raise exceptions.DatabaseError: something went wrong while running the | |
138 process | |
139 """ | |
140 stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) | |
141 proc = await asyncio.create_subprocess_exec( | |
142 sys.executable, "-m", "alembic", *args, | |
143 stdout=stdout, stderr=stderr, cwd=migration_path | |
144 ) | |
145 log_out, log_err = await proc.communicate() | |
146 if proc.returncode != 0: | |
147 msg = _( | |
148 "Can't {operation} database (exit code {exit_code})" | |
149 ).format( | |
150 operation=args[0], | |
151 exit_code=proc.returncode | |
152 ) | |
153 if log_out or log_err: | |
154 msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" | |
155 log.error(msg) | |
156 | |
157 raise exceptions.DatabaseError(msg) | |
158 | |
159 async def create_db(self, engine: AsyncEngine, db_config: dict) -> None: | |
160 """Create a new database | |
161 | |
162 The database is generated from SQLAlchemy model, then stamped by Alembic | |
163 """ | |
164 # the dir may not exist if it's not the XDG recommended one | |
165 db_config["path"].parent.mkdir(0o700, True, True) | |
166 async with engine.begin() as conn: | |
167 await conn.run_sync(Base.metadata.create_all) | |
168 | |
169 log.debug("stamping the database") | |
170 await self.migrate_apply("stamp", "head") | |
171 log.debug("stamping done") | |
172 | |
173 def _check_db_is_up_to_date(self, conn: Connection) -> bool: | |
174 al_ini_path = migration_path / "alembic.ini" | |
175 al_cfg = al_config.Config(al_ini_path) | |
176 directory = al_script.ScriptDirectory.from_config(al_cfg) | |
177 context = al_migration.MigrationContext.configure(conn) | |
178 return set(context.get_current_heads()) == set(directory.get_heads()) | |
179 | |
180 def _sqlite_set_journal_mode_wal(self, conn: Connection) -> None: | |
181 """Check if journal mode is WAL, and set it if necesssary""" | |
182 result = conn.execute(text("PRAGMA journal_mode")) | |
183 if result.scalar() != "wal": | |
184 log.info("WAL mode not activated, activating it") | |
185 conn.execute(text("PRAGMA journal_mode=WAL")) | |
186 | |
187 async def check_and_update_db(self, engine: AsyncEngine, db_config: dict) -> None: | |
188 """Check that database is up-to-date, and update if necessary""" | |
189 async with engine.connect() as conn: | |
190 up_to_date = await conn.run_sync(self._check_db_is_up_to_date) | |
191 if up_to_date: | |
192 log.debug("Database is up-to-date") | |
193 else: | |
194 log.info("Database needs to be updated") | |
195 log.info("updating…") | |
196 await self.migrate_apply("upgrade", "head", log_output=True) | |
197 log.info("Database is now up-to-date") | |
198 | |
199 @aio | |
200 async def initialise(self) -> None: | |
201 log.info(_("Connecting database")) | |
202 | |
203 db_config = sqla_config.get_db_config() | |
204 engine = create_async_engine( | |
205 db_config["url"], | |
206 future=True, | |
207 ) | |
208 | |
209 new_base = not db_config["path"].exists() | |
210 if new_base: | |
211 log.info(_("The database is new, creating the tables")) | |
212 await self.create_db(engine, db_config) | |
213 else: | |
214 await self.check_and_update_db(engine, db_config) | |
215 | |
216 async with engine.connect() as conn: | |
217 await conn.run_sync(self._sqlite_set_journal_mode_wal) | |
218 | |
219 self.session = sessionmaker( | |
220 engine, expire_on_commit=False, class_=AsyncSession | |
221 ) | |
222 | |
223 async with self.session() as session: | |
224 result = await session.execute(select(Profile)) | |
225 for p in result.scalars(): | |
226 self.profiles[p.name] = p.id | |
227 result = await session.execute(select(Component)) | |
228 for c in result.scalars(): | |
229 self.components[c.profile_id] = c.entry_point | |
230 | |
231 self.initialized.callback(None) | |
232 | |
233 ## Generic | |
234 | |
235 @aio | |
236 async def get( | |
237 self, | |
238 client: SatXMPPEntity, | |
239 db_cls: DeclarativeMeta, | |
240 db_id_col: Mapped, | |
241 id_value: Any, | |
242 joined_loads = None | |
243 ) -> Optional[DeclarativeMeta]: | |
244 stmt = select(db_cls).where(db_id_col==id_value) | |
245 if client is not None: | |
246 stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) | |
247 if joined_loads is not None: | |
248 for joined_load in joined_loads: | |
249 stmt = stmt.options(joinedload(joined_load)) | |
250 async with self.session() as session: | |
251 result = await session.execute(stmt) | |
252 if joined_loads is not None: | |
253 result = result.unique() | |
254 return result.scalar_one_or_none() | |
255 | |
256 @aio | |
257 async def add(self, db_obj: DeclarativeMeta) -> None: | |
258 """Add an object to database""" | |
259 async with self.session() as session: | |
260 async with session.begin(): | |
261 session.add(db_obj) | |
262 | |
263 @aio | |
264 async def delete( | |
265 self, | |
266 db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]], | |
267 session_add: Optional[List[DeclarativeMeta]] = None | |
268 ) -> None: | |
269 """Delete an object from database | |
270 | |
271 @param db_obj: object to delete or list of objects to delete | |
272 @param session_add: other objects to add to session. | |
273 This is useful when parents of deleted objects needs to be updated too, or if | |
274 other objects needs to be updated in the same transaction. | |
275 """ | |
276 if not db_obj: | |
277 return | |
278 if not isinstance(db_obj, list): | |
279 db_obj = [db_obj] | |
280 async with self.session() as session: | |
281 async with session.begin(): | |
282 if session_add is not None: | |
283 for obj in session_add: | |
284 session.add(obj) | |
285 for obj in db_obj: | |
286 await session.delete(obj) | |
287 await session.commit() | |
288 | |
289 ## Profiles | |
290 | |
291 def get_profiles_list(self) -> List[str]: | |
292 """"Return list of all registered profiles""" | |
293 return list(self.profiles.keys()) | |
294 | |
295 def has_profile(self, profile_name: str) -> bool: | |
296 """return True if profile_name exists | |
297 | |
298 @param profile_name: name of the profile to check | |
299 """ | |
300 return profile_name in self.profiles | |
301 | |
302 def profile_is_component(self, profile_name: str) -> bool: | |
303 try: | |
304 return self.profiles[profile_name] in self.components | |
305 except KeyError: | |
306 raise exceptions.NotFound("the requested profile doesn't exists") | |
307 | |
308 def get_entry_point(self, profile_name: str) -> str: | |
309 try: | |
310 return self.components[self.profiles[profile_name]] | |
311 except KeyError: | |
312 raise exceptions.NotFound("the requested profile doesn't exists or is not a component") | |
313 | |
314 @aio | |
315 async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None: | |
316 """Create a new profile | |
317 | |
318 @param name: name of the profile | |
319 @param component: if not None, must point to a component entry point | |
320 """ | |
321 async with self.session() as session: | |
322 profile = Profile(name=name) | |
323 async with session.begin(): | |
324 session.add(profile) | |
325 self.profiles[profile.name] = profile.id | |
326 if component_ep is not None: | |
327 async with session.begin(): | |
328 component = Component(profile=profile, entry_point=component_ep) | |
329 session.add(component) | |
330 self.components[profile.id] = component_ep | |
331 return profile | |
332 | |
333 @aio | |
334 async def delete_profile(self, name: str) -> None: | |
335 """Delete profile | |
336 | |
337 @param name: name of the profile | |
338 """ | |
339 async with self.session() as session: | |
340 result = await session.execute(select(Profile).where(Profile.name == name)) | |
341 profile = result.scalar() | |
342 await session.delete(profile) | |
343 await session.commit() | |
344 del self.profiles[profile.name] | |
345 if profile.id in self.components: | |
346 del self.components[profile.id] | |
347 log.info(_("Profile {name!r} deleted").format(name = name)) | |
348 | |
349 ## Params | |
350 | |
351 @aio | |
352 async def load_gen_params(self, params_gen: dict) -> None: | |
353 """Load general parameters | |
354 | |
355 @param params_gen: dictionary to fill | |
356 """ | |
357 log.debug(_("loading general parameters from database")) | |
358 async with self.session() as session: | |
359 result = await session.execute(select(ParamGen)) | |
360 for p in result.scalars(): | |
361 params_gen[(p.category, p.name)] = p.value | |
362 | |
363 @aio | |
364 async def load_ind_params(self, params_ind: dict, profile: str) -> None: | |
365 """Load individual parameters | |
366 | |
367 @param params_ind: dictionary to fill | |
368 @param profile: a profile which *must* exist | |
369 """ | |
370 log.debug(_("loading individual parameters from database")) | |
371 async with self.session() as session: | |
372 result = await session.execute( | |
373 select(ParamInd).where(ParamInd.profile_id == self.profiles[profile]) | |
374 ) | |
375 for p in result.scalars(): | |
376 params_ind[(p.category, p.name)] = p.value | |
377 | |
378 @aio | |
379 async def get_ind_param(self, category: str, name: str, profile: str) -> Optional[str]: | |
380 """Ask database for the value of one specific individual parameter | |
381 | |
382 @param category: category of the parameter | |
383 @param name: name of the parameter | |
384 @param profile: %(doc_profile)s | |
385 """ | |
386 async with self.session() as session: | |
387 result = await session.execute( | |
388 select(ParamInd.value) | |
389 .filter_by( | |
390 category=category, | |
391 name=name, | |
392 profile_id=self.profiles[profile] | |
393 ) | |
394 ) | |
395 return result.scalar_one_or_none() | |
396 | |
397 @aio | |
398 async def get_ind_param_values(self, category: str, name: str) -> Dict[str, str]: | |
399 """Ask database for the individual values of a parameter for all profiles | |
400 | |
401 @param category: category of the parameter | |
402 @param name: name of the parameter | |
403 @return dict: profile => value map | |
404 """ | |
405 async with self.session() as session: | |
406 result = await session.execute( | |
407 select(ParamInd) | |
408 .filter_by( | |
409 category=category, | |
410 name=name | |
411 ) | |
412 .options(subqueryload(ParamInd.profile)) | |
413 ) | |
414 return {param.profile.name: param.value for param in result.scalars()} | |
415 | |
416 @aio | |
417 async def set_gen_param(self, category: str, name: str, value: Optional[str]) -> None: | |
418 """Save the general parameters in database | |
419 | |
420 @param category: category of the parameter | |
421 @param name: name of the parameter | |
422 @param value: value to set | |
423 """ | |
424 async with self.session() as session: | |
425 stmt = insert(ParamGen).values( | |
426 category=category, | |
427 name=name, | |
428 value=value | |
429 ).on_conflict_do_update( | |
430 index_elements=(ParamGen.category, ParamGen.name), | |
431 set_={ | |
432 ParamGen.value: value | |
433 } | |
434 ) | |
435 await session.execute(stmt) | |
436 await session.commit() | |
437 | |
438 @aio | |
439 async def set_ind_param( | |
440 self, | |
441 category:str, | |
442 name: str, | |
443 value: Optional[str], | |
444 profile: str | |
445 ) -> None: | |
446 """Save the individual parameters in database | |
447 | |
448 @param category: category of the parameter | |
449 @param name: name of the parameter | |
450 @param value: value to set | |
451 @param profile: a profile which *must* exist | |
452 """ | |
453 async with self.session() as session: | |
454 stmt = insert(ParamInd).values( | |
455 category=category, | |
456 name=name, | |
457 profile_id=self.profiles[profile], | |
458 value=value | |
459 ).on_conflict_do_update( | |
460 index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), | |
461 set_={ | |
462 ParamInd.value: value | |
463 } | |
464 ) | |
465 await session.execute(stmt) | |
466 await session.commit() | |
467 | |
468 def _jid_filter(self, jid_: jid.JID, dest: bool = False): | |
469 """Generate condition to filter on a JID, using relevant columns | |
470 | |
471 @param dest: True if it's the destinee JID, otherwise it's the source one | |
472 @param jid_: JID to filter by | |
473 """ | |
474 if jid_.resource: | |
475 if dest: | |
476 return and_( | |
477 History.dest == jid_.userhost(), | |
478 History.dest_res == jid_.resource | |
479 ) | |
480 else: | |
481 return and_( | |
482 History.source == jid_.userhost(), | |
483 History.source_res == jid_.resource | |
484 ) | |
485 else: | |
486 if dest: | |
487 return History.dest == jid_.userhost() | |
488 else: | |
489 return History.source == jid_.userhost() | |
490 | |
491 @aio | |
492 async def history_get( | |
493 self, | |
494 from_jid: Optional[jid.JID], | |
495 to_jid: Optional[jid.JID], | |
496 limit: Optional[int] = None, | |
497 between: bool = True, | |
498 filters: Optional[Dict[str, str]] = None, | |
499 profile: Optional[str] = None, | |
500 ) -> List[Tuple[ | |
501 str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] | |
502 ]: | |
503 """Retrieve messages in history | |
504 | |
505 @param from_jid: source JID (full, or bare for catchall) | |
506 @param to_jid: dest JID (full, or bare for catchall) | |
507 @param limit: maximum number of messages to get: | |
508 - 0 for no message (returns the empty list) | |
509 - None for unlimited | |
510 @param between: confound source and dest (ignore the direction) | |
511 @param filters: pattern to filter the history results | |
512 @return: list of messages as in [message_new], minus the profile which is already | |
513 known. | |
514 """ | |
515 # we have to set a default value to profile because it's last argument | |
516 # and thus follow other keyword arguments with default values | |
517 # but None should not be used for it | |
518 assert profile is not None | |
519 if limit == 0: | |
520 return [] | |
521 if filters is None: | |
522 filters = {} | |
523 | |
524 stmt = ( | |
525 select(History) | |
526 .filter_by( | |
527 profile_id=self.profiles[profile] | |
528 ) | |
529 .outerjoin(History.messages) | |
530 .outerjoin(History.subjects) | |
531 .outerjoin(History.thread) | |
532 .options( | |
533 contains_eager(History.messages), | |
534 contains_eager(History.subjects), | |
535 contains_eager(History.thread), | |
536 ) | |
537 .order_by( | |
538 # timestamp may be identical for 2 close messages (specially when delay is | |
539 # used) that's why we order ties by received_timestamp. We'll reverse the | |
540 # order when returning the result. We use DESC here so LIMIT keep the last | |
541 # messages | |
542 History.timestamp.desc(), | |
543 History.received_timestamp.desc() | |
544 ) | |
545 ) | |
546 | |
547 | |
548 if not from_jid and not to_jid: | |
549 # no jid specified, we want all one2one communications | |
550 pass | |
551 elif between: | |
552 if not from_jid or not to_jid: | |
553 # we only have one jid specified, we check all messages | |
554 # from or to this jid | |
555 jid_ = from_jid or to_jid | |
556 stmt = stmt.where( | |
557 or_( | |
558 self._jid_filter(jid_), | |
559 self._jid_filter(jid_, dest=True) | |
560 ) | |
561 ) | |
562 else: | |
563 # we have 2 jids specified, we check all communications between | |
564 # those 2 jids | |
565 stmt = stmt.where( | |
566 or_( | |
567 and_( | |
568 self._jid_filter(from_jid), | |
569 self._jid_filter(to_jid, dest=True), | |
570 ), | |
571 and_( | |
572 self._jid_filter(to_jid), | |
573 self._jid_filter(from_jid, dest=True), | |
574 ) | |
575 ) | |
576 ) | |
577 else: | |
578 # we want one communication in specific direction (from somebody or | |
579 # to somebody). | |
580 if from_jid is not None: | |
581 stmt = stmt.where(self._jid_filter(from_jid)) | |
582 if to_jid is not None: | |
583 stmt = stmt.where(self._jid_filter(to_jid, dest=True)) | |
584 | |
585 if filters: | |
586 if 'timestamp_start' in filters: | |
587 stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) | |
588 if 'before_uid' in filters: | |
589 # orignially this query was using SQLITE's rowid. This has been changed | |
590 # to use coalesce(received_timestamp, timestamp) to be SQL engine independant | |
591 stmt = stmt.where( | |
592 coalesce( | |
593 History.received_timestamp, | |
594 History.timestamp | |
595 ) < ( | |
596 select(coalesce(History.received_timestamp, History.timestamp)) | |
597 .filter_by(uid=filters["before_uid"]) | |
598 ).scalar_subquery() | |
599 ) | |
600 if 'body' in filters: | |
601 # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html | |
602 stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) | |
603 if 'search' in filters: | |
604 search_term = f"%{filters['search']}%" | |
605 stmt = stmt.where(or_( | |
606 Message.message.like(search_term), | |
607 History.source_res.like(search_term) | |
608 )) | |
609 if 'types' in filters: | |
610 types = filters['types'].split() | |
611 stmt = stmt.where(History.type.in_(types)) | |
612 if 'not_types' in filters: | |
613 types = filters['not_types'].split() | |
614 stmt = stmt.where(History.type.not_in(types)) | |
615 if 'last_stanza_id' in filters: | |
616 # this request get the last message with a "stanza_id" that we | |
617 # have in history. This is mainly used to retrieve messages sent | |
618 # while we were offline, using MAM (XEP-0313). | |
619 if (filters['last_stanza_id'] is not True | |
620 or limit != 1): | |
621 raise ValueError("Unexpected values for last_stanza_id filter") | |
622 stmt = stmt.where(History.stanza_id.is_not(None)) | |
623 if 'origin_id' in filters: | |
624 stmt = stmt.where(History.origin_id == filters["origin_id"]) | |
625 | |
626 if limit is not None: | |
627 stmt = stmt.limit(limit) | |
628 | |
629 async with self.session() as session: | |
630 result = await session.execute(stmt) | |
631 | |
632 result = result.scalars().unique().all() | |
633 result.reverse() | |
634 return [h.as_tuple() for h in result] | |
635 | |
636 @aio | |
637 async def add_to_history(self, data: dict, profile: str) -> None: | |
638 """Store a new message in history | |
639 | |
640 @param data: message data as build by SatMessageProtocol.onMessage | |
641 """ | |
642 extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} | |
643 messages = [Message(message=mess, language=lang) | |
644 for lang, mess in data["message"].items()] | |
645 subjects = [Subject(subject=mess, language=lang) | |
646 for lang, mess in data["subject"].items()] | |
647 if "thread" in data["extra"]: | |
648 thread = Thread(thread_id=data["extra"]["thread"], | |
649 parent_id=data["extra"].get["thread_parent"]) | |
650 else: | |
651 thread = None | |
652 try: | |
653 async with self.session() as session: | |
654 async with session.begin(): | |
655 session.add(History( | |
656 uid=data["uid"], | |
657 origin_id=data["extra"].get("origin_id"), | |
658 stanza_id=data["extra"].get("stanza_id"), | |
659 update_uid=data["extra"].get("update_uid"), | |
660 profile_id=self.profiles[profile], | |
661 source_jid=data["from"], | |
662 dest_jid=data["to"], | |
663 timestamp=data["timestamp"], | |
664 received_timestamp=data.get("received_timestamp"), | |
665 type=data["type"], | |
666 extra=extra, | |
667 messages=messages, | |
668 subjects=subjects, | |
669 thread=thread, | |
670 )) | |
671 except IntegrityError as e: | |
672 if "unique" in str(e.orig).lower(): | |
673 log.debug( | |
674 f"message {data['uid']!r} is already in history, not storing it again" | |
675 ) | |
676 else: | |
677 log.error(f"Can't store message {data['uid']!r} in history: {e}") | |
678 except Exception as e: | |
679 log.critical( | |
680 f"Can't store message, unexpected exception (uid: {data['uid']}): {e}" | |
681 ) | |
682 | |
683 ## Private values | |
684 | |
685 def _get_private_class(self, binary, profile): | |
686 """Get ORM class to use for private values""" | |
687 if profile is None: | |
688 return PrivateGenBin if binary else PrivateGen | |
689 else: | |
690 return PrivateIndBin if binary else PrivateInd | |
691 | |
692 | |
693 @aio | |
694 async def get_privates( | |
695 self, | |
696 namespace:str, | |
697 keys: Optional[Iterable[str]] = None, | |
698 binary: bool = False, | |
699 profile: Optional[str] = None | |
700 ) -> Dict[str, Any]: | |
701 """Get private value(s) from databases | |
702 | |
703 @param namespace: namespace of the values | |
704 @param keys: keys of the values to get None to get all keys/values | |
705 @param binary: True to deserialise binary values | |
706 @param profile: profile to use for individual values | |
707 None to use general values | |
708 @return: gotten keys/values | |
709 """ | |
710 if keys is not None: | |
711 keys = list(keys) | |
712 log.debug( | |
713 f"getting {'general' if profile is None else 'individual'}" | |
714 f"{' binary' if binary else ''} private values from database for namespace " | |
715 f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}" | |
716 ) | |
717 cls = self._get_private_class(binary, profile) | |
718 stmt = select(cls).filter_by(namespace=namespace) | |
719 if keys: | |
720 stmt = stmt.where(cls.key.in_(list(keys))) | |
721 if profile is not None: | |
722 stmt = stmt.filter_by(profile_id=self.profiles[profile]) | |
723 async with self.session() as session: | |
724 result = await session.execute(stmt) | |
725 return {p.key: p.value for p in result.scalars()} | |
726 | |
727 @aio | |
728 async def set_private_value( | |
729 self, | |
730 namespace: str, | |
731 key:str, | |
732 value: Any, | |
733 binary: bool = False, | |
734 profile: Optional[str] = None | |
735 ) -> None: | |
736 """Set a private value in database | |
737 | |
738 @param namespace: namespace of the values | |
739 @param key: key of the value to set | |
740 @param value: value to set | |
741 @param binary: True if it's a binary values | |
742 binary values need to be serialised, used for everything but strings | |
743 @param profile: profile to use for individual value | |
744 if None, it's a general value | |
745 """ | |
746 cls = self._get_private_class(binary, profile) | |
747 | |
748 values = { | |
749 "namespace": namespace, | |
750 "key": key, | |
751 "value": value | |
752 } | |
753 index_elements = [cls.namespace, cls.key] | |
754 | |
755 if profile is not None: | |
756 values["profile_id"] = self.profiles[profile] | |
757 index_elements.append(cls.profile_id) | |
758 | |
759 async with self.session() as session: | |
760 await session.execute( | |
761 insert(cls).values(**values).on_conflict_do_update( | |
762 index_elements=index_elements, | |
763 set_={ | |
764 cls.value: value | |
765 } | |
766 ) | |
767 ) | |
768 await session.commit() | |
769 | |
770 @aio | |
771 async def del_private_value( | |
772 self, | |
773 namespace: str, | |
774 key: str, | |
775 binary: bool = False, | |
776 profile: Optional[str] = None | |
777 ) -> None: | |
778 """Delete private value from database | |
779 | |
780 @param category: category of the privateeter | |
781 @param key: key of the private value | |
782 @param binary: True if it's a binary values | |
783 @param profile: profile to use for individual value | |
784 if None, it's a general value | |
785 """ | |
786 cls = self._get_private_class(binary, profile) | |
787 | |
788 stmt = delete(cls).filter_by(namespace=namespace, key=key) | |
789 | |
790 if profile is not None: | |
791 stmt = stmt.filter_by(profile_id=self.profiles[profile]) | |
792 | |
793 async with self.session() as session: | |
794 await session.execute(stmt) | |
795 await session.commit() | |
796 | |
797 @aio | |
798 async def del_private_namespace( | |
799 self, | |
800 namespace: str, | |
801 binary: bool = False, | |
802 profile: Optional[str] = None | |
803 ) -> None: | |
804 """Delete all data from a private namespace | |
805 | |
806 Be really cautious when you use this method, as all data with given namespace are | |
807 removed. | |
808 Params are the same as for del_private_value | |
809 """ | |
810 cls = self._get_private_class(binary, profile) | |
811 | |
812 stmt = delete(cls).filter_by(namespace=namespace) | |
813 | |
814 if profile is not None: | |
815 stmt = stmt.filter_by(profile_id=self.profiles[profile]) | |
816 | |
817 async with self.session() as session: | |
818 await session.execute(stmt) | |
819 await session.commit() | |
820 | |
821 ## Files | |
822 | |
823 @aio | |
824 async def get_files( | |
825 self, | |
826 client: Optional[SatXMPPEntity], | |
827 file_id: Optional[str] = None, | |
828 version: Optional[str] = '', | |
829 parent: Optional[str] = None, | |
830 type_: Optional[str] = None, | |
831 file_hash: Optional[str] = None, | |
832 hash_algo: Optional[str] = None, | |
833 name: Optional[str] = None, | |
834 namespace: Optional[str] = None, | |
835 mime_type: Optional[str] = None, | |
836 public_id: Optional[str] = None, | |
837 owner: Optional[jid.JID] = None, | |
838 access: Optional[dict] = None, | |
839 projection: Optional[List[str]] = None, | |
840 unique: bool = False | |
841 ) -> List[dict]: | |
842 """Retrieve files with with given filters | |
843 | |
844 @param file_id: id of the file | |
845 None to ignore | |
846 @param version: version of the file | |
847 None to ignore | |
848 empty string to look for current version | |
849 @param parent: id of the directory containing the files | |
850 None to ignore | |
851 empty string to look for root files/directories | |
852 @param projection: name of columns to retrieve | |
853 None to retrieve all | |
854 @param unique: if True will remove duplicates | |
855 other params are the same as for [set_file] | |
856 @return: files corresponding to filters | |
857 """ | |
858 if projection is None: | |
859 projection = [ | |
860 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', | |
861 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', | |
862 'created', 'modified', 'owner', 'access', 'extra' | |
863 ] | |
864 | |
865 stmt = select(*[getattr(File, f) for f in projection]) | |
866 | |
867 if unique: | |
868 stmt = stmt.distinct() | |
869 | |
870 if client is not None: | |
871 stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) | |
872 else: | |
873 if public_id is None: | |
874 raise exceptions.InternalError( | |
875 "client can only be omitted when public_id is set" | |
876 ) | |
877 if file_id is not None: | |
878 stmt = stmt.filter_by(id=file_id) | |
879 if version is not None: | |
880 stmt = stmt.filter_by(version=version) | |
881 if parent is not None: | |
882 stmt = stmt.filter_by(parent=parent) | |
883 if type_ is not None: | |
884 stmt = stmt.filter_by(type=type_) | |
885 if file_hash is not None: | |
886 stmt = stmt.filter_by(file_hash=file_hash) | |
887 if hash_algo is not None: | |
888 stmt = stmt.filter_by(hash_algo=hash_algo) | |
889 if name is not None: | |
890 stmt = stmt.filter_by(name=name) | |
891 if namespace is not None: | |
892 stmt = stmt.filter_by(namespace=namespace) | |
893 if mime_type is not None: | |
894 if '/' in mime_type: | |
895 media_type, media_subtype = mime_type.split("/", 1) | |
896 stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) | |
897 else: | |
898 stmt = stmt.filter_by(media_type=mime_type) | |
899 if public_id is not None: | |
900 stmt = stmt.filter_by(public_id=public_id) | |
901 if owner is not None: | |
902 stmt = stmt.filter_by(owner=owner) | |
903 if access is not None: | |
904 raise NotImplementedError('Access check is not implemented yet') | |
905 # a JSON comparison is needed here | |
906 | |
907 async with self.session() as session: | |
908 result = await session.execute(stmt) | |
909 | |
910 return [dict(r) for r in result] | |
911 | |
912 @aio | |
913 async def set_file( | |
914 self, | |
915 client: SatXMPPEntity, | |
916 name: str, | |
917 file_id: str, | |
918 version: str = "", | |
919 parent: str = "", | |
920 type_: str = C.FILE_TYPE_FILE, | |
921 file_hash: Optional[str] = None, | |
922 hash_algo: Optional[str] = None, | |
923 size: int = None, | |
924 namespace: Optional[str] = None, | |
925 mime_type: Optional[str] = None, | |
926 public_id: Optional[str] = None, | |
927 created: Optional[float] = None, | |
928 modified: Optional[float] = None, | |
929 owner: Optional[jid.JID] = None, | |
930 access: Optional[dict] = None, | |
931 extra: Optional[dict] = None | |
932 ) -> None: | |
933 """Set a file metadata | |
934 | |
935 @param client: client owning the file | |
936 @param name: name of the file (must not contain "/") | |
937 @param file_id: unique id of the file | |
938 @param version: version of this file | |
939 @param parent: id of the directory containing this file | |
940 Empty string if it is a root file/directory | |
941 @param type_: one of: | |
942 - file | |
943 - directory | |
944 @param file_hash: unique hash of the payload | |
945 @param hash_algo: algorithm used for hashing the file (usually sha-256) | |
946 @param size: size in bytes | |
947 @param namespace: identifier (human readable is better) to group files | |
948 for instance, namespace could be used to group files in a specific photo album | |
949 @param mime_type: media type of the file, or None if not known/guessed | |
950 @param public_id: ID used to server the file publicly via HTTP | |
951 @param created: UNIX time of creation | |
952 @param modified: UNIX time of last modification, or None to use created date | |
953 @param owner: jid of the owner of the file (mainly useful for component) | |
954 @param access: serialisable dictionary with access rules. See [memory.memory] for | |
955 details | |
956 @param extra: serialisable dictionary of any extra data | |
957 will be encoded to json in database | |
958 """ | |
959 if mime_type is None: | |
960 media_type = media_subtype = None | |
961 elif '/' in mime_type: | |
962 media_type, media_subtype = mime_type.split('/', 1) | |
963 else: | |
964 media_type, media_subtype = mime_type, None | |
965 | |
966 async with self.session() as session: | |
967 async with session.begin(): | |
968 session.add(File( | |
969 id=file_id, | |
970 version=version.strip(), | |
971 parent=parent, | |
972 type=type_, | |
973 file_hash=file_hash, | |
974 hash_algo=hash_algo, | |
975 name=name, | |
976 size=size, | |
977 namespace=namespace, | |
978 media_type=media_type, | |
979 media_subtype=media_subtype, | |
980 public_id=public_id, | |
981 created=time.time() if created is None else created, | |
982 modified=modified, | |
983 owner=owner, | |
984 access=access, | |
985 extra=extra, | |
986 profile_id=self.profiles[client.profile] | |
987 )) | |
988 | |
989 @aio | |
990 async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int: | |
991 async with self.session() as session: | |
992 result = await session.execute( | |
993 select(sum_(File.size)).filter_by( | |
994 owner=owner, | |
995 type=C.FILE_TYPE_FILE, | |
996 profile_id=self.profiles[client.profile] | |
997 )) | |
998 return result.scalar_one_or_none() or 0 | |
999 | |
1000 @aio | |
1001 async def file_delete(self, file_id: str) -> None: | |
1002 """Delete file metadata from the database | |
1003 | |
1004 @param file_id: id of the file to delete | |
1005 NOTE: file itself must still be removed, this method only handle metadata in | |
1006 database | |
1007 """ | |
1008 async with self.session() as session: | |
1009 await session.execute(delete(File).filter_by(id=file_id)) | |
1010 await session.commit() | |
1011 | |
1012 @aio | |
1013 async def file_update( | |
1014 self, | |
1015 file_id: str, | |
1016 column: str, | |
1017 update_cb: Callable[[dict], None] | |
1018 ) -> None: | |
1019 """Update a column value using a method to avoid race conditions | |
1020 | |
1021 the older value will be retrieved from database, then update_cb will be applied to | |
1022 update it, and file will be updated checking that older value has not been changed | |
1023 meanwhile by an other user. If it has changed, it tries again a couple of times | |
1024 before failing | |
1025 @param column: column name (only "access" or "extra" are allowed) | |
1026 @param update_cb: method to update the value of the colum | |
1027 the method will take older value as argument, and must update it in place | |
1028 update_cb must not care about serialization, | |
1029 it get the deserialized data (i.e. a Python object) directly | |
1030 @raise exceptions.NotFound: there is not file with this id | |
1031 """ | |
1032 if column not in ('access', 'extra'): | |
1033 raise exceptions.InternalError('bad column name') | |
1034 orm_col = getattr(File, column) | |
1035 | |
1036 for i in range(5): | |
1037 async with self.session() as session: | |
1038 try: | |
1039 value = (await session.execute( | |
1040 select(orm_col).filter_by(id=file_id) | |
1041 )).scalar_one() | |
1042 except NoResultFound: | |
1043 raise exceptions.NotFound | |
1044 old_value = copy.deepcopy(value) | |
1045 update_cb(value) | |
1046 stmt = update(File).filter_by(id=file_id).values({column: value}) | |
1047 if not old_value: | |
1048 # because JsonDefaultDict convert NULL to an empty dict, we have to | |
1049 # test both for empty dict and None when we have an empty dict | |
1050 stmt = stmt.where((orm_col == None) | (orm_col == old_value)) | |
1051 else: | |
1052 stmt = stmt.where(orm_col == old_value) | |
1053 result = await session.execute(stmt) | |
1054 await session.commit() | |
1055 | |
1056 if result.rowcount == 1: | |
1057 break | |
1058 | |
1059 log.warning( | |
1060 _("table not updated, probably due to race condition, trying again " | |
1061 "({tries})").format(tries=i+1) | |
1062 ) | |
1063 | |
1064 else: | |
1065 raise exceptions.DatabaseError( | |
1066 _("Can't update file {file_id} due to race condition") | |
1067 .format(file_id=file_id) | |
1068 ) | |
1069 | |
1070 @aio | |
1071 async def get_pubsub_node( | |
1072 self, | |
1073 client: SatXMPPEntity, | |
1074 service: jid.JID, | |
1075 name: str, | |
1076 with_items: bool = False, | |
1077 with_subscriptions: bool = False, | |
1078 create: bool = False, | |
1079 create_kwargs: Optional[dict] = None | |
1080 ) -> Optional[PubsubNode]: | |
1081 """Retrieve a PubsubNode from DB | |
1082 | |
1083 @param service: service hosting the node | |
1084 @param name: node's name | |
1085 @param with_items: retrieve items in the same query | |
1086 @param with_subscriptions: retrieve subscriptions in the same query | |
1087 @param create: if the node doesn't exist in DB, create it | |
1088 @param create_kwargs: keyword arguments to use with ``set_pubsub_node`` if the node | |
1089 needs to be created. | |
1090 """ | |
1091 async with self.session() as session: | |
1092 stmt = ( | |
1093 select(PubsubNode) | |
1094 .filter_by( | |
1095 service=service, | |
1096 name=name, | |
1097 profile_id=self.profiles[client.profile], | |
1098 ) | |
1099 ) | |
1100 if with_items: | |
1101 stmt = stmt.options( | |
1102 joinedload(PubsubNode.items) | |
1103 ) | |
1104 if with_subscriptions: | |
1105 stmt = stmt.options( | |
1106 joinedload(PubsubNode.subscriptions) | |
1107 ) | |
1108 result = await session.execute(stmt) | |
1109 ret = result.unique().scalar_one_or_none() | |
1110 if ret is None and create: | |
1111 # we auto-create the node | |
1112 if create_kwargs is None: | |
1113 create_kwargs = {} | |
1114 try: | |
1115 return await as_future(self.set_pubsub_node( | |
1116 client, service, name, **create_kwargs | |
1117 )) | |
1118 except IntegrityError as e: | |
1119 if "unique" in str(e.orig).lower(): | |
1120 # the node may already exist, if it has been created just after | |
1121 # get_pubsub_node above | |
1122 log.debug("ignoring UNIQUE constraint error") | |
1123 cached_node = await as_future(self.get_pubsub_node( | |
1124 client, | |
1125 service, | |
1126 name, | |
1127 with_items=with_items, | |
1128 with_subscriptions=with_subscriptions | |
1129 )) | |
1130 else: | |
1131 raise e | |
1132 else: | |
1133 return ret | |
1134 | |
1135 @aio | |
1136 async def set_pubsub_node( | |
1137 self, | |
1138 client: SatXMPPEntity, | |
1139 service: jid.JID, | |
1140 name: str, | |
1141 analyser: Optional[str] = None, | |
1142 type_: Optional[str] = None, | |
1143 subtype: Optional[str] = None, | |
1144 subscribed: bool = False, | |
1145 ) -> PubsubNode: | |
1146 node = PubsubNode( | |
1147 profile_id=self.profiles[client.profile], | |
1148 service=service, | |
1149 name=name, | |
1150 subscribed=subscribed, | |
1151 analyser=analyser, | |
1152 type_=type_, | |
1153 subtype=subtype, | |
1154 subscriptions=[], | |
1155 ) | |
1156 async with self.session() as session: | |
1157 async with session.begin(): | |
1158 session.add(node) | |
1159 return node | |
1160 | |
1161 @aio | |
1162 async def update_pubsub_node_sync_state( | |
1163 self, | |
1164 node: PubsubNode, | |
1165 state: SyncState | |
1166 ) -> None: | |
1167 async with self.session() as session: | |
1168 async with session.begin(): | |
1169 await session.execute( | |
1170 update(PubsubNode) | |
1171 .filter_by(id=node.id) | |
1172 .values( | |
1173 sync_state=state, | |
1174 sync_state_updated=time.time(), | |
1175 ) | |
1176 ) | |
1177 | |
1178 @aio | |
1179 async def delete_pubsub_node( | |
1180 self, | |
1181 profiles: Optional[List[str]], | |
1182 services: Optional[List[jid.JID]], | |
1183 names: Optional[List[str]] | |
1184 ) -> None: | |
1185 """Delete items cached for a node | |
1186 | |
1187 @param profiles: profile names from which nodes must be deleted. | |
1188 None to remove nodes from ALL profiles | |
1189 @param services: JIDs of pubsub services from which nodes must be deleted. | |
1190 None to remove nodes from ALL services | |
1191 @param names: names of nodes which must be deleted. | |
1192 None to remove ALL nodes whatever is their names | |
1193 """ | |
1194 stmt = delete(PubsubNode) | |
1195 if profiles is not None: | |
1196 stmt = stmt.where( | |
1197 PubsubNode.profile.in_( | |
1198 [self.profiles[p] for p in profiles] | |
1199 ) | |
1200 ) | |
1201 if services is not None: | |
1202 stmt = stmt.where(PubsubNode.service.in_(services)) | |
1203 if names is not None: | |
1204 stmt = stmt.where(PubsubNode.name.in_(names)) | |
1205 async with self.session() as session: | |
1206 await session.execute(stmt) | |
1207 await session.commit() | |
1208 | |
1209 @aio | |
1210 async def cache_pubsub_items( | |
1211 self, | |
1212 client: SatXMPPEntity, | |
1213 node: PubsubNode, | |
1214 items: List[domish.Element], | |
1215 parsed_items: Optional[List[dict]] = None, | |
1216 ) -> None: | |
1217 """Add items to database, using an upsert taking care of "updated" field""" | |
1218 if parsed_items is not None and len(items) != len(parsed_items): | |
1219 raise exceptions.InternalError( | |
1220 "parsed_items must have the same lenght as items" | |
1221 ) | |
1222 async with self.session() as session: | |
1223 async with session.begin(): | |
1224 for idx, item in enumerate(items): | |
1225 parsed = parsed_items[idx] if parsed_items else None | |
1226 stmt = insert(PubsubItem).values( | |
1227 node_id = node.id, | |
1228 name = item["id"], | |
1229 data = item, | |
1230 parsed = parsed, | |
1231 ).on_conflict_do_update( | |
1232 index_elements=(PubsubItem.node_id, PubsubItem.name), | |
1233 set_={ | |
1234 PubsubItem.data: item, | |
1235 PubsubItem.parsed: parsed, | |
1236 PubsubItem.updated: now() | |
1237 } | |
1238 ) | |
1239 await session.execute(stmt) | |
1240 await session.commit() | |
1241 | |
1242 @aio | |
1243 async def delete_pubsub_items( | |
1244 self, | |
1245 node: PubsubNode, | |
1246 items_names: Optional[List[str]] = None | |
1247 ) -> None: | |
1248 """Delete items cached for a node | |
1249 | |
1250 @param node: node from which items must be deleted | |
1251 @param items_names: names of items to delete | |
1252 if None, ALL items will be deleted | |
1253 """ | |
1254 stmt = delete(PubsubItem) | |
1255 if node is not None: | |
1256 if isinstance(node, list): | |
1257 stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node])) | |
1258 else: | |
1259 stmt = stmt.filter_by(node_id=node.id) | |
1260 if items_names is not None: | |
1261 stmt = stmt.where(PubsubItem.name.in_(items_names)) | |
1262 async with self.session() as session: | |
1263 await session.execute(stmt) | |
1264 await session.commit() | |
1265 | |
1266 @aio | |
1267 async def purge_pubsub_items( | |
1268 self, | |
1269 services: Optional[List[jid.JID]] = None, | |
1270 names: Optional[List[str]] = None, | |
1271 types: Optional[List[str]] = None, | |
1272 subtypes: Optional[List[str]] = None, | |
1273 profiles: Optional[List[str]] = None, | |
1274 created_before: Optional[datetime] = None, | |
1275 updated_before: Optional[datetime] = None, | |
1276 ) -> None: | |
1277 """Delete items cached for a node | |
1278 | |
1279 @param node: node from which items must be deleted | |
1280 @param items_names: names of items to delete | |
1281 if None, ALL items will be deleted | |
1282 """ | |
1283 stmt = delete(PubsubItem) | |
1284 node_fields = { | |
1285 "service": services, | |
1286 "name": names, | |
1287 "type_": types, | |
1288 "subtype": subtypes, | |
1289 } | |
1290 if profiles is not None: | |
1291 node_fields["profile_id"] = [self.profiles[p] for p in profiles] | |
1292 | |
1293 if any(x is not None for x in node_fields.values()): | |
1294 sub_q = select(PubsubNode.id) | |
1295 for col, values in node_fields.items(): | |
1296 if values is None: | |
1297 continue | |
1298 sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) | |
1299 stmt = ( | |
1300 stmt | |
1301 .where(PubsubItem.node_id.in_(sub_q)) | |
1302 .execution_options(synchronize_session=False) | |
1303 ) | |
1304 | |
1305 if created_before is not None: | |
1306 stmt = stmt.where(PubsubItem.created < created_before) | |
1307 | |
1308 if updated_before is not None: | |
1309 stmt = stmt.where(PubsubItem.updated < updated_before) | |
1310 | |
1311 async with self.session() as session: | |
1312 await session.execute(stmt) | |
1313 await session.commit() | |
1314 | |
1315 @aio | |
1316 async def get_items( | |
1317 self, | |
1318 node: PubsubNode, | |
1319 max_items: Optional[int] = None, | |
1320 item_ids: Optional[list[str]] = None, | |
1321 before: Optional[str] = None, | |
1322 after: Optional[str] = None, | |
1323 from_index: Optional[int] = None, | |
1324 order_by: Optional[List[str]] = None, | |
1325 desc: bool = True, | |
1326 force_rsm: bool = False, | |
1327 ) -> Tuple[List[PubsubItem], dict]: | |
1328 """Get Pubsub Items from cache | |
1329 | |
1330 @param node: retrieve items from this node (must be synchronised) | |
1331 @param max_items: maximum number of items to retrieve | |
1332 @param before: get items which are before the item with this name in given order | |
1333 empty string is not managed here, use desc order to reproduce RSM | |
1334 behaviour. | |
1335 @param after: get items which are after the item with this name in given order | |
1336 @param from_index: get items with item index (as defined in RSM spec) | |
1337 starting from this number | |
1338 @param order_by: sorting order of items (one of C.ORDER_BY_*) | |
1339 @param desc: direction or ordering | |
1340 @param force_rsm: if True, force the use of RSM worklow. | |
1341 RSM workflow is automatically used if any of before, after or | |
1342 from_index is used, but if only RSM max_items is used, it won't be | |
1343 used by default. This parameter let's use RSM workflow in this | |
1344 case. Note that in addition to RSM metadata, the result will not be | |
1345 the same (max_items without RSM will returns most recent items, | |
1346 i.e. last items in modification order, while max_items with RSM | |
1347 will return the oldest ones (i.e. first items in modification | |
1348 order). | |
1349 to be used when max_items is used from RSM | |
1350 """ | |
1351 | |
1352 metadata = { | |
1353 "service": node.service, | |
1354 "node": node.name, | |
1355 "uri": uri.build_xmpp_uri( | |
1356 "pubsub", | |
1357 path=node.service.full(), | |
1358 node=node.name, | |
1359 ), | |
1360 } | |
1361 if max_items is None: | |
1362 max_items = 20 | |
1363 | |
1364 use_rsm = any((before, after, from_index is not None)) | |
1365 if force_rsm and not use_rsm: | |
1366 # | |
1367 use_rsm = True | |
1368 from_index = 0 | |
1369 | |
1370 stmt = ( | |
1371 select(PubsubItem) | |
1372 .filter_by(node_id=node.id) | |
1373 .limit(max_items) | |
1374 ) | |
1375 | |
1376 if item_ids is not None: | |
1377 stmt = stmt.where(PubsubItem.name.in_(item_ids)) | |
1378 | |
1379 if not order_by: | |
1380 order_by = [C.ORDER_BY_MODIFICATION] | |
1381 | |
1382 order = [] | |
1383 for order_type in order_by: | |
1384 if order_type == C.ORDER_BY_MODIFICATION: | |
1385 if desc: | |
1386 order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc())) | |
1387 else: | |
1388 order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc())) | |
1389 elif order_type == C.ORDER_BY_CREATION: | |
1390 if desc: | |
1391 order.append(PubsubItem.id.desc()) | |
1392 else: | |
1393 order.append(PubsubItem.id.asc()) | |
1394 else: | |
1395 raise exceptions.InternalError(f"Unknown order type {order_type!r}") | |
1396 | |
1397 stmt = stmt.order_by(*order) | |
1398 | |
1399 if use_rsm: | |
1400 # CTE to have result row numbers | |
1401 row_num_q = select( | |
1402 PubsubItem.id, | |
1403 PubsubItem.name, | |
1404 # row_number starts from 1, but RSM index must start from 0 | |
1405 (func.row_number().over(order_by=order)-1).label("item_index") | |
1406 ).filter_by(node_id=node.id) | |
1407 | |
1408 row_num_cte = row_num_q.cte() | |
1409 | |
1410 if max_items > 0: | |
1411 # as we can't simply use PubsubItem.id when we order by modification, | |
1412 # we need to use row number | |
1413 item_name = before or after | |
1414 row_num_limit_q = ( | |
1415 select(row_num_cte.c.item_index) | |
1416 .where(row_num_cte.c.name==item_name) | |
1417 ).scalar_subquery() | |
1418 | |
1419 stmt = ( | |
1420 select(row_num_cte.c.item_index, PubsubItem) | |
1421 .join(row_num_cte, PubsubItem.id == row_num_cte.c.id) | |
1422 .limit(max_items) | |
1423 ) | |
1424 if before: | |
1425 stmt = ( | |
1426 stmt | |
1427 .where(row_num_cte.c.item_index<row_num_limit_q) | |
1428 .order_by(row_num_cte.c.item_index.desc()) | |
1429 ) | |
1430 elif after: | |
1431 stmt = ( | |
1432 stmt | |
1433 .where(row_num_cte.c.item_index>row_num_limit_q) | |
1434 .order_by(row_num_cte.c.item_index.asc()) | |
1435 ) | |
1436 else: | |
1437 stmt = ( | |
1438 stmt | |
1439 .where(row_num_cte.c.item_index>=from_index) | |
1440 .order_by(row_num_cte.c.item_index.asc()) | |
1441 ) | |
1442 # from_index is used | |
1443 | |
1444 async with self.session() as session: | |
1445 if max_items == 0: | |
1446 items = result = [] | |
1447 else: | |
1448 result = await session.execute(stmt) | |
1449 result = result.all() | |
1450 if before: | |
1451 result.reverse() | |
1452 items = [row[-1] for row in result] | |
1453 rows_count = ( | |
1454 await session.execute(row_num_q.with_only_columns(count())) | |
1455 ).scalar_one() | |
1456 | |
1457 try: | |
1458 index = result[0][0] | |
1459 except IndexError: | |
1460 index = None | |
1461 | |
1462 try: | |
1463 first = result[0][1].name | |
1464 except IndexError: | |
1465 first = None | |
1466 last = None | |
1467 else: | |
1468 last = result[-1][1].name | |
1469 | |
1470 metadata["rsm"] = { | |
1471 k: v for k, v in { | |
1472 "index": index, | |
1473 "count": rows_count, | |
1474 "first": first, | |
1475 "last": last, | |
1476 }.items() if v is not None | |
1477 } | |
1478 metadata["complete"] = (index or 0) + len(result) == rows_count | |
1479 | |
1480 return items, metadata | |
1481 | |
1482 async with self.session() as session: | |
1483 result = await session.execute(stmt) | |
1484 | |
1485 result = result.scalars().all() | |
1486 if desc: | |
1487 result.reverse() | |
1488 return result, metadata | |
1489 | |
1490 def _get_sqlite_path( | |
1491 self, | |
1492 path: List[Union[str, int]] | |
1493 ) -> str: | |
1494 """generate path suitable to query JSON element with SQLite""" | |
1495 return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}" | |
1496 | |
1497 @aio | |
1498 async def search_pubsub_items( | |
1499 self, | |
1500 query: dict, | |
1501 ) -> Tuple[List[PubsubItem]]: | |
1502 """Search for pubsub items in cache | |
1503 | |
1504 @param query: search terms. Keys can be: | |
1505 :fts (str): | |
1506 Full-Text Search query. Currently SQLite FT5 engine is used, its query | |
1507 syntax can be used, see `FTS5 Query documentation | |
1508 <https://sqlite.org/fts5.html#full_text_query_syntax>`_ | |
1509 :profiles (list[str]): | |
1510 filter on nodes linked to those profiles | |
1511 :nodes (list[str]): | |
1512 filter on nodes with those names | |
1513 :services (list[jid.JID]): | |
1514 filter on nodes from those services | |
1515 :types (list[str|None]): | |
1516 filter on nodes with those types. None can be used to filter on nodes with | |
1517 no type set | |
1518 :subtypes (list[str|None]): | |
1519 filter on nodes with those subtypes. None can be used to filter on nodes with | |
1520 no subtype set | |
1521 :names (list[str]): | |
1522 filter on items with those names | |
1523 :parsed (list[dict]): | |
1524 Filter on a parsed data field. The dict must contain 3 keys: ``path`` | |
1525 which is a list of str or int giving the path to the field of interest | |
1526 (str for a dict key, int for a list index), ``operator`` with indicate the | |
1527 operator to use to check the condition, and ``value`` which depends of | |
1528 field type and operator. | |
1529 | |
1530 See documentation for details on operators (it's currently explained at | |
1531 ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command | |
1532 documentation). | |
1533 | |
1534 :order-by (list[dict]): | |
1535 Indicates how to order results. The dict can contain either a ``order`` | |
1536 for a well-know order or a ``path`` for a parsed data field path | |
1537 (``order`` and ``path`` can't be used at the same time), an an optional | |
1538 ``direction`` which can be ``asc`` or ``desc``. See documentation for | |
1539 details on well-known orders (it's currently explained at | |
1540 ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command | |
1541 documentation). | |
1542 | |
1543 :index (int): | |
1544 starting index of items to return from the query result. It's translated | |
1545 to SQL's OFFSET | |
1546 | |
1547 :limit (int): | |
1548 maximum number of items to return. It's translated to SQL's LIMIT. | |
1549 | |
1550 @result: found items (the ``node`` attribute will be filled with suitable | |
1551 PubsubNode) | |
1552 """ | |
1553 # TODO: FTS and parsed data filters use SQLite specific syntax | |
1554 # when other DB engines will be used, this will have to be adapted | |
1555 stmt = select(PubsubItem) | |
1556 | |
1557 # Full-Text Search | |
1558 fts = query.get("fts") | |
1559 if fts: | |
1560 fts_select = text( | |
1561 "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)" | |
1562 ).bindparams(fts_query=fts).columns(rowid=Integer).subquery() | |
1563 stmt = ( | |
1564 stmt | |
1565 .select_from(fts_select) | |
1566 .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id) | |
1567 ) | |
1568 | |
1569 # node related filters | |
1570 profiles = query.get("profiles") | |
1571 if (profiles | |
1572 or any(query.get(k) for k in ("nodes", "services", "types", "subtypes")) | |
1573 ): | |
1574 stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node)) | |
1575 if profiles: | |
1576 try: | |
1577 stmt = stmt.where( | |
1578 PubsubNode.profile_id.in_(self.profiles[p] for p in profiles) | |
1579 ) | |
1580 except KeyError as e: | |
1581 raise exceptions.ProfileUnknownError( | |
1582 f"This profile doesn't exist: {e.args[0]!r}" | |
1583 ) | |
1584 for key, attr in ( | |
1585 ("nodes", "name"), | |
1586 ("services", "service"), | |
1587 ("types", "type_"), | |
1588 ("subtypes", "subtype") | |
1589 ): | |
1590 value = query.get(key) | |
1591 if not value: | |
1592 continue | |
1593 if key in ("types", "subtypes") and None in value: | |
1594 # NULL can't be used with SQL's IN, so we have to add a condition with | |
1595 # IS NULL, and use a OR if there are other values to check | |
1596 value.remove(None) | |
1597 condition = getattr(PubsubNode, attr).is_(None) | |
1598 if value: | |
1599 condition = or_( | |
1600 getattr(PubsubNode, attr).in_(value), | |
1601 condition | |
1602 ) | |
1603 else: | |
1604 condition = getattr(PubsubNode, attr).in_(value) | |
1605 stmt = stmt.where(condition) | |
1606 else: | |
1607 stmt = stmt.options(selectinload(PubsubItem.node)) | |
1608 | |
1609 # names | |
1610 names = query.get("names") | |
1611 if names: | |
1612 stmt = stmt.where(PubsubItem.name.in_(names)) | |
1613 | |
1614 # parsed data filters | |
1615 parsed = query.get("parsed", []) | |
1616 for filter_ in parsed: | |
1617 try: | |
1618 path = filter_["path"] | |
1619 operator = filter_["op"] | |
1620 value = filter_["value"] | |
1621 except KeyError as e: | |
1622 raise ValueError( | |
1623 f'missing mandatory key {e.args[0]!r} in "parsed" filter' | |
1624 ) | |
1625 try: | |
1626 op_attr = OP_MAP[operator] | |
1627 except KeyError: | |
1628 raise ValueError(f"invalid operator: {operator!r}") | |
1629 sqlite_path = self._get_sqlite_path(path) | |
1630 if operator in ("overlap", "ioverlap", "disjoint", "idisjoint"): | |
1631 col = literal_column("json_each.value") | |
1632 if operator[0] == "i": | |
1633 col = func.lower(col) | |
1634 value = [str(v).lower() for v in value] | |
1635 condition = ( | |
1636 select(1) | |
1637 .select_from(func.json_each(PubsubItem.parsed, sqlite_path)) | |
1638 .where(col.in_(value)) | |
1639 ).scalar_subquery() | |
1640 if operator in ("disjoint", "idisjoint"): | |
1641 condition = condition.is_(None) | |
1642 stmt = stmt.where(condition) | |
1643 elif operator == "between": | |
1644 try: | |
1645 left, right = value | |
1646 except (ValueError, TypeError): | |
1647 raise ValueError(_( | |
1648 "invalid value for \"between\" filter, you must use a 2 items " | |
1649 "array: {value!r}" | |
1650 ).format(value=value)) | |
1651 col = func.json_extract(PubsubItem.parsed, sqlite_path) | |
1652 stmt = stmt.where(col.between(left, right)) | |
1653 else: | |
1654 # we use func.json_extract instead of generic JSON way because SQLAlchemy | |
1655 # add a JSON_QUOTE to the value, and we want SQL value | |
1656 col = func.json_extract(PubsubItem.parsed, sqlite_path) | |
1657 stmt = stmt.where(getattr(col, op_attr)(value)) | |
1658 | |
1659 # order | |
1660 order_by = query.get("order-by") or [{"order": "creation"}] | |
1661 | |
1662 for order_data in order_by: | |
1663 order, path = order_data.get("order"), order_data.get("path") | |
1664 if order and path: | |
1665 raise ValueError(_( | |
1666 '"order" and "path" can\'t be used at the same time in ' | |
1667 '"order-by" data' | |
1668 )) | |
1669 if order: | |
1670 if order == "creation": | |
1671 col = PubsubItem.id | |
1672 elif order == "modification": | |
1673 col = PubsubItem.updated | |
1674 elif order == "item_id": | |
1675 col = PubsubItem.name | |
1676 elif order == "rank": | |
1677 if not fts: | |
1678 raise ValueError( | |
1679 "'rank' order can only be used with Full-Text Search (fts)" | |
1680 ) | |
1681 col = literal_column("rank") | |
1682 else: | |
1683 raise NotImplementedError(f"Unknown {order!r} order") | |
1684 else: | |
1685 # we have a JSON path | |
1686 # sqlite_path = self._get_sqlite_path(path) | |
1687 col = PubsubItem.parsed[path] | |
1688 direction = order_data.get("direction", "ASC").lower() | |
1689 if not direction in ("asc", "desc"): | |
1690 raise ValueError(f"Invalid order-by direction: {direction!r}") | |
1691 stmt = stmt.order_by(getattr(col, direction)()) | |
1692 | |
1693 # offset, limit | |
1694 index = query.get("index") | |
1695 if index: | |
1696 stmt = stmt.offset(index) | |
1697 limit = query.get("limit") | |
1698 if limit: | |
1699 stmt = stmt.limit(limit) | |
1700 | |
1701 async with self.session() as session: | |
1702 result = await session.execute(stmt) | |
1703 | |
1704 return result.scalars().all() |