comparison sat/memory/sqla.py @ 3595:7510648e8e3a

core (memory/sqla): methods to manipulate pubsub tables
author Goffi <goffi@goffi.org>
date Thu, 29 Jul 2021 22:51:01 +0200
parents 16ade4ad63f3
children 2d97c695af05
comparison
equal deleted inserted replaced
3594:d5116197e403 3595:7510648e8e3a
17 # along with this program. If not, see <http://www.gnu.org/licenses/>. 17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 18
19 import sys 19 import sys
20 import time 20 import time
21 import asyncio 21 import asyncio
22 from datetime import datetime
22 from asyncio.subprocess import PIPE 23 from asyncio.subprocess import PIPE
23 from pathlib import Path 24 from pathlib import Path
24 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional 25 from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional
25 from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine 26 from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine
26 from sqlalchemy.exc import IntegrityError, NoResultFound 27 from sqlalchemy.exc import IntegrityError, NoResultFound
27 from sqlalchemy.orm import sessionmaker, subqueryload, contains_eager 28 from sqlalchemy.orm import (
29 sessionmaker, subqueryload, joinedload, contains_eager # , aliased
30 )
31 from sqlalchemy.orm.decl_api import DeclarativeMeta
28 from sqlalchemy.future import select 32 from sqlalchemy.future import select
29 from sqlalchemy.engine import Engine, Connection 33 from sqlalchemy.engine import Engine, Connection
30 from sqlalchemy import update, delete, and_, or_, event 34 from sqlalchemy import update, delete, and_, or_, event, func
31 from sqlalchemy.sql.functions import coalesce, sum as sum_ 35 from sqlalchemy.sql.functions import coalesce, sum as sum_, now, count
32 from sqlalchemy.dialects.sqlite import insert 36 from sqlalchemy.dialects.sqlite import insert
33 from alembic import script as al_script, config as al_config 37 from alembic import script as al_script, config as al_config
34 from alembic.runtime import migration as al_migration 38 from alembic.runtime import migration as al_migration
35 from twisted.internet import defer 39 from twisted.internet import defer
36 from twisted.words.protocols.jabber import jid 40 from twisted.words.protocols.jabber import jid
41 from twisted.words.xish import domish
37 from sat.core.i18n import _ 42 from sat.core.i18n import _
38 from sat.core import exceptions 43 from sat.core import exceptions
39 from sat.core.log import getLogger 44 from sat.core.log import getLogger
40 from sat.core.constants import Const as C 45 from sat.core.constants import Const as C
41 from sat.core.core_types import SatXMPPEntity 46 from sat.core.core_types import SatXMPPEntity
42 from sat.tools.utils import aio 47 from sat.tools.utils import aio
48 from sat.tools.common import uri
43 from sat.memory import migration 49 from sat.memory import migration
44 from sat.memory import sqla_config 50 from sat.memory import sqla_config
45 from sat.memory.sqla_mapping import ( 51 from sat.memory.sqla_mapping import (
46 NOT_IN_EXTRA, 52 NOT_IN_EXTRA,
53 SyncState,
47 Base, 54 Base,
48 Profile, 55 Profile,
49 Component, 56 Component,
50 History, 57 History,
51 Message, 58 Message,
55 ParamInd, 62 ParamInd,
56 PrivateGen, 63 PrivateGen,
57 PrivateInd, 64 PrivateInd,
58 PrivateGenBin, 65 PrivateGenBin,
59 PrivateIndBin, 66 PrivateIndBin,
60 File 67 File,
68 PubsubNode,
69 PubsubItem,
61 ) 70 )
62 71
63 72
64 log = getLogger(__name__) 73 log = getLogger(__name__)
65 migration_path = Path(migration.__file__).parent 74 migration_path = Path(migration.__file__).parent
151 log.info(_("Connecting database")) 160 log.info(_("Connecting database"))
152 161
153 db_config = sqla_config.getDbConfig() 162 db_config = sqla_config.getDbConfig()
154 engine = create_async_engine( 163 engine = create_async_engine(
155 db_config["url"], 164 db_config["url"],
156 future=True 165 future=True,
157 ) 166 )
158 167
159 new_base = not db_config["path"].exists() 168 new_base = not db_config["path"].exists()
160 if new_base: 169 if new_base:
161 log.info(_("The database is new, creating the tables")) 170 log.info(_("The database is new, creating the tables"))
950 else: 959 else:
951 raise exceptions.DatabaseError( 960 raise exceptions.DatabaseError(
952 _("Can't update file {file_id} due to race condition") 961 _("Can't update file {file_id} due to race condition")
953 .format(file_id=file_id) 962 .format(file_id=file_id)
954 ) 963 )
964
965 @aio
966 async def getPubsubNode(
967 self,
968 client: SatXMPPEntity,
969 service: jid.JID,
970 name: str,
971 with_items: bool = False,
972 ) -> Optional[PubsubNode]:
973 """
974 """
975 async with self.session() as session:
976 stmt = (
977 select(PubsubNode)
978 .filter_by(
979 service=service,
980 name=name,
981 profile_id=self.profiles[client.profile],
982 )
983 )
984 if with_items:
985 stmt = stmt.options(
986 joinedload(PubsubNode.items)
987 )
988 result = await session.execute(stmt)
989 return result.unique().scalar_one_or_none()
990
991 @aio
992 async def setPubsubNode(
993 self,
994 client: SatXMPPEntity,
995 service: jid.JID,
996 name: str,
997 analyser: Optional[str] = None,
998 type_: Optional[str] = None,
999 subtype: Optional[str] = None,
1000 ) -> PubsubNode:
1001 node = PubsubNode(
1002 profile_id=self.profiles[client.profile],
1003 service=service,
1004 name=name,
1005 subscribed=False,
1006 analyser=analyser,
1007 type_=type_,
1008 subtype=subtype,
1009 )
1010 async with self.session() as session:
1011 async with session.begin():
1012 session.add(node)
1013 return node
1014
1015 @aio
1016 async def updatePubsubNodeSyncState(
1017 self,
1018 node: PubsubNode,
1019 state: SyncState
1020 ) -> None:
1021 async with self.session() as session:
1022 async with session.begin():
1023 await session.execute(
1024 update(PubsubNode)
1025 .filter_by(id=node.id)
1026 .values(
1027 sync_state=state,
1028 sync_state_updated=time.time(),
1029 )
1030 )
1031
1032 @aio
1033 async def deletePubsubNode(
1034 self,
1035 profiles: Optional[List[str]],
1036 services: Optional[List[jid.JID]],
1037 names: Optional[List[str]]
1038 ) -> None:
1039 """Delete items cached for a node
1040
1041 @param profiles: profile names from which nodes must be deleted.
1042 None to remove nodes from ALL profiles
1043 @param services: JIDs of pubsub services from which nodes must be deleted.
1044 None to remove nodes from ALL services
1045 @param names: names of nodes which must be deleted.
1046 None to remove ALL nodes whatever is their names
1047 """
1048 stmt = delete(PubsubNode)
1049 if profiles is not None:
1050 stmt = stmt.where(
1051 PubsubNode.profile.in_(
1052 [self.profiles[p] for p in profiles]
1053 )
1054 )
1055 if services is not None:
1056 stmt = stmt.where(PubsubNode.service.in_(services))
1057 if names is not None:
1058 stmt = stmt.where(PubsubNode.name.in_(names))
1059 async with self.session() as session:
1060 await session.execute(stmt)
1061 await session.commit()
1062
1063 @aio
1064 async def cachePubsubItems(
1065 self,
1066 client: SatXMPPEntity,
1067 node: PubsubNode,
1068 items: List[domish.Element],
1069 parsed_items: Optional[List[dict]] = None,
1070 ) -> None:
1071 """Add items to database, using an upsert taking care of "updated" field"""
1072 if parsed_items is not None and len(items) != len(parsed_items):
1073 raise exceptions.InternalError(
1074 "parsed_items must have the same lenght as items"
1075 )
1076 async with self.session() as session:
1077 async with session.begin():
1078 for idx, item in enumerate(items):
1079 parsed = parsed_items[idx] if parsed_items else None
1080 stmt = insert(PubsubItem).values(
1081 node_id = node.id,
1082 name = item["id"],
1083 data = item,
1084 parsed = parsed,
1085 ).on_conflict_do_update(
1086 index_elements=(PubsubItem.node_id, PubsubItem.name),
1087 set_={
1088 PubsubItem.data: item,
1089 PubsubItem.parsed: parsed,
1090 PubsubItem.updated: now()
1091 }
1092 )
1093 await session.execute(stmt)
1094 await session.commit()
1095
1096 @aio
1097 async def deletePubsubItems(
1098 self,
1099 node: PubsubNode,
1100 items_names: Optional[List[str]] = None
1101 ) -> None:
1102 """Delete items cached for a node
1103
1104 @param node: node from which items must be deleted
1105 @param items_names: names of items to delete
1106 if None, ALL items will be deleted
1107 """
1108 stmt = delete(PubsubItem)
1109 if node is not None:
1110 if isinstance(node, list):
1111 stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node]))
1112 else:
1113 stmt = stmt.filter_by(node_id=node.id)
1114 if items_names is not None:
1115 stmt = stmt.where(PubsubItem.name.in_(items_names))
1116 async with self.session() as session:
1117 await session.execute(stmt)
1118 await session.commit()
1119
1120 @aio
1121 async def purgePubsubItems(
1122 self,
1123 services: Optional[List[jid.JID]] = None,
1124 names: Optional[List[str]] = None,
1125 types: Optional[List[str]] = None,
1126 subtypes: Optional[List[str]] = None,
1127 profiles: Optional[List[str]] = None,
1128 created_before: Optional[datetime] = None,
1129 updated_before: Optional[datetime] = None,
1130 ) -> None:
1131 """Delete items cached for a node
1132
1133 @param node: node from which items must be deleted
1134 @param items_names: names of items to delete
1135 if None, ALL items will be deleted
1136 """
1137 stmt = delete(PubsubItem)
1138 node_fields = {
1139 "service": services,
1140 "name": names,
1141 "type_": types,
1142 "subtype": subtypes,
1143 }
1144 if any(x is not None for x in node_fields.values()):
1145 sub_q = select(PubsubNode.id)
1146 for col, values in node_fields.items():
1147 if values is None:
1148 continue
1149 sub_q = sub_q.where(getattr(PubsubNode, col).in_(values))
1150 stmt = (
1151 stmt
1152 .where(PubsubItem.node_id.in_(sub_q))
1153 .execution_options(synchronize_session=False)
1154 )
1155
1156 if profiles is not None:
1157 stmt = stmt.where(
1158 PubsubItem.profile_id.in_([self.profiles[p] for p in profiles])
1159 )
1160
1161 if created_before is not None:
1162 stmt = stmt.where(PubsubItem.created < created_before)
1163
1164 if updated_before is not None:
1165 stmt = stmt.where(PubsubItem.updated < updated_before)
1166
1167 async with self.session() as session:
1168 await session.execute(stmt)
1169 await session.commit()
1170
1171 @aio
1172 async def getItems(
1173 self,
1174 node: PubsubNode,
1175 max_items: Optional[int] = None,
1176 before: Optional[str] = None,
1177 after: Optional[str] = None,
1178 from_index: Optional[int] = None,
1179 order_by: Optional[List[str]] = None,
1180 desc: bool = True,
1181 force_rsm: bool = False,
1182 ) -> Tuple[List[PubsubItem], dict]:
1183 """Get Pubsub Items from cache
1184
1185 @param node: retrieve items from this node (must be synchronised)
1186 @param max_items: maximum number of items to retrieve
1187 @param before: get items which are before the item with this name in given order
1188 empty string is not managed here, use desc order to reproduce RSM
1189 behaviour.
1190 @param after: get items which are after the item with this name in given order
1191 @param from_index: get items with item index (as defined in RSM spec)
1192 starting from this number
1193 @param order_by: sorting order of items (one of C.ORDER_BY_*)
1194 @param desc: direction or ordering
1195 @param force_rsm: if True, force the use of RSM worklow.
1196 RSM workflow is automatically used if any of before, after or
1197 from_index is used, but if only RSM max_items is used, it won't be
1198 used by default. This parameter let's use RSM workflow in this
1199 case. Note that in addition to RSM metadata, the result will not be
1200 the same (max_items without RSM will returns most recent items,
1201 i.e. last items in modification order, while max_items with RSM
1202 will return the oldest ones (i.e. first items in modification
1203 order).
1204 to be used when max_items is used from RSM
1205 """
1206
1207 metadata = {
1208 "service": node.service,
1209 "node": node.name,
1210 "uri": uri.buildXMPPUri(
1211 "pubsub",
1212 path=node.service.full(),
1213 node=node.name,
1214 ),
1215 }
1216 if max_items is None:
1217 max_items = 20
1218
1219 use_rsm = any((before, after, from_index is not None))
1220 if force_rsm and not use_rsm:
1221 #
1222 use_rsm = True
1223 from_index = 0
1224
1225 stmt = (
1226 select(PubsubItem)
1227 .filter_by(node_id=node.id)
1228 .limit(max_items)
1229 )
1230
1231 if not order_by:
1232 order_by = [C.ORDER_BY_MODIFICATION]
1233
1234 order = []
1235 for order_type in order_by:
1236 if order_type == C.ORDER_BY_MODIFICATION:
1237 if desc:
1238 order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc()))
1239 else:
1240 order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc()))
1241 elif order_type == C.ORDER_BY_CREATION:
1242 if desc:
1243 order.append(PubsubItem.id.desc())
1244 else:
1245 order.append(PubsubItem.id.asc())
1246 else:
1247 raise exceptions.InternalError(f"Unknown order type {order_type!r}")
1248
1249 stmt = stmt.order_by(*order)
1250
1251 if use_rsm:
1252 # CTE to have result row numbers
1253 row_num_q = select(
1254 PubsubItem.id,
1255 PubsubItem.name,
1256 # row_number starts from 1, but RSM index must start from 0
1257 (func.row_number().over(order_by=order)-1).label("item_index")
1258 ).filter_by(node_id=node.id)
1259
1260 row_num_cte = row_num_q.cte()
1261
1262 if max_items > 0:
1263 # as we can't simply use PubsubItem.id when we order by modification,
1264 # we need to use row number
1265 item_name = before or after
1266 row_num_limit_q = (
1267 select(row_num_cte.c.item_index)
1268 .where(row_num_cte.c.name==item_name)
1269 ).scalar_subquery()
1270
1271 stmt = (
1272 select(row_num_cte.c.item_index, PubsubItem)
1273 .join(row_num_cte, PubsubItem.id == row_num_cte.c.id)
1274 .limit(max_items)
1275 )
1276 if before:
1277 stmt = (
1278 stmt
1279 .where(row_num_cte.c.item_index<row_num_limit_q)
1280 .order_by(row_num_cte.c.item_index.desc())
1281 )
1282 elif after:
1283 stmt = (
1284 stmt
1285 .where(row_num_cte.c.item_index>row_num_limit_q)
1286 .order_by(row_num_cte.c.item_index.asc())
1287 )
1288 else:
1289 stmt = (
1290 stmt
1291 .where(row_num_cte.c.item_index>=from_index)
1292 .order_by(row_num_cte.c.item_index.asc())
1293 )
1294 # from_index is used
1295
1296 async with self.session() as session:
1297 if max_items == 0:
1298 items = result = []
1299 else:
1300 result = await session.execute(stmt)
1301 result = result.all()
1302 if before:
1303 result.reverse()
1304 items = [row[-1] for row in result]
1305 rows_count = (
1306 await session.execute(row_num_q.with_only_columns(count()))
1307 ).scalar_one()
1308
1309 try:
1310 index = result[0][0]
1311 except IndexError:
1312 index = None
1313
1314 try:
1315 first = result[0][1].name
1316 except IndexError:
1317 first = None
1318 last = None
1319 else:
1320 last = result[-1][1].name
1321
1322
1323 metadata["rsm"] = {
1324 "index": index,
1325 "count": rows_count,
1326 "first": first,
1327 "last": last,
1328 }
1329 metadata["complete"] = index + len(result) == rows_count
1330
1331 return items, metadata
1332
1333 async with self.session() as session:
1334 result = await session.execute(stmt)
1335
1336 result = result.scalars().all()
1337 if desc:
1338 result.reverse()
1339 return result, metadata