comparison libervia/backend/memory/sqla.py @ 4071:4b842c1fb686

refactoring: renamed `sat` package to `libervia.backend`
author Goffi <goffi@goffi.org>
date Fri, 02 Jun 2023 11:49:51 +0200
parents sat/memory/sqla.py@524856bd7b19
children 74c66c0d93f3
comparison
equal deleted inserted replaced
4070:d10748475025 4071:4b842c1fb686
1 #!/usr/bin/env python3
2
3 # Libervia: an XMPP client
4 # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)
5
6 # This program is free software: you can redistribute it and/or modify
7 # it under the terms of the GNU Affero General Public License as published by
8 # the Free Software Foundation, either version 3 of the License, or
9 # (at your option) any later version.
10
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU Affero General Public License for more details.
15
16 # You should have received a copy of the GNU Affero General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18
19 import asyncio
20 from asyncio.subprocess import PIPE
21 import copy
22 from datetime import datetime
23 from pathlib import Path
24 import sys
25 import time
26 from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
27
28 from alembic import config as al_config, script as al_script
29 from alembic.runtime import migration as al_migration
30 from sqlalchemy import and_, delete, event, func, or_, update
31 from sqlalchemy import Integer, literal_column, text
32 from sqlalchemy.dialects.sqlite import insert
33 from sqlalchemy.engine import Connection, Engine
34 from sqlalchemy.exc import IntegrityError, NoResultFound
35 from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
36 from sqlalchemy.future import select
37 from sqlalchemy.orm import (
38 contains_eager,
39 joinedload,
40 selectinload,
41 sessionmaker,
42 subqueryload,
43 )
44 from sqlalchemy.orm.attributes import Mapped
45 from sqlalchemy.orm.decl_api import DeclarativeMeta
46 from sqlalchemy.sql.functions import coalesce, count, now, sum as sum_
47 from twisted.internet import defer
48 from twisted.words.protocols.jabber import jid
49 from twisted.words.xish import domish
50
51 from libervia.backend.core import exceptions
52 from libervia.backend.core.constants import Const as C
53 from libervia.backend.core.core_types import SatXMPPEntity
54 from libervia.backend.core.i18n import _
55 from libervia.backend.core.log import getLogger
56 from libervia.backend.memory import migration
57 from libervia.backend.memory import sqla_config
58 from libervia.backend.memory.sqla_mapping import (
59 Base,
60 Component,
61 File,
62 History,
63 Message,
64 NOT_IN_EXTRA,
65 ParamGen,
66 ParamInd,
67 PrivateGen,
68 PrivateGenBin,
69 PrivateInd,
70 PrivateIndBin,
71 Profile,
72 PubsubItem,
73 PubsubNode,
74 Subject,
75 SyncState,
76 Thread,
77 )
78 from libervia.backend.tools.common import uri
79 from libervia.backend.tools.utils import aio, as_future
80
81
82 log = getLogger(__name__)
83 migration_path = Path(migration.__file__).parent
84 #: mapping of Libervia search query operators to SQLAlchemy method name
85 OP_MAP = {
86 "==": "__eq__",
87 "eq": "__eq__",
88 "!=": "__ne__",
89 "ne": "__ne__",
90 ">": "__gt__",
91 "gt": "__gt__",
92 "<": "__le__",
93 "le": "__le__",
94 "between": "between",
95 "in": "in_",
96 "not_in": "not_in",
97 "overlap": "in_",
98 "ioverlap": "in_",
99 "disjoint": "in_",
100 "idisjoint": "in_",
101 "like": "like",
102 "ilike": "ilike",
103 "not_like": "notlike",
104 "not_ilike": "notilike",
105 }
106
107
108 @event.listens_for(Engine, "connect")
109 def set_sqlite_pragma(dbapi_connection, connection_record):
110 cursor = dbapi_connection.cursor()
111 cursor.execute("PRAGMA foreign_keys=ON")
112 cursor.close()
113
114
115 class Storage:
116
117 def __init__(self):
118 self.initialized = defer.Deferred()
119 # we keep cache for the profiles (key: profile name, value: profile id)
120 # profile id to name
121 self.profiles: Dict[str, int] = {}
122 # profile id to component entry point
123 self.components: Dict[int, str] = {}
124
125 def get_profile_by_id(self, profile_id):
126 return self.profiles.get(profile_id)
127
128 async def migrate_apply(self, *args: str, log_output: bool = False) -> None:
129 """Do a migration command
130
131 Commands are applied by running Alembic in a subprocess.
132 Arguments are alembic executables commands
133
134 @param log_output: manage stdout and stderr:
135 - if False, stdout and stderr are buffered, and logged only in case of error
136 - if True, stdout and stderr will be logged during the command execution
137 @raise exceptions.DatabaseError: something went wrong while running the
138 process
139 """
140 stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,)
141 proc = await asyncio.create_subprocess_exec(
142 sys.executable, "-m", "alembic", *args,
143 stdout=stdout, stderr=stderr, cwd=migration_path
144 )
145 log_out, log_err = await proc.communicate()
146 if proc.returncode != 0:
147 msg = _(
148 "Can't {operation} database (exit code {exit_code})"
149 ).format(
150 operation=args[0],
151 exit_code=proc.returncode
152 )
153 if log_out or log_err:
154 msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}"
155 log.error(msg)
156
157 raise exceptions.DatabaseError(msg)
158
159 async def create_db(self, engine: AsyncEngine, db_config: dict) -> None:
160 """Create a new database
161
162 The database is generated from SQLAlchemy model, then stamped by Alembic
163 """
164 # the dir may not exist if it's not the XDG recommended one
165 db_config["path"].parent.mkdir(0o700, True, True)
166 async with engine.begin() as conn:
167 await conn.run_sync(Base.metadata.create_all)
168
169 log.debug("stamping the database")
170 await self.migrate_apply("stamp", "head")
171 log.debug("stamping done")
172
173 def _check_db_is_up_to_date(self, conn: Connection) -> bool:
174 al_ini_path = migration_path / "alembic.ini"
175 al_cfg = al_config.Config(al_ini_path)
176 directory = al_script.ScriptDirectory.from_config(al_cfg)
177 context = al_migration.MigrationContext.configure(conn)
178 return set(context.get_current_heads()) == set(directory.get_heads())
179
180 def _sqlite_set_journal_mode_wal(self, conn: Connection) -> None:
181 """Check if journal mode is WAL, and set it if necesssary"""
182 result = conn.execute(text("PRAGMA journal_mode"))
183 if result.scalar() != "wal":
184 log.info("WAL mode not activated, activating it")
185 conn.execute(text("PRAGMA journal_mode=WAL"))
186
187 async def check_and_update_db(self, engine: AsyncEngine, db_config: dict) -> None:
188 """Check that database is up-to-date, and update if necessary"""
189 async with engine.connect() as conn:
190 up_to_date = await conn.run_sync(self._check_db_is_up_to_date)
191 if up_to_date:
192 log.debug("Database is up-to-date")
193 else:
194 log.info("Database needs to be updated")
195 log.info("updating…")
196 await self.migrate_apply("upgrade", "head", log_output=True)
197 log.info("Database is now up-to-date")
198
199 @aio
200 async def initialise(self) -> None:
201 log.info(_("Connecting database"))
202
203 db_config = sqla_config.get_db_config()
204 engine = create_async_engine(
205 db_config["url"],
206 future=True,
207 )
208
209 new_base = not db_config["path"].exists()
210 if new_base:
211 log.info(_("The database is new, creating the tables"))
212 await self.create_db(engine, db_config)
213 else:
214 await self.check_and_update_db(engine, db_config)
215
216 async with engine.connect() as conn:
217 await conn.run_sync(self._sqlite_set_journal_mode_wal)
218
219 self.session = sessionmaker(
220 engine, expire_on_commit=False, class_=AsyncSession
221 )
222
223 async with self.session() as session:
224 result = await session.execute(select(Profile))
225 for p in result.scalars():
226 self.profiles[p.name] = p.id
227 result = await session.execute(select(Component))
228 for c in result.scalars():
229 self.components[c.profile_id] = c.entry_point
230
231 self.initialized.callback(None)
232
233 ## Generic
234
235 @aio
236 async def get(
237 self,
238 client: SatXMPPEntity,
239 db_cls: DeclarativeMeta,
240 db_id_col: Mapped,
241 id_value: Any,
242 joined_loads = None
243 ) -> Optional[DeclarativeMeta]:
244 stmt = select(db_cls).where(db_id_col==id_value)
245 if client is not None:
246 stmt = stmt.filter_by(profile_id=self.profiles[client.profile])
247 if joined_loads is not None:
248 for joined_load in joined_loads:
249 stmt = stmt.options(joinedload(joined_load))
250 async with self.session() as session:
251 result = await session.execute(stmt)
252 if joined_loads is not None:
253 result = result.unique()
254 return result.scalar_one_or_none()
255
256 @aio
257 async def add(self, db_obj: DeclarativeMeta) -> None:
258 """Add an object to database"""
259 async with self.session() as session:
260 async with session.begin():
261 session.add(db_obj)
262
263 @aio
264 async def delete(
265 self,
266 db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]],
267 session_add: Optional[List[DeclarativeMeta]] = None
268 ) -> None:
269 """Delete an object from database
270
271 @param db_obj: object to delete or list of objects to delete
272 @param session_add: other objects to add to session.
273 This is useful when parents of deleted objects needs to be updated too, or if
274 other objects needs to be updated in the same transaction.
275 """
276 if not db_obj:
277 return
278 if not isinstance(db_obj, list):
279 db_obj = [db_obj]
280 async with self.session() as session:
281 async with session.begin():
282 if session_add is not None:
283 for obj in session_add:
284 session.add(obj)
285 for obj in db_obj:
286 await session.delete(obj)
287 await session.commit()
288
289 ## Profiles
290
291 def get_profiles_list(self) -> List[str]:
292 """"Return list of all registered profiles"""
293 return list(self.profiles.keys())
294
295 def has_profile(self, profile_name: str) -> bool:
296 """return True if profile_name exists
297
298 @param profile_name: name of the profile to check
299 """
300 return profile_name in self.profiles
301
302 def profile_is_component(self, profile_name: str) -> bool:
303 try:
304 return self.profiles[profile_name] in self.components
305 except KeyError:
306 raise exceptions.NotFound("the requested profile doesn't exists")
307
308 def get_entry_point(self, profile_name: str) -> str:
309 try:
310 return self.components[self.profiles[profile_name]]
311 except KeyError:
312 raise exceptions.NotFound("the requested profile doesn't exists or is not a component")
313
314 @aio
315 async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None:
316 """Create a new profile
317
318 @param name: name of the profile
319 @param component: if not None, must point to a component entry point
320 """
321 async with self.session() as session:
322 profile = Profile(name=name)
323 async with session.begin():
324 session.add(profile)
325 self.profiles[profile.name] = profile.id
326 if component_ep is not None:
327 async with session.begin():
328 component = Component(profile=profile, entry_point=component_ep)
329 session.add(component)
330 self.components[profile.id] = component_ep
331 return profile
332
333 @aio
334 async def delete_profile(self, name: str) -> None:
335 """Delete profile
336
337 @param name: name of the profile
338 """
339 async with self.session() as session:
340 result = await session.execute(select(Profile).where(Profile.name == name))
341 profile = result.scalar()
342 await session.delete(profile)
343 await session.commit()
344 del self.profiles[profile.name]
345 if profile.id in self.components:
346 del self.components[profile.id]
347 log.info(_("Profile {name!r} deleted").format(name = name))
348
349 ## Params
350
351 @aio
352 async def load_gen_params(self, params_gen: dict) -> None:
353 """Load general parameters
354
355 @param params_gen: dictionary to fill
356 """
357 log.debug(_("loading general parameters from database"))
358 async with self.session() as session:
359 result = await session.execute(select(ParamGen))
360 for p in result.scalars():
361 params_gen[(p.category, p.name)] = p.value
362
363 @aio
364 async def load_ind_params(self, params_ind: dict, profile: str) -> None:
365 """Load individual parameters
366
367 @param params_ind: dictionary to fill
368 @param profile: a profile which *must* exist
369 """
370 log.debug(_("loading individual parameters from database"))
371 async with self.session() as session:
372 result = await session.execute(
373 select(ParamInd).where(ParamInd.profile_id == self.profiles[profile])
374 )
375 for p in result.scalars():
376 params_ind[(p.category, p.name)] = p.value
377
378 @aio
379 async def get_ind_param(self, category: str, name: str, profile: str) -> Optional[str]:
380 """Ask database for the value of one specific individual parameter
381
382 @param category: category of the parameter
383 @param name: name of the parameter
384 @param profile: %(doc_profile)s
385 """
386 async with self.session() as session:
387 result = await session.execute(
388 select(ParamInd.value)
389 .filter_by(
390 category=category,
391 name=name,
392 profile_id=self.profiles[profile]
393 )
394 )
395 return result.scalar_one_or_none()
396
397 @aio
398 async def get_ind_param_values(self, category: str, name: str) -> Dict[str, str]:
399 """Ask database for the individual values of a parameter for all profiles
400
401 @param category: category of the parameter
402 @param name: name of the parameter
403 @return dict: profile => value map
404 """
405 async with self.session() as session:
406 result = await session.execute(
407 select(ParamInd)
408 .filter_by(
409 category=category,
410 name=name
411 )
412 .options(subqueryload(ParamInd.profile))
413 )
414 return {param.profile.name: param.value for param in result.scalars()}
415
416 @aio
417 async def set_gen_param(self, category: str, name: str, value: Optional[str]) -> None:
418 """Save the general parameters in database
419
420 @param category: category of the parameter
421 @param name: name of the parameter
422 @param value: value to set
423 """
424 async with self.session() as session:
425 stmt = insert(ParamGen).values(
426 category=category,
427 name=name,
428 value=value
429 ).on_conflict_do_update(
430 index_elements=(ParamGen.category, ParamGen.name),
431 set_={
432 ParamGen.value: value
433 }
434 )
435 await session.execute(stmt)
436 await session.commit()
437
438 @aio
439 async def set_ind_param(
440 self,
441 category:str,
442 name: str,
443 value: Optional[str],
444 profile: str
445 ) -> None:
446 """Save the individual parameters in database
447
448 @param category: category of the parameter
449 @param name: name of the parameter
450 @param value: value to set
451 @param profile: a profile which *must* exist
452 """
453 async with self.session() as session:
454 stmt = insert(ParamInd).values(
455 category=category,
456 name=name,
457 profile_id=self.profiles[profile],
458 value=value
459 ).on_conflict_do_update(
460 index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id),
461 set_={
462 ParamInd.value: value
463 }
464 )
465 await session.execute(stmt)
466 await session.commit()
467
468 def _jid_filter(self, jid_: jid.JID, dest: bool = False):
469 """Generate condition to filter on a JID, using relevant columns
470
471 @param dest: True if it's the destinee JID, otherwise it's the source one
472 @param jid_: JID to filter by
473 """
474 if jid_.resource:
475 if dest:
476 return and_(
477 History.dest == jid_.userhost(),
478 History.dest_res == jid_.resource
479 )
480 else:
481 return and_(
482 History.source == jid_.userhost(),
483 History.source_res == jid_.resource
484 )
485 else:
486 if dest:
487 return History.dest == jid_.userhost()
488 else:
489 return History.source == jid_.userhost()
490
491 @aio
492 async def history_get(
493 self,
494 from_jid: Optional[jid.JID],
495 to_jid: Optional[jid.JID],
496 limit: Optional[int] = None,
497 between: bool = True,
498 filters: Optional[Dict[str, str]] = None,
499 profile: Optional[str] = None,
500 ) -> List[Tuple[
501 str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]
502 ]:
503 """Retrieve messages in history
504
505 @param from_jid: source JID (full, or bare for catchall)
506 @param to_jid: dest JID (full, or bare for catchall)
507 @param limit: maximum number of messages to get:
508 - 0 for no message (returns the empty list)
509 - None for unlimited
510 @param between: confound source and dest (ignore the direction)
511 @param filters: pattern to filter the history results
512 @return: list of messages as in [message_new], minus the profile which is already
513 known.
514 """
515 # we have to set a default value to profile because it's last argument
516 # and thus follow other keyword arguments with default values
517 # but None should not be used for it
518 assert profile is not None
519 if limit == 0:
520 return []
521 if filters is None:
522 filters = {}
523
524 stmt = (
525 select(History)
526 .filter_by(
527 profile_id=self.profiles[profile]
528 )
529 .outerjoin(History.messages)
530 .outerjoin(History.subjects)
531 .outerjoin(History.thread)
532 .options(
533 contains_eager(History.messages),
534 contains_eager(History.subjects),
535 contains_eager(History.thread),
536 )
537 .order_by(
538 # timestamp may be identical for 2 close messages (specially when delay is
539 # used) that's why we order ties by received_timestamp. We'll reverse the
540 # order when returning the result. We use DESC here so LIMIT keep the last
541 # messages
542 History.timestamp.desc(),
543 History.received_timestamp.desc()
544 )
545 )
546
547
548 if not from_jid and not to_jid:
549 # no jid specified, we want all one2one communications
550 pass
551 elif between:
552 if not from_jid or not to_jid:
553 # we only have one jid specified, we check all messages
554 # from or to this jid
555 jid_ = from_jid or to_jid
556 stmt = stmt.where(
557 or_(
558 self._jid_filter(jid_),
559 self._jid_filter(jid_, dest=True)
560 )
561 )
562 else:
563 # we have 2 jids specified, we check all communications between
564 # those 2 jids
565 stmt = stmt.where(
566 or_(
567 and_(
568 self._jid_filter(from_jid),
569 self._jid_filter(to_jid, dest=True),
570 ),
571 and_(
572 self._jid_filter(to_jid),
573 self._jid_filter(from_jid, dest=True),
574 )
575 )
576 )
577 else:
578 # we want one communication in specific direction (from somebody or
579 # to somebody).
580 if from_jid is not None:
581 stmt = stmt.where(self._jid_filter(from_jid))
582 if to_jid is not None:
583 stmt = stmt.where(self._jid_filter(to_jid, dest=True))
584
585 if filters:
586 if 'timestamp_start' in filters:
587 stmt = stmt.where(History.timestamp >= float(filters['timestamp_start']))
588 if 'before_uid' in filters:
589 # orignially this query was using SQLITE's rowid. This has been changed
590 # to use coalesce(received_timestamp, timestamp) to be SQL engine independant
591 stmt = stmt.where(
592 coalesce(
593 History.received_timestamp,
594 History.timestamp
595 ) < (
596 select(coalesce(History.received_timestamp, History.timestamp))
597 .filter_by(uid=filters["before_uid"])
598 ).scalar_subquery()
599 )
600 if 'body' in filters:
601 # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html
602 stmt = stmt.where(Message.message.like(f"%{filters['body']}%"))
603 if 'search' in filters:
604 search_term = f"%{filters['search']}%"
605 stmt = stmt.where(or_(
606 Message.message.like(search_term),
607 History.source_res.like(search_term)
608 ))
609 if 'types' in filters:
610 types = filters['types'].split()
611 stmt = stmt.where(History.type.in_(types))
612 if 'not_types' in filters:
613 types = filters['not_types'].split()
614 stmt = stmt.where(History.type.not_in(types))
615 if 'last_stanza_id' in filters:
616 # this request get the last message with a "stanza_id" that we
617 # have in history. This is mainly used to retrieve messages sent
618 # while we were offline, using MAM (XEP-0313).
619 if (filters['last_stanza_id'] is not True
620 or limit != 1):
621 raise ValueError("Unexpected values for last_stanza_id filter")
622 stmt = stmt.where(History.stanza_id.is_not(None))
623 if 'origin_id' in filters:
624 stmt = stmt.where(History.origin_id == filters["origin_id"])
625
626 if limit is not None:
627 stmt = stmt.limit(limit)
628
629 async with self.session() as session:
630 result = await session.execute(stmt)
631
632 result = result.scalars().unique().all()
633 result.reverse()
634 return [h.as_tuple() for h in result]
635
636 @aio
637 async def add_to_history(self, data: dict, profile: str) -> None:
638 """Store a new message in history
639
640 @param data: message data as build by SatMessageProtocol.onMessage
641 """
642 extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA}
643 messages = [Message(message=mess, language=lang)
644 for lang, mess in data["message"].items()]
645 subjects = [Subject(subject=mess, language=lang)
646 for lang, mess in data["subject"].items()]
647 if "thread" in data["extra"]:
648 thread = Thread(thread_id=data["extra"]["thread"],
649 parent_id=data["extra"].get["thread_parent"])
650 else:
651 thread = None
652 try:
653 async with self.session() as session:
654 async with session.begin():
655 session.add(History(
656 uid=data["uid"],
657 origin_id=data["extra"].get("origin_id"),
658 stanza_id=data["extra"].get("stanza_id"),
659 update_uid=data["extra"].get("update_uid"),
660 profile_id=self.profiles[profile],
661 source_jid=data["from"],
662 dest_jid=data["to"],
663 timestamp=data["timestamp"],
664 received_timestamp=data.get("received_timestamp"),
665 type=data["type"],
666 extra=extra,
667 messages=messages,
668 subjects=subjects,
669 thread=thread,
670 ))
671 except IntegrityError as e:
672 if "unique" in str(e.orig).lower():
673 log.debug(
674 f"message {data['uid']!r} is already in history, not storing it again"
675 )
676 else:
677 log.error(f"Can't store message {data['uid']!r} in history: {e}")
678 except Exception as e:
679 log.critical(
680 f"Can't store message, unexpected exception (uid: {data['uid']}): {e}"
681 )
682
683 ## Private values
684
685 def _get_private_class(self, binary, profile):
686 """Get ORM class to use for private values"""
687 if profile is None:
688 return PrivateGenBin if binary else PrivateGen
689 else:
690 return PrivateIndBin if binary else PrivateInd
691
692
693 @aio
694 async def get_privates(
695 self,
696 namespace:str,
697 keys: Optional[Iterable[str]] = None,
698 binary: bool = False,
699 profile: Optional[str] = None
700 ) -> Dict[str, Any]:
701 """Get private value(s) from databases
702
703 @param namespace: namespace of the values
704 @param keys: keys of the values to get None to get all keys/values
705 @param binary: True to deserialise binary values
706 @param profile: profile to use for individual values
707 None to use general values
708 @return: gotten keys/values
709 """
710 if keys is not None:
711 keys = list(keys)
712 log.debug(
713 f"getting {'general' if profile is None else 'individual'}"
714 f"{' binary' if binary else ''} private values from database for namespace "
715 f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}"
716 )
717 cls = self._get_private_class(binary, profile)
718 stmt = select(cls).filter_by(namespace=namespace)
719 if keys:
720 stmt = stmt.where(cls.key.in_(list(keys)))
721 if profile is not None:
722 stmt = stmt.filter_by(profile_id=self.profiles[profile])
723 async with self.session() as session:
724 result = await session.execute(stmt)
725 return {p.key: p.value for p in result.scalars()}
726
727 @aio
728 async def set_private_value(
729 self,
730 namespace: str,
731 key:str,
732 value: Any,
733 binary: bool = False,
734 profile: Optional[str] = None
735 ) -> None:
736 """Set a private value in database
737
738 @param namespace: namespace of the values
739 @param key: key of the value to set
740 @param value: value to set
741 @param binary: True if it's a binary values
742 binary values need to be serialised, used for everything but strings
743 @param profile: profile to use for individual value
744 if None, it's a general value
745 """
746 cls = self._get_private_class(binary, profile)
747
748 values = {
749 "namespace": namespace,
750 "key": key,
751 "value": value
752 }
753 index_elements = [cls.namespace, cls.key]
754
755 if profile is not None:
756 values["profile_id"] = self.profiles[profile]
757 index_elements.append(cls.profile_id)
758
759 async with self.session() as session:
760 await session.execute(
761 insert(cls).values(**values).on_conflict_do_update(
762 index_elements=index_elements,
763 set_={
764 cls.value: value
765 }
766 )
767 )
768 await session.commit()
769
770 @aio
771 async def del_private_value(
772 self,
773 namespace: str,
774 key: str,
775 binary: bool = False,
776 profile: Optional[str] = None
777 ) -> None:
778 """Delete private value from database
779
780 @param category: category of the privateeter
781 @param key: key of the private value
782 @param binary: True if it's a binary values
783 @param profile: profile to use for individual value
784 if None, it's a general value
785 """
786 cls = self._get_private_class(binary, profile)
787
788 stmt = delete(cls).filter_by(namespace=namespace, key=key)
789
790 if profile is not None:
791 stmt = stmt.filter_by(profile_id=self.profiles[profile])
792
793 async with self.session() as session:
794 await session.execute(stmt)
795 await session.commit()
796
797 @aio
798 async def del_private_namespace(
799 self,
800 namespace: str,
801 binary: bool = False,
802 profile: Optional[str] = None
803 ) -> None:
804 """Delete all data from a private namespace
805
806 Be really cautious when you use this method, as all data with given namespace are
807 removed.
808 Params are the same as for del_private_value
809 """
810 cls = self._get_private_class(binary, profile)
811
812 stmt = delete(cls).filter_by(namespace=namespace)
813
814 if profile is not None:
815 stmt = stmt.filter_by(profile_id=self.profiles[profile])
816
817 async with self.session() as session:
818 await session.execute(stmt)
819 await session.commit()
820
821 ## Files
822
823 @aio
824 async def get_files(
825 self,
826 client: Optional[SatXMPPEntity],
827 file_id: Optional[str] = None,
828 version: Optional[str] = '',
829 parent: Optional[str] = None,
830 type_: Optional[str] = None,
831 file_hash: Optional[str] = None,
832 hash_algo: Optional[str] = None,
833 name: Optional[str] = None,
834 namespace: Optional[str] = None,
835 mime_type: Optional[str] = None,
836 public_id: Optional[str] = None,
837 owner: Optional[jid.JID] = None,
838 access: Optional[dict] = None,
839 projection: Optional[List[str]] = None,
840 unique: bool = False
841 ) -> List[dict]:
842 """Retrieve files with with given filters
843
844 @param file_id: id of the file
845 None to ignore
846 @param version: version of the file
847 None to ignore
848 empty string to look for current version
849 @param parent: id of the directory containing the files
850 None to ignore
851 empty string to look for root files/directories
852 @param projection: name of columns to retrieve
853 None to retrieve all
854 @param unique: if True will remove duplicates
855 other params are the same as for [set_file]
856 @return: files corresponding to filters
857 """
858 if projection is None:
859 projection = [
860 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name',
861 'size', 'namespace', 'media_type', 'media_subtype', 'public_id',
862 'created', 'modified', 'owner', 'access', 'extra'
863 ]
864
865 stmt = select(*[getattr(File, f) for f in projection])
866
867 if unique:
868 stmt = stmt.distinct()
869
870 if client is not None:
871 stmt = stmt.filter_by(profile_id=self.profiles[client.profile])
872 else:
873 if public_id is None:
874 raise exceptions.InternalError(
875 "client can only be omitted when public_id is set"
876 )
877 if file_id is not None:
878 stmt = stmt.filter_by(id=file_id)
879 if version is not None:
880 stmt = stmt.filter_by(version=version)
881 if parent is not None:
882 stmt = stmt.filter_by(parent=parent)
883 if type_ is not None:
884 stmt = stmt.filter_by(type=type_)
885 if file_hash is not None:
886 stmt = stmt.filter_by(file_hash=file_hash)
887 if hash_algo is not None:
888 stmt = stmt.filter_by(hash_algo=hash_algo)
889 if name is not None:
890 stmt = stmt.filter_by(name=name)
891 if namespace is not None:
892 stmt = stmt.filter_by(namespace=namespace)
893 if mime_type is not None:
894 if '/' in mime_type:
895 media_type, media_subtype = mime_type.split("/", 1)
896 stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype)
897 else:
898 stmt = stmt.filter_by(media_type=mime_type)
899 if public_id is not None:
900 stmt = stmt.filter_by(public_id=public_id)
901 if owner is not None:
902 stmt = stmt.filter_by(owner=owner)
903 if access is not None:
904 raise NotImplementedError('Access check is not implemented yet')
905 # a JSON comparison is needed here
906
907 async with self.session() as session:
908 result = await session.execute(stmt)
909
910 return [dict(r) for r in result]
911
912 @aio
913 async def set_file(
914 self,
915 client: SatXMPPEntity,
916 name: str,
917 file_id: str,
918 version: str = "",
919 parent: str = "",
920 type_: str = C.FILE_TYPE_FILE,
921 file_hash: Optional[str] = None,
922 hash_algo: Optional[str] = None,
923 size: int = None,
924 namespace: Optional[str] = None,
925 mime_type: Optional[str] = None,
926 public_id: Optional[str] = None,
927 created: Optional[float] = None,
928 modified: Optional[float] = None,
929 owner: Optional[jid.JID] = None,
930 access: Optional[dict] = None,
931 extra: Optional[dict] = None
932 ) -> None:
933 """Set a file metadata
934
935 @param client: client owning the file
936 @param name: name of the file (must not contain "/")
937 @param file_id: unique id of the file
938 @param version: version of this file
939 @param parent: id of the directory containing this file
940 Empty string if it is a root file/directory
941 @param type_: one of:
942 - file
943 - directory
944 @param file_hash: unique hash of the payload
945 @param hash_algo: algorithm used for hashing the file (usually sha-256)
946 @param size: size in bytes
947 @param namespace: identifier (human readable is better) to group files
948 for instance, namespace could be used to group files in a specific photo album
949 @param mime_type: media type of the file, or None if not known/guessed
950 @param public_id: ID used to server the file publicly via HTTP
951 @param created: UNIX time of creation
952 @param modified: UNIX time of last modification, or None to use created date
953 @param owner: jid of the owner of the file (mainly useful for component)
954 @param access: serialisable dictionary with access rules. See [memory.memory] for
955 details
956 @param extra: serialisable dictionary of any extra data
957 will be encoded to json in database
958 """
959 if mime_type is None:
960 media_type = media_subtype = None
961 elif '/' in mime_type:
962 media_type, media_subtype = mime_type.split('/', 1)
963 else:
964 media_type, media_subtype = mime_type, None
965
966 async with self.session() as session:
967 async with session.begin():
968 session.add(File(
969 id=file_id,
970 version=version.strip(),
971 parent=parent,
972 type=type_,
973 file_hash=file_hash,
974 hash_algo=hash_algo,
975 name=name,
976 size=size,
977 namespace=namespace,
978 media_type=media_type,
979 media_subtype=media_subtype,
980 public_id=public_id,
981 created=time.time() if created is None else created,
982 modified=modified,
983 owner=owner,
984 access=access,
985 extra=extra,
986 profile_id=self.profiles[client.profile]
987 ))
988
989 @aio
990 async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int:
991 async with self.session() as session:
992 result = await session.execute(
993 select(sum_(File.size)).filter_by(
994 owner=owner,
995 type=C.FILE_TYPE_FILE,
996 profile_id=self.profiles[client.profile]
997 ))
998 return result.scalar_one_or_none() or 0
999
1000 @aio
1001 async def file_delete(self, file_id: str) -> None:
1002 """Delete file metadata from the database
1003
1004 @param file_id: id of the file to delete
1005 NOTE: file itself must still be removed, this method only handle metadata in
1006 database
1007 """
1008 async with self.session() as session:
1009 await session.execute(delete(File).filter_by(id=file_id))
1010 await session.commit()
1011
1012 @aio
1013 async def file_update(
1014 self,
1015 file_id: str,
1016 column: str,
1017 update_cb: Callable[[dict], None]
1018 ) -> None:
1019 """Update a column value using a method to avoid race conditions
1020
1021 the older value will be retrieved from database, then update_cb will be applied to
1022 update it, and file will be updated checking that older value has not been changed
1023 meanwhile by an other user. If it has changed, it tries again a couple of times
1024 before failing
1025 @param column: column name (only "access" or "extra" are allowed)
1026 @param update_cb: method to update the value of the colum
1027 the method will take older value as argument, and must update it in place
1028 update_cb must not care about serialization,
1029 it get the deserialized data (i.e. a Python object) directly
1030 @raise exceptions.NotFound: there is not file with this id
1031 """
1032 if column not in ('access', 'extra'):
1033 raise exceptions.InternalError('bad column name')
1034 orm_col = getattr(File, column)
1035
1036 for i in range(5):
1037 async with self.session() as session:
1038 try:
1039 value = (await session.execute(
1040 select(orm_col).filter_by(id=file_id)
1041 )).scalar_one()
1042 except NoResultFound:
1043 raise exceptions.NotFound
1044 old_value = copy.deepcopy(value)
1045 update_cb(value)
1046 stmt = update(File).filter_by(id=file_id).values({column: value})
1047 if not old_value:
1048 # because JsonDefaultDict convert NULL to an empty dict, we have to
1049 # test both for empty dict and None when we have an empty dict
1050 stmt = stmt.where((orm_col == None) | (orm_col == old_value))
1051 else:
1052 stmt = stmt.where(orm_col == old_value)
1053 result = await session.execute(stmt)
1054 await session.commit()
1055
1056 if result.rowcount == 1:
1057 break
1058
1059 log.warning(
1060 _("table not updated, probably due to race condition, trying again "
1061 "({tries})").format(tries=i+1)
1062 )
1063
1064 else:
1065 raise exceptions.DatabaseError(
1066 _("Can't update file {file_id} due to race condition")
1067 .format(file_id=file_id)
1068 )
1069
1070 @aio
1071 async def get_pubsub_node(
1072 self,
1073 client: SatXMPPEntity,
1074 service: jid.JID,
1075 name: str,
1076 with_items: bool = False,
1077 with_subscriptions: bool = False,
1078 create: bool = False,
1079 create_kwargs: Optional[dict] = None
1080 ) -> Optional[PubsubNode]:
1081 """Retrieve a PubsubNode from DB
1082
1083 @param service: service hosting the node
1084 @param name: node's name
1085 @param with_items: retrieve items in the same query
1086 @param with_subscriptions: retrieve subscriptions in the same query
1087 @param create: if the node doesn't exist in DB, create it
1088 @param create_kwargs: keyword arguments to use with ``set_pubsub_node`` if the node
1089 needs to be created.
1090 """
1091 async with self.session() as session:
1092 stmt = (
1093 select(PubsubNode)
1094 .filter_by(
1095 service=service,
1096 name=name,
1097 profile_id=self.profiles[client.profile],
1098 )
1099 )
1100 if with_items:
1101 stmt = stmt.options(
1102 joinedload(PubsubNode.items)
1103 )
1104 if with_subscriptions:
1105 stmt = stmt.options(
1106 joinedload(PubsubNode.subscriptions)
1107 )
1108 result = await session.execute(stmt)
1109 ret = result.unique().scalar_one_or_none()
1110 if ret is None and create:
1111 # we auto-create the node
1112 if create_kwargs is None:
1113 create_kwargs = {}
1114 try:
1115 return await as_future(self.set_pubsub_node(
1116 client, service, name, **create_kwargs
1117 ))
1118 except IntegrityError as e:
1119 if "unique" in str(e.orig).lower():
1120 # the node may already exist, if it has been created just after
1121 # get_pubsub_node above
1122 log.debug("ignoring UNIQUE constraint error")
1123 cached_node = await as_future(self.get_pubsub_node(
1124 client,
1125 service,
1126 name,
1127 with_items=with_items,
1128 with_subscriptions=with_subscriptions
1129 ))
1130 else:
1131 raise e
1132 else:
1133 return ret
1134
1135 @aio
1136 async def set_pubsub_node(
1137 self,
1138 client: SatXMPPEntity,
1139 service: jid.JID,
1140 name: str,
1141 analyser: Optional[str] = None,
1142 type_: Optional[str] = None,
1143 subtype: Optional[str] = None,
1144 subscribed: bool = False,
1145 ) -> PubsubNode:
1146 node = PubsubNode(
1147 profile_id=self.profiles[client.profile],
1148 service=service,
1149 name=name,
1150 subscribed=subscribed,
1151 analyser=analyser,
1152 type_=type_,
1153 subtype=subtype,
1154 subscriptions=[],
1155 )
1156 async with self.session() as session:
1157 async with session.begin():
1158 session.add(node)
1159 return node
1160
1161 @aio
1162 async def update_pubsub_node_sync_state(
1163 self,
1164 node: PubsubNode,
1165 state: SyncState
1166 ) -> None:
1167 async with self.session() as session:
1168 async with session.begin():
1169 await session.execute(
1170 update(PubsubNode)
1171 .filter_by(id=node.id)
1172 .values(
1173 sync_state=state,
1174 sync_state_updated=time.time(),
1175 )
1176 )
1177
1178 @aio
1179 async def delete_pubsub_node(
1180 self,
1181 profiles: Optional[List[str]],
1182 services: Optional[List[jid.JID]],
1183 names: Optional[List[str]]
1184 ) -> None:
1185 """Delete items cached for a node
1186
1187 @param profiles: profile names from which nodes must be deleted.
1188 None to remove nodes from ALL profiles
1189 @param services: JIDs of pubsub services from which nodes must be deleted.
1190 None to remove nodes from ALL services
1191 @param names: names of nodes which must be deleted.
1192 None to remove ALL nodes whatever is their names
1193 """
1194 stmt = delete(PubsubNode)
1195 if profiles is not None:
1196 stmt = stmt.where(
1197 PubsubNode.profile.in_(
1198 [self.profiles[p] for p in profiles]
1199 )
1200 )
1201 if services is not None:
1202 stmt = stmt.where(PubsubNode.service.in_(services))
1203 if names is not None:
1204 stmt = stmt.where(PubsubNode.name.in_(names))
1205 async with self.session() as session:
1206 await session.execute(stmt)
1207 await session.commit()
1208
1209 @aio
1210 async def cache_pubsub_items(
1211 self,
1212 client: SatXMPPEntity,
1213 node: PubsubNode,
1214 items: List[domish.Element],
1215 parsed_items: Optional[List[dict]] = None,
1216 ) -> None:
1217 """Add items to database, using an upsert taking care of "updated" field"""
1218 if parsed_items is not None and len(items) != len(parsed_items):
1219 raise exceptions.InternalError(
1220 "parsed_items must have the same lenght as items"
1221 )
1222 async with self.session() as session:
1223 async with session.begin():
1224 for idx, item in enumerate(items):
1225 parsed = parsed_items[idx] if parsed_items else None
1226 stmt = insert(PubsubItem).values(
1227 node_id = node.id,
1228 name = item["id"],
1229 data = item,
1230 parsed = parsed,
1231 ).on_conflict_do_update(
1232 index_elements=(PubsubItem.node_id, PubsubItem.name),
1233 set_={
1234 PubsubItem.data: item,
1235 PubsubItem.parsed: parsed,
1236 PubsubItem.updated: now()
1237 }
1238 )
1239 await session.execute(stmt)
1240 await session.commit()
1241
1242 @aio
1243 async def delete_pubsub_items(
1244 self,
1245 node: PubsubNode,
1246 items_names: Optional[List[str]] = None
1247 ) -> None:
1248 """Delete items cached for a node
1249
1250 @param node: node from which items must be deleted
1251 @param items_names: names of items to delete
1252 if None, ALL items will be deleted
1253 """
1254 stmt = delete(PubsubItem)
1255 if node is not None:
1256 if isinstance(node, list):
1257 stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node]))
1258 else:
1259 stmt = stmt.filter_by(node_id=node.id)
1260 if items_names is not None:
1261 stmt = stmt.where(PubsubItem.name.in_(items_names))
1262 async with self.session() as session:
1263 await session.execute(stmt)
1264 await session.commit()
1265
1266 @aio
1267 async def purge_pubsub_items(
1268 self,
1269 services: Optional[List[jid.JID]] = None,
1270 names: Optional[List[str]] = None,
1271 types: Optional[List[str]] = None,
1272 subtypes: Optional[List[str]] = None,
1273 profiles: Optional[List[str]] = None,
1274 created_before: Optional[datetime] = None,
1275 updated_before: Optional[datetime] = None,
1276 ) -> None:
1277 """Delete items cached for a node
1278
1279 @param node: node from which items must be deleted
1280 @param items_names: names of items to delete
1281 if None, ALL items will be deleted
1282 """
1283 stmt = delete(PubsubItem)
1284 node_fields = {
1285 "service": services,
1286 "name": names,
1287 "type_": types,
1288 "subtype": subtypes,
1289 }
1290 if profiles is not None:
1291 node_fields["profile_id"] = [self.profiles[p] for p in profiles]
1292
1293 if any(x is not None for x in node_fields.values()):
1294 sub_q = select(PubsubNode.id)
1295 for col, values in node_fields.items():
1296 if values is None:
1297 continue
1298 sub_q = sub_q.where(getattr(PubsubNode, col).in_(values))
1299 stmt = (
1300 stmt
1301 .where(PubsubItem.node_id.in_(sub_q))
1302 .execution_options(synchronize_session=False)
1303 )
1304
1305 if created_before is not None:
1306 stmt = stmt.where(PubsubItem.created < created_before)
1307
1308 if updated_before is not None:
1309 stmt = stmt.where(PubsubItem.updated < updated_before)
1310
1311 async with self.session() as session:
1312 await session.execute(stmt)
1313 await session.commit()
1314
1315 @aio
1316 async def get_items(
1317 self,
1318 node: PubsubNode,
1319 max_items: Optional[int] = None,
1320 item_ids: Optional[list[str]] = None,
1321 before: Optional[str] = None,
1322 after: Optional[str] = None,
1323 from_index: Optional[int] = None,
1324 order_by: Optional[List[str]] = None,
1325 desc: bool = True,
1326 force_rsm: bool = False,
1327 ) -> Tuple[List[PubsubItem], dict]:
1328 """Get Pubsub Items from cache
1329
1330 @param node: retrieve items from this node (must be synchronised)
1331 @param max_items: maximum number of items to retrieve
1332 @param before: get items which are before the item with this name in given order
1333 empty string is not managed here, use desc order to reproduce RSM
1334 behaviour.
1335 @param after: get items which are after the item with this name in given order
1336 @param from_index: get items with item index (as defined in RSM spec)
1337 starting from this number
1338 @param order_by: sorting order of items (one of C.ORDER_BY_*)
1339 @param desc: direction or ordering
1340 @param force_rsm: if True, force the use of RSM worklow.
1341 RSM workflow is automatically used if any of before, after or
1342 from_index is used, but if only RSM max_items is used, it won't be
1343 used by default. This parameter let's use RSM workflow in this
1344 case. Note that in addition to RSM metadata, the result will not be
1345 the same (max_items without RSM will returns most recent items,
1346 i.e. last items in modification order, while max_items with RSM
1347 will return the oldest ones (i.e. first items in modification
1348 order).
1349 to be used when max_items is used from RSM
1350 """
1351
1352 metadata = {
1353 "service": node.service,
1354 "node": node.name,
1355 "uri": uri.build_xmpp_uri(
1356 "pubsub",
1357 path=node.service.full(),
1358 node=node.name,
1359 ),
1360 }
1361 if max_items is None:
1362 max_items = 20
1363
1364 use_rsm = any((before, after, from_index is not None))
1365 if force_rsm and not use_rsm:
1366 #
1367 use_rsm = True
1368 from_index = 0
1369
1370 stmt = (
1371 select(PubsubItem)
1372 .filter_by(node_id=node.id)
1373 .limit(max_items)
1374 )
1375
1376 if item_ids is not None:
1377 stmt = stmt.where(PubsubItem.name.in_(item_ids))
1378
1379 if not order_by:
1380 order_by = [C.ORDER_BY_MODIFICATION]
1381
1382 order = []
1383 for order_type in order_by:
1384 if order_type == C.ORDER_BY_MODIFICATION:
1385 if desc:
1386 order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc()))
1387 else:
1388 order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc()))
1389 elif order_type == C.ORDER_BY_CREATION:
1390 if desc:
1391 order.append(PubsubItem.id.desc())
1392 else:
1393 order.append(PubsubItem.id.asc())
1394 else:
1395 raise exceptions.InternalError(f"Unknown order type {order_type!r}")
1396
1397 stmt = stmt.order_by(*order)
1398
1399 if use_rsm:
1400 # CTE to have result row numbers
1401 row_num_q = select(
1402 PubsubItem.id,
1403 PubsubItem.name,
1404 # row_number starts from 1, but RSM index must start from 0
1405 (func.row_number().over(order_by=order)-1).label("item_index")
1406 ).filter_by(node_id=node.id)
1407
1408 row_num_cte = row_num_q.cte()
1409
1410 if max_items > 0:
1411 # as we can't simply use PubsubItem.id when we order by modification,
1412 # we need to use row number
1413 item_name = before or after
1414 row_num_limit_q = (
1415 select(row_num_cte.c.item_index)
1416 .where(row_num_cte.c.name==item_name)
1417 ).scalar_subquery()
1418
1419 stmt = (
1420 select(row_num_cte.c.item_index, PubsubItem)
1421 .join(row_num_cte, PubsubItem.id == row_num_cte.c.id)
1422 .limit(max_items)
1423 )
1424 if before:
1425 stmt = (
1426 stmt
1427 .where(row_num_cte.c.item_index<row_num_limit_q)
1428 .order_by(row_num_cte.c.item_index.desc())
1429 )
1430 elif after:
1431 stmt = (
1432 stmt
1433 .where(row_num_cte.c.item_index>row_num_limit_q)
1434 .order_by(row_num_cte.c.item_index.asc())
1435 )
1436 else:
1437 stmt = (
1438 stmt
1439 .where(row_num_cte.c.item_index>=from_index)
1440 .order_by(row_num_cte.c.item_index.asc())
1441 )
1442 # from_index is used
1443
1444 async with self.session() as session:
1445 if max_items == 0:
1446 items = result = []
1447 else:
1448 result = await session.execute(stmt)
1449 result = result.all()
1450 if before:
1451 result.reverse()
1452 items = [row[-1] for row in result]
1453 rows_count = (
1454 await session.execute(row_num_q.with_only_columns(count()))
1455 ).scalar_one()
1456
1457 try:
1458 index = result[0][0]
1459 except IndexError:
1460 index = None
1461
1462 try:
1463 first = result[0][1].name
1464 except IndexError:
1465 first = None
1466 last = None
1467 else:
1468 last = result[-1][1].name
1469
1470 metadata["rsm"] = {
1471 k: v for k, v in {
1472 "index": index,
1473 "count": rows_count,
1474 "first": first,
1475 "last": last,
1476 }.items() if v is not None
1477 }
1478 metadata["complete"] = (index or 0) + len(result) == rows_count
1479
1480 return items, metadata
1481
1482 async with self.session() as session:
1483 result = await session.execute(stmt)
1484
1485 result = result.scalars().all()
1486 if desc:
1487 result.reverse()
1488 return result, metadata
1489
1490 def _get_sqlite_path(
1491 self,
1492 path: List[Union[str, int]]
1493 ) -> str:
1494 """generate path suitable to query JSON element with SQLite"""
1495 return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}"
1496
1497 @aio
1498 async def search_pubsub_items(
1499 self,
1500 query: dict,
1501 ) -> Tuple[List[PubsubItem]]:
1502 """Search for pubsub items in cache
1503
1504 @param query: search terms. Keys can be:
1505 :fts (str):
1506 Full-Text Search query. Currently SQLite FT5 engine is used, its query
1507 syntax can be used, see `FTS5 Query documentation
1508 <https://sqlite.org/fts5.html#full_text_query_syntax>`_
1509 :profiles (list[str]):
1510 filter on nodes linked to those profiles
1511 :nodes (list[str]):
1512 filter on nodes with those names
1513 :services (list[jid.JID]):
1514 filter on nodes from those services
1515 :types (list[str|None]):
1516 filter on nodes with those types. None can be used to filter on nodes with
1517 no type set
1518 :subtypes (list[str|None]):
1519 filter on nodes with those subtypes. None can be used to filter on nodes with
1520 no subtype set
1521 :names (list[str]):
1522 filter on items with those names
1523 :parsed (list[dict]):
1524 Filter on a parsed data field. The dict must contain 3 keys: ``path``
1525 which is a list of str or int giving the path to the field of interest
1526 (str for a dict key, int for a list index), ``operator`` with indicate the
1527 operator to use to check the condition, and ``value`` which depends of
1528 field type and operator.
1529
1530 See documentation for details on operators (it's currently explained at
1531 ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command
1532 documentation).
1533
1534 :order-by (list[dict]):
1535 Indicates how to order results. The dict can contain either a ``order``
1536 for a well-know order or a ``path`` for a parsed data field path
1537 (``order`` and ``path`` can't be used at the same time), an an optional
1538 ``direction`` which can be ``asc`` or ``desc``. See documentation for
1539 details on well-known orders (it's currently explained at
1540 ``doc/libervia-cli/pubsub_cache.rst`` in ``search`` command
1541 documentation).
1542
1543 :index (int):
1544 starting index of items to return from the query result. It's translated
1545 to SQL's OFFSET
1546
1547 :limit (int):
1548 maximum number of items to return. It's translated to SQL's LIMIT.
1549
1550 @result: found items (the ``node`` attribute will be filled with suitable
1551 PubsubNode)
1552 """
1553 # TODO: FTS and parsed data filters use SQLite specific syntax
1554 # when other DB engines will be used, this will have to be adapted
1555 stmt = select(PubsubItem)
1556
1557 # Full-Text Search
1558 fts = query.get("fts")
1559 if fts:
1560 fts_select = text(
1561 "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)"
1562 ).bindparams(fts_query=fts).columns(rowid=Integer).subquery()
1563 stmt = (
1564 stmt
1565 .select_from(fts_select)
1566 .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id)
1567 )
1568
1569 # node related filters
1570 profiles = query.get("profiles")
1571 if (profiles
1572 or any(query.get(k) for k in ("nodes", "services", "types", "subtypes"))
1573 ):
1574 stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node))
1575 if profiles:
1576 try:
1577 stmt = stmt.where(
1578 PubsubNode.profile_id.in_(self.profiles[p] for p in profiles)
1579 )
1580 except KeyError as e:
1581 raise exceptions.ProfileUnknownError(
1582 f"This profile doesn't exist: {e.args[0]!r}"
1583 )
1584 for key, attr in (
1585 ("nodes", "name"),
1586 ("services", "service"),
1587 ("types", "type_"),
1588 ("subtypes", "subtype")
1589 ):
1590 value = query.get(key)
1591 if not value:
1592 continue
1593 if key in ("types", "subtypes") and None in value:
1594 # NULL can't be used with SQL's IN, so we have to add a condition with
1595 # IS NULL, and use a OR if there are other values to check
1596 value.remove(None)
1597 condition = getattr(PubsubNode, attr).is_(None)
1598 if value:
1599 condition = or_(
1600 getattr(PubsubNode, attr).in_(value),
1601 condition
1602 )
1603 else:
1604 condition = getattr(PubsubNode, attr).in_(value)
1605 stmt = stmt.where(condition)
1606 else:
1607 stmt = stmt.options(selectinload(PubsubItem.node))
1608
1609 # names
1610 names = query.get("names")
1611 if names:
1612 stmt = stmt.where(PubsubItem.name.in_(names))
1613
1614 # parsed data filters
1615 parsed = query.get("parsed", [])
1616 for filter_ in parsed:
1617 try:
1618 path = filter_["path"]
1619 operator = filter_["op"]
1620 value = filter_["value"]
1621 except KeyError as e:
1622 raise ValueError(
1623 f'missing mandatory key {e.args[0]!r} in "parsed" filter'
1624 )
1625 try:
1626 op_attr = OP_MAP[operator]
1627 except KeyError:
1628 raise ValueError(f"invalid operator: {operator!r}")
1629 sqlite_path = self._get_sqlite_path(path)
1630 if operator in ("overlap", "ioverlap", "disjoint", "idisjoint"):
1631 col = literal_column("json_each.value")
1632 if operator[0] == "i":
1633 col = func.lower(col)
1634 value = [str(v).lower() for v in value]
1635 condition = (
1636 select(1)
1637 .select_from(func.json_each(PubsubItem.parsed, sqlite_path))
1638 .where(col.in_(value))
1639 ).scalar_subquery()
1640 if operator in ("disjoint", "idisjoint"):
1641 condition = condition.is_(None)
1642 stmt = stmt.where(condition)
1643 elif operator == "between":
1644 try:
1645 left, right = value
1646 except (ValueError, TypeError):
1647 raise ValueError(_(
1648 "invalid value for \"between\" filter, you must use a 2 items "
1649 "array: {value!r}"
1650 ).format(value=value))
1651 col = func.json_extract(PubsubItem.parsed, sqlite_path)
1652 stmt = stmt.where(col.between(left, right))
1653 else:
1654 # we use func.json_extract instead of generic JSON way because SQLAlchemy
1655 # add a JSON_QUOTE to the value, and we want SQL value
1656 col = func.json_extract(PubsubItem.parsed, sqlite_path)
1657 stmt = stmt.where(getattr(col, op_attr)(value))
1658
1659 # order
1660 order_by = query.get("order-by") or [{"order": "creation"}]
1661
1662 for order_data in order_by:
1663 order, path = order_data.get("order"), order_data.get("path")
1664 if order and path:
1665 raise ValueError(_(
1666 '"order" and "path" can\'t be used at the same time in '
1667 '"order-by" data'
1668 ))
1669 if order:
1670 if order == "creation":
1671 col = PubsubItem.id
1672 elif order == "modification":
1673 col = PubsubItem.updated
1674 elif order == "item_id":
1675 col = PubsubItem.name
1676 elif order == "rank":
1677 if not fts:
1678 raise ValueError(
1679 "'rank' order can only be used with Full-Text Search (fts)"
1680 )
1681 col = literal_column("rank")
1682 else:
1683 raise NotImplementedError(f"Unknown {order!r} order")
1684 else:
1685 # we have a JSON path
1686 # sqlite_path = self._get_sqlite_path(path)
1687 col = PubsubItem.parsed[path]
1688 direction = order_data.get("direction", "ASC").lower()
1689 if not direction in ("asc", "desc"):
1690 raise ValueError(f"Invalid order-by direction: {direction!r}")
1691 stmt = stmt.order_by(getattr(col, direction)())
1692
1693 # offset, limit
1694 index = query.get("index")
1695 if index:
1696 stmt = stmt.offset(index)
1697 limit = query.get("limit")
1698 if limit:
1699 stmt = stmt.limit(limit)
1700
1701 async with self.session() as session:
1702 result = await session.execute(stmt)
1703
1704 return result.scalars().all()