Mercurial > libervia-backend
comparison libervia/backend/memory/sqla.py @ 4130:02f0adc745c6
core: notifications implementation, first draft:
add a new table for notifications, and methods/bridge methods to manipulate them.
author | Goffi <goffi@goffi.org> |
---|---|
date | Mon, 16 Oct 2023 17:29:31 +0200 |
parents | 74c66c0d93f3 |
children | 6a0066ea5c97 |
comparison
equal
deleted
inserted
replaced
4129:51744ad00a42 | 4130:02f0adc745c6 |
---|---|
60 Component, | 60 Component, |
61 File, | 61 File, |
62 History, | 62 History, |
63 Message, | 63 Message, |
64 NOT_IN_EXTRA, | 64 NOT_IN_EXTRA, |
65 Notification, | |
66 NotificationPriority, | |
67 NotificationStatus, | |
68 NotificationType, | |
65 ParamGen, | 69 ParamGen, |
66 ParamInd, | 70 ParamInd, |
67 PrivateGen, | 71 PrivateGen, |
68 PrivateGenBin, | 72 PrivateGenBin, |
69 PrivateInd, | 73 PrivateInd, |
72 PubsubItem, | 76 PubsubItem, |
73 PubsubNode, | 77 PubsubNode, |
74 Subject, | 78 Subject, |
75 SyncState, | 79 SyncState, |
76 Thread, | 80 Thread, |
81 get_profile_by_id, | |
82 profiles, | |
77 ) | 83 ) |
78 from libervia.backend.tools.common import uri | 84 from libervia.backend.tools.common import uri |
79 from libervia.backend.tools.utils import aio, as_future | 85 from libervia.backend.tools.utils import aio, as_future |
80 | 86 |
81 | 87 |
115 class Storage: | 121 class Storage: |
116 | 122 |
117 def __init__(self): | 123 def __init__(self): |
118 self.initialized = defer.Deferred() | 124 self.initialized = defer.Deferred() |
119 # we keep cache for the profiles (key: profile name, value: profile id) | 125 # we keep cache for the profiles (key: profile name, value: profile id) |
120 # profile id to name | |
121 self.profiles: Dict[str, int] = {} | |
122 # profile id to component entry point | 126 # profile id to component entry point |
123 self.components: Dict[int, str] = {} | 127 self.components: Dict[int, str] = {} |
124 | 128 |
129 @property | |
130 def profiles(self): | |
131 return profiles | |
132 | |
125 def get_profile_by_id(self, profile_id): | 133 def get_profile_by_id(self, profile_id): |
126 return self.profiles.get(profile_id) | 134 return get_profile_by_id(profile_id) |
127 | 135 |
128 async def migrate_apply(self, *args: str, log_output: bool = False) -> None: | 136 async def migrate_apply(self, *args: str, log_output: bool = False) -> None: |
129 """Do a migration command | 137 """Do a migration command |
130 | 138 |
131 Commands are applied by running Alembic in a subprocess. | 139 Commands are applied by running Alembic in a subprocess. |
137 @raise exceptions.DatabaseError: something went wrong while running the | 145 @raise exceptions.DatabaseError: something went wrong while running the |
138 process | 146 process |
139 """ | 147 """ |
140 stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) | 148 stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) |
141 proc = await asyncio.create_subprocess_exec( | 149 proc = await asyncio.create_subprocess_exec( |
142 sys.executable, "-m", "alembic", *args, | 150 sys.executable, |
143 stdout=stdout, stderr=stderr, cwd=migration_path | 151 "-m", |
152 "alembic", | |
153 *args, | |
154 stdout=stdout, | |
155 stderr=stderr, | |
156 cwd=migration_path, | |
144 ) | 157 ) |
145 log_out, log_err = await proc.communicate() | 158 log_out, log_err = await proc.communicate() |
146 if proc.returncode != 0: | 159 if proc.returncode != 0: |
147 msg = _( | 160 msg = _("Can't {operation} database (exit code {exit_code})").format( |
148 "Can't {operation} database (exit code {exit_code})" | 161 operation=args[0], exit_code=proc.returncode |
149 ).format( | |
150 operation=args[0], | |
151 exit_code=proc.returncode | |
152 ) | 162 ) |
153 if log_out or log_err: | 163 if log_out or log_err: |
154 msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" | 164 msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" |
155 log.error(msg) | 165 log.error(msg) |
156 | 166 |
214 await self.check_and_update_db(engine, db_config) | 224 await self.check_and_update_db(engine, db_config) |
215 | 225 |
216 async with engine.connect() as conn: | 226 async with engine.connect() as conn: |
217 await conn.run_sync(self._sqlite_set_journal_mode_wal) | 227 await conn.run_sync(self._sqlite_set_journal_mode_wal) |
218 | 228 |
219 self.session = sessionmaker( | 229 self.session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) |
220 engine, expire_on_commit=False, class_=AsyncSession | |
221 ) | |
222 | 230 |
223 async with self.session() as session: | 231 async with self.session() as session: |
224 result = await session.execute(select(Profile)) | 232 result = await session.execute(select(Profile)) |
225 for p in result.scalars(): | 233 for p in result.scalars(): |
226 self.profiles[p.name] = p.id | 234 self.profiles[p.name] = p.id |
237 self, | 245 self, |
238 client: SatXMPPEntity, | 246 client: SatXMPPEntity, |
239 db_cls: DeclarativeMeta, | 247 db_cls: DeclarativeMeta, |
240 db_id_col: Mapped, | 248 db_id_col: Mapped, |
241 id_value: Any, | 249 id_value: Any, |
242 joined_loads = None | 250 joined_loads=None, |
243 ) -> Optional[DeclarativeMeta]: | 251 ) -> Optional[DeclarativeMeta]: |
244 stmt = select(db_cls).where(db_id_col==id_value) | 252 stmt = select(db_cls).where(db_id_col == id_value) |
245 if client is not None: | 253 if client is not None: |
246 stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) | 254 stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) |
247 if joined_loads is not None: | 255 if joined_loads is not None: |
248 for joined_load in joined_loads: | 256 for joined_load in joined_loads: |
249 stmt = stmt.options(joinedload(joined_load)) | 257 stmt = stmt.options(joinedload(joined_load)) |
262 | 270 |
263 @aio | 271 @aio |
264 async def delete( | 272 async def delete( |
265 self, | 273 self, |
266 db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]], | 274 db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]], |
267 session_add: Optional[List[DeclarativeMeta]] = None | 275 session_add: Optional[List[DeclarativeMeta]] = None, |
268 ) -> None: | 276 ) -> None: |
269 """Delete an object from database | 277 """Delete an object from database |
270 | 278 |
271 @param db_obj: object to delete or list of objects to delete | 279 @param db_obj: object to delete or list of objects to delete |
272 @param session_add: other objects to add to session. | 280 @param session_add: other objects to add to session. |
287 await session.commit() | 295 await session.commit() |
288 | 296 |
289 ## Profiles | 297 ## Profiles |
290 | 298 |
291 def get_profiles_list(self) -> List[str]: | 299 def get_profiles_list(self) -> List[str]: |
292 """"Return list of all registered profiles""" | 300 """ "Return list of all registered profiles""" |
293 return list(self.profiles.keys()) | 301 return list(self.profiles.keys()) |
294 | 302 |
295 def has_profile(self, profile_name: str) -> bool: | 303 def has_profile(self, profile_name: str) -> bool: |
296 """return True if profile_name exists | 304 """return True if profile_name exists |
297 | 305 |
307 | 315 |
308 def get_entry_point(self, profile_name: str) -> str: | 316 def get_entry_point(self, profile_name: str) -> str: |
309 try: | 317 try: |
310 return self.components[self.profiles[profile_name]] | 318 return self.components[self.profiles[profile_name]] |
311 except KeyError: | 319 except KeyError: |
312 raise exceptions.NotFound("the requested profile doesn't exists or is not a component") | 320 raise exceptions.NotFound( |
321 "the requested profile doesn't exists or is not a component" | |
322 ) | |
313 | 323 |
314 @aio | 324 @aio |
315 async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None: | 325 async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None: |
316 """Create a new profile | 326 """Create a new profile |
317 | 327 |
342 await session.delete(profile) | 352 await session.delete(profile) |
343 await session.commit() | 353 await session.commit() |
344 del self.profiles[profile.name] | 354 del self.profiles[profile.name] |
345 if profile.id in self.components: | 355 if profile.id in self.components: |
346 del self.components[profile.id] | 356 del self.components[profile.id] |
347 log.info(_("Profile {name!r} deleted").format(name = name)) | 357 log.info(_("Profile {name!r} deleted").format(name=name)) |
348 | 358 |
349 ## Params | 359 ## Params |
350 | 360 |
351 @aio | 361 @aio |
352 async def load_gen_params(self, params_gen: dict) -> None: | 362 async def load_gen_params(self, params_gen: dict) -> None: |
374 ) | 384 ) |
375 for p in result.scalars(): | 385 for p in result.scalars(): |
376 params_ind[(p.category, p.name)] = p.value | 386 params_ind[(p.category, p.name)] = p.value |
377 | 387 |
378 @aio | 388 @aio |
379 async def get_ind_param(self, category: str, name: str, profile: str) -> Optional[str]: | 389 async def get_ind_param( |
390 self, category: str, name: str, profile: str | |
391 ) -> Optional[str]: | |
380 """Ask database for the value of one specific individual parameter | 392 """Ask database for the value of one specific individual parameter |
381 | 393 |
382 @param category: category of the parameter | 394 @param category: category of the parameter |
383 @param name: name of the parameter | 395 @param name: name of the parameter |
384 @param profile: %(doc_profile)s | 396 @param profile: %(doc_profile)s |
385 """ | 397 """ |
386 async with self.session() as session: | 398 async with self.session() as session: |
387 result = await session.execute( | 399 result = await session.execute( |
388 select(ParamInd.value) | 400 select(ParamInd.value).filter_by( |
389 .filter_by( | 401 category=category, name=name, profile_id=self.profiles[profile] |
390 category=category, | |
391 name=name, | |
392 profile_id=self.profiles[profile] | |
393 ) | 402 ) |
394 ) | 403 ) |
395 return result.scalar_one_or_none() | 404 return result.scalar_one_or_none() |
396 | 405 |
397 @aio | 406 @aio |
403 @return dict: profile => value map | 412 @return dict: profile => value map |
404 """ | 413 """ |
405 async with self.session() as session: | 414 async with self.session() as session: |
406 result = await session.execute( | 415 result = await session.execute( |
407 select(ParamInd) | 416 select(ParamInd) |
408 .filter_by( | 417 .filter_by(category=category, name=name) |
409 category=category, | |
410 name=name | |
411 ) | |
412 .options(subqueryload(ParamInd.profile)) | 418 .options(subqueryload(ParamInd.profile)) |
413 ) | 419 ) |
414 return {param.profile.name: param.value for param in result.scalars()} | 420 return {param.profile.name: param.value for param in result.scalars()} |
415 | 421 |
416 @aio | 422 @aio |
420 @param category: category of the parameter | 426 @param category: category of the parameter |
421 @param name: name of the parameter | 427 @param name: name of the parameter |
422 @param value: value to set | 428 @param value: value to set |
423 """ | 429 """ |
424 async with self.session() as session: | 430 async with self.session() as session: |
425 stmt = insert(ParamGen).values( | 431 stmt = ( |
426 category=category, | 432 insert(ParamGen) |
427 name=name, | 433 .values(category=category, name=name, value=value) |
428 value=value | 434 .on_conflict_do_update( |
429 ).on_conflict_do_update( | 435 index_elements=(ParamGen.category, ParamGen.name), |
430 index_elements=(ParamGen.category, ParamGen.name), | 436 set_={ParamGen.value: value}, |
431 set_={ | 437 ) |
432 ParamGen.value: value | |
433 } | |
434 ) | 438 ) |
435 await session.execute(stmt) | 439 await session.execute(stmt) |
436 await session.commit() | 440 await session.commit() |
437 | 441 |
438 @aio | 442 @aio |
439 async def set_ind_param( | 443 async def set_ind_param( |
440 self, | 444 self, category: str, name: str, value: Optional[str], profile: str |
441 category:str, | |
442 name: str, | |
443 value: Optional[str], | |
444 profile: str | |
445 ) -> None: | 445 ) -> None: |
446 """Save the individual parameters in database | 446 """Save the individual parameters in database |
447 | 447 |
448 @param category: category of the parameter | 448 @param category: category of the parameter |
449 @param name: name of the parameter | 449 @param name: name of the parameter |
450 @param value: value to set | 450 @param value: value to set |
451 @param profile: a profile which *must* exist | 451 @param profile: a profile which *must* exist |
452 """ | 452 """ |
453 async with self.session() as session: | 453 async with self.session() as session: |
454 stmt = insert(ParamInd).values( | 454 stmt = ( |
455 category=category, | 455 insert(ParamInd) |
456 name=name, | 456 .values( |
457 profile_id=self.profiles[profile], | 457 category=category, |
458 value=value | 458 name=name, |
459 ).on_conflict_do_update( | 459 profile_id=self.profiles[profile], |
460 index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), | 460 value=value, |
461 set_={ | 461 ) |
462 ParamInd.value: value | 462 .on_conflict_do_update( |
463 } | 463 index_elements=( |
464 ParamInd.category, | |
465 ParamInd.name, | |
466 ParamInd.profile_id, | |
467 ), | |
468 set_={ParamInd.value: value}, | |
469 ) | |
464 ) | 470 ) |
465 await session.execute(stmt) | 471 await session.execute(stmt) |
466 await session.commit() | 472 await session.commit() |
467 | 473 |
468 def _jid_filter(self, jid_: jid.JID, dest: bool = False): | 474 def _jid_filter(self, jid_: jid.JID, dest: bool = False): |
472 @param jid_: JID to filter by | 478 @param jid_: JID to filter by |
473 """ | 479 """ |
474 if jid_.resource: | 480 if jid_.resource: |
475 if dest: | 481 if dest: |
476 return and_( | 482 return and_( |
477 History.dest == jid_.userhost(), | 483 History.dest == jid_.userhost(), History.dest_res == jid_.resource |
478 History.dest_res == jid_.resource | |
479 ) | 484 ) |
480 else: | 485 else: |
481 return and_( | 486 return and_( |
482 History.source == jid_.userhost(), | 487 History.source == jid_.userhost(), History.source_res == jid_.resource |
483 History.source_res == jid_.resource | |
484 ) | 488 ) |
485 else: | 489 else: |
486 if dest: | 490 if dest: |
487 return History.dest == jid_.userhost() | 491 return History.dest == jid_.userhost() |
488 else: | 492 else: |
495 to_jid: Optional[jid.JID], | 499 to_jid: Optional[jid.JID], |
496 limit: Optional[int] = None, | 500 limit: Optional[int] = None, |
497 between: bool = True, | 501 between: bool = True, |
498 filters: Optional[Dict[str, str]] = None, | 502 filters: Optional[Dict[str, str]] = None, |
499 profile: Optional[str] = None, | 503 profile: Optional[str] = None, |
500 ) -> List[Tuple[ | 504 ) -> List[Tuple[str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]]: |
501 str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] | |
502 ]: | |
503 """Retrieve messages in history | 505 """Retrieve messages in history |
504 | 506 |
505 @param from_jid: source JID (full, or bare for catchall) | 507 @param from_jid: source JID (full, or bare for catchall) |
506 @param to_jid: dest JID (full, or bare for catchall) | 508 @param to_jid: dest JID (full, or bare for catchall) |
507 @param limit: maximum number of messages to get: | 509 @param limit: maximum number of messages to get: |
521 if filters is None: | 523 if filters is None: |
522 filters = {} | 524 filters = {} |
523 | 525 |
524 stmt = ( | 526 stmt = ( |
525 select(History) | 527 select(History) |
526 .filter_by( | 528 .filter_by(profile_id=self.profiles[profile]) |
527 profile_id=self.profiles[profile] | |
528 ) | |
529 .outerjoin(History.messages) | 529 .outerjoin(History.messages) |
530 .outerjoin(History.subjects) | 530 .outerjoin(History.subjects) |
531 .outerjoin(History.thread) | 531 .outerjoin(History.thread) |
532 .options( | 532 .options( |
533 contains_eager(History.messages), | 533 contains_eager(History.messages), |
538 # timestamp may be identical for 2 close messages (specially when delay is | 538 # timestamp may be identical for 2 close messages (specially when delay is |
539 # used) that's why we order ties by received_timestamp. We'll reverse the | 539 # used) that's why we order ties by received_timestamp. We'll reverse the |
540 # order when returning the result. We use DESC here so LIMIT keep the last | 540 # order when returning the result. We use DESC here so LIMIT keep the last |
541 # messages | 541 # messages |
542 History.timestamp.desc(), | 542 History.timestamp.desc(), |
543 History.received_timestamp.desc() | 543 History.received_timestamp.desc(), |
544 ) | 544 ) |
545 ) | 545 ) |
546 | |
547 | 546 |
548 if not from_jid and not to_jid: | 547 if not from_jid and not to_jid: |
549 # no jid specified, we want all one2one communications | 548 # no jid specified, we want all one2one communications |
550 pass | 549 pass |
551 elif between: | 550 elif between: |
552 if not from_jid or not to_jid: | 551 if not from_jid or not to_jid: |
553 # we only have one jid specified, we check all messages | 552 # we only have one jid specified, we check all messages |
554 # from or to this jid | 553 # from or to this jid |
555 jid_ = from_jid or to_jid | 554 jid_ = from_jid or to_jid |
556 stmt = stmt.where( | 555 stmt = stmt.where( |
557 or_( | 556 or_(self._jid_filter(jid_), self._jid_filter(jid_, dest=True)) |
558 self._jid_filter(jid_), | |
559 self._jid_filter(jid_, dest=True) | |
560 ) | |
561 ) | 557 ) |
562 else: | 558 else: |
563 # we have 2 jids specified, we check all communications between | 559 # we have 2 jids specified, we check all communications between |
564 # those 2 jids | 560 # those 2 jids |
565 stmt = stmt.where( | 561 stmt = stmt.where( |
569 self._jid_filter(to_jid, dest=True), | 565 self._jid_filter(to_jid, dest=True), |
570 ), | 566 ), |
571 and_( | 567 and_( |
572 self._jid_filter(to_jid), | 568 self._jid_filter(to_jid), |
573 self._jid_filter(from_jid, dest=True), | 569 self._jid_filter(from_jid, dest=True), |
574 ) | 570 ), |
575 ) | 571 ) |
576 ) | 572 ) |
577 else: | 573 else: |
578 # we want one communication in specific direction (from somebody or | 574 # we want one communication in specific direction (from somebody or |
579 # to somebody). | 575 # to somebody). |
581 stmt = stmt.where(self._jid_filter(from_jid)) | 577 stmt = stmt.where(self._jid_filter(from_jid)) |
582 if to_jid is not None: | 578 if to_jid is not None: |
583 stmt = stmt.where(self._jid_filter(to_jid, dest=True)) | 579 stmt = stmt.where(self._jid_filter(to_jid, dest=True)) |
584 | 580 |
585 if filters: | 581 if filters: |
586 if 'timestamp_start' in filters: | 582 if "timestamp_start" in filters: |
587 stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) | 583 stmt = stmt.where(History.timestamp >= float(filters["timestamp_start"])) |
588 if 'before_uid' in filters: | 584 if "before_uid" in filters: |
589 # orignially this query was using SQLITE's rowid. This has been changed | 585 # orignially this query was using SQLITE's rowid. This has been changed |
590 # to use coalesce(received_timestamp, timestamp) to be SQL engine independant | 586 # to use coalesce(received_timestamp, timestamp) to be SQL engine independant |
591 stmt = stmt.where( | 587 stmt = stmt.where( |
592 coalesce( | 588 coalesce(History.received_timestamp, History.timestamp) |
593 History.received_timestamp, | 589 < ( |
594 History.timestamp | 590 select( |
595 ) < ( | 591 coalesce(History.received_timestamp, History.timestamp) |
596 select(coalesce(History.received_timestamp, History.timestamp)) | 592 ).filter_by(uid=filters["before_uid"]) |
597 .filter_by(uid=filters["before_uid"]) | |
598 ).scalar_subquery() | 593 ).scalar_subquery() |
599 ) | 594 ) |
600 if 'body' in filters: | 595 if "body" in filters: |
601 # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html | 596 # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html |
602 stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) | 597 stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) |
603 if 'search' in filters: | 598 if "search" in filters: |
604 search_term = f"%{filters['search']}%" | 599 search_term = f"%{filters['search']}%" |
605 stmt = stmt.where(or_( | 600 stmt = stmt.where( |
606 Message.message.like(search_term), | 601 or_( |
607 History.source_res.like(search_term) | 602 Message.message.like(search_term), |
608 )) | 603 History.source_res.like(search_term), |
609 if 'types' in filters: | 604 ) |
610 types = filters['types'].split() | 605 ) |
606 if "types" in filters: | |
607 types = filters["types"].split() | |
611 stmt = stmt.where(History.type.in_(types)) | 608 stmt = stmt.where(History.type.in_(types)) |
612 if 'not_types' in filters: | 609 if "not_types" in filters: |
613 types = filters['not_types'].split() | 610 types = filters["not_types"].split() |
614 stmt = stmt.where(History.type.not_in(types)) | 611 stmt = stmt.where(History.type.not_in(types)) |
615 if 'last_stanza_id' in filters: | 612 if "last_stanza_id" in filters: |
616 # this request get the last message with a "stanza_id" that we | 613 # this request get the last message with a "stanza_id" that we |
617 # have in history. This is mainly used to retrieve messages sent | 614 # have in history. This is mainly used to retrieve messages sent |
618 # while we were offline, using MAM (XEP-0313). | 615 # while we were offline, using MAM (XEP-0313). |
619 if (filters['last_stanza_id'] is not True | 616 if filters["last_stanza_id"] is not True or limit != 1: |
620 or limit != 1): | |
621 raise ValueError("Unexpected values for last_stanza_id filter") | 617 raise ValueError("Unexpected values for last_stanza_id filter") |
622 stmt = stmt.where(History.stanza_id.is_not(None)) | 618 stmt = stmt.where(History.stanza_id.is_not(None)) |
623 if 'origin_id' in filters: | 619 if "origin_id" in filters: |
624 stmt = stmt.where(History.origin_id == filters["origin_id"]) | 620 stmt = stmt.where(History.origin_id == filters["origin_id"]) |
625 | 621 |
626 if limit is not None: | 622 if limit is not None: |
627 stmt = stmt.limit(limit) | 623 stmt = stmt.limit(limit) |
628 | 624 |
638 """Store a new message in history | 634 """Store a new message in history |
639 | 635 |
640 @param data: message data as build by SatMessageProtocol.onMessage | 636 @param data: message data as build by SatMessageProtocol.onMessage |
641 """ | 637 """ |
642 extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} | 638 extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} |
643 messages = [Message(message=mess, language=lang) | 639 messages = [ |
644 for lang, mess in data["message"].items()] | 640 Message(message=mess, language=lang) for lang, mess in data["message"].items() |
645 subjects = [Subject(subject=mess, language=lang) | 641 ] |
646 for lang, mess in data["subject"].items()] | 642 subjects = [ |
643 Subject(subject=mess, language=lang) for lang, mess in data["subject"].items() | |
644 ] | |
647 if "thread" in data["extra"]: | 645 if "thread" in data["extra"]: |
648 thread = Thread(thread_id=data["extra"]["thread"], | 646 thread = Thread( |
649 parent_id=data["extra"].get["thread_parent"]) | 647 thread_id=data["extra"]["thread"], |
648 parent_id=data["extra"].get["thread_parent"], | |
649 ) | |
650 else: | 650 else: |
651 thread = None | 651 thread = None |
652 try: | 652 try: |
653 async with self.session() as session: | 653 async with self.session() as session: |
654 async with session.begin(): | 654 async with session.begin(): |
655 session.add(History( | 655 session.add( |
656 uid=data["uid"], | 656 History( |
657 origin_id=data["extra"].get("origin_id"), | 657 uid=data["uid"], |
658 stanza_id=data["extra"].get("stanza_id"), | 658 origin_id=data["extra"].get("origin_id"), |
659 update_uid=data["extra"].get("update_uid"), | 659 stanza_id=data["extra"].get("stanza_id"), |
660 profile_id=self.profiles[profile], | 660 update_uid=data["extra"].get("update_uid"), |
661 source_jid=data["from"], | 661 profile_id=self.profiles[profile], |
662 dest_jid=data["to"], | 662 source_jid=data["from"], |
663 timestamp=data["timestamp"], | 663 dest_jid=data["to"], |
664 received_timestamp=data.get("received_timestamp"), | 664 timestamp=data["timestamp"], |
665 type=data["type"], | 665 received_timestamp=data.get("received_timestamp"), |
666 extra=extra, | 666 type=data["type"], |
667 messages=messages, | 667 extra=extra, |
668 subjects=subjects, | 668 messages=messages, |
669 thread=thread, | 669 subjects=subjects, |
670 )) | 670 thread=thread, |
671 ) | |
672 ) | |
671 except IntegrityError as e: | 673 except IntegrityError as e: |
672 if "unique" in str(e.orig).lower(): | 674 if "unique" in str(e.orig).lower(): |
673 log.debug( | 675 log.debug( |
674 f"message {data['uid']!r} is already in history, not storing it again" | 676 f"message {data['uid']!r} is already in history, not storing it again" |
675 ) | 677 ) |
687 if profile is None: | 689 if profile is None: |
688 return PrivateGenBin if binary else PrivateGen | 690 return PrivateGenBin if binary else PrivateGen |
689 else: | 691 else: |
690 return PrivateIndBin if binary else PrivateInd | 692 return PrivateIndBin if binary else PrivateInd |
691 | 693 |
692 | |
693 @aio | 694 @aio |
694 async def get_privates( | 695 async def get_privates( |
695 self, | 696 self, |
696 namespace:str, | 697 namespace: str, |
697 keys: Optional[Iterable[str]] = None, | 698 keys: Optional[Iterable[str]] = None, |
698 binary: bool = False, | 699 binary: bool = False, |
699 profile: Optional[str] = None | 700 profile: Optional[str] = None, |
700 ) -> Dict[str, Any]: | 701 ) -> Dict[str, Any]: |
701 """Get private value(s) from databases | 702 """Get private value(s) from databases |
702 | 703 |
703 @param namespace: namespace of the values | 704 @param namespace: namespace of the values |
704 @param keys: keys of the values to get None to get all keys/values | 705 @param keys: keys of the values to get None to get all keys/values |
726 | 727 |
727 @aio | 728 @aio |
728 async def set_private_value( | 729 async def set_private_value( |
729 self, | 730 self, |
730 namespace: str, | 731 namespace: str, |
731 key:str, | 732 key: str, |
732 value: Any, | 733 value: Any, |
733 binary: bool = False, | 734 binary: bool = False, |
734 profile: Optional[str] = None | 735 profile: Optional[str] = None, |
735 ) -> None: | 736 ) -> None: |
736 """Set a private value in database | 737 """Set a private value in database |
737 | 738 |
738 @param namespace: namespace of the values | 739 @param namespace: namespace of the values |
739 @param key: key of the value to set | 740 @param key: key of the value to set |
743 @param profile: profile to use for individual value | 744 @param profile: profile to use for individual value |
744 if None, it's a general value | 745 if None, it's a general value |
745 """ | 746 """ |
746 cls = self._get_private_class(binary, profile) | 747 cls = self._get_private_class(binary, profile) |
747 | 748 |
748 values = { | 749 values = {"namespace": namespace, "key": key, "value": value} |
749 "namespace": namespace, | |
750 "key": key, | |
751 "value": value | |
752 } | |
753 index_elements = [cls.namespace, cls.key] | 750 index_elements = [cls.namespace, cls.key] |
754 | 751 |
755 if profile is not None: | 752 if profile is not None: |
756 values["profile_id"] = self.profiles[profile] | 753 values["profile_id"] = self.profiles[profile] |
757 index_elements.append(cls.profile_id) | 754 index_elements.append(cls.profile_id) |
758 | 755 |
759 async with self.session() as session: | 756 async with self.session() as session: |
760 await session.execute( | 757 await session.execute( |
761 insert(cls).values(**values).on_conflict_do_update( | 758 insert(cls) |
762 index_elements=index_elements, | 759 .values(**values) |
763 set_={ | 760 .on_conflict_do_update( |
764 cls.value: value | 761 index_elements=index_elements, set_={cls.value: value} |
765 } | |
766 ) | 762 ) |
767 ) | 763 ) |
768 await session.commit() | 764 await session.commit() |
769 | 765 |
770 @aio | 766 @aio |
771 async def del_private_value( | 767 async def del_private_value( |
772 self, | 768 self, |
773 namespace: str, | 769 namespace: str, |
774 key: str, | 770 key: str, |
775 binary: bool = False, | 771 binary: bool = False, |
776 profile: Optional[str] = None | 772 profile: Optional[str] = None, |
777 ) -> None: | 773 ) -> None: |
778 """Delete private value from database | 774 """Delete private value from database |
779 | 775 |
780 @param category: category of the privateeter | 776 @param category: category of the privateeter |
781 @param key: key of the private value | 777 @param key: key of the private value |
794 await session.execute(stmt) | 790 await session.execute(stmt) |
795 await session.commit() | 791 await session.commit() |
796 | 792 |
797 @aio | 793 @aio |
798 async def del_private_namespace( | 794 async def del_private_namespace( |
799 self, | 795 self, namespace: str, binary: bool = False, profile: Optional[str] = None |
800 namespace: str, | |
801 binary: bool = False, | |
802 profile: Optional[str] = None | |
803 ) -> None: | 796 ) -> None: |
804 """Delete all data from a private namespace | 797 """Delete all data from a private namespace |
805 | 798 |
806 Be really cautious when you use this method, as all data with given namespace are | 799 Be really cautious when you use this method, as all data with given namespace are |
807 removed. | 800 removed. |
823 @aio | 816 @aio |
824 async def get_files( | 817 async def get_files( |
825 self, | 818 self, |
826 client: Optional[SatXMPPEntity], | 819 client: Optional[SatXMPPEntity], |
827 file_id: Optional[str] = None, | 820 file_id: Optional[str] = None, |
828 version: Optional[str] = '', | 821 version: Optional[str] = "", |
829 parent: Optional[str] = None, | 822 parent: Optional[str] = None, |
830 type_: Optional[str] = None, | 823 type_: Optional[str] = None, |
831 file_hash: Optional[str] = None, | 824 file_hash: Optional[str] = None, |
832 hash_algo: Optional[str] = None, | 825 hash_algo: Optional[str] = None, |
833 name: Optional[str] = None, | 826 name: Optional[str] = None, |
835 mime_type: Optional[str] = None, | 828 mime_type: Optional[str] = None, |
836 public_id: Optional[str] = None, | 829 public_id: Optional[str] = None, |
837 owner: Optional[jid.JID] = None, | 830 owner: Optional[jid.JID] = None, |
838 access: Optional[dict] = None, | 831 access: Optional[dict] = None, |
839 projection: Optional[List[str]] = None, | 832 projection: Optional[List[str]] = None, |
840 unique: bool = False | 833 unique: bool = False, |
841 ) -> List[dict]: | 834 ) -> List[dict]: |
842 """Retrieve files with with given filters | 835 """Retrieve files with with given filters |
843 | 836 |
844 @param file_id: id of the file | 837 @param file_id: id of the file |
845 None to ignore | 838 None to ignore |
855 other params are the same as for [set_file] | 848 other params are the same as for [set_file] |
856 @return: files corresponding to filters | 849 @return: files corresponding to filters |
857 """ | 850 """ |
858 if projection is None: | 851 if projection is None: |
859 projection = [ | 852 projection = [ |
860 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', | 853 "id", |
861 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', | 854 "version", |
862 'created', 'modified', 'owner', 'access', 'extra' | 855 "parent", |
856 "type", | |
857 "file_hash", | |
858 "hash_algo", | |
859 "name", | |
860 "size", | |
861 "namespace", | |
862 "media_type", | |
863 "media_subtype", | |
864 "public_id", | |
865 "created", | |
866 "modified", | |
867 "owner", | |
868 "access", | |
869 "extra", | |
863 ] | 870 ] |
864 | 871 |
865 stmt = select(*[getattr(File, f) for f in projection]) | 872 stmt = select(*[getattr(File, f) for f in projection]) |
866 | 873 |
867 if unique: | 874 if unique: |
889 if name is not None: | 896 if name is not None: |
890 stmt = stmt.filter_by(name=name) | 897 stmt = stmt.filter_by(name=name) |
891 if namespace is not None: | 898 if namespace is not None: |
892 stmt = stmt.filter_by(namespace=namespace) | 899 stmt = stmt.filter_by(namespace=namespace) |
893 if mime_type is not None: | 900 if mime_type is not None: |
894 if '/' in mime_type: | 901 if "/" in mime_type: |
895 media_type, media_subtype = mime_type.split("/", 1) | 902 media_type, media_subtype = mime_type.split("/", 1) |
896 stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) | 903 stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) |
897 else: | 904 else: |
898 stmt = stmt.filter_by(media_type=mime_type) | 905 stmt = stmt.filter_by(media_type=mime_type) |
899 if public_id is not None: | 906 if public_id is not None: |
900 stmt = stmt.filter_by(public_id=public_id) | 907 stmt = stmt.filter_by(public_id=public_id) |
901 if owner is not None: | 908 if owner is not None: |
902 stmt = stmt.filter_by(owner=owner) | 909 stmt = stmt.filter_by(owner=owner) |
903 if access is not None: | 910 if access is not None: |
904 raise NotImplementedError('Access check is not implemented yet') | 911 raise NotImplementedError("Access check is not implemented yet") |
905 # a JSON comparison is needed here | 912 # a JSON comparison is needed here |
906 | 913 |
907 async with self.session() as session: | 914 async with self.session() as session: |
908 result = await session.execute(stmt) | 915 result = await session.execute(stmt) |
909 | 916 |
910 return [r._asdict() for r in result] | 917 return [r._asdict() for r in result] |
926 public_id: Optional[str] = None, | 933 public_id: Optional[str] = None, |
927 created: Optional[float] = None, | 934 created: Optional[float] = None, |
928 modified: Optional[float] = None, | 935 modified: Optional[float] = None, |
929 owner: Optional[jid.JID] = None, | 936 owner: Optional[jid.JID] = None, |
930 access: Optional[dict] = None, | 937 access: Optional[dict] = None, |
931 extra: Optional[dict] = None | 938 extra: Optional[dict] = None, |
932 ) -> None: | 939 ) -> None: |
933 """Set a file metadata | 940 """Set a file metadata |
934 | 941 |
935 @param client: client owning the file | 942 @param client: client owning the file |
936 @param name: name of the file (must not contain "/") | 943 @param name: name of the file (must not contain "/") |
956 @param extra: serialisable dictionary of any extra data | 963 @param extra: serialisable dictionary of any extra data |
957 will be encoded to json in database | 964 will be encoded to json in database |
958 """ | 965 """ |
959 if mime_type is None: | 966 if mime_type is None: |
960 media_type = media_subtype = None | 967 media_type = media_subtype = None |
961 elif '/' in mime_type: | 968 elif "/" in mime_type: |
962 media_type, media_subtype = mime_type.split('/', 1) | 969 media_type, media_subtype = mime_type.split("/", 1) |
963 else: | 970 else: |
964 media_type, media_subtype = mime_type, None | 971 media_type, media_subtype = mime_type, None |
965 | 972 |
966 async with self.session() as session: | 973 async with self.session() as session: |
967 async with session.begin(): | 974 async with session.begin(): |
968 session.add(File( | 975 session.add( |
969 id=file_id, | 976 File( |
970 version=version.strip(), | 977 id=file_id, |
971 parent=parent, | 978 version=version.strip(), |
972 type=type_, | 979 parent=parent, |
973 file_hash=file_hash, | 980 type=type_, |
974 hash_algo=hash_algo, | 981 file_hash=file_hash, |
975 name=name, | 982 hash_algo=hash_algo, |
976 size=size, | 983 name=name, |
977 namespace=namespace, | 984 size=size, |
978 media_type=media_type, | 985 namespace=namespace, |
979 media_subtype=media_subtype, | 986 media_type=media_type, |
980 public_id=public_id, | 987 media_subtype=media_subtype, |
981 created=time.time() if created is None else created, | 988 public_id=public_id, |
982 modified=modified, | 989 created=time.time() if created is None else created, |
983 owner=owner, | 990 modified=modified, |
984 access=access, | 991 owner=owner, |
985 extra=extra, | 992 access=access, |
986 profile_id=self.profiles[client.profile] | 993 extra=extra, |
987 )) | 994 profile_id=self.profiles[client.profile], |
995 ) | |
996 ) | |
988 | 997 |
989 @aio | 998 @aio |
990 async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int: | 999 async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int: |
991 async with self.session() as session: | 1000 async with self.session() as session: |
992 result = await session.execute( | 1001 result = await session.execute( |
993 select(sum_(File.size)).filter_by( | 1002 select(sum_(File.size)).filter_by( |
994 owner=owner, | 1003 owner=owner, |
995 type=C.FILE_TYPE_FILE, | 1004 type=C.FILE_TYPE_FILE, |
996 profile_id=self.profiles[client.profile] | 1005 profile_id=self.profiles[client.profile], |
997 )) | 1006 ) |
1007 ) | |
998 return result.scalar_one_or_none() or 0 | 1008 return result.scalar_one_or_none() or 0 |
999 | 1009 |
1000 @aio | 1010 @aio |
1001 async def file_delete(self, file_id: str) -> None: | 1011 async def file_delete(self, file_id: str) -> None: |
1002 """Delete file metadata from the database | 1012 """Delete file metadata from the database |
1009 await session.execute(delete(File).filter_by(id=file_id)) | 1019 await session.execute(delete(File).filter_by(id=file_id)) |
1010 await session.commit() | 1020 await session.commit() |
1011 | 1021 |
1012 @aio | 1022 @aio |
1013 async def file_update( | 1023 async def file_update( |
1014 self, | 1024 self, file_id: str, column: str, update_cb: Callable[[dict], None] |
1015 file_id: str, | |
1016 column: str, | |
1017 update_cb: Callable[[dict], None] | |
1018 ) -> None: | 1025 ) -> None: |
1019 """Update a column value using a method to avoid race conditions | 1026 """Update a column value using a method to avoid race conditions |
1020 | 1027 |
1021 the older value will be retrieved from database, then update_cb will be applied to | 1028 the older value will be retrieved from database, then update_cb will be applied to |
1022 update it, and file will be updated checking that older value has not been changed | 1029 update it, and file will be updated checking that older value has not been changed |
1027 the method will take older value as argument, and must update it in place | 1034 the method will take older value as argument, and must update it in place |
1028 update_cb must not care about serialization, | 1035 update_cb must not care about serialization, |
1029 it get the deserialized data (i.e. a Python object) directly | 1036 it get the deserialized data (i.e. a Python object) directly |
1030 @raise exceptions.NotFound: there is not file with this id | 1037 @raise exceptions.NotFound: there is not file with this id |
1031 """ | 1038 """ |
1032 if column not in ('access', 'extra'): | 1039 if column not in ("access", "extra"): |
1033 raise exceptions.InternalError('bad column name') | 1040 raise exceptions.InternalError("bad column name") |
1034 orm_col = getattr(File, column) | 1041 orm_col = getattr(File, column) |
1035 | 1042 |
1036 for i in range(5): | 1043 for i in range(5): |
1037 async with self.session() as session: | 1044 async with self.session() as session: |
1038 try: | 1045 try: |
1039 value = (await session.execute( | 1046 value = ( |
1040 select(orm_col).filter_by(id=file_id) | 1047 await session.execute(select(orm_col).filter_by(id=file_id)) |
1041 )).scalar_one() | 1048 ).scalar_one() |
1042 except NoResultFound: | 1049 except NoResultFound: |
1043 raise exceptions.NotFound | 1050 raise exceptions.NotFound |
1044 old_value = copy.deepcopy(value) | 1051 old_value = copy.deepcopy(value) |
1045 update_cb(value) | 1052 update_cb(value) |
1046 stmt = update(File).filter_by(id=file_id).values({column: value}) | 1053 stmt = update(File).filter_by(id=file_id).values({column: value}) |
1055 | 1062 |
1056 if result.rowcount == 1: | 1063 if result.rowcount == 1: |
1057 break | 1064 break |
1058 | 1065 |
1059 log.warning( | 1066 log.warning( |
1060 _("table not updated, probably due to race condition, trying again " | 1067 _( |
1061 "({tries})").format(tries=i+1) | 1068 "table not updated, probably due to race condition, trying again " |
1069 "({tries})" | |
1070 ).format(tries=i + 1) | |
1062 ) | 1071 ) |
1063 | 1072 |
1064 else: | 1073 else: |
1065 raise exceptions.DatabaseError( | 1074 raise exceptions.DatabaseError( |
1066 _("Can't update file {file_id} due to race condition") | 1075 _("Can't update file {file_id} due to race condition").format( |
1067 .format(file_id=file_id) | 1076 file_id=file_id |
1077 ) | |
1068 ) | 1078 ) |
1069 | 1079 |
1070 @aio | 1080 @aio |
1071 async def get_pubsub_node( | 1081 async def get_pubsub_node( |
1072 self, | 1082 self, |
1074 service: jid.JID, | 1084 service: jid.JID, |
1075 name: str, | 1085 name: str, |
1076 with_items: bool = False, | 1086 with_items: bool = False, |
1077 with_subscriptions: bool = False, | 1087 with_subscriptions: bool = False, |
1078 create: bool = False, | 1088 create: bool = False, |
1079 create_kwargs: Optional[dict] = None | 1089 create_kwargs: Optional[dict] = None, |
1080 ) -> Optional[PubsubNode]: | 1090 ) -> Optional[PubsubNode]: |
1081 """Retrieve a PubsubNode from DB | 1091 """Retrieve a PubsubNode from DB |
1082 | 1092 |
1083 @param service: service hosting the node | 1093 @param service: service hosting the node |
1084 @param name: node's name | 1094 @param name: node's name |
1087 @param create: if the node doesn't exist in DB, create it | 1097 @param create: if the node doesn't exist in DB, create it |
1088 @param create_kwargs: keyword arguments to use with ``set_pubsub_node`` if the node | 1098 @param create_kwargs: keyword arguments to use with ``set_pubsub_node`` if the node |
1089 needs to be created. | 1099 needs to be created. |
1090 """ | 1100 """ |
1091 async with self.session() as session: | 1101 async with self.session() as session: |
1092 stmt = ( | 1102 stmt = select(PubsubNode).filter_by( |
1093 select(PubsubNode) | 1103 service=service, |
1094 .filter_by( | 1104 name=name, |
1095 service=service, | 1105 profile_id=self.profiles[client.profile], |
1096 name=name, | |
1097 profile_id=self.profiles[client.profile], | |
1098 ) | |
1099 ) | 1106 ) |
1100 if with_items: | 1107 if with_items: |
1101 stmt = stmt.options( | 1108 stmt = stmt.options(joinedload(PubsubNode.items)) |
1102 joinedload(PubsubNode.items) | |
1103 ) | |
1104 if with_subscriptions: | 1109 if with_subscriptions: |
1105 stmt = stmt.options( | 1110 stmt = stmt.options(joinedload(PubsubNode.subscriptions)) |
1106 joinedload(PubsubNode.subscriptions) | |
1107 ) | |
1108 result = await session.execute(stmt) | 1111 result = await session.execute(stmt) |
1109 ret = result.unique().scalar_one_or_none() | 1112 ret = result.unique().scalar_one_or_none() |
1110 if ret is None and create: | 1113 if ret is None and create: |
1111 # we auto-create the node | 1114 # we auto-create the node |
1112 if create_kwargs is None: | 1115 if create_kwargs is None: |
1113 create_kwargs = {} | 1116 create_kwargs = {} |
1114 try: | 1117 try: |
1115 return await as_future(self.set_pubsub_node( | 1118 return await as_future( |
1116 client, service, name, **create_kwargs | 1119 self.set_pubsub_node(client, service, name, **create_kwargs) |
1117 )) | 1120 ) |
1118 except IntegrityError as e: | 1121 except IntegrityError as e: |
1119 if "unique" in str(e.orig).lower(): | 1122 if "unique" in str(e.orig).lower(): |
1120 # the node may already exist, if it has been created just after | 1123 # the node may already exist, if it has been created just after |
1121 # get_pubsub_node above | 1124 # get_pubsub_node above |
1122 log.debug("ignoring UNIQUE constraint error") | 1125 log.debug("ignoring UNIQUE constraint error") |
1123 cached_node = await as_future(self.get_pubsub_node( | 1126 cached_node = await as_future( |
1124 client, | 1127 self.get_pubsub_node( |
1125 service, | 1128 client, |
1126 name, | 1129 service, |
1127 with_items=with_items, | 1130 name, |
1128 with_subscriptions=with_subscriptions | 1131 with_items=with_items, |
1129 )) | 1132 with_subscriptions=with_subscriptions, |
1133 ) | |
1134 ) | |
1130 else: | 1135 else: |
1131 raise e | 1136 raise e |
1132 else: | 1137 else: |
1133 return ret | 1138 return ret |
1134 | 1139 |
1158 session.add(node) | 1163 session.add(node) |
1159 return node | 1164 return node |
1160 | 1165 |
1161 @aio | 1166 @aio |
1162 async def update_pubsub_node_sync_state( | 1167 async def update_pubsub_node_sync_state( |
1163 self, | 1168 self, node: PubsubNode, state: SyncState |
1164 node: PubsubNode, | |
1165 state: SyncState | |
1166 ) -> None: | 1169 ) -> None: |
1167 async with self.session() as session: | 1170 async with self.session() as session: |
1168 async with session.begin(): | 1171 async with session.begin(): |
1169 await session.execute( | 1172 await session.execute( |
1170 update(PubsubNode) | 1173 update(PubsubNode) |
1178 @aio | 1181 @aio |
1179 async def delete_pubsub_node( | 1182 async def delete_pubsub_node( |
1180 self, | 1183 self, |
1181 profiles: Optional[List[str]], | 1184 profiles: Optional[List[str]], |
1182 services: Optional[List[jid.JID]], | 1185 services: Optional[List[jid.JID]], |
1183 names: Optional[List[str]] | 1186 names: Optional[List[str]], |
1184 ) -> None: | 1187 ) -> None: |
1185 """Delete items cached for a node | 1188 """Delete items cached for a node |
1186 | 1189 |
1187 @param profiles: profile names from which nodes must be deleted. | 1190 @param profiles: profile names from which nodes must be deleted. |
1188 None to remove nodes from ALL profiles | 1191 None to remove nodes from ALL profiles |
1192 None to remove ALL nodes whatever is their names | 1195 None to remove ALL nodes whatever is their names |
1193 """ | 1196 """ |
1194 stmt = delete(PubsubNode) | 1197 stmt = delete(PubsubNode) |
1195 if profiles is not None: | 1198 if profiles is not None: |
1196 stmt = stmt.where( | 1199 stmt = stmt.where( |
1197 PubsubNode.profile.in_( | 1200 PubsubNode.profile.in_([self.profiles[p] for p in profiles]) |
1198 [self.profiles[p] for p in profiles] | |
1199 ) | |
1200 ) | 1201 ) |
1201 if services is not None: | 1202 if services is not None: |
1202 stmt = stmt.where(PubsubNode.service.in_(services)) | 1203 stmt = stmt.where(PubsubNode.service.in_(services)) |
1203 if names is not None: | 1204 if names is not None: |
1204 stmt = stmt.where(PubsubNode.name.in_(names)) | 1205 stmt = stmt.where(PubsubNode.name.in_(names)) |
1221 ) | 1222 ) |
1222 async with self.session() as session: | 1223 async with self.session() as session: |
1223 async with session.begin(): | 1224 async with session.begin(): |
1224 for idx, item in enumerate(items): | 1225 for idx, item in enumerate(items): |
1225 parsed = parsed_items[idx] if parsed_items else None | 1226 parsed = parsed_items[idx] if parsed_items else None |
1226 stmt = insert(PubsubItem).values( | 1227 stmt = ( |
1227 node_id = node.id, | 1228 insert(PubsubItem) |
1228 name = item["id"], | 1229 .values( |
1229 data = item, | 1230 node_id=node.id, |
1230 parsed = parsed, | 1231 name=item["id"], |
1231 ).on_conflict_do_update( | 1232 data=item, |
1232 index_elements=(PubsubItem.node_id, PubsubItem.name), | 1233 parsed=parsed, |
1233 set_={ | 1234 ) |
1234 PubsubItem.data: item, | 1235 .on_conflict_do_update( |
1235 PubsubItem.parsed: parsed, | 1236 index_elements=(PubsubItem.node_id, PubsubItem.name), |
1236 PubsubItem.updated: now() | 1237 set_={ |
1237 } | 1238 PubsubItem.data: item, |
1239 PubsubItem.parsed: parsed, | |
1240 PubsubItem.updated: now(), | |
1241 }, | |
1242 ) | |
1238 ) | 1243 ) |
1239 await session.execute(stmt) | 1244 await session.execute(stmt) |
1240 await session.commit() | 1245 await session.commit() |
1241 | 1246 |
1242 @aio | 1247 @aio |
1243 async def delete_pubsub_items( | 1248 async def delete_pubsub_items( |
1244 self, | 1249 self, node: PubsubNode, items_names: Optional[List[str]] = None |
1245 node: PubsubNode, | |
1246 items_names: Optional[List[str]] = None | |
1247 ) -> None: | 1250 ) -> None: |
1248 """Delete items cached for a node | 1251 """Delete items cached for a node |
1249 | 1252 |
1250 @param node: node from which items must be deleted | 1253 @param node: node from which items must be deleted |
1251 @param items_names: names of items to delete | 1254 @param items_names: names of items to delete |
1294 sub_q = select(PubsubNode.id) | 1297 sub_q = select(PubsubNode.id) |
1295 for col, values in node_fields.items(): | 1298 for col, values in node_fields.items(): |
1296 if values is None: | 1299 if values is None: |
1297 continue | 1300 continue |
1298 sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) | 1301 sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) |
1299 stmt = ( | 1302 stmt = stmt.where(PubsubItem.node_id.in_(sub_q)).execution_options( |
1300 stmt | 1303 synchronize_session=False |
1301 .where(PubsubItem.node_id.in_(sub_q)) | |
1302 .execution_options(synchronize_session=False) | |
1303 ) | 1304 ) |
1304 | 1305 |
1305 if created_before is not None: | 1306 if created_before is not None: |
1306 stmt = stmt.where(PubsubItem.created < created_before) | 1307 stmt = stmt.where(PubsubItem.created < created_before) |
1307 | 1308 |
1365 if force_rsm and not use_rsm: | 1366 if force_rsm and not use_rsm: |
1366 # | 1367 # |
1367 use_rsm = True | 1368 use_rsm = True |
1368 from_index = 0 | 1369 from_index = 0 |
1369 | 1370 |
1370 stmt = ( | 1371 stmt = select(PubsubItem).filter_by(node_id=node.id).limit(max_items) |
1371 select(PubsubItem) | |
1372 .filter_by(node_id=node.id) | |
1373 .limit(max_items) | |
1374 ) | |
1375 | 1372 |
1376 if item_ids is not None: | 1373 if item_ids is not None: |
1377 stmt = stmt.where(PubsubItem.name.in_(item_ids)) | 1374 stmt = stmt.where(PubsubItem.name.in_(item_ids)) |
1378 | 1375 |
1379 if not order_by: | 1376 if not order_by: |
1400 # CTE to have result row numbers | 1397 # CTE to have result row numbers |
1401 row_num_q = select( | 1398 row_num_q = select( |
1402 PubsubItem.id, | 1399 PubsubItem.id, |
1403 PubsubItem.name, | 1400 PubsubItem.name, |
1404 # row_number starts from 1, but RSM index must start from 0 | 1401 # row_number starts from 1, but RSM index must start from 0 |
1405 (func.row_number().over(order_by=order)-1).label("item_index") | 1402 (func.row_number().over(order_by=order) - 1).label("item_index"), |
1406 ).filter_by(node_id=node.id) | 1403 ).filter_by(node_id=node.id) |
1407 | 1404 |
1408 row_num_cte = row_num_q.cte() | 1405 row_num_cte = row_num_q.cte() |
1409 | 1406 |
1410 if max_items > 0: | 1407 if max_items > 0: |
1411 # as we can't simply use PubsubItem.id when we order by modification, | 1408 # as we can't simply use PubsubItem.id when we order by modification, |
1412 # we need to use row number | 1409 # we need to use row number |
1413 item_name = before or after | 1410 item_name = before or after |
1414 row_num_limit_q = ( | 1411 row_num_limit_q = ( |
1415 select(row_num_cte.c.item_index) | 1412 select(row_num_cte.c.item_index).where( |
1416 .where(row_num_cte.c.name==item_name) | 1413 row_num_cte.c.name == item_name |
1414 ) | |
1417 ).scalar_subquery() | 1415 ).scalar_subquery() |
1418 | 1416 |
1419 stmt = ( | 1417 stmt = ( |
1420 select(row_num_cte.c.item_index, PubsubItem) | 1418 select(row_num_cte.c.item_index, PubsubItem) |
1421 .join(row_num_cte, PubsubItem.id == row_num_cte.c.id) | 1419 .join(row_num_cte, PubsubItem.id == row_num_cte.c.id) |
1422 .limit(max_items) | 1420 .limit(max_items) |
1423 ) | 1421 ) |
1424 if before: | 1422 if before: |
1425 stmt = ( | 1423 stmt = stmt.where( |
1426 stmt | 1424 row_num_cte.c.item_index < row_num_limit_q |
1427 .where(row_num_cte.c.item_index<row_num_limit_q) | 1425 ).order_by(row_num_cte.c.item_index.desc()) |
1428 .order_by(row_num_cte.c.item_index.desc()) | |
1429 ) | |
1430 elif after: | 1426 elif after: |
1431 stmt = ( | 1427 stmt = stmt.where( |
1432 stmt | 1428 row_num_cte.c.item_index > row_num_limit_q |
1433 .where(row_num_cte.c.item_index>row_num_limit_q) | 1429 ).order_by(row_num_cte.c.item_index.asc()) |
1434 .order_by(row_num_cte.c.item_index.asc()) | |
1435 ) | |
1436 else: | 1430 else: |
1437 stmt = ( | 1431 stmt = stmt.where(row_num_cte.c.item_index >= from_index).order_by( |
1438 stmt | 1432 row_num_cte.c.item_index.asc() |
1439 .where(row_num_cte.c.item_index>=from_index) | |
1440 .order_by(row_num_cte.c.item_index.asc()) | |
1441 ) | 1433 ) |
1442 # from_index is used | 1434 # from_index is used |
1443 | 1435 |
1444 async with self.session() as session: | 1436 async with self.session() as session: |
1445 if max_items == 0: | 1437 if max_items == 0: |
1466 last = None | 1458 last = None |
1467 else: | 1459 else: |
1468 last = result[-1][1].name | 1460 last = result[-1][1].name |
1469 | 1461 |
1470 metadata["rsm"] = { | 1462 metadata["rsm"] = { |
1471 k: v for k, v in { | 1463 k: v |
1464 for k, v in { | |
1472 "index": index, | 1465 "index": index, |
1473 "count": rows_count, | 1466 "count": rows_count, |
1474 "first": first, | 1467 "first": first, |
1475 "last": last, | 1468 "last": last, |
1476 }.items() if v is not None | 1469 }.items() |
1470 if v is not None | |
1477 } | 1471 } |
1478 metadata["complete"] = (index or 0) + len(result) == rows_count | 1472 metadata["complete"] = (index or 0) + len(result) == rows_count |
1479 | 1473 |
1480 return items, metadata | 1474 return items, metadata |
1481 | 1475 |
1485 result = result.scalars().all() | 1479 result = result.scalars().all() |
1486 if desc: | 1480 if desc: |
1487 result.reverse() | 1481 result.reverse() |
1488 return result, metadata | 1482 return result, metadata |
1489 | 1483 |
1490 def _get_sqlite_path( | 1484 def _get_sqlite_path(self, path: List[Union[str, int]]) -> str: |
1491 self, | |
1492 path: List[Union[str, int]] | |
1493 ) -> str: | |
1494 """generate path suitable to query JSON element with SQLite""" | 1485 """generate path suitable to query JSON element with SQLite""" |
1495 return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}" | 1486 return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}" |
1496 | 1487 |
1497 @aio | 1488 @aio |
1498 async def search_pubsub_items( | 1489 async def search_pubsub_items( |
1555 stmt = select(PubsubItem) | 1546 stmt = select(PubsubItem) |
1556 | 1547 |
1557 # Full-Text Search | 1548 # Full-Text Search |
1558 fts = query.get("fts") | 1549 fts = query.get("fts") |
1559 if fts: | 1550 if fts: |
1560 fts_select = text( | 1551 fts_select = ( |
1561 "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)" | 1552 text("SELECT rowid, rank FROM pubsub_items_fts(:fts_query)") |
1562 ).bindparams(fts_query=fts).columns(rowid=Integer).subquery() | 1553 .bindparams(fts_query=fts) |
1563 stmt = ( | 1554 .columns(rowid=Integer) |
1564 stmt | 1555 .subquery() |
1565 .select_from(fts_select) | 1556 ) |
1566 .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id) | 1557 stmt = stmt.select_from(fts_select).outerjoin( |
1558 PubsubItem, fts_select.c.rowid == PubsubItem.id | |
1567 ) | 1559 ) |
1568 | 1560 |
1569 # node related filters | 1561 # node related filters |
1570 profiles = query.get("profiles") | 1562 profiles = query.get("profiles") |
1571 if (profiles | 1563 if profiles or any( |
1572 or any(query.get(k) for k in ("nodes", "services", "types", "subtypes")) | 1564 query.get(k) for k in ("nodes", "services", "types", "subtypes") |
1573 ): | 1565 ): |
1574 stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node)) | 1566 stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node)) |
1575 if profiles: | 1567 if profiles: |
1576 try: | 1568 try: |
1577 stmt = stmt.where( | 1569 stmt = stmt.where( |
1583 ) | 1575 ) |
1584 for key, attr in ( | 1576 for key, attr in ( |
1585 ("nodes", "name"), | 1577 ("nodes", "name"), |
1586 ("services", "service"), | 1578 ("services", "service"), |
1587 ("types", "type_"), | 1579 ("types", "type_"), |
1588 ("subtypes", "subtype") | 1580 ("subtypes", "subtype"), |
1589 ): | 1581 ): |
1590 value = query.get(key) | 1582 value = query.get(key) |
1591 if not value: | 1583 if not value: |
1592 continue | 1584 continue |
1593 if key in ("types", "subtypes") and None in value: | 1585 if key in ("types", "subtypes") and None in value: |
1594 # NULL can't be used with SQL's IN, so we have to add a condition with | 1586 # NULL can't be used with SQL's IN, so we have to add a condition with |
1595 # IS NULL, and use a OR if there are other values to check | 1587 # IS NULL, and use a OR if there are other values to check |
1596 value.remove(None) | 1588 value.remove(None) |
1597 condition = getattr(PubsubNode, attr).is_(None) | 1589 condition = getattr(PubsubNode, attr).is_(None) |
1598 if value: | 1590 if value: |
1599 condition = or_( | 1591 condition = or_(getattr(PubsubNode, attr).in_(value), condition) |
1600 getattr(PubsubNode, attr).in_(value), | |
1601 condition | |
1602 ) | |
1603 else: | 1592 else: |
1604 condition = getattr(PubsubNode, attr).in_(value) | 1593 condition = getattr(PubsubNode, attr).in_(value) |
1605 stmt = stmt.where(condition) | 1594 stmt = stmt.where(condition) |
1606 else: | 1595 else: |
1607 stmt = stmt.options(selectinload(PubsubItem.node)) | 1596 stmt = stmt.options(selectinload(PubsubItem.node)) |
1642 stmt = stmt.where(condition) | 1631 stmt = stmt.where(condition) |
1643 elif operator == "between": | 1632 elif operator == "between": |
1644 try: | 1633 try: |
1645 left, right = value | 1634 left, right = value |
1646 except (ValueError, TypeError): | 1635 except (ValueError, TypeError): |
1647 raise ValueError(_( | 1636 raise ValueError( |
1648 "invalid value for \"between\" filter, you must use a 2 items " | 1637 _( |
1649 "array: {value!r}" | 1638 'invalid value for "between" filter, you must use a 2 items ' |
1650 ).format(value=value)) | 1639 "array: {value!r}" |
1640 ).format(value=value) | |
1641 ) | |
1651 col = func.json_extract(PubsubItem.parsed, sqlite_path) | 1642 col = func.json_extract(PubsubItem.parsed, sqlite_path) |
1652 stmt = stmt.where(col.between(left, right)) | 1643 stmt = stmt.where(col.between(left, right)) |
1653 else: | 1644 else: |
1654 # we use func.json_extract instead of generic JSON way because SQLAlchemy | 1645 # we use func.json_extract instead of generic JSON way because SQLAlchemy |
1655 # add a JSON_QUOTE to the value, and we want SQL value | 1646 # add a JSON_QUOTE to the value, and we want SQL value |
1660 order_by = query.get("order-by") or [{"order": "creation"}] | 1651 order_by = query.get("order-by") or [{"order": "creation"}] |
1661 | 1652 |
1662 for order_data in order_by: | 1653 for order_data in order_by: |
1663 order, path = order_data.get("order"), order_data.get("path") | 1654 order, path = order_data.get("order"), order_data.get("path") |
1664 if order and path: | 1655 if order and path: |
1665 raise ValueError(_( | 1656 raise ValueError( |
1666 '"order" and "path" can\'t be used at the same time in ' | 1657 _( |
1667 '"order-by" data' | 1658 '"order" and "path" can\'t be used at the same time in ' |
1668 )) | 1659 '"order-by" data' |
1660 ) | |
1661 ) | |
1669 if order: | 1662 if order: |
1670 if order == "creation": | 1663 if order == "creation": |
1671 col = PubsubItem.id | 1664 col = PubsubItem.id |
1672 elif order == "modification": | 1665 elif order == "modification": |
1673 col = PubsubItem.updated | 1666 col = PubsubItem.updated |
1700 | 1693 |
1701 async with self.session() as session: | 1694 async with self.session() as session: |
1702 result = await session.execute(stmt) | 1695 result = await session.execute(stmt) |
1703 | 1696 |
1704 return result.scalars().all() | 1697 return result.scalars().all() |
1698 | |
1699 # Notifications | |
1700 | |
1701 @aio | |
1702 async def add_notification( | |
1703 self, | |
1704 client: Optional[SatXMPPEntity], | |
1705 type_: NotificationType, | |
1706 body_plain: str, | |
1707 body_rich: Optional[str] = None, | |
1708 title: Optional[str] = None, | |
1709 requires_action: bool = False, | |
1710 priority: NotificationPriority = NotificationPriority.MEDIUM, | |
1711 expire_at: Optional[float] = None, | |
1712 extra: Optional[dict] = None, | |
1713 ) -> Notification: | |
1714 """Add a new notification to the DB. | |
1715 | |
1716 @param client: client associated with the notification. If None, the notification | |
1717 will be global. | |
1718 @param type_: type of the notification. | |
1719 @param body_plain: plain text body. | |
1720 @param body_rich: rich text (XHTML) body. | |
1721 @param title: optional title. | |
1722 @param requires_action: True if the notification requires user action (e.g. a | |
1723 dialog need to be answered). | |
1724 @priority: how urgent the notification is | |
1725 @param expire_at: expiration timestamp for the notification. | |
1726 @param extra: additional data. | |
1727 @return: created Notification | |
1728 """ | |
1729 profile_id = self.profiles[client.profile] if client else None | |
1730 notification = Notification( | |
1731 profile_id=profile_id, | |
1732 type=type_, | |
1733 body_plain=body_plain, | |
1734 body_rich=body_rich, | |
1735 requires_action=requires_action, | |
1736 priority=priority, | |
1737 expire_at=expire_at, | |
1738 title=title, | |
1739 extra_data=extra, | |
1740 status=NotificationStatus.new, | |
1741 ) | |
1742 async with self.session() as session: | |
1743 async with session.begin(): | |
1744 session.add(notification) | |
1745 return notification | |
1746 | |
1747 @aio | |
1748 async def update_notification( | |
1749 self, client: SatXMPPEntity, notification_id: int, **kwargs | |
1750 ) -> None: | |
1751 """Update an existing notification. | |
1752 | |
1753 @param client: client associated with the notification. | |
1754 @param notification_id: ID of the notification to update. | |
1755 """ | |
1756 profile_id = self.profiles[client.profile] | |
1757 async with self.session() as session: | |
1758 await session.execute( | |
1759 update(Notification) | |
1760 .where( | |
1761 and_( | |
1762 Notification.profile_id == profile_id, | |
1763 Notification.id == notification_id, | |
1764 ) | |
1765 ) | |
1766 .values(**kwargs) | |
1767 ) | |
1768 await session.commit() | |
1769 | |
1770 @aio | |
1771 async def get_notifications( | |
1772 self, | |
1773 client: SatXMPPEntity, | |
1774 type_: Optional[NotificationType] = None, | |
1775 status: Optional[NotificationStatus] = None, | |
1776 requires_action: Optional[bool] = None, | |
1777 min_priority: Optional[int] = None | |
1778 ) -> List[Notification]: | |
1779 """Retrieve all notifications for a given profile with optional filters. | |
1780 | |
1781 @param client: client associated with the notifications. | |
1782 @param type_: filter by type of the notification. | |
1783 @param status: filter by status of the notification. | |
1784 @param requires_action: filter by notifications that require user action. | |
1785 @param min_priority: filter by minimum priority value. | |
1786 @return: list of matching Notification instances. | |
1787 """ | |
1788 profile_id = self.profiles[client.profile] | |
1789 filters = [or_(Notification.profile_id == profile_id, Notification.profile_id.is_(None))] | |
1790 | |
1791 if type_: | |
1792 filters.append(Notification.type == type_) | |
1793 if status: | |
1794 filters.append(Notification.status == status) | |
1795 if requires_action is not None: | |
1796 filters.append(Notification.requires_action == requires_action) | |
1797 if min_priority: | |
1798 filters.append(Notification.priority >= min_priority) | |
1799 | |
1800 async with self.session() as session: | |
1801 result = await session.execute( | |
1802 select(Notification) | |
1803 .where(and_(*filters)) | |
1804 .order_by(Notification.id) | |
1805 ) | |
1806 return result.scalars().all() | |
1807 | |
1808 @aio | |
1809 async def delete_notification( | |
1810 self, client: Optional[SatXMPPEntity], notification_id: str | |
1811 ) -> None: | |
1812 """Delete a notification by its profile and id. | |
1813 | |
1814 @param client: client associated with the notification. If None, profile_id will be NULL. | |
1815 @param notification_id: ID of the notification to delete. | |
1816 """ | |
1817 profile_id = self.profiles[client.profile] if client else None | |
1818 async with self.session() as session: | |
1819 await session.execute( | |
1820 delete(Notification).where( | |
1821 and_( | |
1822 Notification.profile_id == profile_id, | |
1823 Notification.id == int(notification_id), | |
1824 ) | |
1825 ) | |
1826 ) | |
1827 await session.commit() | |
1828 | |
1829 @aio | |
1830 async def clean_expired_notifications( | |
1831 self, client: Optional[SatXMPPEntity], limit_timestamp: Optional[float] = None | |
1832 ) -> None: | |
1833 """Cleans expired notifications and older profile-specific notifications. | |
1834 | |
1835 - Removes all notifications where the expiration timestamp has passed, | |
1836 irrespective of their profile. | |
1837 - If a limit_timestamp is provided, removes older notifications with a profile set | |
1838 (i.e., not global notifications) that do not require user action. If client is | |
1839 provided, only remove notification for this profile. | |
1840 | |
1841 @param client: if provided, only expire notification for this client (in addition | |
1842 to truly expired notifications for everybody). | |
1843 @param limit_timestamp: Timestamp limit for older notifications. If None, only | |
1844 truly expired notifications are removed. | |
1845 """ | |
1846 | |
1847 # Delete truly expired notifications | |
1848 expired_condition = Notification.expire_at < time.time() | |
1849 | |
1850 # Delete older profile-specific notifications (created before the limit_timestamp) | |
1851 if client is None: | |
1852 profile_condition = Notification.profile_id.isnot(None) | |
1853 else: | |
1854 profile_condition = Notification.profile_id == self.profiles[client.profile] | |
1855 older_condition = and_( | |
1856 profile_condition, | |
1857 Notification.timestamp < limit_timestamp if limit_timestamp else False, | |
1858 Notification.requires_action == False, | |
1859 ) | |
1860 | |
1861 # Combine the conditions | |
1862 conditions = or_(expired_condition, older_condition) | |
1863 | |
1864 async with self.session() as session: | |
1865 await session.execute(delete(Notification).where(conditions)) | |
1866 await session.commit() |