comparison libervia/backend/memory/sqla.py @ 4130:02f0adc745c6

core: notifications implementation, first draft: add a new table for notifications, and methods/bridge methods to manipulate them.
author Goffi <goffi@goffi.org>
date Mon, 16 Oct 2023 17:29:31 +0200
parents 74c66c0d93f3
children 6a0066ea5c97
comparison
equal deleted inserted replaced
4129:51744ad00a42 4130:02f0adc745c6
60 Component, 60 Component,
61 File, 61 File,
62 History, 62 History,
63 Message, 63 Message,
64 NOT_IN_EXTRA, 64 NOT_IN_EXTRA,
65 Notification,
66 NotificationPriority,
67 NotificationStatus,
68 NotificationType,
65 ParamGen, 69 ParamGen,
66 ParamInd, 70 ParamInd,
67 PrivateGen, 71 PrivateGen,
68 PrivateGenBin, 72 PrivateGenBin,
69 PrivateInd, 73 PrivateInd,
72 PubsubItem, 76 PubsubItem,
73 PubsubNode, 77 PubsubNode,
74 Subject, 78 Subject,
75 SyncState, 79 SyncState,
76 Thread, 80 Thread,
81 get_profile_by_id,
82 profiles,
77 ) 83 )
78 from libervia.backend.tools.common import uri 84 from libervia.backend.tools.common import uri
79 from libervia.backend.tools.utils import aio, as_future 85 from libervia.backend.tools.utils import aio, as_future
80 86
81 87
115 class Storage: 121 class Storage:
116 122
117 def __init__(self): 123 def __init__(self):
118 self.initialized = defer.Deferred() 124 self.initialized = defer.Deferred()
119 # we keep cache for the profiles (key: profile name, value: profile id) 125 # we keep cache for the profiles (key: profile name, value: profile id)
120 # profile id to name
121 self.profiles: Dict[str, int] = {}
122 # profile id to component entry point 126 # profile id to component entry point
123 self.components: Dict[int, str] = {} 127 self.components: Dict[int, str] = {}
124 128
129 @property
130 def profiles(self):
131 return profiles
132
125 def get_profile_by_id(self, profile_id): 133 def get_profile_by_id(self, profile_id):
126 return self.profiles.get(profile_id) 134 return get_profile_by_id(profile_id)
127 135
128 async def migrate_apply(self, *args: str, log_output: bool = False) -> None: 136 async def migrate_apply(self, *args: str, log_output: bool = False) -> None:
129 """Do a migration command 137 """Do a migration command
130 138
131 Commands are applied by running Alembic in a subprocess. 139 Commands are applied by running Alembic in a subprocess.
137 @raise exceptions.DatabaseError: something went wrong while running the 145 @raise exceptions.DatabaseError: something went wrong while running the
138 process 146 process
139 """ 147 """
140 stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) 148 stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,)
141 proc = await asyncio.create_subprocess_exec( 149 proc = await asyncio.create_subprocess_exec(
142 sys.executable, "-m", "alembic", *args, 150 sys.executable,
143 stdout=stdout, stderr=stderr, cwd=migration_path 151 "-m",
152 "alembic",
153 *args,
154 stdout=stdout,
155 stderr=stderr,
156 cwd=migration_path,
144 ) 157 )
145 log_out, log_err = await proc.communicate() 158 log_out, log_err = await proc.communicate()
146 if proc.returncode != 0: 159 if proc.returncode != 0:
147 msg = _( 160 msg = _("Can't {operation} database (exit code {exit_code})").format(
148 "Can't {operation} database (exit code {exit_code})" 161 operation=args[0], exit_code=proc.returncode
149 ).format(
150 operation=args[0],
151 exit_code=proc.returncode
152 ) 162 )
153 if log_out or log_err: 163 if log_out or log_err:
154 msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" 164 msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}"
155 log.error(msg) 165 log.error(msg)
156 166
214 await self.check_and_update_db(engine, db_config) 224 await self.check_and_update_db(engine, db_config)
215 225
216 async with engine.connect() as conn: 226 async with engine.connect() as conn:
217 await conn.run_sync(self._sqlite_set_journal_mode_wal) 227 await conn.run_sync(self._sqlite_set_journal_mode_wal)
218 228
219 self.session = sessionmaker( 229 self.session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
220 engine, expire_on_commit=False, class_=AsyncSession
221 )
222 230
223 async with self.session() as session: 231 async with self.session() as session:
224 result = await session.execute(select(Profile)) 232 result = await session.execute(select(Profile))
225 for p in result.scalars(): 233 for p in result.scalars():
226 self.profiles[p.name] = p.id 234 self.profiles[p.name] = p.id
237 self, 245 self,
238 client: SatXMPPEntity, 246 client: SatXMPPEntity,
239 db_cls: DeclarativeMeta, 247 db_cls: DeclarativeMeta,
240 db_id_col: Mapped, 248 db_id_col: Mapped,
241 id_value: Any, 249 id_value: Any,
242 joined_loads = None 250 joined_loads=None,
243 ) -> Optional[DeclarativeMeta]: 251 ) -> Optional[DeclarativeMeta]:
244 stmt = select(db_cls).where(db_id_col==id_value) 252 stmt = select(db_cls).where(db_id_col == id_value)
245 if client is not None: 253 if client is not None:
246 stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) 254 stmt = stmt.filter_by(profile_id=self.profiles[client.profile])
247 if joined_loads is not None: 255 if joined_loads is not None:
248 for joined_load in joined_loads: 256 for joined_load in joined_loads:
249 stmt = stmt.options(joinedload(joined_load)) 257 stmt = stmt.options(joinedload(joined_load))
262 270
263 @aio 271 @aio
264 async def delete( 272 async def delete(
265 self, 273 self,
266 db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]], 274 db_obj: Union[DeclarativeMeta, List[DeclarativeMeta]],
267 session_add: Optional[List[DeclarativeMeta]] = None 275 session_add: Optional[List[DeclarativeMeta]] = None,
268 ) -> None: 276 ) -> None:
269 """Delete an object from database 277 """Delete an object from database
270 278
271 @param db_obj: object to delete or list of objects to delete 279 @param db_obj: object to delete or list of objects to delete
272 @param session_add: other objects to add to session. 280 @param session_add: other objects to add to session.
287 await session.commit() 295 await session.commit()
288 296
289 ## Profiles 297 ## Profiles
290 298
291 def get_profiles_list(self) -> List[str]: 299 def get_profiles_list(self) -> List[str]:
292 """"Return list of all registered profiles""" 300 """ "Return list of all registered profiles"""
293 return list(self.profiles.keys()) 301 return list(self.profiles.keys())
294 302
295 def has_profile(self, profile_name: str) -> bool: 303 def has_profile(self, profile_name: str) -> bool:
296 """return True if profile_name exists 304 """return True if profile_name exists
297 305
307 315
308 def get_entry_point(self, profile_name: str) -> str: 316 def get_entry_point(self, profile_name: str) -> str:
309 try: 317 try:
310 return self.components[self.profiles[profile_name]] 318 return self.components[self.profiles[profile_name]]
311 except KeyError: 319 except KeyError:
312 raise exceptions.NotFound("the requested profile doesn't exists or is not a component") 320 raise exceptions.NotFound(
321 "the requested profile doesn't exists or is not a component"
322 )
313 323
314 @aio 324 @aio
315 async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None: 325 async def create_profile(self, name: str, component_ep: Optional[str] = None) -> None:
316 """Create a new profile 326 """Create a new profile
317 327
342 await session.delete(profile) 352 await session.delete(profile)
343 await session.commit() 353 await session.commit()
344 del self.profiles[profile.name] 354 del self.profiles[profile.name]
345 if profile.id in self.components: 355 if profile.id in self.components:
346 del self.components[profile.id] 356 del self.components[profile.id]
347 log.info(_("Profile {name!r} deleted").format(name = name)) 357 log.info(_("Profile {name!r} deleted").format(name=name))
348 358
349 ## Params 359 ## Params
350 360
351 @aio 361 @aio
352 async def load_gen_params(self, params_gen: dict) -> None: 362 async def load_gen_params(self, params_gen: dict) -> None:
374 ) 384 )
375 for p in result.scalars(): 385 for p in result.scalars():
376 params_ind[(p.category, p.name)] = p.value 386 params_ind[(p.category, p.name)] = p.value
377 387
378 @aio 388 @aio
379 async def get_ind_param(self, category: str, name: str, profile: str) -> Optional[str]: 389 async def get_ind_param(
390 self, category: str, name: str, profile: str
391 ) -> Optional[str]:
380 """Ask database for the value of one specific individual parameter 392 """Ask database for the value of one specific individual parameter
381 393
382 @param category: category of the parameter 394 @param category: category of the parameter
383 @param name: name of the parameter 395 @param name: name of the parameter
384 @param profile: %(doc_profile)s 396 @param profile: %(doc_profile)s
385 """ 397 """
386 async with self.session() as session: 398 async with self.session() as session:
387 result = await session.execute( 399 result = await session.execute(
388 select(ParamInd.value) 400 select(ParamInd.value).filter_by(
389 .filter_by( 401 category=category, name=name, profile_id=self.profiles[profile]
390 category=category,
391 name=name,
392 profile_id=self.profiles[profile]
393 ) 402 )
394 ) 403 )
395 return result.scalar_one_or_none() 404 return result.scalar_one_or_none()
396 405
397 @aio 406 @aio
403 @return dict: profile => value map 412 @return dict: profile => value map
404 """ 413 """
405 async with self.session() as session: 414 async with self.session() as session:
406 result = await session.execute( 415 result = await session.execute(
407 select(ParamInd) 416 select(ParamInd)
408 .filter_by( 417 .filter_by(category=category, name=name)
409 category=category,
410 name=name
411 )
412 .options(subqueryload(ParamInd.profile)) 418 .options(subqueryload(ParamInd.profile))
413 ) 419 )
414 return {param.profile.name: param.value for param in result.scalars()} 420 return {param.profile.name: param.value for param in result.scalars()}
415 421
416 @aio 422 @aio
420 @param category: category of the parameter 426 @param category: category of the parameter
421 @param name: name of the parameter 427 @param name: name of the parameter
422 @param value: value to set 428 @param value: value to set
423 """ 429 """
424 async with self.session() as session: 430 async with self.session() as session:
425 stmt = insert(ParamGen).values( 431 stmt = (
426 category=category, 432 insert(ParamGen)
427 name=name, 433 .values(category=category, name=name, value=value)
428 value=value 434 .on_conflict_do_update(
429 ).on_conflict_do_update( 435 index_elements=(ParamGen.category, ParamGen.name),
430 index_elements=(ParamGen.category, ParamGen.name), 436 set_={ParamGen.value: value},
431 set_={ 437 )
432 ParamGen.value: value
433 }
434 ) 438 )
435 await session.execute(stmt) 439 await session.execute(stmt)
436 await session.commit() 440 await session.commit()
437 441
438 @aio 442 @aio
439 async def set_ind_param( 443 async def set_ind_param(
440 self, 444 self, category: str, name: str, value: Optional[str], profile: str
441 category:str,
442 name: str,
443 value: Optional[str],
444 profile: str
445 ) -> None: 445 ) -> None:
446 """Save the individual parameters in database 446 """Save the individual parameters in database
447 447
448 @param category: category of the parameter 448 @param category: category of the parameter
449 @param name: name of the parameter 449 @param name: name of the parameter
450 @param value: value to set 450 @param value: value to set
451 @param profile: a profile which *must* exist 451 @param profile: a profile which *must* exist
452 """ 452 """
453 async with self.session() as session: 453 async with self.session() as session:
454 stmt = insert(ParamInd).values( 454 stmt = (
455 category=category, 455 insert(ParamInd)
456 name=name, 456 .values(
457 profile_id=self.profiles[profile], 457 category=category,
458 value=value 458 name=name,
459 ).on_conflict_do_update( 459 profile_id=self.profiles[profile],
460 index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), 460 value=value,
461 set_={ 461 )
462 ParamInd.value: value 462 .on_conflict_do_update(
463 } 463 index_elements=(
464 ParamInd.category,
465 ParamInd.name,
466 ParamInd.profile_id,
467 ),
468 set_={ParamInd.value: value},
469 )
464 ) 470 )
465 await session.execute(stmt) 471 await session.execute(stmt)
466 await session.commit() 472 await session.commit()
467 473
468 def _jid_filter(self, jid_: jid.JID, dest: bool = False): 474 def _jid_filter(self, jid_: jid.JID, dest: bool = False):
472 @param jid_: JID to filter by 478 @param jid_: JID to filter by
473 """ 479 """
474 if jid_.resource: 480 if jid_.resource:
475 if dest: 481 if dest:
476 return and_( 482 return and_(
477 History.dest == jid_.userhost(), 483 History.dest == jid_.userhost(), History.dest_res == jid_.resource
478 History.dest_res == jid_.resource
479 ) 484 )
480 else: 485 else:
481 return and_( 486 return and_(
482 History.source == jid_.userhost(), 487 History.source == jid_.userhost(), History.source_res == jid_.resource
483 History.source_res == jid_.resource
484 ) 488 )
485 else: 489 else:
486 if dest: 490 if dest:
487 return History.dest == jid_.userhost() 491 return History.dest == jid_.userhost()
488 else: 492 else:
495 to_jid: Optional[jid.JID], 499 to_jid: Optional[jid.JID],
496 limit: Optional[int] = None, 500 limit: Optional[int] = None,
497 between: bool = True, 501 between: bool = True,
498 filters: Optional[Dict[str, str]] = None, 502 filters: Optional[Dict[str, str]] = None,
499 profile: Optional[str] = None, 503 profile: Optional[str] = None,
500 ) -> List[Tuple[ 504 ) -> List[Tuple[str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]]:
501 str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]
502 ]:
503 """Retrieve messages in history 505 """Retrieve messages in history
504 506
505 @param from_jid: source JID (full, or bare for catchall) 507 @param from_jid: source JID (full, or bare for catchall)
506 @param to_jid: dest JID (full, or bare for catchall) 508 @param to_jid: dest JID (full, or bare for catchall)
507 @param limit: maximum number of messages to get: 509 @param limit: maximum number of messages to get:
521 if filters is None: 523 if filters is None:
522 filters = {} 524 filters = {}
523 525
524 stmt = ( 526 stmt = (
525 select(History) 527 select(History)
526 .filter_by( 528 .filter_by(profile_id=self.profiles[profile])
527 profile_id=self.profiles[profile]
528 )
529 .outerjoin(History.messages) 529 .outerjoin(History.messages)
530 .outerjoin(History.subjects) 530 .outerjoin(History.subjects)
531 .outerjoin(History.thread) 531 .outerjoin(History.thread)
532 .options( 532 .options(
533 contains_eager(History.messages), 533 contains_eager(History.messages),
538 # timestamp may be identical for 2 close messages (specially when delay is 538 # timestamp may be identical for 2 close messages (specially when delay is
539 # used) that's why we order ties by received_timestamp. We'll reverse the 539 # used) that's why we order ties by received_timestamp. We'll reverse the
540 # order when returning the result. We use DESC here so LIMIT keep the last 540 # order when returning the result. We use DESC here so LIMIT keep the last
541 # messages 541 # messages
542 History.timestamp.desc(), 542 History.timestamp.desc(),
543 History.received_timestamp.desc() 543 History.received_timestamp.desc(),
544 ) 544 )
545 ) 545 )
546
547 546
548 if not from_jid and not to_jid: 547 if not from_jid and not to_jid:
549 # no jid specified, we want all one2one communications 548 # no jid specified, we want all one2one communications
550 pass 549 pass
551 elif between: 550 elif between:
552 if not from_jid or not to_jid: 551 if not from_jid or not to_jid:
553 # we only have one jid specified, we check all messages 552 # we only have one jid specified, we check all messages
554 # from or to this jid 553 # from or to this jid
555 jid_ = from_jid or to_jid 554 jid_ = from_jid or to_jid
556 stmt = stmt.where( 555 stmt = stmt.where(
557 or_( 556 or_(self._jid_filter(jid_), self._jid_filter(jid_, dest=True))
558 self._jid_filter(jid_),
559 self._jid_filter(jid_, dest=True)
560 )
561 ) 557 )
562 else: 558 else:
563 # we have 2 jids specified, we check all communications between 559 # we have 2 jids specified, we check all communications between
564 # those 2 jids 560 # those 2 jids
565 stmt = stmt.where( 561 stmt = stmt.where(
569 self._jid_filter(to_jid, dest=True), 565 self._jid_filter(to_jid, dest=True),
570 ), 566 ),
571 and_( 567 and_(
572 self._jid_filter(to_jid), 568 self._jid_filter(to_jid),
573 self._jid_filter(from_jid, dest=True), 569 self._jid_filter(from_jid, dest=True),
574 ) 570 ),
575 ) 571 )
576 ) 572 )
577 else: 573 else:
578 # we want one communication in specific direction (from somebody or 574 # we want one communication in specific direction (from somebody or
579 # to somebody). 575 # to somebody).
581 stmt = stmt.where(self._jid_filter(from_jid)) 577 stmt = stmt.where(self._jid_filter(from_jid))
582 if to_jid is not None: 578 if to_jid is not None:
583 stmt = stmt.where(self._jid_filter(to_jid, dest=True)) 579 stmt = stmt.where(self._jid_filter(to_jid, dest=True))
584 580
585 if filters: 581 if filters:
586 if 'timestamp_start' in filters: 582 if "timestamp_start" in filters:
587 stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) 583 stmt = stmt.where(History.timestamp >= float(filters["timestamp_start"]))
588 if 'before_uid' in filters: 584 if "before_uid" in filters:
589 # orignially this query was using SQLITE's rowid. This has been changed 585 # orignially this query was using SQLITE's rowid. This has been changed
590 # to use coalesce(received_timestamp, timestamp) to be SQL engine independant 586 # to use coalesce(received_timestamp, timestamp) to be SQL engine independant
591 stmt = stmt.where( 587 stmt = stmt.where(
592 coalesce( 588 coalesce(History.received_timestamp, History.timestamp)
593 History.received_timestamp, 589 < (
594 History.timestamp 590 select(
595 ) < ( 591 coalesce(History.received_timestamp, History.timestamp)
596 select(coalesce(History.received_timestamp, History.timestamp)) 592 ).filter_by(uid=filters["before_uid"])
597 .filter_by(uid=filters["before_uid"])
598 ).scalar_subquery() 593 ).scalar_subquery()
599 ) 594 )
600 if 'body' in filters: 595 if "body" in filters:
601 # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html 596 # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html
602 stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) 597 stmt = stmt.where(Message.message.like(f"%{filters['body']}%"))
603 if 'search' in filters: 598 if "search" in filters:
604 search_term = f"%{filters['search']}%" 599 search_term = f"%{filters['search']}%"
605 stmt = stmt.where(or_( 600 stmt = stmt.where(
606 Message.message.like(search_term), 601 or_(
607 History.source_res.like(search_term) 602 Message.message.like(search_term),
608 )) 603 History.source_res.like(search_term),
609 if 'types' in filters: 604 )
610 types = filters['types'].split() 605 )
606 if "types" in filters:
607 types = filters["types"].split()
611 stmt = stmt.where(History.type.in_(types)) 608 stmt = stmt.where(History.type.in_(types))
612 if 'not_types' in filters: 609 if "not_types" in filters:
613 types = filters['not_types'].split() 610 types = filters["not_types"].split()
614 stmt = stmt.where(History.type.not_in(types)) 611 stmt = stmt.where(History.type.not_in(types))
615 if 'last_stanza_id' in filters: 612 if "last_stanza_id" in filters:
616 # this request get the last message with a "stanza_id" that we 613 # this request get the last message with a "stanza_id" that we
617 # have in history. This is mainly used to retrieve messages sent 614 # have in history. This is mainly used to retrieve messages sent
618 # while we were offline, using MAM (XEP-0313). 615 # while we were offline, using MAM (XEP-0313).
619 if (filters['last_stanza_id'] is not True 616 if filters["last_stanza_id"] is not True or limit != 1:
620 or limit != 1):
621 raise ValueError("Unexpected values for last_stanza_id filter") 617 raise ValueError("Unexpected values for last_stanza_id filter")
622 stmt = stmt.where(History.stanza_id.is_not(None)) 618 stmt = stmt.where(History.stanza_id.is_not(None))
623 if 'origin_id' in filters: 619 if "origin_id" in filters:
624 stmt = stmt.where(History.origin_id == filters["origin_id"]) 620 stmt = stmt.where(History.origin_id == filters["origin_id"])
625 621
626 if limit is not None: 622 if limit is not None:
627 stmt = stmt.limit(limit) 623 stmt = stmt.limit(limit)
628 624
638 """Store a new message in history 634 """Store a new message in history
639 635
640 @param data: message data as build by SatMessageProtocol.onMessage 636 @param data: message data as build by SatMessageProtocol.onMessage
641 """ 637 """
642 extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} 638 extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA}
643 messages = [Message(message=mess, language=lang) 639 messages = [
644 for lang, mess in data["message"].items()] 640 Message(message=mess, language=lang) for lang, mess in data["message"].items()
645 subjects = [Subject(subject=mess, language=lang) 641 ]
646 for lang, mess in data["subject"].items()] 642 subjects = [
643 Subject(subject=mess, language=lang) for lang, mess in data["subject"].items()
644 ]
647 if "thread" in data["extra"]: 645 if "thread" in data["extra"]:
648 thread = Thread(thread_id=data["extra"]["thread"], 646 thread = Thread(
649 parent_id=data["extra"].get["thread_parent"]) 647 thread_id=data["extra"]["thread"],
648 parent_id=data["extra"].get["thread_parent"],
649 )
650 else: 650 else:
651 thread = None 651 thread = None
652 try: 652 try:
653 async with self.session() as session: 653 async with self.session() as session:
654 async with session.begin(): 654 async with session.begin():
655 session.add(History( 655 session.add(
656 uid=data["uid"], 656 History(
657 origin_id=data["extra"].get("origin_id"), 657 uid=data["uid"],
658 stanza_id=data["extra"].get("stanza_id"), 658 origin_id=data["extra"].get("origin_id"),
659 update_uid=data["extra"].get("update_uid"), 659 stanza_id=data["extra"].get("stanza_id"),
660 profile_id=self.profiles[profile], 660 update_uid=data["extra"].get("update_uid"),
661 source_jid=data["from"], 661 profile_id=self.profiles[profile],
662 dest_jid=data["to"], 662 source_jid=data["from"],
663 timestamp=data["timestamp"], 663 dest_jid=data["to"],
664 received_timestamp=data.get("received_timestamp"), 664 timestamp=data["timestamp"],
665 type=data["type"], 665 received_timestamp=data.get("received_timestamp"),
666 extra=extra, 666 type=data["type"],
667 messages=messages, 667 extra=extra,
668 subjects=subjects, 668 messages=messages,
669 thread=thread, 669 subjects=subjects,
670 )) 670 thread=thread,
671 )
672 )
671 except IntegrityError as e: 673 except IntegrityError as e:
672 if "unique" in str(e.orig).lower(): 674 if "unique" in str(e.orig).lower():
673 log.debug( 675 log.debug(
674 f"message {data['uid']!r} is already in history, not storing it again" 676 f"message {data['uid']!r} is already in history, not storing it again"
675 ) 677 )
687 if profile is None: 689 if profile is None:
688 return PrivateGenBin if binary else PrivateGen 690 return PrivateGenBin if binary else PrivateGen
689 else: 691 else:
690 return PrivateIndBin if binary else PrivateInd 692 return PrivateIndBin if binary else PrivateInd
691 693
692
693 @aio 694 @aio
694 async def get_privates( 695 async def get_privates(
695 self, 696 self,
696 namespace:str, 697 namespace: str,
697 keys: Optional[Iterable[str]] = None, 698 keys: Optional[Iterable[str]] = None,
698 binary: bool = False, 699 binary: bool = False,
699 profile: Optional[str] = None 700 profile: Optional[str] = None,
700 ) -> Dict[str, Any]: 701 ) -> Dict[str, Any]:
701 """Get private value(s) from databases 702 """Get private value(s) from databases
702 703
703 @param namespace: namespace of the values 704 @param namespace: namespace of the values
704 @param keys: keys of the values to get None to get all keys/values 705 @param keys: keys of the values to get None to get all keys/values
726 727
727 @aio 728 @aio
728 async def set_private_value( 729 async def set_private_value(
729 self, 730 self,
730 namespace: str, 731 namespace: str,
731 key:str, 732 key: str,
732 value: Any, 733 value: Any,
733 binary: bool = False, 734 binary: bool = False,
734 profile: Optional[str] = None 735 profile: Optional[str] = None,
735 ) -> None: 736 ) -> None:
736 """Set a private value in database 737 """Set a private value in database
737 738
738 @param namespace: namespace of the values 739 @param namespace: namespace of the values
739 @param key: key of the value to set 740 @param key: key of the value to set
743 @param profile: profile to use for individual value 744 @param profile: profile to use for individual value
744 if None, it's a general value 745 if None, it's a general value
745 """ 746 """
746 cls = self._get_private_class(binary, profile) 747 cls = self._get_private_class(binary, profile)
747 748
748 values = { 749 values = {"namespace": namespace, "key": key, "value": value}
749 "namespace": namespace,
750 "key": key,
751 "value": value
752 }
753 index_elements = [cls.namespace, cls.key] 750 index_elements = [cls.namespace, cls.key]
754 751
755 if profile is not None: 752 if profile is not None:
756 values["profile_id"] = self.profiles[profile] 753 values["profile_id"] = self.profiles[profile]
757 index_elements.append(cls.profile_id) 754 index_elements.append(cls.profile_id)
758 755
759 async with self.session() as session: 756 async with self.session() as session:
760 await session.execute( 757 await session.execute(
761 insert(cls).values(**values).on_conflict_do_update( 758 insert(cls)
762 index_elements=index_elements, 759 .values(**values)
763 set_={ 760 .on_conflict_do_update(
764 cls.value: value 761 index_elements=index_elements, set_={cls.value: value}
765 }
766 ) 762 )
767 ) 763 )
768 await session.commit() 764 await session.commit()
769 765
770 @aio 766 @aio
771 async def del_private_value( 767 async def del_private_value(
772 self, 768 self,
773 namespace: str, 769 namespace: str,
774 key: str, 770 key: str,
775 binary: bool = False, 771 binary: bool = False,
776 profile: Optional[str] = None 772 profile: Optional[str] = None,
777 ) -> None: 773 ) -> None:
778 """Delete private value from database 774 """Delete private value from database
779 775
780 @param category: category of the privateeter 776 @param category: category of the privateeter
781 @param key: key of the private value 777 @param key: key of the private value
794 await session.execute(stmt) 790 await session.execute(stmt)
795 await session.commit() 791 await session.commit()
796 792
797 @aio 793 @aio
798 async def del_private_namespace( 794 async def del_private_namespace(
799 self, 795 self, namespace: str, binary: bool = False, profile: Optional[str] = None
800 namespace: str,
801 binary: bool = False,
802 profile: Optional[str] = None
803 ) -> None: 796 ) -> None:
804 """Delete all data from a private namespace 797 """Delete all data from a private namespace
805 798
806 Be really cautious when you use this method, as all data with given namespace are 799 Be really cautious when you use this method, as all data with given namespace are
807 removed. 800 removed.
823 @aio 816 @aio
824 async def get_files( 817 async def get_files(
825 self, 818 self,
826 client: Optional[SatXMPPEntity], 819 client: Optional[SatXMPPEntity],
827 file_id: Optional[str] = None, 820 file_id: Optional[str] = None,
828 version: Optional[str] = '', 821 version: Optional[str] = "",
829 parent: Optional[str] = None, 822 parent: Optional[str] = None,
830 type_: Optional[str] = None, 823 type_: Optional[str] = None,
831 file_hash: Optional[str] = None, 824 file_hash: Optional[str] = None,
832 hash_algo: Optional[str] = None, 825 hash_algo: Optional[str] = None,
833 name: Optional[str] = None, 826 name: Optional[str] = None,
835 mime_type: Optional[str] = None, 828 mime_type: Optional[str] = None,
836 public_id: Optional[str] = None, 829 public_id: Optional[str] = None,
837 owner: Optional[jid.JID] = None, 830 owner: Optional[jid.JID] = None,
838 access: Optional[dict] = None, 831 access: Optional[dict] = None,
839 projection: Optional[List[str]] = None, 832 projection: Optional[List[str]] = None,
840 unique: bool = False 833 unique: bool = False,
841 ) -> List[dict]: 834 ) -> List[dict]:
842 """Retrieve files with with given filters 835 """Retrieve files with with given filters
843 836
844 @param file_id: id of the file 837 @param file_id: id of the file
845 None to ignore 838 None to ignore
855 other params are the same as for [set_file] 848 other params are the same as for [set_file]
856 @return: files corresponding to filters 849 @return: files corresponding to filters
857 """ 850 """
858 if projection is None: 851 if projection is None:
859 projection = [ 852 projection = [
860 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', 853 "id",
861 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', 854 "version",
862 'created', 'modified', 'owner', 'access', 'extra' 855 "parent",
856 "type",
857 "file_hash",
858 "hash_algo",
859 "name",
860 "size",
861 "namespace",
862 "media_type",
863 "media_subtype",
864 "public_id",
865 "created",
866 "modified",
867 "owner",
868 "access",
869 "extra",
863 ] 870 ]
864 871
865 stmt = select(*[getattr(File, f) for f in projection]) 872 stmt = select(*[getattr(File, f) for f in projection])
866 873
867 if unique: 874 if unique:
889 if name is not None: 896 if name is not None:
890 stmt = stmt.filter_by(name=name) 897 stmt = stmt.filter_by(name=name)
891 if namespace is not None: 898 if namespace is not None:
892 stmt = stmt.filter_by(namespace=namespace) 899 stmt = stmt.filter_by(namespace=namespace)
893 if mime_type is not None: 900 if mime_type is not None:
894 if '/' in mime_type: 901 if "/" in mime_type:
895 media_type, media_subtype = mime_type.split("/", 1) 902 media_type, media_subtype = mime_type.split("/", 1)
896 stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) 903 stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype)
897 else: 904 else:
898 stmt = stmt.filter_by(media_type=mime_type) 905 stmt = stmt.filter_by(media_type=mime_type)
899 if public_id is not None: 906 if public_id is not None:
900 stmt = stmt.filter_by(public_id=public_id) 907 stmt = stmt.filter_by(public_id=public_id)
901 if owner is not None: 908 if owner is not None:
902 stmt = stmt.filter_by(owner=owner) 909 stmt = stmt.filter_by(owner=owner)
903 if access is not None: 910 if access is not None:
904 raise NotImplementedError('Access check is not implemented yet') 911 raise NotImplementedError("Access check is not implemented yet")
905 # a JSON comparison is needed here 912 # a JSON comparison is needed here
906 913
907 async with self.session() as session: 914 async with self.session() as session:
908 result = await session.execute(stmt) 915 result = await session.execute(stmt)
909 916
910 return [r._asdict() for r in result] 917 return [r._asdict() for r in result]
926 public_id: Optional[str] = None, 933 public_id: Optional[str] = None,
927 created: Optional[float] = None, 934 created: Optional[float] = None,
928 modified: Optional[float] = None, 935 modified: Optional[float] = None,
929 owner: Optional[jid.JID] = None, 936 owner: Optional[jid.JID] = None,
930 access: Optional[dict] = None, 937 access: Optional[dict] = None,
931 extra: Optional[dict] = None 938 extra: Optional[dict] = None,
932 ) -> None: 939 ) -> None:
933 """Set a file metadata 940 """Set a file metadata
934 941
935 @param client: client owning the file 942 @param client: client owning the file
936 @param name: name of the file (must not contain "/") 943 @param name: name of the file (must not contain "/")
956 @param extra: serialisable dictionary of any extra data 963 @param extra: serialisable dictionary of any extra data
957 will be encoded to json in database 964 will be encoded to json in database
958 """ 965 """
959 if mime_type is None: 966 if mime_type is None:
960 media_type = media_subtype = None 967 media_type = media_subtype = None
961 elif '/' in mime_type: 968 elif "/" in mime_type:
962 media_type, media_subtype = mime_type.split('/', 1) 969 media_type, media_subtype = mime_type.split("/", 1)
963 else: 970 else:
964 media_type, media_subtype = mime_type, None 971 media_type, media_subtype = mime_type, None
965 972
966 async with self.session() as session: 973 async with self.session() as session:
967 async with session.begin(): 974 async with session.begin():
968 session.add(File( 975 session.add(
969 id=file_id, 976 File(
970 version=version.strip(), 977 id=file_id,
971 parent=parent, 978 version=version.strip(),
972 type=type_, 979 parent=parent,
973 file_hash=file_hash, 980 type=type_,
974 hash_algo=hash_algo, 981 file_hash=file_hash,
975 name=name, 982 hash_algo=hash_algo,
976 size=size, 983 name=name,
977 namespace=namespace, 984 size=size,
978 media_type=media_type, 985 namespace=namespace,
979 media_subtype=media_subtype, 986 media_type=media_type,
980 public_id=public_id, 987 media_subtype=media_subtype,
981 created=time.time() if created is None else created, 988 public_id=public_id,
982 modified=modified, 989 created=time.time() if created is None else created,
983 owner=owner, 990 modified=modified,
984 access=access, 991 owner=owner,
985 extra=extra, 992 access=access,
986 profile_id=self.profiles[client.profile] 993 extra=extra,
987 )) 994 profile_id=self.profiles[client.profile],
995 )
996 )
988 997
989 @aio 998 @aio
990 async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int: 999 async def file_get_used_space(self, client: SatXMPPEntity, owner: jid.JID) -> int:
991 async with self.session() as session: 1000 async with self.session() as session:
992 result = await session.execute( 1001 result = await session.execute(
993 select(sum_(File.size)).filter_by( 1002 select(sum_(File.size)).filter_by(
994 owner=owner, 1003 owner=owner,
995 type=C.FILE_TYPE_FILE, 1004 type=C.FILE_TYPE_FILE,
996 profile_id=self.profiles[client.profile] 1005 profile_id=self.profiles[client.profile],
997 )) 1006 )
1007 )
998 return result.scalar_one_or_none() or 0 1008 return result.scalar_one_or_none() or 0
999 1009
1000 @aio 1010 @aio
1001 async def file_delete(self, file_id: str) -> None: 1011 async def file_delete(self, file_id: str) -> None:
1002 """Delete file metadata from the database 1012 """Delete file metadata from the database
1009 await session.execute(delete(File).filter_by(id=file_id)) 1019 await session.execute(delete(File).filter_by(id=file_id))
1010 await session.commit() 1020 await session.commit()
1011 1021
1012 @aio 1022 @aio
1013 async def file_update( 1023 async def file_update(
1014 self, 1024 self, file_id: str, column: str, update_cb: Callable[[dict], None]
1015 file_id: str,
1016 column: str,
1017 update_cb: Callable[[dict], None]
1018 ) -> None: 1025 ) -> None:
1019 """Update a column value using a method to avoid race conditions 1026 """Update a column value using a method to avoid race conditions
1020 1027
1021 the older value will be retrieved from database, then update_cb will be applied to 1028 the older value will be retrieved from database, then update_cb will be applied to
1022 update it, and file will be updated checking that older value has not been changed 1029 update it, and file will be updated checking that older value has not been changed
1027 the method will take older value as argument, and must update it in place 1034 the method will take older value as argument, and must update it in place
1028 update_cb must not care about serialization, 1035 update_cb must not care about serialization,
1029 it get the deserialized data (i.e. a Python object) directly 1036 it get the deserialized data (i.e. a Python object) directly
1030 @raise exceptions.NotFound: there is not file with this id 1037 @raise exceptions.NotFound: there is not file with this id
1031 """ 1038 """
1032 if column not in ('access', 'extra'): 1039 if column not in ("access", "extra"):
1033 raise exceptions.InternalError('bad column name') 1040 raise exceptions.InternalError("bad column name")
1034 orm_col = getattr(File, column) 1041 orm_col = getattr(File, column)
1035 1042
1036 for i in range(5): 1043 for i in range(5):
1037 async with self.session() as session: 1044 async with self.session() as session:
1038 try: 1045 try:
1039 value = (await session.execute( 1046 value = (
1040 select(orm_col).filter_by(id=file_id) 1047 await session.execute(select(orm_col).filter_by(id=file_id))
1041 )).scalar_one() 1048 ).scalar_one()
1042 except NoResultFound: 1049 except NoResultFound:
1043 raise exceptions.NotFound 1050 raise exceptions.NotFound
1044 old_value = copy.deepcopy(value) 1051 old_value = copy.deepcopy(value)
1045 update_cb(value) 1052 update_cb(value)
1046 stmt = update(File).filter_by(id=file_id).values({column: value}) 1053 stmt = update(File).filter_by(id=file_id).values({column: value})
1055 1062
1056 if result.rowcount == 1: 1063 if result.rowcount == 1:
1057 break 1064 break
1058 1065
1059 log.warning( 1066 log.warning(
1060 _("table not updated, probably due to race condition, trying again " 1067 _(
1061 "({tries})").format(tries=i+1) 1068 "table not updated, probably due to race condition, trying again "
1069 "({tries})"
1070 ).format(tries=i + 1)
1062 ) 1071 )
1063 1072
1064 else: 1073 else:
1065 raise exceptions.DatabaseError( 1074 raise exceptions.DatabaseError(
1066 _("Can't update file {file_id} due to race condition") 1075 _("Can't update file {file_id} due to race condition").format(
1067 .format(file_id=file_id) 1076 file_id=file_id
1077 )
1068 ) 1078 )
1069 1079
1070 @aio 1080 @aio
1071 async def get_pubsub_node( 1081 async def get_pubsub_node(
1072 self, 1082 self,
1074 service: jid.JID, 1084 service: jid.JID,
1075 name: str, 1085 name: str,
1076 with_items: bool = False, 1086 with_items: bool = False,
1077 with_subscriptions: bool = False, 1087 with_subscriptions: bool = False,
1078 create: bool = False, 1088 create: bool = False,
1079 create_kwargs: Optional[dict] = None 1089 create_kwargs: Optional[dict] = None,
1080 ) -> Optional[PubsubNode]: 1090 ) -> Optional[PubsubNode]:
1081 """Retrieve a PubsubNode from DB 1091 """Retrieve a PubsubNode from DB
1082 1092
1083 @param service: service hosting the node 1093 @param service: service hosting the node
1084 @param name: node's name 1094 @param name: node's name
1087 @param create: if the node doesn't exist in DB, create it 1097 @param create: if the node doesn't exist in DB, create it
1088 @param create_kwargs: keyword arguments to use with ``set_pubsub_node`` if the node 1098 @param create_kwargs: keyword arguments to use with ``set_pubsub_node`` if the node
1089 needs to be created. 1099 needs to be created.
1090 """ 1100 """
1091 async with self.session() as session: 1101 async with self.session() as session:
1092 stmt = ( 1102 stmt = select(PubsubNode).filter_by(
1093 select(PubsubNode) 1103 service=service,
1094 .filter_by( 1104 name=name,
1095 service=service, 1105 profile_id=self.profiles[client.profile],
1096 name=name,
1097 profile_id=self.profiles[client.profile],
1098 )
1099 ) 1106 )
1100 if with_items: 1107 if with_items:
1101 stmt = stmt.options( 1108 stmt = stmt.options(joinedload(PubsubNode.items))
1102 joinedload(PubsubNode.items)
1103 )
1104 if with_subscriptions: 1109 if with_subscriptions:
1105 stmt = stmt.options( 1110 stmt = stmt.options(joinedload(PubsubNode.subscriptions))
1106 joinedload(PubsubNode.subscriptions)
1107 )
1108 result = await session.execute(stmt) 1111 result = await session.execute(stmt)
1109 ret = result.unique().scalar_one_or_none() 1112 ret = result.unique().scalar_one_or_none()
1110 if ret is None and create: 1113 if ret is None and create:
1111 # we auto-create the node 1114 # we auto-create the node
1112 if create_kwargs is None: 1115 if create_kwargs is None:
1113 create_kwargs = {} 1116 create_kwargs = {}
1114 try: 1117 try:
1115 return await as_future(self.set_pubsub_node( 1118 return await as_future(
1116 client, service, name, **create_kwargs 1119 self.set_pubsub_node(client, service, name, **create_kwargs)
1117 )) 1120 )
1118 except IntegrityError as e: 1121 except IntegrityError as e:
1119 if "unique" in str(e.orig).lower(): 1122 if "unique" in str(e.orig).lower():
1120 # the node may already exist, if it has been created just after 1123 # the node may already exist, if it has been created just after
1121 # get_pubsub_node above 1124 # get_pubsub_node above
1122 log.debug("ignoring UNIQUE constraint error") 1125 log.debug("ignoring UNIQUE constraint error")
1123 cached_node = await as_future(self.get_pubsub_node( 1126 cached_node = await as_future(
1124 client, 1127 self.get_pubsub_node(
1125 service, 1128 client,
1126 name, 1129 service,
1127 with_items=with_items, 1130 name,
1128 with_subscriptions=with_subscriptions 1131 with_items=with_items,
1129 )) 1132 with_subscriptions=with_subscriptions,
1133 )
1134 )
1130 else: 1135 else:
1131 raise e 1136 raise e
1132 else: 1137 else:
1133 return ret 1138 return ret
1134 1139
1158 session.add(node) 1163 session.add(node)
1159 return node 1164 return node
1160 1165
1161 @aio 1166 @aio
1162 async def update_pubsub_node_sync_state( 1167 async def update_pubsub_node_sync_state(
1163 self, 1168 self, node: PubsubNode, state: SyncState
1164 node: PubsubNode,
1165 state: SyncState
1166 ) -> None: 1169 ) -> None:
1167 async with self.session() as session: 1170 async with self.session() as session:
1168 async with session.begin(): 1171 async with session.begin():
1169 await session.execute( 1172 await session.execute(
1170 update(PubsubNode) 1173 update(PubsubNode)
1178 @aio 1181 @aio
1179 async def delete_pubsub_node( 1182 async def delete_pubsub_node(
1180 self, 1183 self,
1181 profiles: Optional[List[str]], 1184 profiles: Optional[List[str]],
1182 services: Optional[List[jid.JID]], 1185 services: Optional[List[jid.JID]],
1183 names: Optional[List[str]] 1186 names: Optional[List[str]],
1184 ) -> None: 1187 ) -> None:
1185 """Delete items cached for a node 1188 """Delete items cached for a node
1186 1189
1187 @param profiles: profile names from which nodes must be deleted. 1190 @param profiles: profile names from which nodes must be deleted.
1188 None to remove nodes from ALL profiles 1191 None to remove nodes from ALL profiles
1192 None to remove ALL nodes whatever is their names 1195 None to remove ALL nodes whatever is their names
1193 """ 1196 """
1194 stmt = delete(PubsubNode) 1197 stmt = delete(PubsubNode)
1195 if profiles is not None: 1198 if profiles is not None:
1196 stmt = stmt.where( 1199 stmt = stmt.where(
1197 PubsubNode.profile.in_( 1200 PubsubNode.profile.in_([self.profiles[p] for p in profiles])
1198 [self.profiles[p] for p in profiles]
1199 )
1200 ) 1201 )
1201 if services is not None: 1202 if services is not None:
1202 stmt = stmt.where(PubsubNode.service.in_(services)) 1203 stmt = stmt.where(PubsubNode.service.in_(services))
1203 if names is not None: 1204 if names is not None:
1204 stmt = stmt.where(PubsubNode.name.in_(names)) 1205 stmt = stmt.where(PubsubNode.name.in_(names))
1221 ) 1222 )
1222 async with self.session() as session: 1223 async with self.session() as session:
1223 async with session.begin(): 1224 async with session.begin():
1224 for idx, item in enumerate(items): 1225 for idx, item in enumerate(items):
1225 parsed = parsed_items[idx] if parsed_items else None 1226 parsed = parsed_items[idx] if parsed_items else None
1226 stmt = insert(PubsubItem).values( 1227 stmt = (
1227 node_id = node.id, 1228 insert(PubsubItem)
1228 name = item["id"], 1229 .values(
1229 data = item, 1230 node_id=node.id,
1230 parsed = parsed, 1231 name=item["id"],
1231 ).on_conflict_do_update( 1232 data=item,
1232 index_elements=(PubsubItem.node_id, PubsubItem.name), 1233 parsed=parsed,
1233 set_={ 1234 )
1234 PubsubItem.data: item, 1235 .on_conflict_do_update(
1235 PubsubItem.parsed: parsed, 1236 index_elements=(PubsubItem.node_id, PubsubItem.name),
1236 PubsubItem.updated: now() 1237 set_={
1237 } 1238 PubsubItem.data: item,
1239 PubsubItem.parsed: parsed,
1240 PubsubItem.updated: now(),
1241 },
1242 )
1238 ) 1243 )
1239 await session.execute(stmt) 1244 await session.execute(stmt)
1240 await session.commit() 1245 await session.commit()
1241 1246
1242 @aio 1247 @aio
1243 async def delete_pubsub_items( 1248 async def delete_pubsub_items(
1244 self, 1249 self, node: PubsubNode, items_names: Optional[List[str]] = None
1245 node: PubsubNode,
1246 items_names: Optional[List[str]] = None
1247 ) -> None: 1250 ) -> None:
1248 """Delete items cached for a node 1251 """Delete items cached for a node
1249 1252
1250 @param node: node from which items must be deleted 1253 @param node: node from which items must be deleted
1251 @param items_names: names of items to delete 1254 @param items_names: names of items to delete
1294 sub_q = select(PubsubNode.id) 1297 sub_q = select(PubsubNode.id)
1295 for col, values in node_fields.items(): 1298 for col, values in node_fields.items():
1296 if values is None: 1299 if values is None:
1297 continue 1300 continue
1298 sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) 1301 sub_q = sub_q.where(getattr(PubsubNode, col).in_(values))
1299 stmt = ( 1302 stmt = stmt.where(PubsubItem.node_id.in_(sub_q)).execution_options(
1300 stmt 1303 synchronize_session=False
1301 .where(PubsubItem.node_id.in_(sub_q))
1302 .execution_options(synchronize_session=False)
1303 ) 1304 )
1304 1305
1305 if created_before is not None: 1306 if created_before is not None:
1306 stmt = stmt.where(PubsubItem.created < created_before) 1307 stmt = stmt.where(PubsubItem.created < created_before)
1307 1308
1365 if force_rsm and not use_rsm: 1366 if force_rsm and not use_rsm:
1366 # 1367 #
1367 use_rsm = True 1368 use_rsm = True
1368 from_index = 0 1369 from_index = 0
1369 1370
1370 stmt = ( 1371 stmt = select(PubsubItem).filter_by(node_id=node.id).limit(max_items)
1371 select(PubsubItem)
1372 .filter_by(node_id=node.id)
1373 .limit(max_items)
1374 )
1375 1372
1376 if item_ids is not None: 1373 if item_ids is not None:
1377 stmt = stmt.where(PubsubItem.name.in_(item_ids)) 1374 stmt = stmt.where(PubsubItem.name.in_(item_ids))
1378 1375
1379 if not order_by: 1376 if not order_by:
1400 # CTE to have result row numbers 1397 # CTE to have result row numbers
1401 row_num_q = select( 1398 row_num_q = select(
1402 PubsubItem.id, 1399 PubsubItem.id,
1403 PubsubItem.name, 1400 PubsubItem.name,
1404 # row_number starts from 1, but RSM index must start from 0 1401 # row_number starts from 1, but RSM index must start from 0
1405 (func.row_number().over(order_by=order)-1).label("item_index") 1402 (func.row_number().over(order_by=order) - 1).label("item_index"),
1406 ).filter_by(node_id=node.id) 1403 ).filter_by(node_id=node.id)
1407 1404
1408 row_num_cte = row_num_q.cte() 1405 row_num_cte = row_num_q.cte()
1409 1406
1410 if max_items > 0: 1407 if max_items > 0:
1411 # as we can't simply use PubsubItem.id when we order by modification, 1408 # as we can't simply use PubsubItem.id when we order by modification,
1412 # we need to use row number 1409 # we need to use row number
1413 item_name = before or after 1410 item_name = before or after
1414 row_num_limit_q = ( 1411 row_num_limit_q = (
1415 select(row_num_cte.c.item_index) 1412 select(row_num_cte.c.item_index).where(
1416 .where(row_num_cte.c.name==item_name) 1413 row_num_cte.c.name == item_name
1414 )
1417 ).scalar_subquery() 1415 ).scalar_subquery()
1418 1416
1419 stmt = ( 1417 stmt = (
1420 select(row_num_cte.c.item_index, PubsubItem) 1418 select(row_num_cte.c.item_index, PubsubItem)
1421 .join(row_num_cte, PubsubItem.id == row_num_cte.c.id) 1419 .join(row_num_cte, PubsubItem.id == row_num_cte.c.id)
1422 .limit(max_items) 1420 .limit(max_items)
1423 ) 1421 )
1424 if before: 1422 if before:
1425 stmt = ( 1423 stmt = stmt.where(
1426 stmt 1424 row_num_cte.c.item_index < row_num_limit_q
1427 .where(row_num_cte.c.item_index<row_num_limit_q) 1425 ).order_by(row_num_cte.c.item_index.desc())
1428 .order_by(row_num_cte.c.item_index.desc())
1429 )
1430 elif after: 1426 elif after:
1431 stmt = ( 1427 stmt = stmt.where(
1432 stmt 1428 row_num_cte.c.item_index > row_num_limit_q
1433 .where(row_num_cte.c.item_index>row_num_limit_q) 1429 ).order_by(row_num_cte.c.item_index.asc())
1434 .order_by(row_num_cte.c.item_index.asc())
1435 )
1436 else: 1430 else:
1437 stmt = ( 1431 stmt = stmt.where(row_num_cte.c.item_index >= from_index).order_by(
1438 stmt 1432 row_num_cte.c.item_index.asc()
1439 .where(row_num_cte.c.item_index>=from_index)
1440 .order_by(row_num_cte.c.item_index.asc())
1441 ) 1433 )
1442 # from_index is used 1434 # from_index is used
1443 1435
1444 async with self.session() as session: 1436 async with self.session() as session:
1445 if max_items == 0: 1437 if max_items == 0:
1466 last = None 1458 last = None
1467 else: 1459 else:
1468 last = result[-1][1].name 1460 last = result[-1][1].name
1469 1461
1470 metadata["rsm"] = { 1462 metadata["rsm"] = {
1471 k: v for k, v in { 1463 k: v
1464 for k, v in {
1472 "index": index, 1465 "index": index,
1473 "count": rows_count, 1466 "count": rows_count,
1474 "first": first, 1467 "first": first,
1475 "last": last, 1468 "last": last,
1476 }.items() if v is not None 1469 }.items()
1470 if v is not None
1477 } 1471 }
1478 metadata["complete"] = (index or 0) + len(result) == rows_count 1472 metadata["complete"] = (index or 0) + len(result) == rows_count
1479 1473
1480 return items, metadata 1474 return items, metadata
1481 1475
1485 result = result.scalars().all() 1479 result = result.scalars().all()
1486 if desc: 1480 if desc:
1487 result.reverse() 1481 result.reverse()
1488 return result, metadata 1482 return result, metadata
1489 1483
1490 def _get_sqlite_path( 1484 def _get_sqlite_path(self, path: List[Union[str, int]]) -> str:
1491 self,
1492 path: List[Union[str, int]]
1493 ) -> str:
1494 """generate path suitable to query JSON element with SQLite""" 1485 """generate path suitable to query JSON element with SQLite"""
1495 return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}" 1486 return f"${''.join(f'[{p}]' if isinstance(p, int) else f'.{p}' for p in path)}"
1496 1487
1497 @aio 1488 @aio
1498 async def search_pubsub_items( 1489 async def search_pubsub_items(
1555 stmt = select(PubsubItem) 1546 stmt = select(PubsubItem)
1556 1547
1557 # Full-Text Search 1548 # Full-Text Search
1558 fts = query.get("fts") 1549 fts = query.get("fts")
1559 if fts: 1550 if fts:
1560 fts_select = text( 1551 fts_select = (
1561 "SELECT rowid, rank FROM pubsub_items_fts(:fts_query)" 1552 text("SELECT rowid, rank FROM pubsub_items_fts(:fts_query)")
1562 ).bindparams(fts_query=fts).columns(rowid=Integer).subquery() 1553 .bindparams(fts_query=fts)
1563 stmt = ( 1554 .columns(rowid=Integer)
1564 stmt 1555 .subquery()
1565 .select_from(fts_select) 1556 )
1566 .outerjoin(PubsubItem, fts_select.c.rowid == PubsubItem.id) 1557 stmt = stmt.select_from(fts_select).outerjoin(
1558 PubsubItem, fts_select.c.rowid == PubsubItem.id
1567 ) 1559 )
1568 1560
1569 # node related filters 1561 # node related filters
1570 profiles = query.get("profiles") 1562 profiles = query.get("profiles")
1571 if (profiles 1563 if profiles or any(
1572 or any(query.get(k) for k in ("nodes", "services", "types", "subtypes")) 1564 query.get(k) for k in ("nodes", "services", "types", "subtypes")
1573 ): 1565 ):
1574 stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node)) 1566 stmt = stmt.join(PubsubNode).options(contains_eager(PubsubItem.node))
1575 if profiles: 1567 if profiles:
1576 try: 1568 try:
1577 stmt = stmt.where( 1569 stmt = stmt.where(
1583 ) 1575 )
1584 for key, attr in ( 1576 for key, attr in (
1585 ("nodes", "name"), 1577 ("nodes", "name"),
1586 ("services", "service"), 1578 ("services", "service"),
1587 ("types", "type_"), 1579 ("types", "type_"),
1588 ("subtypes", "subtype") 1580 ("subtypes", "subtype"),
1589 ): 1581 ):
1590 value = query.get(key) 1582 value = query.get(key)
1591 if not value: 1583 if not value:
1592 continue 1584 continue
1593 if key in ("types", "subtypes") and None in value: 1585 if key in ("types", "subtypes") and None in value:
1594 # NULL can't be used with SQL's IN, so we have to add a condition with 1586 # NULL can't be used with SQL's IN, so we have to add a condition with
1595 # IS NULL, and use a OR if there are other values to check 1587 # IS NULL, and use a OR if there are other values to check
1596 value.remove(None) 1588 value.remove(None)
1597 condition = getattr(PubsubNode, attr).is_(None) 1589 condition = getattr(PubsubNode, attr).is_(None)
1598 if value: 1590 if value:
1599 condition = or_( 1591 condition = or_(getattr(PubsubNode, attr).in_(value), condition)
1600 getattr(PubsubNode, attr).in_(value),
1601 condition
1602 )
1603 else: 1592 else:
1604 condition = getattr(PubsubNode, attr).in_(value) 1593 condition = getattr(PubsubNode, attr).in_(value)
1605 stmt = stmt.where(condition) 1594 stmt = stmt.where(condition)
1606 else: 1595 else:
1607 stmt = stmt.options(selectinload(PubsubItem.node)) 1596 stmt = stmt.options(selectinload(PubsubItem.node))
1642 stmt = stmt.where(condition) 1631 stmt = stmt.where(condition)
1643 elif operator == "between": 1632 elif operator == "between":
1644 try: 1633 try:
1645 left, right = value 1634 left, right = value
1646 except (ValueError, TypeError): 1635 except (ValueError, TypeError):
1647 raise ValueError(_( 1636 raise ValueError(
1648 "invalid value for \"between\" filter, you must use a 2 items " 1637 _(
1649 "array: {value!r}" 1638 'invalid value for "between" filter, you must use a 2 items '
1650 ).format(value=value)) 1639 "array: {value!r}"
1640 ).format(value=value)
1641 )
1651 col = func.json_extract(PubsubItem.parsed, sqlite_path) 1642 col = func.json_extract(PubsubItem.parsed, sqlite_path)
1652 stmt = stmt.where(col.between(left, right)) 1643 stmt = stmt.where(col.between(left, right))
1653 else: 1644 else:
1654 # we use func.json_extract instead of generic JSON way because SQLAlchemy 1645 # we use func.json_extract instead of generic JSON way because SQLAlchemy
1655 # add a JSON_QUOTE to the value, and we want SQL value 1646 # add a JSON_QUOTE to the value, and we want SQL value
1660 order_by = query.get("order-by") or [{"order": "creation"}] 1651 order_by = query.get("order-by") or [{"order": "creation"}]
1661 1652
1662 for order_data in order_by: 1653 for order_data in order_by:
1663 order, path = order_data.get("order"), order_data.get("path") 1654 order, path = order_data.get("order"), order_data.get("path")
1664 if order and path: 1655 if order and path:
1665 raise ValueError(_( 1656 raise ValueError(
1666 '"order" and "path" can\'t be used at the same time in ' 1657 _(
1667 '"order-by" data' 1658 '"order" and "path" can\'t be used at the same time in '
1668 )) 1659 '"order-by" data'
1660 )
1661 )
1669 if order: 1662 if order:
1670 if order == "creation": 1663 if order == "creation":
1671 col = PubsubItem.id 1664 col = PubsubItem.id
1672 elif order == "modification": 1665 elif order == "modification":
1673 col = PubsubItem.updated 1666 col = PubsubItem.updated
1700 1693
1701 async with self.session() as session: 1694 async with self.session() as session:
1702 result = await session.execute(stmt) 1695 result = await session.execute(stmt)
1703 1696
1704 return result.scalars().all() 1697 return result.scalars().all()
1698
1699 # Notifications
1700
1701 @aio
1702 async def add_notification(
1703 self,
1704 client: Optional[SatXMPPEntity],
1705 type_: NotificationType,
1706 body_plain: str,
1707 body_rich: Optional[str] = None,
1708 title: Optional[str] = None,
1709 requires_action: bool = False,
1710 priority: NotificationPriority = NotificationPriority.MEDIUM,
1711 expire_at: Optional[float] = None,
1712 extra: Optional[dict] = None,
1713 ) -> Notification:
1714 """Add a new notification to the DB.
1715
1716 @param client: client associated with the notification. If None, the notification
1717 will be global.
1718 @param type_: type of the notification.
1719 @param body_plain: plain text body.
1720 @param body_rich: rich text (XHTML) body.
1721 @param title: optional title.
1722 @param requires_action: True if the notification requires user action (e.g. a
1723 dialog need to be answered).
1724 @priority: how urgent the notification is
1725 @param expire_at: expiration timestamp for the notification.
1726 @param extra: additional data.
1727 @return: created Notification
1728 """
1729 profile_id = self.profiles[client.profile] if client else None
1730 notification = Notification(
1731 profile_id=profile_id,
1732 type=type_,
1733 body_plain=body_plain,
1734 body_rich=body_rich,
1735 requires_action=requires_action,
1736 priority=priority,
1737 expire_at=expire_at,
1738 title=title,
1739 extra_data=extra,
1740 status=NotificationStatus.new,
1741 )
1742 async with self.session() as session:
1743 async with session.begin():
1744 session.add(notification)
1745 return notification
1746
1747 @aio
1748 async def update_notification(
1749 self, client: SatXMPPEntity, notification_id: int, **kwargs
1750 ) -> None:
1751 """Update an existing notification.
1752
1753 @param client: client associated with the notification.
1754 @param notification_id: ID of the notification to update.
1755 """
1756 profile_id = self.profiles[client.profile]
1757 async with self.session() as session:
1758 await session.execute(
1759 update(Notification)
1760 .where(
1761 and_(
1762 Notification.profile_id == profile_id,
1763 Notification.id == notification_id,
1764 )
1765 )
1766 .values(**kwargs)
1767 )
1768 await session.commit()
1769
1770 @aio
1771 async def get_notifications(
1772 self,
1773 client: SatXMPPEntity,
1774 type_: Optional[NotificationType] = None,
1775 status: Optional[NotificationStatus] = None,
1776 requires_action: Optional[bool] = None,
1777 min_priority: Optional[int] = None
1778 ) -> List[Notification]:
1779 """Retrieve all notifications for a given profile with optional filters.
1780
1781 @param client: client associated with the notifications.
1782 @param type_: filter by type of the notification.
1783 @param status: filter by status of the notification.
1784 @param requires_action: filter by notifications that require user action.
1785 @param min_priority: filter by minimum priority value.
1786 @return: list of matching Notification instances.
1787 """
1788 profile_id = self.profiles[client.profile]
1789 filters = [or_(Notification.profile_id == profile_id, Notification.profile_id.is_(None))]
1790
1791 if type_:
1792 filters.append(Notification.type == type_)
1793 if status:
1794 filters.append(Notification.status == status)
1795 if requires_action is not None:
1796 filters.append(Notification.requires_action == requires_action)
1797 if min_priority:
1798 filters.append(Notification.priority >= min_priority)
1799
1800 async with self.session() as session:
1801 result = await session.execute(
1802 select(Notification)
1803 .where(and_(*filters))
1804 .order_by(Notification.id)
1805 )
1806 return result.scalars().all()
1807
1808 @aio
1809 async def delete_notification(
1810 self, client: Optional[SatXMPPEntity], notification_id: str
1811 ) -> None:
1812 """Delete a notification by its profile and id.
1813
1814 @param client: client associated with the notification. If None, profile_id will be NULL.
1815 @param notification_id: ID of the notification to delete.
1816 """
1817 profile_id = self.profiles[client.profile] if client else None
1818 async with self.session() as session:
1819 await session.execute(
1820 delete(Notification).where(
1821 and_(
1822 Notification.profile_id == profile_id,
1823 Notification.id == int(notification_id),
1824 )
1825 )
1826 )
1827 await session.commit()
1828
1829 @aio
1830 async def clean_expired_notifications(
1831 self, client: Optional[SatXMPPEntity], limit_timestamp: Optional[float] = None
1832 ) -> None:
1833 """Cleans expired notifications and older profile-specific notifications.
1834
1835 - Removes all notifications where the expiration timestamp has passed,
1836 irrespective of their profile.
1837 - If a limit_timestamp is provided, removes older notifications with a profile set
1838 (i.e., not global notifications) that do not require user action. If client is
1839 provided, only remove notification for this profile.
1840
1841 @param client: if provided, only expire notification for this client (in addition
1842 to truly expired notifications for everybody).
1843 @param limit_timestamp: Timestamp limit for older notifications. If None, only
1844 truly expired notifications are removed.
1845 """
1846
1847 # Delete truly expired notifications
1848 expired_condition = Notification.expire_at < time.time()
1849
1850 # Delete older profile-specific notifications (created before the limit_timestamp)
1851 if client is None:
1852 profile_condition = Notification.profile_id.isnot(None)
1853 else:
1854 profile_condition = Notification.profile_id == self.profiles[client.profile]
1855 older_condition = and_(
1856 profile_condition,
1857 Notification.timestamp < limit_timestamp if limit_timestamp else False,
1858 Notification.requires_action == False,
1859 )
1860
1861 # Combine the conditions
1862 conditions = or_(expired_condition, older_condition)
1863
1864 async with self.session() as session:
1865 await session.execute(delete(Notification).where(conditions))
1866 await session.commit()