Mercurial > libervia-backend
comparison sat/memory/sqla.py @ 3537:f9a5b810f14d
core (memory/storage): backend storage is now based on SQLAlchemy
author | Goffi <goffi@goffi.org> |
---|---|
date | Thu, 03 Jun 2021 15:20:47 +0200 |
parents | |
children | 71516731d0aa |
comparison
equal
deleted
inserted
replaced
3536:0985c47ffd96 | 3537:f9a5b810f14d |
---|---|
1 #!/usr/bin/env python3 | |
2 | |
3 # Libervia: an XMPP client | |
4 # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) | |
5 | |
6 # This program is free software: you can redistribute it and/or modify | |
7 # it under the terms of the GNU Affero General Public License as published by | |
8 # the Free Software Foundation, either version 3 of the License, or | |
9 # (at your option) any later version. | |
10 | |
11 # This program is distributed in the hope that it will be useful, | |
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
14 # GNU Affero General Public License for more details. | |
15 | |
16 # You should have received a copy of the GNU Affero General Public License | |
17 # along with this program. If not, see <http://www.gnu.org/licenses/>. | |
18 | |
19 import time | |
20 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional | |
21 from urllib.parse import quote | |
22 from pathlib import Path | |
23 from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine | |
24 from sqlalchemy.exc import IntegrityError, NoResultFound | |
25 from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager | |
26 from sqlalchemy.future import select | |
27 from sqlalchemy.engine import Engine | |
28 from sqlalchemy import update, delete, and_, or_, event | |
29 from sqlalchemy.sql.functions import coalesce, sum as sum_ | |
30 from sqlalchemy.dialects.sqlite import insert | |
31 from twisted.internet import defer | |
32 from twisted.words.protocols.jabber import jid | |
33 from sat.core.i18n import _ | |
34 from sat.core import exceptions | |
35 from sat.core.log import getLogger | |
36 from sat.core.constants import Const as C | |
37 from sat.core.core_types import SatXMPPEntity | |
38 from sat.tools.utils import aio | |
39 from sat.memory.sqla_mapping import ( | |
40 NOT_IN_EXTRA, | |
41 Base, | |
42 Profile, | |
43 Component, | |
44 History, | |
45 Message, | |
46 Subject, | |
47 Thread, | |
48 ParamGen, | |
49 ParamInd, | |
50 PrivateGen, | |
51 PrivateInd, | |
52 PrivateGenBin, | |
53 PrivateIndBin, | |
54 File | |
55 ) | |
56 | |
57 | |
58 log = getLogger(__name__) | |
59 | |
60 | |
61 @event.listens_for(Engine, "connect") | |
62 def set_sqlite_pragma(dbapi_connection, connection_record): | |
63 cursor = dbapi_connection.cursor() | |
64 cursor.execute("PRAGMA foreign_keys=ON") | |
65 cursor.close() | |
66 | |
67 | |
68 class Storage: | |
69 | |
70 def __init__(self, db_filename, sat_version): | |
71 self.initialized = defer.Deferred() | |
72 self.filename = Path(db_filename) | |
73 # we keep cache for the profiles (key: profile name, value: profile id) | |
74 # profile id to name | |
75 self.profiles: Dict[int, str] = {} | |
76 # profile id to component entry point | |
77 self.components: Dict[int, str] = {} | |
78 | |
79 @aio | |
80 async def initialise(self): | |
81 log.info(_("Connecting database")) | |
82 engine = create_async_engine( | |
83 f"sqlite+aiosqlite:///{quote(str(self.filename))}", | |
84 future=True | |
85 ) | |
86 self.session = sessionmaker( | |
87 engine, expire_on_commit=False, class_=AsyncSession | |
88 ) | |
89 new_base = not self.filename.exists() | |
90 if new_base: | |
91 log.info(_("The database is new, creating the tables")) | |
92 # the dir may not exist if it's not the XDG recommended one | |
93 self.filename.parent.mkdir(0o700, True, True) | |
94 async with engine.begin() as conn: | |
95 await conn.run_sync(Base.metadata.create_all) | |
96 | |
97 async with self.session() as session: | |
98 result = await session.execute(select(Profile)) | |
99 for p in result.scalars(): | |
100 self.profiles[p.name] = p.id | |
101 result = await session.execute(select(Component)) | |
102 for c in result.scalars(): | |
103 self.components[c.profile_id] = c.entry_point | |
104 | |
105 self.initialized.callback(None) | |
106 | |
107 ## Profiles | |
108 | |
109 def getProfilesList(self) -> List[str]: | |
110 """"Return list of all registered profiles""" | |
111 return list(self.profiles.keys()) | |
112 | |
113 def hasProfile(self, profile_name: str) -> bool: | |
114 """return True if profile_name exists | |
115 | |
116 @param profile_name: name of the profile to check | |
117 """ | |
118 return profile_name in self.profiles | |
119 | |
120 def profileIsComponent(self, profile_name: str) -> bool: | |
121 try: | |
122 return self.profiles[profile_name] in self.components | |
123 except KeyError: | |
124 raise exceptions.NotFound("the requested profile doesn't exists") | |
125 | |
126 def getEntryPoint(self, profile_name: str) -> str: | |
127 try: | |
128 return self.components[self.profiles[profile_name]] | |
129 except KeyError: | |
130 raise exceptions.NotFound("the requested profile doesn't exists or is not a component") | |
131 | |
132 @aio | |
133 async def createProfile(self, name: str, component_ep: Optional[str] = None) -> None: | |
134 """Create a new profile | |
135 | |
136 @param name: name of the profile | |
137 @param component: if not None, must point to a component entry point | |
138 """ | |
139 async with self.session() as session: | |
140 profile = Profile(name=name) | |
141 async with session.begin(): | |
142 session.add(profile) | |
143 self.profiles[profile.id] = profile.name | |
144 if component_ep is not None: | |
145 async with session.begin(): | |
146 component = Component(profile=profile, entry_point=component_ep) | |
147 session.add(component) | |
148 self.components[profile.id] = component_ep | |
149 return profile | |
150 | |
151 @aio | |
152 async def deleteProfile(self, name: str) -> None: | |
153 """Delete profile | |
154 | |
155 @param name: name of the profile | |
156 """ | |
157 async with self.session() as session: | |
158 result = await session.execute(select(Profile).where(Profile.name == name)) | |
159 profile = result.scalar() | |
160 await session.delete(profile) | |
161 await session.commit() | |
162 del self.profiles[profile.id] | |
163 if profile.id in self.components: | |
164 del self.components[profile.id] | |
165 log.info(_("Profile {name!r} deleted").format(name = name)) | |
166 | |
167 ## Params | |
168 | |
169 @aio | |
170 async def loadGenParams(self, params_gen: dict) -> None: | |
171 """Load general parameters | |
172 | |
173 @param params_gen: dictionary to fill | |
174 """ | |
175 log.debug(_("loading general parameters from database")) | |
176 async with self.session() as session: | |
177 result = await session.execute(select(ParamGen)) | |
178 for p in result.scalars(): | |
179 params_gen[(p.category, p.name)] = p.value | |
180 | |
181 @aio | |
182 async def loadIndParams(self, params_ind: dict, profile: str) -> None: | |
183 """Load individual parameters | |
184 | |
185 @param params_ind: dictionary to fill | |
186 @param profile: a profile which *must* exist | |
187 """ | |
188 log.debug(_("loading individual parameters from database")) | |
189 async with self.session() as session: | |
190 result = await session.execute( | |
191 select(ParamInd).where(ParamInd.profile_id == self.profiles[profile]) | |
192 ) | |
193 for p in result.scalars(): | |
194 params_ind[(p.category, p.name)] = p.value | |
195 | |
196 @aio | |
197 async def getIndParam(self, category: str, name: str, profile: str) -> Optional[str]: | |
198 """Ask database for the value of one specific individual parameter | |
199 | |
200 @param category: category of the parameter | |
201 @param name: name of the parameter | |
202 @param profile: %(doc_profile)s | |
203 """ | |
204 async with self.session() as session: | |
205 result = await session.execute( | |
206 select(ParamInd.value) | |
207 .filter_by( | |
208 category=category, | |
209 name=name, | |
210 profile_id=self.profiles[profile] | |
211 ) | |
212 ) | |
213 return result.scalar_one_or_none() | |
214 | |
215 @aio | |
216 async def getIndParamValues(self, category: str, name: str) -> Dict[str, str]: | |
217 """Ask database for the individual values of a parameter for all profiles | |
218 | |
219 @param category: category of the parameter | |
220 @param name: name of the parameter | |
221 @return dict: profile => value map | |
222 """ | |
223 async with self.session() as session: | |
224 result = await session.execute( | |
225 select(ParamInd) | |
226 .filter_by( | |
227 category=category, | |
228 name=name | |
229 ) | |
230 .options(subqueryload(ParamInd.profile)) | |
231 ) | |
232 return {param.profile.name: param.value for param in result.scalars()} | |
233 | |
234 @aio | |
235 async def setGenParam(self, category: str, name: str, value: Optional[str]) -> None: | |
236 """Save the general parameters in database | |
237 | |
238 @param category: category of the parameter | |
239 @param name: name of the parameter | |
240 @param value: value to set | |
241 """ | |
242 async with self.session() as session: | |
243 stmt = insert(ParamGen).values( | |
244 category=category, | |
245 name=name, | |
246 value=value | |
247 ).on_conflict_do_update( | |
248 index_elements=(ParamGen.category, ParamGen.name), | |
249 set_={ | |
250 ParamGen.value: value | |
251 } | |
252 ) | |
253 await session.execute(stmt) | |
254 await session.commit() | |
255 | |
256 @aio | |
257 async def setIndParam( | |
258 self, | |
259 category:str, | |
260 name: str, | |
261 value: Optional[str], | |
262 profile: str | |
263 ) -> None: | |
264 """Save the individual parameters in database | |
265 | |
266 @param category: category of the parameter | |
267 @param name: name of the parameter | |
268 @param value: value to set | |
269 @param profile: a profile which *must* exist | |
270 """ | |
271 async with self.session() as session: | |
272 stmt = insert(ParamInd).values( | |
273 category=category, | |
274 name=name, | |
275 profile_id=self.profiles[profile], | |
276 value=value | |
277 ).on_conflict_do_update( | |
278 index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), | |
279 set_={ | |
280 ParamInd.value: value | |
281 } | |
282 ) | |
283 await session.execute(stmt) | |
284 await session.commit() | |
285 | |
286 def _jid_filter(self, jid_: jid.JID, dest: bool = False): | |
287 """Generate condition to filter on a JID, using relevant columns | |
288 | |
289 @param dest: True if it's the destinee JID, otherwise it's the source one | |
290 @param jid_: JID to filter by | |
291 """ | |
292 if jid_.resource: | |
293 if dest: | |
294 return and_( | |
295 History.dest == jid_.userhost(), | |
296 History.dest_res == jid_.resource | |
297 ) | |
298 else: | |
299 return and_( | |
300 History.source == jid_.userhost(), | |
301 History.source_res == jid_.resource | |
302 ) | |
303 else: | |
304 if dest: | |
305 return History.dest == jid_.userhost() | |
306 else: | |
307 return History.source == jid_.userhost() | |
308 | |
309 @aio | |
310 async def historyGet( | |
311 self, | |
312 from_jid: Optional[jid.JID], | |
313 to_jid: Optional[jid.JID], | |
314 limit: Optional[int] = None, | |
315 between: bool = True, | |
316 filters: Optional[Dict[str, str]] = None, | |
317 profile: Optional[str] = None, | |
318 ) -> List[Tuple[ | |
319 str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] | |
320 ]: | |
321 """Retrieve messages in history | |
322 | |
323 @param from_jid: source JID (full, or bare for catchall) | |
324 @param to_jid: dest JID (full, or bare for catchall) | |
325 @param limit: maximum number of messages to get: | |
326 - 0 for no message (returns the empty list) | |
327 - None for unlimited | |
328 @param between: confound source and dest (ignore the direction) | |
329 @param filters: pattern to filter the history results | |
330 @return: list of messages as in [messageNew], minus the profile which is already | |
331 known. | |
332 """ | |
333 # we have to set a default value to profile because it's last argument | |
334 # and thus follow other keyword arguments with default values | |
335 # but None should not be used for it | |
336 assert profile is not None | |
337 if limit == 0: | |
338 return [] | |
339 if filters is None: | |
340 filters = {} | |
341 | |
342 stmt = ( | |
343 select(History) | |
344 .filter_by( | |
345 profile_id=self.profiles[profile] | |
346 ) | |
347 .outerjoin(History.messages) | |
348 .outerjoin(History.subjects) | |
349 .outerjoin(History.thread) | |
350 .options( | |
351 contains_eager(History.messages), | |
352 contains_eager(History.subjects), | |
353 contains_eager(History.thread), | |
354 ) | |
355 .order_by( | |
356 # timestamp may be identical for 2 close messages (specially when delay is | |
357 # used) that's why we order ties by received_timestamp. We'll reverse the | |
358 # order when returning the result. We use DESC here so LIMIT keep the last | |
359 # messages | |
360 History.timestamp.desc(), | |
361 History.received_timestamp.desc() | |
362 ) | |
363 ) | |
364 | |
365 | |
366 if not from_jid and not to_jid: | |
367 # no jid specified, we want all one2one communications | |
368 pass | |
369 elif between: | |
370 if not from_jid or not to_jid: | |
371 # we only have one jid specified, we check all messages | |
372 # from or to this jid | |
373 jid_ = from_jid or to_jid | |
374 stmt = stmt.where( | |
375 or_( | |
376 self._jid_filter(jid_), | |
377 self._jid_filter(jid_, dest=True) | |
378 ) | |
379 ) | |
380 else: | |
381 # we have 2 jids specified, we check all communications between | |
382 # those 2 jids | |
383 stmt = stmt.where( | |
384 or_( | |
385 and_( | |
386 self._jid_filter(from_jid), | |
387 self._jid_filter(to_jid, dest=True), | |
388 ), | |
389 and_( | |
390 self._jid_filter(to_jid), | |
391 self._jid_filter(from_jid, dest=True), | |
392 ) | |
393 ) | |
394 ) | |
395 else: | |
396 # we want one communication in specific direction (from somebody or | |
397 # to somebody). | |
398 if from_jid is not None: | |
399 stmt = stmt.where(self._jid_filter(from_jid)) | |
400 if to_jid is not None: | |
401 stmt = stmt.where(self._jid_filter(to_jid, dest=True)) | |
402 | |
403 if filters: | |
404 if 'timestamp_start' in filters: | |
405 stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) | |
406 if 'before_uid' in filters: | |
407 # orignially this query was using SQLITE's rowid. This has been changed | |
408 # to use coalesce(received_timestamp, timestamp) to be SQL engine independant | |
409 stmt = stmt.where( | |
410 coalesce( | |
411 History.received_timestamp, | |
412 History.timestamp | |
413 ) < ( | |
414 select(coalesce(History.received_timestamp, History.timestamp)) | |
415 .filter_by(uid=filters["before_uid"]) | |
416 ).scalar_subquery() | |
417 ) | |
418 if 'body' in filters: | |
419 # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html | |
420 stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) | |
421 if 'search' in filters: | |
422 search_term = f"%{filters['search']}%" | |
423 stmt = stmt.where(or_( | |
424 Message.message.like(search_term), | |
425 History.source_res.like(search_term) | |
426 )) | |
427 if 'types' in filters: | |
428 types = filters['types'].split() | |
429 stmt = stmt.where(History.type.in_(types)) | |
430 if 'not_types' in filters: | |
431 types = filters['not_types'].split() | |
432 stmt = stmt.where(History.type.not_in(types)) | |
433 if 'last_stanza_id' in filters: | |
434 # this request get the last message with a "stanza_id" that we | |
435 # have in history. This is mainly used to retrieve messages sent | |
436 # while we were offline, using MAM (XEP-0313). | |
437 if (filters['last_stanza_id'] is not True | |
438 or limit != 1): | |
439 raise ValueError("Unexpected values for last_stanza_id filter") | |
440 stmt = stmt.where(History.stanza_id.is_not(None)) | |
441 | |
442 if limit is not None: | |
443 stmt = stmt.limit(limit) | |
444 | |
445 async with self.session() as session: | |
446 result = await session.execute(stmt) | |
447 | |
448 result = result.scalars().unique().all() | |
449 result.reverse() | |
450 return [h.as_tuple() for h in result] | |
451 | |
452 @aio | |
453 async def addToHistory(self, data: dict, profile: str) -> None: | |
454 """Store a new message in history | |
455 | |
456 @param data: message data as build by SatMessageProtocol.onMessage | |
457 """ | |
458 extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} | |
459 messages = [Message(message=mess, language=lang) | |
460 for lang, mess in data["message"].items()] | |
461 subjects = [Subject(subject=mess, language=lang) | |
462 for lang, mess in data["subject"].items()] | |
463 if "thread" in data["extra"]: | |
464 thread = Thread(thread_id=data["extra"]["thread"], | |
465 parent_id=data["extra"].get["thread_parent"]) | |
466 else: | |
467 thread = None | |
468 try: | |
469 async with self.session() as session: | |
470 async with session.begin(): | |
471 session.add(History( | |
472 uid=data["uid"], | |
473 stanza_id=data["extra"].get("stanza_id"), | |
474 update_uid=data["extra"].get("update_uid"), | |
475 profile_id=self.profiles[profile], | |
476 source_jid=data["from"], | |
477 dest_jid=data["to"], | |
478 timestamp=data["timestamp"], | |
479 received_timestamp=data.get("received_timestamp"), | |
480 type=data["type"], | |
481 extra=extra, | |
482 messages=messages, | |
483 subjects=subjects, | |
484 thread=thread, | |
485 )) | |
486 except IntegrityError as e: | |
487 if "unique" in str(e.orig).lower(): | |
488 log.debug( | |
489 f"message {data['uid']!r} is already in history, not storing it again" | |
490 ) | |
491 else: | |
492 log.error(f"Can't store message {data['uid']!r} in history: {e}") | |
493 except Exception as e: | |
494 log.critical( | |
495 f"Can't store message, unexpected exception (uid: {data['uid']}): {e}" | |
496 ) | |
497 | |
498 ## Private values | |
499 | |
500 def _getPrivateClass(self, binary, profile): | |
501 """Get ORM class to use for private values""" | |
502 if profile is None: | |
503 return PrivateGenBin if binary else PrivateGen | |
504 else: | |
505 return PrivateIndBin if binary else PrivateInd | |
506 | |
507 | |
508 @aio | |
509 async def getPrivates( | |
510 self, | |
511 namespace:str, | |
512 keys: Optional[Iterable[str]] = None, | |
513 binary: bool = False, | |
514 profile: Optional[str] = None | |
515 ) -> Dict[str, Any]: | |
516 """Get private value(s) from databases | |
517 | |
518 @param namespace: namespace of the values | |
519 @param keys: keys of the values to get None to get all keys/values | |
520 @param binary: True to deserialise binary values | |
521 @param profile: profile to use for individual values | |
522 None to use general values | |
523 @return: gotten keys/values | |
524 """ | |
525 if keys is not None: | |
526 keys = list(keys) | |
527 log.debug( | |
528 f"getting {'general' if profile is None else 'individual'}" | |
529 f"{' binary' if binary else ''} private values from database for namespace " | |
530 f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}" | |
531 ) | |
532 cls = self._getPrivateClass(binary, profile) | |
533 stmt = select(cls).filter_by(namespace=namespace) | |
534 if keys: | |
535 stmt = stmt.where(cls.key.in_(list(keys))) | |
536 if profile is not None: | |
537 stmt = stmt.filter_by(profile_id=self.profiles[profile]) | |
538 async with self.session() as session: | |
539 result = await session.execute(stmt) | |
540 return {p.key: p.value for p in result.scalars()} | |
541 | |
542 @aio | |
543 async def setPrivateValue( | |
544 self, | |
545 namespace: str, | |
546 key:str, | |
547 value: Any, | |
548 binary: bool = False, | |
549 profile: Optional[str] = None | |
550 ) -> None: | |
551 """Set a private value in database | |
552 | |
553 @param namespace: namespace of the values | |
554 @param key: key of the value to set | |
555 @param value: value to set | |
556 @param binary: True if it's a binary values | |
557 binary values need to be serialised, used for everything but strings | |
558 @param profile: profile to use for individual value | |
559 if None, it's a general value | |
560 """ | |
561 cls = self._getPrivateClass(binary, profile) | |
562 | |
563 values = { | |
564 "namespace": namespace, | |
565 "key": key, | |
566 "value": value | |
567 } | |
568 index_elements = [cls.namespace, cls.key] | |
569 | |
570 if profile is not None: | |
571 values["profile_id"] = self.profiles[profile] | |
572 index_elements.append(cls.profile_id) | |
573 | |
574 async with self.session() as session: | |
575 await session.execute( | |
576 insert(cls).values(**values).on_conflict_do_update( | |
577 index_elements=index_elements, | |
578 set_={ | |
579 cls.value: value | |
580 } | |
581 ) | |
582 ) | |
583 await session.commit() | |
584 | |
585 @aio | |
586 async def delPrivateValue( | |
587 self, | |
588 namespace: str, | |
589 key: str, | |
590 binary: bool = False, | |
591 profile: Optional[str] = None | |
592 ) -> None: | |
593 """Delete private value from database | |
594 | |
595 @param category: category of the privateeter | |
596 @param key: key of the private value | |
597 @param binary: True if it's a binary values | |
598 @param profile: profile to use for individual value | |
599 if None, it's a general value | |
600 """ | |
601 cls = self._getPrivateClass(binary, profile) | |
602 | |
603 stmt = delete(cls).filter_by(namespace=namespace, key=key) | |
604 | |
605 if profile is not None: | |
606 stmt = stmt.filter_by(profile_id=self.profiles[profile]) | |
607 | |
608 async with self.session() as session: | |
609 await session.execute(stmt) | |
610 await session.commit() | |
611 | |
612 @aio | |
613 async def delPrivateNamespace( | |
614 self, | |
615 namespace: str, | |
616 binary: bool = False, | |
617 profile: Optional[str] = None | |
618 ) -> None: | |
619 """Delete all data from a private namespace | |
620 | |
621 Be really cautious when you use this method, as all data with given namespace are | |
622 removed. | |
623 Params are the same as for delPrivateValue | |
624 """ | |
625 cls = self._getPrivateClass(binary, profile) | |
626 | |
627 stmt = delete(cls).filter_by(namespace=namespace) | |
628 | |
629 if profile is not None: | |
630 stmt = stmt.filter_by(profile_id=self.profiles[profile]) | |
631 | |
632 async with self.session() as session: | |
633 await session.execute(stmt) | |
634 await session.commit() | |
635 | |
636 ## Files | |
637 | |
638 @aio | |
639 async def getFiles( | |
640 self, | |
641 client: Optional[SatXMPPEntity], | |
642 file_id: Optional[str] = None, | |
643 version: Optional[str] = '', | |
644 parent: Optional[str] = None, | |
645 type_: Optional[str] = None, | |
646 file_hash: Optional[str] = None, | |
647 hash_algo: Optional[str] = None, | |
648 name: Optional[str] = None, | |
649 namespace: Optional[str] = None, | |
650 mime_type: Optional[str] = None, | |
651 public_id: Optional[str] = None, | |
652 owner: Optional[jid.JID] = None, | |
653 access: Optional[dict] = None, | |
654 projection: Optional[List[str]] = None, | |
655 unique: bool = False | |
656 ) -> List[dict]: | |
657 """Retrieve files with with given filters | |
658 | |
659 @param file_id: id of the file | |
660 None to ignore | |
661 @param version: version of the file | |
662 None to ignore | |
663 empty string to look for current version | |
664 @param parent: id of the directory containing the files | |
665 None to ignore | |
666 empty string to look for root files/directories | |
667 @param projection: name of columns to retrieve | |
668 None to retrieve all | |
669 @param unique: if True will remove duplicates | |
670 other params are the same as for [setFile] | |
671 @return: files corresponding to filters | |
672 """ | |
673 if projection is None: | |
674 projection = [ | |
675 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', | |
676 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', | |
677 'created', 'modified', 'owner', 'access', 'extra' | |
678 ] | |
679 | |
680 stmt = select(*[getattr(File, f) for f in projection]) | |
681 | |
682 if unique: | |
683 stmt = stmt.distinct() | |
684 | |
685 if client is not None: | |
686 stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) | |
687 else: | |
688 if public_id is None: | |
689 raise exceptions.InternalError( | |
690 "client can only be omitted when public_id is set" | |
691 ) | |
692 if file_id is not None: | |
693 stmt = stmt.filter_by(id=file_id) | |
694 if version is not None: | |
695 stmt = stmt.filter_by(version=version) | |
696 if parent is not None: | |
697 stmt = stmt.filter_by(parent=parent) | |
698 if type_ is not None: | |
699 stmt = stmt.filter_by(type=type_) | |
700 if file_hash is not None: | |
701 stmt = stmt.filter_by(file_hash=file_hash) | |
702 if hash_algo is not None: | |
703 stmt = stmt.filter_by(hash_algo=hash_algo) | |
704 if name is not None: | |
705 stmt = stmt.filter_by(name=name) | |
706 if namespace is not None: | |
707 stmt = stmt.filter_by(namespace=namespace) | |
708 if mime_type is not None: | |
709 if '/' in mime_type: | |
710 media_type, media_subtype = mime_type.split("/", 1) | |
711 stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) | |
712 else: | |
713 stmt = stmt.filter_by(media_type=mime_type) | |
714 if public_id is not None: | |
715 stmt = stmt.filter_by(public_id=public_id) | |
716 if owner is not None: | |
717 stmt = stmt.filter_by(owner=owner) | |
718 if access is not None: | |
719 raise NotImplementedError('Access check is not implemented yet') | |
720 # a JSON comparison is needed here | |
721 | |
722 async with self.session() as session: | |
723 result = await session.execute(stmt) | |
724 | |
725 return [dict(r) for r in result] | |
726 | |
727 @aio | |
728 async def setFile( | |
729 self, | |
730 client: SatXMPPEntity, | |
731 name: str, | |
732 file_id: str, | |
733 version: str = "", | |
734 parent: str = "", | |
735 type_: str = C.FILE_TYPE_FILE, | |
736 file_hash: Optional[str] = None, | |
737 hash_algo: Optional[str] = None, | |
738 size: int = None, | |
739 namespace: Optional[str] = None, | |
740 mime_type: Optional[str] = None, | |
741 public_id: Optional[str] = None, | |
742 created: Optional[float] = None, | |
743 modified: Optional[float] = None, | |
744 owner: Optional[jid.JID] = None, | |
745 access: Optional[dict] = None, | |
746 extra: Optional[dict] = None | |
747 ) -> None: | |
748 """Set a file metadata | |
749 | |
750 @param client: client owning the file | |
751 @param name: name of the file (must not contain "/") | |
752 @param file_id: unique id of the file | |
753 @param version: version of this file | |
754 @param parent: id of the directory containing this file | |
755 Empty string if it is a root file/directory | |
756 @param type_: one of: | |
757 - file | |
758 - directory | |
759 @param file_hash: unique hash of the payload | |
760 @param hash_algo: algorithm used for hashing the file (usually sha-256) | |
761 @param size: size in bytes | |
762 @param namespace: identifier (human readable is better) to group files | |
763 for instance, namespace could be used to group files in a specific photo album | |
764 @param mime_type: media type of the file, or None if not known/guessed | |
765 @param public_id: ID used to server the file publicly via HTTP | |
766 @param created: UNIX time of creation | |
767 @param modified: UNIX time of last modification, or None to use created date | |
768 @param owner: jid of the owner of the file (mainly useful for component) | |
769 @param access: serialisable dictionary with access rules. See [memory.memory] for details | |
770 @param extra: serialisable dictionary of any extra data | |
771 will be encoded to json in database | |
772 """ | |
773 if mime_type is None: | |
774 media_type = media_subtype = None | |
775 elif '/' in mime_type: | |
776 media_type, media_subtype = mime_type.split('/', 1) | |
777 else: | |
778 media_type, media_subtype = mime_type, None | |
779 | |
780 async with self.session() as session: | |
781 async with session.begin(): | |
782 session.add(File( | |
783 id=file_id, | |
784 version=version.strip(), | |
785 parent=parent, | |
786 type=type_, | |
787 file_hash=file_hash, | |
788 hash_algo=hash_algo, | |
789 name=name, | |
790 size=size, | |
791 namespace=namespace, | |
792 media_type=media_type, | |
793 media_subtype=media_subtype, | |
794 public_id=public_id, | |
795 created=time.time() if created is None else created, | |
796 modified=modified, | |
797 owner=owner, | |
798 access=access, | |
799 extra=extra, | |
800 profile_id=self.profiles[client.profile] | |
801 )) | |
802 | |
803 @aio | |
804 async def fileGetUsedSpace(self, client: SatXMPPEntity, owner: jid.JID) -> int: | |
805 async with self.session() as session: | |
806 result = await session.execute( | |
807 select(sum_(File.size)).filter_by( | |
808 owner=owner, | |
809 type=C.FILE_TYPE_FILE, | |
810 profile_id=self.profiles[client.profile] | |
811 )) | |
812 return result.scalar_one_or_none() or 0 | |
813 | |
814 @aio | |
815 async def fileDelete(self, file_id: str) -> None: | |
816 """Delete file metadata from the database | |
817 | |
818 @param file_id: id of the file to delete | |
819 NOTE: file itself must still be removed, this method only handle metadata in | |
820 database | |
821 """ | |
822 async with self.session() as session: | |
823 await session.execute(delete(File).filter_by(id=file_id)) | |
824 await session.commit() | |
825 | |
826 @aio | |
827 async def fileUpdate( | |
828 self, | |
829 file_id: str, | |
830 column: str, | |
831 update_cb: Callable[[dict], None] | |
832 ) -> None: | |
833 """Update a column value using a method to avoid race conditions | |
834 | |
835 the older value will be retrieved from database, then update_cb will be applied to | |
836 update it, and file will be updated checking that older value has not been changed | |
837 meanwhile by an other user. If it has changed, it tries again a couple of times | |
838 before failing | |
839 @param column: column name (only "access" or "extra" are allowed) | |
840 @param update_cb: method to update the value of the colum | |
841 the method will take older value as argument, and must update it in place | |
842 update_cb must not care about serialization, | |
843 it get the deserialized data (i.e. a Python object) directly | |
844 @raise exceptions.NotFound: there is not file with this id | |
845 """ | |
846 if column not in ('access', 'extra'): | |
847 raise exceptions.InternalError('bad column name') | |
848 orm_col = getattr(File, column) | |
849 | |
850 for i in range(5): | |
851 async with self.session() as session: | |
852 try: | |
853 value = (await session.execute( | |
854 select(orm_col).filter_by(id=file_id) | |
855 )).scalar_one() | |
856 except NoResultFound: | |
857 raise exceptions.NotFound | |
858 update_cb(value) | |
859 stmt = update(orm_col).filter_by(id=file_id) | |
860 if not value: | |
861 # because JsonDefaultDict convert NULL to an empty dict, we have to | |
862 # test both for empty dict and None when we have and empty dict | |
863 stmt = stmt.where((orm_col == None) | (orm_col == value)) | |
864 else: | |
865 stmt = stmt.where(orm_col == value) | |
866 result = await session.execute(stmt) | |
867 await session.commit() | |
868 | |
869 if result.rowcount == 1: | |
870 break | |
871 | |
872 log.warning( | |
873 _("table not updated, probably due to race condition, trying again " | |
874 "({tries})").format(tries=i+1) | |
875 ) | |
876 | |
877 else: | |
878 raise exceptions.DatabaseError( | |
879 _("Can't update file {file_id} due to race condition") | |
880 .format(file_id=file_id) | |
881 ) |