Mercurial > libervia-backend
comparison sat/memory/sqla.py @ 3595:7510648e8e3a
core (memory/sqla): methods to manipulate pubsub tables
author | Goffi <goffi@goffi.org> |
---|---|
date | Thu, 29 Jul 2021 22:51:01 +0200 |
parents | 16ade4ad63f3 |
children | 2d97c695af05 |
comparison
equal
deleted
inserted
replaced
3594:d5116197e403 | 3595:7510648e8e3a |
---|---|
17 # along with this program. If not, see <http://www.gnu.org/licenses/>. | 17 # along with this program. If not, see <http://www.gnu.org/licenses/>. |
18 | 18 |
19 import sys | 19 import sys |
20 import time | 20 import time |
21 import asyncio | 21 import asyncio |
22 from datetime import datetime | |
22 from asyncio.subprocess import PIPE | 23 from asyncio.subprocess import PIPE |
23 from pathlib import Path | 24 from pathlib import Path |
24 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional | 25 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional |
25 from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine | 26 from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine |
26 from sqlalchemy.exc import IntegrityError, NoResultFound | 27 from sqlalchemy.exc import IntegrityError, NoResultFound |
27 from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager | 28 from sqlalchemy.orm import ( |
29 sessionmaker, subqueryload, joinedload, contains_eager # , aliased | |
30 ) | |
31 from sqlalchemy.orm.decl_api import DeclarativeMeta | |
28 from sqlalchemy.future import select | 32 from sqlalchemy.future import select |
29 from sqlalchemy.engine import Engine, Connection | 33 from sqlalchemy.engine import Engine, Connection |
30 from sqlalchemy import update, delete, and_, or_, event | 34 from sqlalchemy import update, delete, and_, or_, event, func |
31 from sqlalchemy.sql.functions import coalesce, sum as sum_ | 35 from sqlalchemy.sql.functions import coalesce, sum as sum_, now, count |
32 from sqlalchemy.dialects.sqlite import insert | 36 from sqlalchemy.dialects.sqlite import insert |
33 from alembic import script as al_script, config as al_config | 37 from alembic import script as al_script, config as al_config |
34 from alembic.runtime import migration as al_migration | 38 from alembic.runtime import migration as al_migration |
35 from twisted.internet import defer | 39 from twisted.internet import defer |
36 from twisted.words.protocols.jabber import jid | 40 from twisted.words.protocols.jabber import jid |
41 from twisted.words.xish import domish | |
37 from sat.core.i18n import _ | 42 from sat.core.i18n import _ |
38 from sat.core import exceptions | 43 from sat.core import exceptions |
39 from sat.core.log import getLogger | 44 from sat.core.log import getLogger |
40 from sat.core.constants import Const as C | 45 from sat.core.constants import Const as C |
41 from sat.core.core_types import SatXMPPEntity | 46 from sat.core.core_types import SatXMPPEntity |
42 from sat.tools.utils import aio | 47 from sat.tools.utils import aio |
48 from sat.tools.common import uri | |
43 from sat.memory import migration | 49 from sat.memory import migration |
44 from sat.memory import sqla_config | 50 from sat.memory import sqla_config |
45 from sat.memory.sqla_mapping import ( | 51 from sat.memory.sqla_mapping import ( |
46 NOT_IN_EXTRA, | 52 NOT_IN_EXTRA, |
53 SyncState, | |
47 Base, | 54 Base, |
48 Profile, | 55 Profile, |
49 Component, | 56 Component, |
50 History, | 57 History, |
51 Message, | 58 Message, |
55 ParamInd, | 62 ParamInd, |
56 PrivateGen, | 63 PrivateGen, |
57 PrivateInd, | 64 PrivateInd, |
58 PrivateGenBin, | 65 PrivateGenBin, |
59 PrivateIndBin, | 66 PrivateIndBin, |
60 File | 67 File, |
68 PubsubNode, | |
69 PubsubItem, | |
61 ) | 70 ) |
62 | 71 |
63 | 72 |
64 log = getLogger(__name__) | 73 log = getLogger(__name__) |
65 migration_path = Path(migration.__file__).parent | 74 migration_path = Path(migration.__file__).parent |
151 log.info(_("Connecting database")) | 160 log.info(_("Connecting database")) |
152 | 161 |
153 db_config = sqla_config.getDbConfig() | 162 db_config = sqla_config.getDbConfig() |
154 engine = create_async_engine( | 163 engine = create_async_engine( |
155 db_config["url"], | 164 db_config["url"], |
156 future=True | 165 future=True, |
157 ) | 166 ) |
158 | 167 |
159 new_base = not db_config["path"].exists() | 168 new_base = not db_config["path"].exists() |
160 if new_base: | 169 if new_base: |
161 log.info(_("The database is new, creating the tables")) | 170 log.info(_("The database is new, creating the tables")) |
950 else: | 959 else: |
951 raise exceptions.DatabaseError( | 960 raise exceptions.DatabaseError( |
952 _("Can't update file {file_id} due to race condition") | 961 _("Can't update file {file_id} due to race condition") |
953 .format(file_id=file_id) | 962 .format(file_id=file_id) |
954 ) | 963 ) |
964 | |
965 @aio | |
966 async def getPubsubNode( | |
967 self, | |
968 client: SatXMPPEntity, | |
969 service: jid.JID, | |
970 name: str, | |
971 with_items: bool = False, | |
972 ) -> Optional[PubsubNode]: | |
973 """ | |
974 """ | |
975 async with self.session() as session: | |
976 stmt = ( | |
977 select(PubsubNode) | |
978 .filter_by( | |
979 service=service, | |
980 name=name, | |
981 profile_id=self.profiles[client.profile], | |
982 ) | |
983 ) | |
984 if with_items: | |
985 stmt = stmt.options( | |
986 joinedload(PubsubNode.items) | |
987 ) | |
988 result = await session.execute(stmt) | |
989 return result.unique().scalar_one_or_none() | |
990 | |
991 @aio | |
992 async def setPubsubNode( | |
993 self, | |
994 client: SatXMPPEntity, | |
995 service: jid.JID, | |
996 name: str, | |
997 analyser: Optional[str] = None, | |
998 type_: Optional[str] = None, | |
999 subtype: Optional[str] = None, | |
1000 ) -> PubsubNode: | |
1001 node = PubsubNode( | |
1002 profile_id=self.profiles[client.profile], | |
1003 service=service, | |
1004 name=name, | |
1005 subscribed=False, | |
1006 analyser=analyser, | |
1007 type_=type_, | |
1008 subtype=subtype, | |
1009 ) | |
1010 async with self.session() as session: | |
1011 async with session.begin(): | |
1012 session.add(node) | |
1013 return node | |
1014 | |
1015 @aio | |
1016 async def updatePubsubNodeSyncState( | |
1017 self, | |
1018 node: PubsubNode, | |
1019 state: SyncState | |
1020 ) -> None: | |
1021 async with self.session() as session: | |
1022 async with session.begin(): | |
1023 await session.execute( | |
1024 update(PubsubNode) | |
1025 .filter_by(id=node.id) | |
1026 .values( | |
1027 sync_state=state, | |
1028 sync_state_updated=time.time(), | |
1029 ) | |
1030 ) | |
1031 | |
1032 @aio | |
1033 async def deletePubsubNode( | |
1034 self, | |
1035 profiles: Optional[List[str]], | |
1036 services: Optional[List[jid.JID]], | |
1037 names: Optional[List[str]] | |
1038 ) -> None: | |
1039 """Delete items cached for a node | |
1040 | |
1041 @param profiles: profile names from which nodes must be deleted. | |
1042 None to remove nodes from ALL profiles | |
1043 @param services: JIDs of pubsub services from which nodes must be deleted. | |
1044 None to remove nodes from ALL services | |
1045 @param names: names of nodes which must be deleted. | |
1046 None to remove ALL nodes whatever is their names | |
1047 """ | |
1048 stmt = delete(PubsubNode) | |
1049 if profiles is not None: | |
1050 stmt = stmt.where( | |
1051 PubsubNode.profile.in_( | |
1052 [self.profiles[p] for p in profiles] | |
1053 ) | |
1054 ) | |
1055 if services is not None: | |
1056 stmt = stmt.where(PubsubNode.service.in_(services)) | |
1057 if names is not None: | |
1058 stmt = stmt.where(PubsubNode.name.in_(names)) | |
1059 async with self.session() as session: | |
1060 await session.execute(stmt) | |
1061 await session.commit() | |
1062 | |
1063 @aio | |
1064 async def cachePubsubItems( | |
1065 self, | |
1066 client: SatXMPPEntity, | |
1067 node: PubsubNode, | |
1068 items: List[domish.Element], | |
1069 parsed_items: Optional[List[dict]] = None, | |
1070 ) -> None: | |
1071 """Add items to database, using an upsert taking care of "updated" field""" | |
1072 if parsed_items is not None and len(items) != len(parsed_items): | |
1073 raise exceptions.InternalError( | |
1074 "parsed_items must have the same lenght as items" | |
1075 ) | |
1076 async with self.session() as session: | |
1077 async with session.begin(): | |
1078 for idx, item in enumerate(items): | |
1079 parsed = parsed_items[idx] if parsed_items else None | |
1080 stmt = insert(PubsubItem).values( | |
1081 node_id = node.id, | |
1082 name = item["id"], | |
1083 data = item, | |
1084 parsed = parsed, | |
1085 ).on_conflict_do_update( | |
1086 index_elements=(PubsubItem.node_id, PubsubItem.name), | |
1087 set_={ | |
1088 PubsubItem.data: item, | |
1089 PubsubItem.parsed: parsed, | |
1090 PubsubItem.updated: now() | |
1091 } | |
1092 ) | |
1093 await session.execute(stmt) | |
1094 await session.commit() | |
1095 | |
1096 @aio | |
1097 async def deletePubsubItems( | |
1098 self, | |
1099 node: PubsubNode, | |
1100 items_names: Optional[List[str]] = None | |
1101 ) -> None: | |
1102 """Delete items cached for a node | |
1103 | |
1104 @param node: node from which items must be deleted | |
1105 @param items_names: names of items to delete | |
1106 if None, ALL items will be deleted | |
1107 """ | |
1108 stmt = delete(PubsubItem) | |
1109 if node is not None: | |
1110 if isinstance(node, list): | |
1111 stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node])) | |
1112 else: | |
1113 stmt = stmt.filter_by(node_id=node.id) | |
1114 if items_names is not None: | |
1115 stmt = stmt.where(PubsubItem.name.in_(items_names)) | |
1116 async with self.session() as session: | |
1117 await session.execute(stmt) | |
1118 await session.commit() | |
1119 | |
1120 @aio | |
1121 async def purgePubsubItems( | |
1122 self, | |
1123 services: Optional[List[jid.JID]] = None, | |
1124 names: Optional[List[str]] = None, | |
1125 types: Optional[List[str]] = None, | |
1126 subtypes: Optional[List[str]] = None, | |
1127 profiles: Optional[List[str]] = None, | |
1128 created_before: Optional[datetime] = None, | |
1129 updated_before: Optional[datetime] = None, | |
1130 ) -> None: | |
1131 """Delete items cached for a node | |
1132 | |
1133 @param node: node from which items must be deleted | |
1134 @param items_names: names of items to delete | |
1135 if None, ALL items will be deleted | |
1136 """ | |
1137 stmt = delete(PubsubItem) | |
1138 node_fields = { | |
1139 "service": services, | |
1140 "name": names, | |
1141 "type_": types, | |
1142 "subtype": subtypes, | |
1143 } | |
1144 if any(x is not None for x in node_fields.values()): | |
1145 sub_q = select(PubsubNode.id) | |
1146 for col, values in node_fields.items(): | |
1147 if values is None: | |
1148 continue | |
1149 sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) | |
1150 stmt = ( | |
1151 stmt | |
1152 .where(PubsubItem.node_id.in_(sub_q)) | |
1153 .execution_options(synchronize_session=False) | |
1154 ) | |
1155 | |
1156 if profiles is not None: | |
1157 stmt = stmt.where( | |
1158 PubsubItem.profile_id.in_([self.profiles[p] for p in profiles]) | |
1159 ) | |
1160 | |
1161 if created_before is not None: | |
1162 stmt = stmt.where(PubsubItem.created < created_before) | |
1163 | |
1164 if updated_before is not None: | |
1165 stmt = stmt.where(PubsubItem.updated < updated_before) | |
1166 | |
1167 async with self.session() as session: | |
1168 await session.execute(stmt) | |
1169 await session.commit() | |
1170 | |
1171 @aio | |
1172 async def getItems( | |
1173 self, | |
1174 node: PubsubNode, | |
1175 max_items: Optional[int] = None, | |
1176 before: Optional[str] = None, | |
1177 after: Optional[str] = None, | |
1178 from_index: Optional[int] = None, | |
1179 order_by: Optional[List[str]] = None, | |
1180 desc: bool = True, | |
1181 force_rsm: bool = False, | |
1182 ) -> Tuple[List[PubsubItem], dict]: | |
1183 """Get Pubsub Items from cache | |
1184 | |
1185 @param node: retrieve items from this node (must be synchronised) | |
1186 @param max_items: maximum number of items to retrieve | |
1187 @param before: get items which are before the item with this name in given order | |
1188 empty string is not managed here, use desc order to reproduce RSM | |
1189 behaviour. | |
1190 @param after: get items which are after the item with this name in given order | |
1191 @param from_index: get items with item index (as defined in RSM spec) | |
1192 starting from this number | |
1193 @param order_by: sorting order of items (one of C.ORDER_BY_*) | |
1194 @param desc: direction or ordering | |
1195 @param force_rsm: if True, force the use of RSM worklow. | |
1196 RSM workflow is automatically used if any of before, after or | |
1197 from_index is used, but if only RSM max_items is used, it won't be | |
1198 used by default. This parameter let's use RSM workflow in this | |
1199 case. Note that in addition to RSM metadata, the result will not be | |
1200 the same (max_items without RSM will returns most recent items, | |
1201 i.e. last items in modification order, while max_items with RSM | |
1202 will return the oldest ones (i.e. first items in modification | |
1203 order). | |
1204 to be used when max_items is used from RSM | |
1205 """ | |
1206 | |
1207 metadata = { | |
1208 "service": node.service, | |
1209 "node": node.name, | |
1210 "uri": uri.buildXMPPUri( | |
1211 "pubsub", | |
1212 path=node.service.full(), | |
1213 node=node.name, | |
1214 ), | |
1215 } | |
1216 if max_items is None: | |
1217 max_items = 20 | |
1218 | |
1219 use_rsm = any((before, after, from_index is not None)) | |
1220 if force_rsm and not use_rsm: | |
1221 # | |
1222 use_rsm = True | |
1223 from_index = 0 | |
1224 | |
1225 stmt = ( | |
1226 select(PubsubItem) | |
1227 .filter_by(node_id=node.id) | |
1228 .limit(max_items) | |
1229 ) | |
1230 | |
1231 if not order_by: | |
1232 order_by = [C.ORDER_BY_MODIFICATION] | |
1233 | |
1234 order = [] | |
1235 for order_type in order_by: | |
1236 if order_type == C.ORDER_BY_MODIFICATION: | |
1237 if desc: | |
1238 order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc())) | |
1239 else: | |
1240 order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc())) | |
1241 elif order_type == C.ORDER_BY_CREATION: | |
1242 if desc: | |
1243 order.append(PubsubItem.id.desc()) | |
1244 else: | |
1245 order.append(PubsubItem.id.asc()) | |
1246 else: | |
1247 raise exceptions.InternalError(f"Unknown order type {order_type!r}") | |
1248 | |
1249 stmt = stmt.order_by(*order) | |
1250 | |
1251 if use_rsm: | |
1252 # CTE to have result row numbers | |
1253 row_num_q = select( | |
1254 PubsubItem.id, | |
1255 PubsubItem.name, | |
1256 # row_number starts from 1, but RSM index must start from 0 | |
1257 (func.row_number().over(order_by=order)-1).label("item_index") | |
1258 ).filter_by(node_id=node.id) | |
1259 | |
1260 row_num_cte = row_num_q.cte() | |
1261 | |
1262 if max_items > 0: | |
1263 # as we can't simply use PubsubItem.id when we order by modification, | |
1264 # we need to use row number | |
1265 item_name = before or after | |
1266 row_num_limit_q = ( | |
1267 select(row_num_cte.c.item_index) | |
1268 .where(row_num_cte.c.name==item_name) | |
1269 ).scalar_subquery() | |
1270 | |
1271 stmt = ( | |
1272 select(row_num_cte.c.item_index, PubsubItem) | |
1273 .join(row_num_cte, PubsubItem.id == row_num_cte.c.id) | |
1274 .limit(max_items) | |
1275 ) | |
1276 if before: | |
1277 stmt = ( | |
1278 stmt | |
1279 .where(row_num_cte.c.item_index<row_num_limit_q) | |
1280 .order_by(row_num_cte.c.item_index.desc()) | |
1281 ) | |
1282 elif after: | |
1283 stmt = ( | |
1284 stmt | |
1285 .where(row_num_cte.c.item_index>row_num_limit_q) | |
1286 .order_by(row_num_cte.c.item_index.asc()) | |
1287 ) | |
1288 else: | |
1289 stmt = ( | |
1290 stmt | |
1291 .where(row_num_cte.c.item_index>=from_index) | |
1292 .order_by(row_num_cte.c.item_index.asc()) | |
1293 ) | |
1294 # from_index is used | |
1295 | |
1296 async with self.session() as session: | |
1297 if max_items == 0: | |
1298 items = result = [] | |
1299 else: | |
1300 result = await session.execute(stmt) | |
1301 result = result.all() | |
1302 if before: | |
1303 result.reverse() | |
1304 items = [row[-1] for row in result] | |
1305 rows_count = ( | |
1306 await session.execute(row_num_q.with_only_columns(count())) | |
1307 ).scalar_one() | |
1308 | |
1309 try: | |
1310 index = result[0][0] | |
1311 except IndexError: | |
1312 index = None | |
1313 | |
1314 try: | |
1315 first = result[0][1].name | |
1316 except IndexError: | |
1317 first = None | |
1318 last = None | |
1319 else: | |
1320 last = result[-1][1].name | |
1321 | |
1322 | |
1323 metadata["rsm"] = { | |
1324 "index": index, | |
1325 "count": rows_count, | |
1326 "first": first, | |
1327 "last": last, | |
1328 } | |
1329 metadata["complete"] = index + len(result) == rows_count | |
1330 | |
1331 return items, metadata | |
1332 | |
1333 async with self.session() as session: | |
1334 result = await session.execute(stmt) | |
1335 | |
1336 result = result.scalars().all() | |
1337 if desc: | |
1338 result.reverse() | |
1339 return result, metadata |