comparison sat/memory/sqla.py @ 3537:f9a5b810f14d

core (memory/storage): backend storage is now based on SQLAlchemy
author Goffi <goffi@goffi.org>
date Thu, 03 Jun 2021 15:20:47 +0200
parents
children 71516731d0aa
comparison
equal deleted inserted replaced
3536:0985c47ffd96 3537:f9a5b810f14d
1 #!/usr/bin/env python3
2
3 # Libervia: an XMPP client
4 # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)
5
6 # This program is free software: you can redistribute it and/or modify
7 # it under the terms of the GNU Affero General Public License as published by
8 # the Free Software Foundation, either version 3 of the License, or
9 # (at your option) any later version.
10
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU Affero General Public License for more details.
15
16 # You should have received a copy of the GNU Affero General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18
19 import time
20 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional
21 from urllib.parse import quote
22 from pathlib import Path
23 from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
24 from sqlalchemy.exc import IntegrityError, NoResultFound
25 from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager
26 from sqlalchemy.future import select
27 from sqlalchemy.engine import Engine
28 from sqlalchemy import update, delete, and_, or_, event
29 from sqlalchemy.sql.functions import coalesce, sum as sum_
30 from sqlalchemy.dialects.sqlite import insert
31 from twisted.internet import defer
32 from twisted.words.protocols.jabber import jid
33 from sat.core.i18n import _
34 from sat.core import exceptions
35 from sat.core.log import getLogger
36 from sat.core.constants import Const as C
37 from sat.core.core_types import SatXMPPEntity
38 from sat.tools.utils import aio
39 from sat.memory.sqla_mapping import (
40 NOT_IN_EXTRA,
41 Base,
42 Profile,
43 Component,
44 History,
45 Message,
46 Subject,
47 Thread,
48 ParamGen,
49 ParamInd,
50 PrivateGen,
51 PrivateInd,
52 PrivateGenBin,
53 PrivateIndBin,
54 File
55 )
56
57
58 log = getLogger(__name__)
59
60
61 @event.listens_for(Engine, "connect")
62 def set_sqlite_pragma(dbapi_connection, connection_record):
63 cursor = dbapi_connection.cursor()
64 cursor.execute("PRAGMA foreign_keys=ON")
65 cursor.close()
66
67
68 class Storage:
69
70 def __init__(self, db_filename, sat_version):
71 self.initialized = defer.Deferred()
72 self.filename = Path(db_filename)
73 # we keep cache for the profiles (key: profile name, value: profile id)
74 # profile id to name
75 self.profiles: Dict[int, str] = {}
76 # profile id to component entry point
77 self.components: Dict[int, str] = {}
78
79 @aio
80 async def initialise(self):
81 log.info(_("Connecting database"))
82 engine = create_async_engine(
83 f"sqlite+aiosqlite:///{quote(str(self.filename))}",
84 future=True
85 )
86 self.session = sessionmaker(
87 engine, expire_on_commit=False, class_=AsyncSession
88 )
89 new_base = not self.filename.exists()
90 if new_base:
91 log.info(_("The database is new, creating the tables"))
92 # the dir may not exist if it's not the XDG recommended one
93 self.filename.parent.mkdir(0o700, True, True)
94 async with engine.begin() as conn:
95 await conn.run_sync(Base.metadata.create_all)
96
97 async with self.session() as session:
98 result = await session.execute(select(Profile))
99 for p in result.scalars():
100 self.profiles[p.name] = p.id
101 result = await session.execute(select(Component))
102 for c in result.scalars():
103 self.components[c.profile_id] = c.entry_point
104
105 self.initialized.callback(None)
106
107 ## Profiles
108
109 def getProfilesList(self) -> List[str]:
110 """"Return list of all registered profiles"""
111 return list(self.profiles.keys())
112
113 def hasProfile(self, profile_name: str) -> bool:
114 """return True if profile_name exists
115
116 @param profile_name: name of the profile to check
117 """
118 return profile_name in self.profiles
119
120 def profileIsComponent(self, profile_name: str) -> bool:
121 try:
122 return self.profiles[profile_name] in self.components
123 except KeyError:
124 raise exceptions.NotFound("the requested profile doesn't exists")
125
126 def getEntryPoint(self, profile_name: str) -> str:
127 try:
128 return self.components[self.profiles[profile_name]]
129 except KeyError:
130 raise exceptions.NotFound("the requested profile doesn't exists or is not a component")
131
132 @aio
133 async def createProfile(self, name: str, component_ep: Optional[str] = None) -> None:
134 """Create a new profile
135
136 @param name: name of the profile
137 @param component: if not None, must point to a component entry point
138 """
139 async with self.session() as session:
140 profile = Profile(name=name)
141 async with session.begin():
142 session.add(profile)
143 self.profiles[profile.id] = profile.name
144 if component_ep is not None:
145 async with session.begin():
146 component = Component(profile=profile, entry_point=component_ep)
147 session.add(component)
148 self.components[profile.id] = component_ep
149 return profile
150
151 @aio
152 async def deleteProfile(self, name: str) -> None:
153 """Delete profile
154
155 @param name: name of the profile
156 """
157 async with self.session() as session:
158 result = await session.execute(select(Profile).where(Profile.name == name))
159 profile = result.scalar()
160 await session.delete(profile)
161 await session.commit()
162 del self.profiles[profile.id]
163 if profile.id in self.components:
164 del self.components[profile.id]
165 log.info(_("Profile {name!r} deleted").format(name = name))
166
167 ## Params
168
169 @aio
170 async def loadGenParams(self, params_gen: dict) -> None:
171 """Load general parameters
172
173 @param params_gen: dictionary to fill
174 """
175 log.debug(_("loading general parameters from database"))
176 async with self.session() as session:
177 result = await session.execute(select(ParamGen))
178 for p in result.scalars():
179 params_gen[(p.category, p.name)] = p.value
180
181 @aio
182 async def loadIndParams(self, params_ind: dict, profile: str) -> None:
183 """Load individual parameters
184
185 @param params_ind: dictionary to fill
186 @param profile: a profile which *must* exist
187 """
188 log.debug(_("loading individual parameters from database"))
189 async with self.session() as session:
190 result = await session.execute(
191 select(ParamInd).where(ParamInd.profile_id == self.profiles[profile])
192 )
193 for p in result.scalars():
194 params_ind[(p.category, p.name)] = p.value
195
196 @aio
197 async def getIndParam(self, category: str, name: str, profile: str) -> Optional[str]:
198 """Ask database for the value of one specific individual parameter
199
200 @param category: category of the parameter
201 @param name: name of the parameter
202 @param profile: %(doc_profile)s
203 """
204 async with self.session() as session:
205 result = await session.execute(
206 select(ParamInd.value)
207 .filter_by(
208 category=category,
209 name=name,
210 profile_id=self.profiles[profile]
211 )
212 )
213 return result.scalar_one_or_none()
214
215 @aio
216 async def getIndParamValues(self, category: str, name: str) -> Dict[str, str]:
217 """Ask database for the individual values of a parameter for all profiles
218
219 @param category: category of the parameter
220 @param name: name of the parameter
221 @return dict: profile => value map
222 """
223 async with self.session() as session:
224 result = await session.execute(
225 select(ParamInd)
226 .filter_by(
227 category=category,
228 name=name
229 )
230 .options(subqueryload(ParamInd.profile))
231 )
232 return {param.profile.name: param.value for param in result.scalars()}
233
234 @aio
235 async def setGenParam(self, category: str, name: str, value: Optional[str]) -> None:
236 """Save the general parameters in database
237
238 @param category: category of the parameter
239 @param name: name of the parameter
240 @param value: value to set
241 """
242 async with self.session() as session:
243 stmt = insert(ParamGen).values(
244 category=category,
245 name=name,
246 value=value
247 ).on_conflict_do_update(
248 index_elements=(ParamGen.category, ParamGen.name),
249 set_={
250 ParamGen.value: value
251 }
252 )
253 await session.execute(stmt)
254 await session.commit()
255
256 @aio
257 async def setIndParam(
258 self,
259 category:str,
260 name: str,
261 value: Optional[str],
262 profile: str
263 ) -> None:
264 """Save the individual parameters in database
265
266 @param category: category of the parameter
267 @param name: name of the parameter
268 @param value: value to set
269 @param profile: a profile which *must* exist
270 """
271 async with self.session() as session:
272 stmt = insert(ParamInd).values(
273 category=category,
274 name=name,
275 profile_id=self.profiles[profile],
276 value=value
277 ).on_conflict_do_update(
278 index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id),
279 set_={
280 ParamInd.value: value
281 }
282 )
283 await session.execute(stmt)
284 await session.commit()
285
286 def _jid_filter(self, jid_: jid.JID, dest: bool = False):
287 """Generate condition to filter on a JID, using relevant columns
288
289 @param dest: True if it's the destinee JID, otherwise it's the source one
290 @param jid_: JID to filter by
291 """
292 if jid_.resource:
293 if dest:
294 return and_(
295 History.dest == jid_.userhost(),
296 History.dest_res == jid_.resource
297 )
298 else:
299 return and_(
300 History.source == jid_.userhost(),
301 History.source_res == jid_.resource
302 )
303 else:
304 if dest:
305 return History.dest == jid_.userhost()
306 else:
307 return History.source == jid_.userhost()
308
309 @aio
310 async def historyGet(
311 self,
312 from_jid: Optional[jid.JID],
313 to_jid: Optional[jid.JID],
314 limit: Optional[int] = None,
315 between: bool = True,
316 filters: Optional[Dict[str, str]] = None,
317 profile: Optional[str] = None,
318 ) -> List[Tuple[
319 str, int, str, str, Dict[str, str], Dict[str, str], str, str, str]
320 ]:
321 """Retrieve messages in history
322
323 @param from_jid: source JID (full, or bare for catchall)
324 @param to_jid: dest JID (full, or bare for catchall)
325 @param limit: maximum number of messages to get:
326 - 0 for no message (returns the empty list)
327 - None for unlimited
328 @param between: confound source and dest (ignore the direction)
329 @param filters: pattern to filter the history results
330 @return: list of messages as in [messageNew], minus the profile which is already
331 known.
332 """
333 # we have to set a default value to profile because it's last argument
334 # and thus follow other keyword arguments with default values
335 # but None should not be used for it
336 assert profile is not None
337 if limit == 0:
338 return []
339 if filters is None:
340 filters = {}
341
342 stmt = (
343 select(History)
344 .filter_by(
345 profile_id=self.profiles[profile]
346 )
347 .outerjoin(History.messages)
348 .outerjoin(History.subjects)
349 .outerjoin(History.thread)
350 .options(
351 contains_eager(History.messages),
352 contains_eager(History.subjects),
353 contains_eager(History.thread),
354 )
355 .order_by(
356 # timestamp may be identical for 2 close messages (specially when delay is
357 # used) that's why we order ties by received_timestamp. We'll reverse the
358 # order when returning the result. We use DESC here so LIMIT keep the last
359 # messages
360 History.timestamp.desc(),
361 History.received_timestamp.desc()
362 )
363 )
364
365
366 if not from_jid and not to_jid:
367 # no jid specified, we want all one2one communications
368 pass
369 elif between:
370 if not from_jid or not to_jid:
371 # we only have one jid specified, we check all messages
372 # from or to this jid
373 jid_ = from_jid or to_jid
374 stmt = stmt.where(
375 or_(
376 self._jid_filter(jid_),
377 self._jid_filter(jid_, dest=True)
378 )
379 )
380 else:
381 # we have 2 jids specified, we check all communications between
382 # those 2 jids
383 stmt = stmt.where(
384 or_(
385 and_(
386 self._jid_filter(from_jid),
387 self._jid_filter(to_jid, dest=True),
388 ),
389 and_(
390 self._jid_filter(to_jid),
391 self._jid_filter(from_jid, dest=True),
392 )
393 )
394 )
395 else:
396 # we want one communication in specific direction (from somebody or
397 # to somebody).
398 if from_jid is not None:
399 stmt = stmt.where(self._jid_filter(from_jid))
400 if to_jid is not None:
401 stmt = stmt.where(self._jid_filter(to_jid, dest=True))
402
403 if filters:
404 if 'timestamp_start' in filters:
405 stmt = stmt.where(History.timestamp >= float(filters['timestamp_start']))
406 if 'before_uid' in filters:
407 # orignially this query was using SQLITE's rowid. This has been changed
408 # to use coalesce(received_timestamp, timestamp) to be SQL engine independant
409 stmt = stmt.where(
410 coalesce(
411 History.received_timestamp,
412 History.timestamp
413 ) < (
414 select(coalesce(History.received_timestamp, History.timestamp))
415 .filter_by(uid=filters["before_uid"])
416 ).scalar_subquery()
417 )
418 if 'body' in filters:
419 # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html
420 stmt = stmt.where(Message.message.like(f"%{filters['body']}%"))
421 if 'search' in filters:
422 search_term = f"%{filters['search']}%"
423 stmt = stmt.where(or_(
424 Message.message.like(search_term),
425 History.source_res.like(search_term)
426 ))
427 if 'types' in filters:
428 types = filters['types'].split()
429 stmt = stmt.where(History.type.in_(types))
430 if 'not_types' in filters:
431 types = filters['not_types'].split()
432 stmt = stmt.where(History.type.not_in(types))
433 if 'last_stanza_id' in filters:
434 # this request get the last message with a "stanza_id" that we
435 # have in history. This is mainly used to retrieve messages sent
436 # while we were offline, using MAM (XEP-0313).
437 if (filters['last_stanza_id'] is not True
438 or limit != 1):
439 raise ValueError("Unexpected values for last_stanza_id filter")
440 stmt = stmt.where(History.stanza_id.is_not(None))
441
442 if limit is not None:
443 stmt = stmt.limit(limit)
444
445 async with self.session() as session:
446 result = await session.execute(stmt)
447
448 result = result.scalars().unique().all()
449 result.reverse()
450 return [h.as_tuple() for h in result]
451
452 @aio
453 async def addToHistory(self, data: dict, profile: str) -> None:
454 """Store a new message in history
455
456 @param data: message data as build by SatMessageProtocol.onMessage
457 """
458 extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA}
459 messages = [Message(message=mess, language=lang)
460 for lang, mess in data["message"].items()]
461 subjects = [Subject(subject=mess, language=lang)
462 for lang, mess in data["subject"].items()]
463 if "thread" in data["extra"]:
464 thread = Thread(thread_id=data["extra"]["thread"],
465 parent_id=data["extra"].get["thread_parent"])
466 else:
467 thread = None
468 try:
469 async with self.session() as session:
470 async with session.begin():
471 session.add(History(
472 uid=data["uid"],
473 stanza_id=data["extra"].get("stanza_id"),
474 update_uid=data["extra"].get("update_uid"),
475 profile_id=self.profiles[profile],
476 source_jid=data["from"],
477 dest_jid=data["to"],
478 timestamp=data["timestamp"],
479 received_timestamp=data.get("received_timestamp"),
480 type=data["type"],
481 extra=extra,
482 messages=messages,
483 subjects=subjects,
484 thread=thread,
485 ))
486 except IntegrityError as e:
487 if "unique" in str(e.orig).lower():
488 log.debug(
489 f"message {data['uid']!r} is already in history, not storing it again"
490 )
491 else:
492 log.error(f"Can't store message {data['uid']!r} in history: {e}")
493 except Exception as e:
494 log.critical(
495 f"Can't store message, unexpected exception (uid: {data['uid']}): {e}"
496 )
497
498 ## Private values
499
500 def _getPrivateClass(self, binary, profile):
501 """Get ORM class to use for private values"""
502 if profile is None:
503 return PrivateGenBin if binary else PrivateGen
504 else:
505 return PrivateIndBin if binary else PrivateInd
506
507
508 @aio
509 async def getPrivates(
510 self,
511 namespace:str,
512 keys: Optional[Iterable[str]] = None,
513 binary: bool = False,
514 profile: Optional[str] = None
515 ) -> Dict[str, Any]:
516 """Get private value(s) from databases
517
518 @param namespace: namespace of the values
519 @param keys: keys of the values to get None to get all keys/values
520 @param binary: True to deserialise binary values
521 @param profile: profile to use for individual values
522 None to use general values
523 @return: gotten keys/values
524 """
525 if keys is not None:
526 keys = list(keys)
527 log.debug(
528 f"getting {'general' if profile is None else 'individual'}"
529 f"{' binary' if binary else ''} private values from database for namespace "
530 f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}"
531 )
532 cls = self._getPrivateClass(binary, profile)
533 stmt = select(cls).filter_by(namespace=namespace)
534 if keys:
535 stmt = stmt.where(cls.key.in_(list(keys)))
536 if profile is not None:
537 stmt = stmt.filter_by(profile_id=self.profiles[profile])
538 async with self.session() as session:
539 result = await session.execute(stmt)
540 return {p.key: p.value for p in result.scalars()}
541
542 @aio
543 async def setPrivateValue(
544 self,
545 namespace: str,
546 key:str,
547 value: Any,
548 binary: bool = False,
549 profile: Optional[str] = None
550 ) -> None:
551 """Set a private value in database
552
553 @param namespace: namespace of the values
554 @param key: key of the value to set
555 @param value: value to set
556 @param binary: True if it's a binary values
557 binary values need to be serialised, used for everything but strings
558 @param profile: profile to use for individual value
559 if None, it's a general value
560 """
561 cls = self._getPrivateClass(binary, profile)
562
563 values = {
564 "namespace": namespace,
565 "key": key,
566 "value": value
567 }
568 index_elements = [cls.namespace, cls.key]
569
570 if profile is not None:
571 values["profile_id"] = self.profiles[profile]
572 index_elements.append(cls.profile_id)
573
574 async with self.session() as session:
575 await session.execute(
576 insert(cls).values(**values).on_conflict_do_update(
577 index_elements=index_elements,
578 set_={
579 cls.value: value
580 }
581 )
582 )
583 await session.commit()
584
585 @aio
586 async def delPrivateValue(
587 self,
588 namespace: str,
589 key: str,
590 binary: bool = False,
591 profile: Optional[str] = None
592 ) -> None:
593 """Delete private value from database
594
595 @param category: category of the privateeter
596 @param key: key of the private value
597 @param binary: True if it's a binary values
598 @param profile: profile to use for individual value
599 if None, it's a general value
600 """
601 cls = self._getPrivateClass(binary, profile)
602
603 stmt = delete(cls).filter_by(namespace=namespace, key=key)
604
605 if profile is not None:
606 stmt = stmt.filter_by(profile_id=self.profiles[profile])
607
608 async with self.session() as session:
609 await session.execute(stmt)
610 await session.commit()
611
612 @aio
613 async def delPrivateNamespace(
614 self,
615 namespace: str,
616 binary: bool = False,
617 profile: Optional[str] = None
618 ) -> None:
619 """Delete all data from a private namespace
620
621 Be really cautious when you use this method, as all data with given namespace are
622 removed.
623 Params are the same as for delPrivateValue
624 """
625 cls = self._getPrivateClass(binary, profile)
626
627 stmt = delete(cls).filter_by(namespace=namespace)
628
629 if profile is not None:
630 stmt = stmt.filter_by(profile_id=self.profiles[profile])
631
632 async with self.session() as session:
633 await session.execute(stmt)
634 await session.commit()
635
636 ## Files
637
638 @aio
639 async def getFiles(
640 self,
641 client: Optional[SatXMPPEntity],
642 file_id: Optional[str] = None,
643 version: Optional[str] = '',
644 parent: Optional[str] = None,
645 type_: Optional[str] = None,
646 file_hash: Optional[str] = None,
647 hash_algo: Optional[str] = None,
648 name: Optional[str] = None,
649 namespace: Optional[str] = None,
650 mime_type: Optional[str] = None,
651 public_id: Optional[str] = None,
652 owner: Optional[jid.JID] = None,
653 access: Optional[dict] = None,
654 projection: Optional[List[str]] = None,
655 unique: bool = False
656 ) -> List[dict]:
657 """Retrieve files with with given filters
658
659 @param file_id: id of the file
660 None to ignore
661 @param version: version of the file
662 None to ignore
663 empty string to look for current version
664 @param parent: id of the directory containing the files
665 None to ignore
666 empty string to look for root files/directories
667 @param projection: name of columns to retrieve
668 None to retrieve all
669 @param unique: if True will remove duplicates
670 other params are the same as for [setFile]
671 @return: files corresponding to filters
672 """
673 if projection is None:
674 projection = [
675 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name',
676 'size', 'namespace', 'media_type', 'media_subtype', 'public_id',
677 'created', 'modified', 'owner', 'access', 'extra'
678 ]
679
680 stmt = select(*[getattr(File, f) for f in projection])
681
682 if unique:
683 stmt = stmt.distinct()
684
685 if client is not None:
686 stmt = stmt.filter_by(profile_id=self.profiles[client.profile])
687 else:
688 if public_id is None:
689 raise exceptions.InternalError(
690 "client can only be omitted when public_id is set"
691 )
692 if file_id is not None:
693 stmt = stmt.filter_by(id=file_id)
694 if version is not None:
695 stmt = stmt.filter_by(version=version)
696 if parent is not None:
697 stmt = stmt.filter_by(parent=parent)
698 if type_ is not None:
699 stmt = stmt.filter_by(type=type_)
700 if file_hash is not None:
701 stmt = stmt.filter_by(file_hash=file_hash)
702 if hash_algo is not None:
703 stmt = stmt.filter_by(hash_algo=hash_algo)
704 if name is not None:
705 stmt = stmt.filter_by(name=name)
706 if namespace is not None:
707 stmt = stmt.filter_by(namespace=namespace)
708 if mime_type is not None:
709 if '/' in mime_type:
710 media_type, media_subtype = mime_type.split("/", 1)
711 stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype)
712 else:
713 stmt = stmt.filter_by(media_type=mime_type)
714 if public_id is not None:
715 stmt = stmt.filter_by(public_id=public_id)
716 if owner is not None:
717 stmt = stmt.filter_by(owner=owner)
718 if access is not None:
719 raise NotImplementedError('Access check is not implemented yet')
720 # a JSON comparison is needed here
721
722 async with self.session() as session:
723 result = await session.execute(stmt)
724
725 return [dict(r) for r in result]
726
727 @aio
728 async def setFile(
729 self,
730 client: SatXMPPEntity,
731 name: str,
732 file_id: str,
733 version: str = "",
734 parent: str = "",
735 type_: str = C.FILE_TYPE_FILE,
736 file_hash: Optional[str] = None,
737 hash_algo: Optional[str] = None,
738 size: int = None,
739 namespace: Optional[str] = None,
740 mime_type: Optional[str] = None,
741 public_id: Optional[str] = None,
742 created: Optional[float] = None,
743 modified: Optional[float] = None,
744 owner: Optional[jid.JID] = None,
745 access: Optional[dict] = None,
746 extra: Optional[dict] = None
747 ) -> None:
748 """Set a file metadata
749
750 @param client: client owning the file
751 @param name: name of the file (must not contain "/")
752 @param file_id: unique id of the file
753 @param version: version of this file
754 @param parent: id of the directory containing this file
755 Empty string if it is a root file/directory
756 @param type_: one of:
757 - file
758 - directory
759 @param file_hash: unique hash of the payload
760 @param hash_algo: algorithm used for hashing the file (usually sha-256)
761 @param size: size in bytes
762 @param namespace: identifier (human readable is better) to group files
763 for instance, namespace could be used to group files in a specific photo album
764 @param mime_type: media type of the file, or None if not known/guessed
765 @param public_id: ID used to server the file publicly via HTTP
766 @param created: UNIX time of creation
767 @param modified: UNIX time of last modification, or None to use created date
768 @param owner: jid of the owner of the file (mainly useful for component)
769 @param access: serialisable dictionary with access rules. See [memory.memory] for details
770 @param extra: serialisable dictionary of any extra data
771 will be encoded to json in database
772 """
773 if mime_type is None:
774 media_type = media_subtype = None
775 elif '/' in mime_type:
776 media_type, media_subtype = mime_type.split('/', 1)
777 else:
778 media_type, media_subtype = mime_type, None
779
780 async with self.session() as session:
781 async with session.begin():
782 session.add(File(
783 id=file_id,
784 version=version.strip(),
785 parent=parent,
786 type=type_,
787 file_hash=file_hash,
788 hash_algo=hash_algo,
789 name=name,
790 size=size,
791 namespace=namespace,
792 media_type=media_type,
793 media_subtype=media_subtype,
794 public_id=public_id,
795 created=time.time() if created is None else created,
796 modified=modified,
797 owner=owner,
798 access=access,
799 extra=extra,
800 profile_id=self.profiles[client.profile]
801 ))
802
803 @aio
804 async def fileGetUsedSpace(self, client: SatXMPPEntity, owner: jid.JID) -> int:
805 async with self.session() as session:
806 result = await session.execute(
807 select(sum_(File.size)).filter_by(
808 owner=owner,
809 type=C.FILE_TYPE_FILE,
810 profile_id=self.profiles[client.profile]
811 ))
812 return result.scalar_one_or_none() or 0
813
814 @aio
815 async def fileDelete(self, file_id: str) -> None:
816 """Delete file metadata from the database
817
818 @param file_id: id of the file to delete
819 NOTE: file itself must still be removed, this method only handle metadata in
820 database
821 """
822 async with self.session() as session:
823 await session.execute(delete(File).filter_by(id=file_id))
824 await session.commit()
825
826 @aio
827 async def fileUpdate(
828 self,
829 file_id: str,
830 column: str,
831 update_cb: Callable[[dict], None]
832 ) -> None:
833 """Update a column value using a method to avoid race conditions
834
835 the older value will be retrieved from database, then update_cb will be applied to
836 update it, and file will be updated checking that older value has not been changed
837 meanwhile by an other user. If it has changed, it tries again a couple of times
838 before failing
839 @param column: column name (only "access" or "extra" are allowed)
840 @param update_cb: method to update the value of the colum
841 the method will take older value as argument, and must update it in place
842 update_cb must not care about serialization,
843 it get the deserialized data (i.e. a Python object) directly
844 @raise exceptions.NotFound: there is not file with this id
845 """
846 if column not in ('access', 'extra'):
847 raise exceptions.InternalError('bad column name')
848 orm_col = getattr(File, column)
849
850 for i in range(5):
851 async with self.session() as session:
852 try:
853 value = (await session.execute(
854 select(orm_col).filter_by(id=file_id)
855 )).scalar_one()
856 except NoResultFound:
857 raise exceptions.NotFound
858 update_cb(value)
859 stmt = update(orm_col).filter_by(id=file_id)
860 if not value:
861 # because JsonDefaultDict convert NULL to an empty dict, we have to
862 # test both for empty dict and None when we have and empty dict
863 stmt = stmt.where((orm_col == None) | (orm_col == value))
864 else:
865 stmt = stmt.where(orm_col == value)
866 result = await session.execute(stmt)
867 await session.commit()
868
869 if result.rowcount == 1:
870 break
871
872 log.warning(
873 _("table not updated, probably due to race condition, trying again "
874 "({tries})").format(tries=i+1)
875 )
876
877 else:
878 raise exceptions.DatabaseError(
879 _("Can't update file {file_id} due to race condition")
880 .format(file_id=file_id)
881 )