Mercurial > libervia-pubsub
comparison sat_pubsub/pgsql_storage.py @ 405:c56a728412f1
file organisation + setup refactoring:
- `/src` has been renamed to `/sat_pubsub`, this is the recommended naming convention
- revamped `setup.py` on the basis of SàT's `setup.py`
- added a `VERSION` which is the unique place where version number will now be set
- use same trick as in SàT to specify dev version (`D` at the end)
- use setuptools_scm to retrieve Mercurial hash when in dev version
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 16 Aug 2019 12:00:02 +0200 |
parents | src/pgsql_storage.py@1dc606612405 |
children | a58610ab2983 |
comparison
equal
deleted
inserted
replaced
404:105a0772eedd | 405:c56a728412f1 |
---|---|
1 #!/usr/bin/python | |
2 #-*- coding: utf-8 -*- | |
3 | |
4 # Copyright (c) 2012-2019 Jérôme Poisson | |
5 # Copyright (c) 2013-2016 Adrien Cossa | |
6 # Copyright (c) 2003-2011 Ralph Meijer | |
7 | |
8 | |
9 # This program is free software: you can redistribute it and/or modify | |
10 # it under the terms of the GNU Affero General Public License as published by | |
11 # the Free Software Foundation, either version 3 of the License, or | |
12 # (at your option) any later version. | |
13 | |
14 # This program is distributed in the hope that it will be useful, | |
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
17 # GNU Affero General Public License for more details. | |
18 | |
19 # You should have received a copy of the GNU Affero General Public License | |
20 # along with this program. If not, see <http://www.gnu.org/licenses/>. | |
21 # -- | |
22 | |
23 # This program is based on Idavoll (http://idavoll.ik.nu/), | |
24 # originaly written by Ralph Meijer (http://ralphm.net/blog/) | |
25 # It is sublicensed under AGPL v3 (or any later version) as allowed by the original | |
26 # license. | |
27 | |
28 # -- | |
29 | |
30 # Here is a copy of the original license: | |
31 | |
32 # Copyright (c) 2003-2011 Ralph Meijer | |
33 | |
34 # Permission is hereby granted, free of charge, to any person obtaining | |
35 # a copy of this software and associated documentation files (the | |
36 # "Software"), to deal in the Software without restriction, including | |
37 # without limitation the rights to use, copy, modify, merge, publish, | |
38 # distribute, sublicense, and/or sell copies of the Software, and to | |
39 # permit persons to whom the Software is furnished to do so, subject to | |
40 # the following conditions: | |
41 | |
42 # The above copyright notice and this permission notice shall be | |
43 # included in all copies or substantial portions of the Software. | |
44 | |
45 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |
46 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |
47 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | |
48 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | |
49 # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | |
50 # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | |
51 # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |
52 | |
53 | |
54 import copy, logging | |
55 | |
56 from zope.interface import implements | |
57 | |
58 from twisted.internet import reactor | |
59 from twisted.internet import defer | |
60 from twisted.words.protocols.jabber import jid | |
61 from twisted.python import log | |
62 | |
63 from wokkel import generic | |
64 from wokkel.pubsub import Subscription | |
65 | |
66 from sat_pubsub import error | |
67 from sat_pubsub import iidavoll | |
68 from sat_pubsub import const | |
69 from sat_pubsub import container | |
70 from sat_pubsub import exceptions | |
71 import uuid | |
72 import psycopg2 | |
73 import psycopg2.extensions | |
74 # we wants psycopg2 to return us unicode, not str | |
75 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) | |
76 psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) | |
77 | |
78 # parseXml manage str, but we get unicode | |
79 parseXml = lambda unicode_data: generic.parseXml(unicode_data.encode('utf-8')) | |
80 ITEMS_SEQ_NAME = u'node_{node_id}_seq' | |
81 PEP_COL_NAME = 'pep' | |
82 CURRENT_VERSION = '5' | |
83 # retrieve the maximum integer item id + 1 | |
84 NEXT_ITEM_ID_QUERY = r"SELECT COALESCE(max(item::integer)+1,1) as val from items where node_id={node_id} and item ~ E'^\\d+$'" | |
85 | |
86 | |
87 def withPEP(query, values, pep, recipient): | |
88 """Helper method to facilitate PEP management | |
89 | |
90 @param query: SQL query basis | |
91 @param values: current values to replace in query | |
92 @param pep(bool): True if we are in PEP mode | |
93 @param recipient(jid.JID): jid of the recipient | |
94 @return: query + PEP AND check, | |
95 recipient's bare jid is added to value if needed | |
96 """ | |
97 if pep: | |
98 pep_check="AND {}=%s".format(PEP_COL_NAME) | |
99 values=list(values) + [recipient.userhost()] | |
100 else: | |
101 pep_check="AND {} IS NULL".format(PEP_COL_NAME) | |
102 return "{} {}".format(query, pep_check), values | |
103 | |
104 | |
105 class Storage: | |
106 | |
107 implements(iidavoll.IStorage) | |
108 | |
109 defaultConfig = { | |
110 'leaf': { | |
111 const.OPT_PERSIST_ITEMS: True, | |
112 const.OPT_DELIVER_PAYLOADS: True, | |
113 const.OPT_SEND_LAST_PUBLISHED_ITEM: 'on_sub', | |
114 const.OPT_ACCESS_MODEL: const.VAL_AMODEL_DEFAULT, | |
115 const.OPT_PUBLISH_MODEL: const.VAL_PMODEL_DEFAULT, | |
116 const.OPT_SERIAL_IDS: False, | |
117 const.OPT_CONSISTENT_PUBLISHER: False, | |
118 }, | |
119 'collection': { | |
120 const.OPT_DELIVER_PAYLOADS: True, | |
121 const.OPT_SEND_LAST_PUBLISHED_ITEM: 'on_sub', | |
122 const.OPT_ACCESS_MODEL: const.VAL_AMODEL_DEFAULT, | |
123 const.OPT_PUBLISH_MODEL: const.VAL_PMODEL_DEFAULT, | |
124 } | |
125 } | |
126 | |
127 def __init__(self, dbpool): | |
128 self.dbpool = dbpool | |
129 d = self.dbpool.runQuery("SELECT value FROM metadata WHERE key='version'") | |
130 d.addCallbacks(self._checkVersion, self._versionEb) | |
131 | |
132 def _checkVersion(self, row): | |
133 version = row[0].value | |
134 if version != CURRENT_VERSION: | |
135 logging.error("Bad database schema version ({current}), please upgrade to {needed}".format( | |
136 current=version, needed=CURRENT_VERSION)) | |
137 reactor.stop() | |
138 | |
139 def _versionEb(self, failure): | |
140 logging.error("Can't check schema version: {reason}".format(reason=failure)) | |
141 reactor.stop() | |
142 | |
143 def _buildNode(self, row): | |
144 """Build a note class from database result row""" | |
145 configuration = {} | |
146 | |
147 if not row: | |
148 raise error.NodeNotFound() | |
149 | |
150 if row[2] == 'leaf': | |
151 configuration = { | |
152 'pubsub#persist_items': row[3], | |
153 'pubsub#deliver_payloads': row[4], | |
154 'pubsub#send_last_published_item': row[5], | |
155 const.OPT_ACCESS_MODEL:row[6], | |
156 const.OPT_PUBLISH_MODEL:row[7], | |
157 const.OPT_SERIAL_IDS:row[8], | |
158 const.OPT_CONSISTENT_PUBLISHER:row[9], | |
159 } | |
160 schema = row[10] | |
161 if schema is not None: | |
162 schema = parseXml(schema) | |
163 node = LeafNode(row[0], row[1], configuration, schema) | |
164 node.dbpool = self.dbpool | |
165 return node | |
166 elif row[2] == 'collection': | |
167 configuration = { | |
168 'pubsub#deliver_payloads': row[4], | |
169 'pubsub#send_last_published_item': row[5], | |
170 const.OPT_ACCESS_MODEL: row[6], | |
171 const.OPT_PUBLISH_MODEL:row[7], | |
172 } | |
173 node = CollectionNode(row[0], row[1], configuration, None) | |
174 node.dbpool = self.dbpool | |
175 return node | |
176 else: | |
177 raise ValueError("Unknown node type !") | |
178 | |
179 def getNodeById(self, nodeDbId): | |
180 """Get node using database ID insted of pubsub identifier | |
181 | |
182 @param nodeDbId(unicode): database ID | |
183 """ | |
184 return self.dbpool.runInteraction(self._getNodeById, nodeDbId) | |
185 | |
186 def _getNodeById(self, cursor, nodeDbId): | |
187 cursor.execute("""SELECT node_id, | |
188 node, | |
189 node_type, | |
190 persist_items, | |
191 deliver_payloads, | |
192 send_last_published_item, | |
193 access_model, | |
194 publish_model, | |
195 serial_ids, | |
196 consistent_publisher, | |
197 schema::text, | |
198 pep | |
199 FROM nodes | |
200 WHERE node_id=%s""", | |
201 (nodeDbId,)) | |
202 row = cursor.fetchone() | |
203 return self._buildNode(row) | |
204 | |
205 def getNode(self, nodeIdentifier, pep, recipient=None): | |
206 return self.dbpool.runInteraction(self._getNode, nodeIdentifier, pep, recipient) | |
207 | |
208 def _getNode(self, cursor, nodeIdentifier, pep, recipient): | |
209 cursor.execute(*withPEP("""SELECT node_id, | |
210 node, | |
211 node_type, | |
212 persist_items, | |
213 deliver_payloads, | |
214 send_last_published_item, | |
215 access_model, | |
216 publish_model, | |
217 serial_ids, | |
218 consistent_publisher, | |
219 schema::text, | |
220 pep | |
221 FROM nodes | |
222 WHERE node=%s""", | |
223 (nodeIdentifier,), pep, recipient)) | |
224 row = cursor.fetchone() | |
225 return self._buildNode(row) | |
226 | |
227 def getNodeIds(self, pep, recipient, allowed_accesses=None): | |
228 """retrieve ids of existing nodes | |
229 | |
230 @param pep(bool): True if it's a PEP request | |
231 @param recipient(jid.JID, None): recipient of the PEP request | |
232 @param allowed_accesses(None, set): only nodes with access | |
233 in this set will be returned | |
234 None to return all nodes | |
235 @return (list[unicode]): ids of nodes | |
236 """ | |
237 if not pep: | |
238 query = "SELECT node from nodes WHERE pep is NULL" | |
239 values = [] | |
240 else: | |
241 query = "SELECT node from nodes WHERE pep=%s" | |
242 values = [recipient.userhost()] | |
243 | |
244 if allowed_accesses is not None: | |
245 query += "AND access_model IN %s" | |
246 values.append(tuple(allowed_accesses)) | |
247 | |
248 d = self.dbpool.runQuery(query, values) | |
249 d.addCallback(lambda results: [r[0] for r in results]) | |
250 return d | |
251 | |
252 def createNode(self, nodeIdentifier, owner, config, schema, pep, recipient=None): | |
253 return self.dbpool.runInteraction(self._createNode, nodeIdentifier, | |
254 owner, config, schema, pep, recipient) | |
255 | |
256 def _createNode(self, cursor, nodeIdentifier, owner, config, schema, pep, recipient): | |
257 if config['pubsub#node_type'] != 'leaf': | |
258 raise error.NoCollections() | |
259 | |
260 owner = owner.userhost() | |
261 | |
262 try: | |
263 cursor.execute("""INSERT INTO nodes | |
264 (node, | |
265 node_type, | |
266 persist_items, | |
267 deliver_payloads, | |
268 send_last_published_item, | |
269 access_model, | |
270 publish_model, | |
271 serial_ids, | |
272 consistent_publisher, | |
273 schema, | |
274 pep) | |
275 VALUES | |
276 (%s, 'leaf', %s, %s, %s, %s, %s, %s, %s, %s, %s)""", | |
277 (nodeIdentifier, | |
278 config['pubsub#persist_items'], | |
279 config['pubsub#deliver_payloads'], | |
280 config['pubsub#send_last_published_item'], | |
281 config[const.OPT_ACCESS_MODEL], | |
282 config[const.OPT_PUBLISH_MODEL], | |
283 config[const.OPT_SERIAL_IDS], | |
284 config[const.OPT_CONSISTENT_PUBLISHER], | |
285 schema, | |
286 recipient.userhost() if pep else None | |
287 ) | |
288 ) | |
289 except cursor._pool.dbapi.IntegrityError as e: | |
290 if e.pgcode == "23505": | |
291 # unique_violation | |
292 raise error.NodeExists() | |
293 else: | |
294 raise error.InvalidConfigurationOption() | |
295 | |
296 cursor.execute(*withPEP("""SELECT node_id FROM nodes WHERE node=%s""", | |
297 (nodeIdentifier,), pep, recipient)); | |
298 node_id = cursor.fetchone()[0] | |
299 | |
300 cursor.execute("""SELECT 1 as bool from entities where jid=%s""", | |
301 (owner,)) | |
302 | |
303 if not cursor.fetchone(): | |
304 # XXX: we can NOT rely on the previous query! Commit is needed now because | |
305 # if the entry exists the next query will leave the database in a corrupted | |
306 # state: the solution is to rollback. I tried with other methods like | |
307 # "WHERE NOT EXISTS" but none of them worked, so the following solution | |
308 # looks like the sole - unless you have auto-commit on. More info | |
309 # about this issue: http://cssmay.com/question/tag/tag-psycopg2 | |
310 cursor.connection.commit() | |
311 try: | |
312 cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""", | |
313 (owner,)) | |
314 except psycopg2.IntegrityError as e: | |
315 cursor.connection.rollback() | |
316 logging.warning("during node creation: %s" % e.message) | |
317 | |
318 cursor.execute("""INSERT INTO affiliations | |
319 (node_id, entity_id, affiliation) | |
320 SELECT %s, entity_id, 'owner' FROM | |
321 (SELECT entity_id FROM entities | |
322 WHERE jid=%s) as e""", | |
323 (node_id, owner)) | |
324 | |
325 if config[const.OPT_ACCESS_MODEL] == const.VAL_AMODEL_PUBLISHER_ROSTER: | |
326 if const.OPT_ROSTER_GROUPS_ALLOWED in config: | |
327 allowed_groups = config[const.OPT_ROSTER_GROUPS_ALLOWED] | |
328 else: | |
329 allowed_groups = [] | |
330 for group in allowed_groups: | |
331 #TODO: check that group are actually in roster | |
332 cursor.execute("""INSERT INTO node_groups_authorized (node_id, groupname) | |
333 VALUES (%s,%s)""" , (node_id, group)) | |
334 # XXX: affiliations can't be set on during node creation (at least not with XEP-0060 alone) | |
335 # so whitelist affiliations need to be done afterward | |
336 | |
337 # no we may have to do extra things according to config options | |
338 default_conf = self.defaultConfig['leaf'] | |
339 # XXX: trigger works on node creation because OPT_SERIAL_IDS is False in defaultConfig | |
340 # if this value is changed, the _configurationTriggers method should be adapted. | |
341 Node._configurationTriggers(cursor, node_id, default_conf, config) | |
342 | |
343 def deleteNodeByDbId(self, db_id): | |
344 """Delete a node using directly its database id""" | |
345 return self.dbpool.runInteraction(self._deleteNodeByDbId, db_id) | |
346 | |
347 def _deleteNodeByDbId(self, cursor, db_id): | |
348 cursor.execute("""DELETE FROM nodes WHERE node_id=%s""", | |
349 (db_id,)) | |
350 | |
351 if cursor.rowcount != 1: | |
352 raise error.NodeNotFound() | |
353 | |
354 def deleteNode(self, nodeIdentifier, pep, recipient=None): | |
355 return self.dbpool.runInteraction(self._deleteNode, nodeIdentifier, pep, recipient) | |
356 | |
357 def _deleteNode(self, cursor, nodeIdentifier, pep, recipient): | |
358 cursor.execute(*withPEP("""DELETE FROM nodes WHERE node=%s""", | |
359 (nodeIdentifier,), pep, recipient)) | |
360 | |
361 if cursor.rowcount != 1: | |
362 raise error.NodeNotFound() | |
363 | |
364 def getAffiliations(self, entity, nodeIdentifier, pep, recipient=None): | |
365 return self.dbpool.runInteraction(self._getAffiliations, entity, nodeIdentifier, pep, recipient) | |
366 | |
367 def _getAffiliations(self, cursor, entity, nodeIdentifier, pep, recipient=None): | |
368 query = ["""SELECT node, affiliation FROM entities | |
369 NATURAL JOIN affiliations | |
370 NATURAL JOIN nodes | |
371 WHERE jid=%s"""] | |
372 args = [entity.userhost()] | |
373 | |
374 if nodeIdentifier is not None: | |
375 query.append("AND node=%s") | |
376 args.append(nodeIdentifier) | |
377 | |
378 cursor.execute(*withPEP(' '.join(query), args, pep, recipient)) | |
379 rows = cursor.fetchall() | |
380 return [tuple(r) for r in rows] | |
381 | |
382 def getSubscriptions(self, entity, nodeIdentifier=None, pep=False, recipient=None): | |
383 """retrieve subscriptions of an entity | |
384 | |
385 @param entity(jid.JID): entity to check | |
386 @param nodeIdentifier(unicode, None): node identifier | |
387 None to retrieve all subscriptions | |
388 @param pep: True if we are in PEP mode | |
389 @param recipient: jid of the recipient | |
390 """ | |
391 | |
392 def toSubscriptions(rows): | |
393 subscriptions = [] | |
394 for row in rows: | |
395 subscriber = jid.internJID('%s/%s' % (row.jid, | |
396 row.resource)) | |
397 subscription = Subscription(row.node, subscriber, row.state) | |
398 subscriptions.append(subscription) | |
399 return subscriptions | |
400 | |
401 query = ["""SELECT node, | |
402 jid, | |
403 resource, | |
404 state | |
405 FROM entities | |
406 NATURAL JOIN subscriptions | |
407 NATURAL JOIN nodes | |
408 WHERE jid=%s"""] | |
409 | |
410 args = [entity.userhost()] | |
411 | |
412 if nodeIdentifier is not None: | |
413 query.append("AND node=%s") | |
414 args.append(nodeIdentifier) | |
415 | |
416 d = self.dbpool.runQuery(*withPEP(' '.join(query), args, pep, recipient)) | |
417 d.addCallback(toSubscriptions) | |
418 return d | |
419 | |
420 def getDefaultConfiguration(self, nodeType): | |
421 return self.defaultConfig[nodeType].copy() | |
422 | |
423 def formatLastItems(self, result): | |
424 last_items = [] | |
425 for pep_jid_s, node, data, item_access_model in result: | |
426 pep_jid = jid.JID(pep_jid_s) | |
427 item = generic.stripNamespace(parseXml(data)) | |
428 last_items.append((pep_jid, node, item, item_access_model)) | |
429 return last_items | |
430 | |
431 def getLastItems(self, entities, nodes, node_accesses, item_accesses, pep): | |
432 """get last item for several nodes and entities in a single request""" | |
433 if not entities or not nodes or not node_accesses or not item_accesses: | |
434 raise ValueError("entities, nodes and accesses must not be empty") | |
435 if node_accesses != ('open',) or item_accesses != ('open',): | |
436 raise NotImplementedError('only "open" access model is handled for now') | |
437 if not pep: | |
438 raise NotImplementedError(u"getLastItems is only implemented for PEP at the moment") | |
439 d = self.dbpool.runQuery("""SELECT DISTINCT ON (node_id) pep, node, data::text, items.access_model | |
440 FROM items | |
441 NATURAL JOIN nodes | |
442 WHERE nodes.pep IN %s | |
443 AND node IN %s | |
444 AND nodes.access_model in %s | |
445 AND items.access_model in %s | |
446 ORDER BY node_id DESC, items.updated DESC""", | |
447 (tuple([e.userhost() for e in entities]), | |
448 nodes, | |
449 node_accesses, | |
450 item_accesses)) | |
451 d.addCallback(self.formatLastItems) | |
452 return d | |
453 | |
454 | |
455 class Node: | |
456 | |
457 implements(iidavoll.INode) | |
458 | |
459 def __init__(self, nodeDbId, nodeIdentifier, config, schema): | |
460 self.nodeDbId = nodeDbId | |
461 self.nodeIdentifier = nodeIdentifier | |
462 self._config = config | |
463 self._schema = schema | |
464 | |
465 def _checkNodeExists(self, cursor): | |
466 cursor.execute("""SELECT 1 as exist FROM nodes WHERE node_id=%s""", | |
467 (self.nodeDbId,)) | |
468 if not cursor.fetchone(): | |
469 raise error.NodeNotFound() | |
470 | |
471 def getType(self): | |
472 return self.nodeType | |
473 | |
474 def getOwners(self): | |
475 d = self.dbpool.runQuery("""SELECT jid FROM nodes NATURAL JOIN affiliations NATURAL JOIN entities WHERE node_id=%s and affiliation='owner'""", (self.nodeDbId,)) | |
476 d.addCallback(lambda rows: [jid.JID(r[0]) for r in rows]) | |
477 return d | |
478 | |
479 def getConfiguration(self): | |
480 return self._config | |
481 | |
482 def getNextId(self): | |
483 """return XMPP item id usable for next item to publish | |
484 | |
485 the return value will be next int if serila_ids is set, | |
486 else an UUID will be returned | |
487 """ | |
488 if self._config[const.OPT_SERIAL_IDS]: | |
489 d = self.dbpool.runQuery("SELECT nextval('{seq_name}')".format( | |
490 seq_name = ITEMS_SEQ_NAME.format(node_id=self.nodeDbId))) | |
491 d.addCallback(lambda rows: unicode(rows[0][0])) | |
492 return d | |
493 else: | |
494 return defer.succeed(unicode(uuid.uuid4())) | |
495 | |
496 @staticmethod | |
497 def _configurationTriggers(cursor, node_id, old_config, new_config): | |
498 """trigger database relative actions needed when a config is changed | |
499 | |
500 @param cursor(): current db cursor | |
501 @param node_id(unicode): database ID of the node | |
502 @param old_config(dict): config of the node before the change | |
503 @param new_config(dict): new options that will be changed | |
504 """ | |
505 serial_ids = new_config[const.OPT_SERIAL_IDS] | |
506 if serial_ids != old_config[const.OPT_SERIAL_IDS]: | |
507 # serial_ids option has been modified, | |
508 # we need to handle corresponding sequence | |
509 | |
510 # XXX: we use .format in following queries because values | |
511 # are generated by ourself | |
512 seq_name = ITEMS_SEQ_NAME.format(node_id=node_id) | |
513 if serial_ids: | |
514 # the next query get the max value +1 of all XMPP items ids | |
515 # which are integers, and default to 1 | |
516 cursor.execute(NEXT_ITEM_ID_QUERY.format(node_id=node_id)) | |
517 next_val = cursor.fetchone()[0] | |
518 cursor.execute("DROP SEQUENCE IF EXISTS {seq_name}".format(seq_name = seq_name)) | |
519 cursor.execute("CREATE SEQUENCE {seq_name} START {next_val} OWNED BY nodes.node_id".format( | |
520 seq_name = seq_name, | |
521 next_val = next_val)) | |
522 else: | |
523 cursor.execute("DROP SEQUENCE IF EXISTS {seq_name}".format(seq_name = seq_name)) | |
524 | |
525 def setConfiguration(self, options): | |
526 config = copy.copy(self._config) | |
527 | |
528 for option in options: | |
529 if option in config: | |
530 config[option] = options[option] | |
531 | |
532 d = self.dbpool.runInteraction(self._setConfiguration, config) | |
533 d.addCallback(self._setCachedConfiguration, config) | |
534 return d | |
535 | |
536 def _setConfiguration(self, cursor, config): | |
537 self._checkNodeExists(cursor) | |
538 self._configurationTriggers(cursor, self.nodeDbId, self._config, config) | |
539 cursor.execute("""UPDATE nodes SET persist_items=%s, | |
540 deliver_payloads=%s, | |
541 send_last_published_item=%s, | |
542 access_model=%s, | |
543 publish_model=%s, | |
544 serial_ids=%s, | |
545 consistent_publisher=%s | |
546 WHERE node_id=%s""", | |
547 (config[const.OPT_PERSIST_ITEMS], | |
548 config[const.OPT_DELIVER_PAYLOADS], | |
549 config[const.OPT_SEND_LAST_PUBLISHED_ITEM], | |
550 config[const.OPT_ACCESS_MODEL], | |
551 config[const.OPT_PUBLISH_MODEL], | |
552 config[const.OPT_SERIAL_IDS], | |
553 config[const.OPT_CONSISTENT_PUBLISHER], | |
554 self.nodeDbId)) | |
555 | |
556 def _setCachedConfiguration(self, void, config): | |
557 self._config = config | |
558 | |
559 def getSchema(self): | |
560 return self._schema | |
561 | |
562 def setSchema(self, schema): | |
563 d = self.dbpool.runInteraction(self._setSchema, schema) | |
564 d.addCallback(self._setCachedSchema, schema) | |
565 return d | |
566 | |
567 def _setSchema(self, cursor, schema): | |
568 self._checkNodeExists(cursor) | |
569 cursor.execute("""UPDATE nodes SET schema=%s | |
570 WHERE node_id=%s""", | |
571 (schema.toXml() if schema else None, | |
572 self.nodeDbId)) | |
573 | |
574 def _setCachedSchema(self, void, schema): | |
575 self._schema = schema | |
576 | |
577 def getMetaData(self): | |
578 config = copy.copy(self._config) | |
579 config["pubsub#node_type"] = self.nodeType | |
580 return config | |
581 | |
582 def getAffiliation(self, entity): | |
583 return self.dbpool.runInteraction(self._getAffiliation, entity) | |
584 | |
585 def _getAffiliation(self, cursor, entity): | |
586 self._checkNodeExists(cursor) | |
587 cursor.execute("""SELECT affiliation FROM affiliations | |
588 NATURAL JOIN nodes | |
589 NATURAL JOIN entities | |
590 WHERE node_id=%s AND jid=%s""", | |
591 (self.nodeDbId, | |
592 entity.userhost())) | |
593 | |
594 try: | |
595 return cursor.fetchone()[0] | |
596 except TypeError: | |
597 return None | |
598 | |
599 def getAccessModel(self): | |
600 return self._config[const.OPT_ACCESS_MODEL] | |
601 | |
602 def getSubscription(self, subscriber): | |
603 return self.dbpool.runInteraction(self._getSubscription, subscriber) | |
604 | |
605 def _getSubscription(self, cursor, subscriber): | |
606 self._checkNodeExists(cursor) | |
607 | |
608 userhost = subscriber.userhost() | |
609 resource = subscriber.resource or '' | |
610 | |
611 cursor.execute("""SELECT state FROM subscriptions | |
612 NATURAL JOIN nodes | |
613 NATURAL JOIN entities | |
614 WHERE node_id=%s AND jid=%s AND resource=%s""", | |
615 (self.nodeDbId, | |
616 userhost, | |
617 resource)) | |
618 | |
619 row = cursor.fetchone() | |
620 if not row: | |
621 return None | |
622 else: | |
623 return Subscription(self.nodeIdentifier, subscriber, row[0]) | |
624 | |
625 def getSubscriptions(self, state=None): | |
626 return self.dbpool.runInteraction(self._getSubscriptions, state) | |
627 | |
628 def _getSubscriptions(self, cursor, state): | |
629 self._checkNodeExists(cursor) | |
630 | |
631 query = """SELECT node, jid, resource, state, | |
632 subscription_type, subscription_depth | |
633 FROM subscriptions | |
634 NATURAL JOIN nodes | |
635 NATURAL JOIN entities | |
636 WHERE node_id=%s""" | |
637 values = [self.nodeDbId] | |
638 | |
639 if state: | |
640 query += " AND state=%s" | |
641 values.append(state) | |
642 | |
643 cursor.execute(query, values) | |
644 rows = cursor.fetchall() | |
645 | |
646 subscriptions = [] | |
647 for row in rows: | |
648 subscriber = jid.JID(u'%s/%s' % (row.jid, row.resource)) | |
649 | |
650 options = {} | |
651 if row.subscription_type: | |
652 options['pubsub#subscription_type'] = row.subscription_type; | |
653 if row.subscription_depth: | |
654 options['pubsub#subscription_depth'] = row.subscription_depth; | |
655 | |
656 subscriptions.append(Subscription(row.node, subscriber, | |
657 row.state, options)) | |
658 | |
659 return subscriptions | |
660 | |
661 def addSubscription(self, subscriber, state, config): | |
662 return self.dbpool.runInteraction(self._addSubscription, subscriber, | |
663 state, config) | |
664 | |
665 def _addSubscription(self, cursor, subscriber, state, config): | |
666 self._checkNodeExists(cursor) | |
667 | |
668 userhost = subscriber.userhost() | |
669 resource = subscriber.resource or '' | |
670 | |
671 subscription_type = config.get('pubsub#subscription_type') | |
672 subscription_depth = config.get('pubsub#subscription_depth') | |
673 | |
674 try: | |
675 cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""", | |
676 (userhost,)) | |
677 except cursor._pool.dbapi.IntegrityError: | |
678 cursor.connection.rollback() | |
679 | |
680 try: | |
681 cursor.execute("""INSERT INTO subscriptions | |
682 (node_id, entity_id, resource, state, | |
683 subscription_type, subscription_depth) | |
684 SELECT %s, entity_id, %s, %s, %s, %s FROM | |
685 (SELECT entity_id FROM entities | |
686 WHERE jid=%s) AS ent_id""", | |
687 (self.nodeDbId, | |
688 resource, | |
689 state, | |
690 subscription_type, | |
691 subscription_depth, | |
692 userhost)) | |
693 except cursor._pool.dbapi.IntegrityError: | |
694 raise error.SubscriptionExists() | |
695 | |
696 def removeSubscription(self, subscriber): | |
697 return self.dbpool.runInteraction(self._removeSubscription, | |
698 subscriber) | |
699 | |
700 def _removeSubscription(self, cursor, subscriber): | |
701 self._checkNodeExists(cursor) | |
702 | |
703 userhost = subscriber.userhost() | |
704 resource = subscriber.resource or '' | |
705 | |
706 cursor.execute("""DELETE FROM subscriptions WHERE | |
707 node_id=%s AND | |
708 entity_id=(SELECT entity_id FROM entities | |
709 WHERE jid=%s) AND | |
710 resource=%s""", | |
711 (self.nodeDbId, | |
712 userhost, | |
713 resource)) | |
714 if cursor.rowcount != 1: | |
715 raise error.NotSubscribed() | |
716 | |
717 return None | |
718 | |
719 def setSubscriptions(self, subscriptions): | |
720 return self.dbpool.runInteraction(self._setSubscriptions, subscriptions) | |
721 | |
722 def _setSubscriptions(self, cursor, subscriptions): | |
723 self._checkNodeExists(cursor) | |
724 | |
725 entities = self.getOrCreateEntities(cursor, [s.subscriber for s in subscriptions]) | |
726 entities_map = {jid.JID(e.jid): e for e in entities} | |
727 | |
728 # then we construct values for subscriptions update according to entity_id we just got | |
729 placeholders = ','.join(len(subscriptions) * ["%s"]) | |
730 values = [] | |
731 for subscription in subscriptions: | |
732 entity_id = entities_map[subscription.subscriber].entity_id | |
733 resource = subscription.subscriber.resource or u'' | |
734 values.append((self.nodeDbId, entity_id, resource, subscription.state, None, None)) | |
735 # we use upsert so new values are inserted and existing one updated. This feature is only available for PostgreSQL >= 9.5 | |
736 cursor.execute("INSERT INTO subscriptions(node_id, entity_id, resource, state, subscription_type, subscription_depth) VALUES " + placeholders + " ON CONFLICT (entity_id, resource, node_id) DO UPDATE SET state=EXCLUDED.state", [v for v in values]) | |
737 | |
738 def isSubscribed(self, entity): | |
739 return self.dbpool.runInteraction(self._isSubscribed, entity) | |
740 | |
741 def _isSubscribed(self, cursor, entity): | |
742 self._checkNodeExists(cursor) | |
743 | |
744 cursor.execute("""SELECT 1 as bool FROM entities | |
745 NATURAL JOIN subscriptions | |
746 NATURAL JOIN nodes | |
747 WHERE entities.jid=%s | |
748 AND node_id=%s AND state='subscribed'""", | |
749 (entity.userhost(), | |
750 self.nodeDbId)) | |
751 | |
752 return cursor.fetchone() is not None | |
753 | |
754 def getAffiliations(self): | |
755 return self.dbpool.runInteraction(self._getAffiliations) | |
756 | |
757 def _getAffiliations(self, cursor): | |
758 self._checkNodeExists(cursor) | |
759 | |
760 cursor.execute("""SELECT jid, affiliation FROM nodes | |
761 NATURAL JOIN affiliations | |
762 NATURAL JOIN entities | |
763 WHERE node_id=%s""", | |
764 (self.nodeDbId,)) | |
765 result = cursor.fetchall() | |
766 | |
767 return {jid.internJID(r[0]): r[1] for r in result} | |
768 | |
769 def getOrCreateEntities(self, cursor, entities_jids): | |
770 """Get entity_id from entities in entities table | |
771 | |
772 Entities will be inserted it they don't exist | |
773 @param entities_jid(list[jid.JID]): entities to get or create | |
774 @return list[record(entity_id,jid)]]: list of entity_id and jid (as plain string) | |
775 both existing and inserted entities are returned | |
776 """ | |
777 # cf. http://stackoverflow.com/a/35265559 | |
778 placeholders = ','.join(len(entities_jids) * ["(%s)"]) | |
779 query = ( | |
780 """ | |
781 WITH | |
782 jid_values (jid) AS ( | |
783 VALUES {placeholders} | |
784 ), | |
785 inserted (entity_id, jid) AS ( | |
786 INSERT INTO entities (jid) | |
787 SELECT jid | |
788 FROM jid_values | |
789 ON CONFLICT DO NOTHING | |
790 RETURNING entity_id, jid | |
791 ) | |
792 SELECT e.entity_id, e.jid | |
793 FROM entities e JOIN jid_values jv ON jv.jid = e.jid | |
794 UNION ALL | |
795 SELECT entity_id, jid | |
796 FROM inserted""".format(placeholders=placeholders)) | |
797 cursor.execute(query, [j.userhost() for j in entities_jids]) | |
798 return cursor.fetchall() | |
799 | |
800 def setAffiliations(self, affiliations): | |
801 return self.dbpool.runInteraction(self._setAffiliations, affiliations) | |
802 | |
803 def _setAffiliations(self, cursor, affiliations): | |
804 self._checkNodeExists(cursor) | |
805 | |
806 entities = self.getOrCreateEntities(cursor, affiliations) | |
807 | |
808 # then we construct values for affiliations update according to entity_id we just got | |
809 placeholders = ','.join(len(affiliations) * ["(%s,%s,%s)"]) | |
810 values = [] | |
811 map(values.extend, ((e.entity_id, affiliations[jid.JID(e.jid)], self.nodeDbId) for e in entities)) | |
812 | |
813 # we use upsert so new values are inserted and existing one updated. This feature is only available for PostgreSQL >= 9.5 | |
814 cursor.execute("INSERT INTO affiliations(entity_id,affiliation,node_id) VALUES " + placeholders + " ON CONFLICT (entity_id,node_id) DO UPDATE SET affiliation=EXCLUDED.affiliation", values) | |
815 | |
816 def deleteAffiliations(self, entities): | |
817 return self.dbpool.runInteraction(self._deleteAffiliations, entities) | |
818 | |
819 def _deleteAffiliations(self, cursor, entities): | |
820 """delete affiliations and subscriptions for this entity""" | |
821 self._checkNodeExists(cursor) | |
822 placeholders = ','.join(len(entities) * ["%s"]) | |
823 cursor.execute("DELETE FROM affiliations WHERE node_id=%s AND entity_id in (SELECT entity_id FROM entities WHERE jid IN (" + placeholders + ")) RETURNING entity_id", [self.nodeDbId] + [e.userhost() for e in entities]) | |
824 | |
825 rows = cursor.fetchall() | |
826 placeholders = ','.join(len(rows) * ["%s"]) | |
827 cursor.execute("DELETE FROM subscriptions WHERE node_id=%s AND entity_id in (" + placeholders + ")", [self.nodeDbId] + [r[0] for r in rows]) | |
828 | |
829 def getAuthorizedGroups(self): | |
830 return self.dbpool.runInteraction(self._getNodeGroups) | |
831 | |
832 def _getAuthorizedGroups(self, cursor): | |
833 cursor.execute("SELECT groupname FROM node_groups_authorized NATURAL JOIN nodes WHERE node=%s", | |
834 (self.nodeDbId,)) | |
835 rows = cursor.fetchall() | |
836 return [row[0] for row in rows] | |
837 | |
838 | |
839 class LeafNode(Node): | |
840 | |
841 implements(iidavoll.ILeafNode) | |
842 | |
843 nodeType = 'leaf' | |
844 | |
845 def getOrderBy(self, ext_data, direction='DESC'): | |
846 """Return ORDER BY clause corresponding to Order By key in ext_data | |
847 | |
848 @param ext_data (dict): extra data as used in getItems | |
849 @param direction (unicode): ORDER BY direction (ASC or DESC) | |
850 @return (unicode): ORDER BY clause to use | |
851 """ | |
852 keys = ext_data.get('order_by') | |
853 if not keys: | |
854 return u'ORDER BY updated ' + direction | |
855 cols_statmnt = [] | |
856 for key in keys: | |
857 if key == 'creation': | |
858 column = 'item_id' # could work with items.created too | |
859 elif key == 'modification': | |
860 column = 'updated' | |
861 else: | |
862 log.msg(u"WARNING: Unknown order by key: {key}".format(key=key)) | |
863 column = 'updated' | |
864 cols_statmnt.append(column + u' ' + direction) | |
865 | |
866 return u"ORDER BY " + u",".join([col for col in cols_statmnt]) | |
867 | |
868 @defer.inlineCallbacks | |
869 def storeItems(self, items_data, publisher): | |
870 # XXX: runInteraction doesn't seem to work when there are several "insert" | |
871 # or "update". | |
872 # Before the unpacking was done in _storeItems, but this was causing trouble | |
873 # in case of multiple items_data. So this has now be moved here. | |
874 # FIXME: investigate the issue with runInteraction | |
875 for item_data in items_data: | |
876 yield self.dbpool.runInteraction(self._storeItems, item_data, publisher) | |
877 | |
878 def _storeItems(self, cursor, item_data, publisher): | |
879 self._checkNodeExists(cursor) | |
880 self._storeItem(cursor, item_data, publisher) | |
881 | |
882 def _storeItem(self, cursor, item_data, publisher): | |
883 # first try to insert the item | |
884 # - if it fails (conflict), and the item is new and we have serial_ids options, | |
885 # current id will be recomputed using next item id query (note that is not perfect, as | |
886 # table is not locked and this can fail if two items are added at the same time | |
887 # but this can only happen with serial_ids and if future ids have been set by a client, | |
888 # this case should be rare enough to consider this situation acceptable) | |
889 # - if item insertion fail and the item is not new, we do an update | |
890 # - in other cases, exception is raised | |
891 item, access_model, item_config = item_data.item, item_data.access_model, item_data.config | |
892 data = item.toXml() | |
893 | |
894 insert_query = """INSERT INTO items (node_id, item, publisher, data, access_model) | |
895 SELECT %s, %s, %s, %s, %s FROM nodes | |
896 WHERE node_id=%s | |
897 RETURNING item_id""" | |
898 insert_data = [self.nodeDbId, | |
899 item["id"], | |
900 publisher.full(), | |
901 data, | |
902 access_model, | |
903 self.nodeDbId] | |
904 | |
905 try: | |
906 cursor.execute(insert_query, insert_data) | |
907 except cursor._pool.dbapi.IntegrityError as e: | |
908 if e.pgcode != "23505": | |
909 # we only handle unique_violation, every other exception must be raised | |
910 raise e | |
911 cursor.connection.rollback() | |
912 # the item already exist | |
913 if item_data.new: | |
914 # the item is new | |
915 if self._config[const.OPT_SERIAL_IDS]: | |
916 # this can happen with serial_ids, if a item has been stored | |
917 # with a future id (generated by XMPP client) | |
918 cursor.execute(NEXT_ITEM_ID_QUERY.format(node_id=self.nodeDbId)) | |
919 next_id = cursor.fetchone()[0] | |
920 # we update the sequence, so we can skip conflicting ids | |
921 cursor.execute(u"SELECT setval('{seq_name}', %s)".format( | |
922 seq_name = ITEMS_SEQ_NAME.format(node_id=self.nodeDbId)), [next_id]) | |
923 # and now we can retry the query with the new id | |
924 item['id'] = insert_data[1] = unicode(next_id) | |
925 # item saved in DB must also be updated with the new id | |
926 insert_data[3] = item.toXml() | |
927 cursor.execute(insert_query, insert_data) | |
928 else: | |
929 # but if we have not serial_ids, we have a real problem | |
930 raise e | |
931 else: | |
932 # this is an update | |
933 cursor.execute("""UPDATE items SET updated=now(), publisher=%s, data=%s | |
934 FROM nodes | |
935 WHERE nodes.node_id = items.node_id AND | |
936 nodes.node_id = %s and items.item=%s | |
937 RETURNING item_id""", | |
938 (publisher.full(), | |
939 data, | |
940 self.nodeDbId, | |
941 item["id"])) | |
942 if cursor.rowcount != 1: | |
943 raise exceptions.InternalError("item has not been updated correctly") | |
944 item_id = cursor.fetchone()[0]; | |
945 self._storeCategories(cursor, item_id, item_data.categories, update=True) | |
946 return | |
947 | |
948 item_id = cursor.fetchone()[0]; | |
949 self._storeCategories(cursor, item_id, item_data.categories) | |
950 | |
951 if access_model == const.VAL_AMODEL_PUBLISHER_ROSTER: | |
952 if const.OPT_ROSTER_GROUPS_ALLOWED in item_config: | |
953 item_config.fields[const.OPT_ROSTER_GROUPS_ALLOWED].fieldType='list-multi' #XXX: needed to force list if there is only one value | |
954 allowed_groups = item_config[const.OPT_ROSTER_GROUPS_ALLOWED] | |
955 else: | |
956 allowed_groups = [] | |
957 for group in allowed_groups: | |
958 #TODO: check that group are actually in roster | |
959 cursor.execute("""INSERT INTO item_groups_authorized (item_id, groupname) | |
960 VALUES (%s,%s)""" , (item_id, group)) | |
961 # TODO: whitelist access model | |
962 | |
963 def _storeCategories(self, cursor, item_id, categories, update=False): | |
964 # TODO: handle canonical form | |
965 if update: | |
966 cursor.execute("""DELETE FROM item_categories | |
967 WHERE item_id=%s""", (item_id,)) | |
968 | |
969 # we use a set to avoid duplicates | |
970 for category in set(categories): | |
971 cursor.execute("""INSERT INTO item_categories (item_id, category) | |
972 VALUES (%s, %s)""", (item_id, category)) | |
973 | |
974 def removeItems(self, itemIdentifiers): | |
975 return self.dbpool.runInteraction(self._removeItems, itemIdentifiers) | |
976 | |
977 def _removeItems(self, cursor, itemIdentifiers): | |
978 self._checkNodeExists(cursor) | |
979 | |
980 deleted = [] | |
981 | |
982 for itemIdentifier in itemIdentifiers: | |
983 cursor.execute("""DELETE FROM items WHERE | |
984 node_id=%s AND | |
985 item=%s""", | |
986 (self.nodeDbId, | |
987 itemIdentifier)) | |
988 | |
989 if cursor.rowcount: | |
990 deleted.append(itemIdentifier) | |
991 | |
992 return deleted | |
993 | |
994 def getItems(self, authorized_groups, unrestricted, maxItems=None, ext_data=None): | |
995 """ Get all authorised items | |
996 | |
997 @param authorized_groups: we want to get items that these groups can access | |
998 @param unrestricted: if true, don't check permissions (i.e.: get all items) | |
999 @param maxItems: nb of items we want to get | |
1000 @param ext_data: options for extra features like RSM and MAM | |
1001 | |
1002 @return: list of container.ItemData | |
1003 if unrestricted is False, access_model and config will be None | |
1004 """ | |
1005 if ext_data is None: | |
1006 ext_data = {} | |
1007 return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems, ext_data, ids_only=False) | |
1008 | |
1009 def getItemsIds(self, authorized_groups, unrestricted, maxItems=None, ext_data=None): | |
1010 """ Get all authorised items ids | |
1011 | |
1012 @param authorized_groups: we want to get items that these groups can access | |
1013 @param unrestricted: if true, don't check permissions (i.e.: get all items) | |
1014 @param maxItems: nb of items we want to get | |
1015 @param ext_data: options for extra features like RSM and MAM | |
1016 | |
1017 @return list(unicode): list of ids | |
1018 """ | |
1019 if ext_data is None: | |
1020 ext_data = {} | |
1021 return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems, ext_data, ids_only=True) | |
1022 | |
1023 def _appendSourcesAndFilters(self, query, args, authorized_groups, unrestricted, ext_data): | |
1024 """append sources and filters to sql query requesting items and return ORDER BY | |
1025 | |
1026 arguments query, args, authorized_groups, unrestricted and ext_data are the same as for | |
1027 _getItems | |
1028 """ | |
1029 # SOURCES | |
1030 query.append("FROM nodes INNER JOIN items USING (node_id)") | |
1031 | |
1032 if unrestricted: | |
1033 query_filters = ["WHERE node_id=%s"] | |
1034 args.append(self.nodeDbId) | |
1035 else: | |
1036 query.append("LEFT JOIN item_groups_authorized USING (item_id)") | |
1037 args.append(self.nodeDbId) | |
1038 if authorized_groups: | |
1039 get_groups = " or (items.access_model='roster' and groupname in %s)" | |
1040 args.append(authorized_groups) | |
1041 else: | |
1042 get_groups = "" | |
1043 | |
1044 query_filters = ["WHERE node_id=%s AND (items.access_model='open'" + get_groups + ")"] | |
1045 | |
1046 # FILTERS | |
1047 if 'filters' in ext_data: # MAM filters | |
1048 for filter_ in ext_data['filters']: | |
1049 if filter_.var == 'start': | |
1050 query_filters.append("AND created>=%s") | |
1051 args.append(filter_.value) | |
1052 elif filter_.var == 'end': | |
1053 query_filters.append("AND created<=%s") | |
1054 args.append(filter_.value) | |
1055 elif filter_.var == 'with': | |
1056 jid_s = filter_.value | |
1057 if '/' in jid_s: | |
1058 query_filters.append("AND publisher=%s") | |
1059 args.append(filter_.value) | |
1060 else: | |
1061 query_filters.append("AND publisher LIKE %s") | |
1062 args.append(u"{}%".format(filter_.value)) | |
1063 elif filter_.var == const.MAM_FILTER_CATEGORY: | |
1064 query.append("LEFT JOIN item_categories USING (item_id)") | |
1065 query_filters.append("AND category=%s") | |
1066 args.append(filter_.value) | |
1067 else: | |
1068 log.msg("WARNING: unknown filter: {}".format(filter_.encode('utf-8'))) | |
1069 | |
1070 query.extend(query_filters) | |
1071 | |
1072 return self.getOrderBy(ext_data) | |
1073 | |
1074 def _getItems(self, cursor, authorized_groups, unrestricted, maxItems, ext_data, ids_only): | |
1075 self._checkNodeExists(cursor) | |
1076 | |
1077 if maxItems == 0: | |
1078 return [] | |
1079 | |
1080 args = [] | |
1081 | |
1082 # SELECT | |
1083 if ids_only: | |
1084 query = ["SELECT item"] | |
1085 else: | |
1086 query = ["SELECT data::text,items.access_model,item_id,created,updated"] | |
1087 | |
1088 query_order = self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data) | |
1089 | |
1090 if 'rsm' in ext_data: | |
1091 rsm = ext_data['rsm'] | |
1092 maxItems = rsm.max | |
1093 if rsm.index is not None: | |
1094 # We need to know the item_id of corresponding to the index (offset) of the current query | |
1095 # so we execute the query to look for the item_id | |
1096 tmp_query = query[:] | |
1097 tmp_args = args[:] | |
1098 tmp_query[0] = "SELECT item_id" | |
1099 tmp_query.append("{} LIMIT 1 OFFSET %s".format(query_order)) | |
1100 tmp_args.append(rsm.index) | |
1101 cursor.execute(' '.join(query), args) | |
1102 # FIXME: bad index is not managed yet | |
1103 item_id = cursor.fetchall()[0][0] | |
1104 | |
1105 # now that we have the id, we can use it | |
1106 query.append("AND item_id<=%s") | |
1107 args.append(item_id) | |
1108 elif rsm.before is not None: | |
1109 if rsm.before != '': | |
1110 query.append("AND item_id>(SELECT item_id FROM items WHERE item=%s LIMIT 1)") | |
1111 args.append(rsm.before) | |
1112 if maxItems is not None: | |
1113 # if we have maxItems (i.e. a limit), we need to reverse order | |
1114 # in a first query to get the right items | |
1115 query.insert(0,"SELECT * from (") | |
1116 query.append(self.getOrderBy(ext_data, direction='ASC')) | |
1117 query.append("LIMIT %s) as x") | |
1118 args.append(maxItems) | |
1119 elif rsm.after: | |
1120 query.append("AND item_id<(SELECT item_id FROM items WHERE item=%s LIMIT 1)") | |
1121 args.append(rsm.after) | |
1122 | |
1123 query.append(query_order) | |
1124 | |
1125 if maxItems is not None: | |
1126 query.append("LIMIT %s") | |
1127 args.append(maxItems) | |
1128 | |
1129 cursor.execute(' '.join(query), args) | |
1130 | |
1131 result = cursor.fetchall() | |
1132 if unrestricted and not ids_only: | |
1133 # with unrestricted query, we need to fill the access_list for a roster access items | |
1134 ret = [] | |
1135 for item_data in result: | |
1136 item = generic.stripNamespace(parseXml(item_data.data)) | |
1137 access_model = item_data.access_model | |
1138 item_id = item_data.item_id | |
1139 created = item_data.created | |
1140 updated = item_data.updated | |
1141 access_list = {} | |
1142 if access_model == const.VAL_AMODEL_PUBLISHER_ROSTER: | |
1143 cursor.execute('SELECT groupname FROM item_groups_authorized WHERE item_id=%s', (item_id,)) | |
1144 access_list[const.OPT_ROSTER_GROUPS_ALLOWED] = [r.groupname for r in cursor.fetchall()] | |
1145 | |
1146 ret.append(container.ItemData(item, access_model, access_list, created=created, updated=updated)) | |
1147 # TODO: whitelist item access model | |
1148 return ret | |
1149 | |
1150 if ids_only: | |
1151 return [r.item for r in result] | |
1152 else: | |
1153 items_data = [container.ItemData(generic.stripNamespace(parseXml(r.data)), r.access_model, created=r.created, updated=r.updated) for r in result] | |
1154 return items_data | |
1155 | |
1156 def getItemsById(self, authorized_groups, unrestricted, itemIdentifiers): | |
1157 """Get items which are in the given list | |
1158 | |
1159 @param authorized_groups: we want to get items that these groups can access | |
1160 @param unrestricted: if true, don't check permissions | |
1161 @param itemIdentifiers: list of ids of the items we want to get | |
1162 @return: list of container.ItemData | |
1163 ItemData.config will contains access_list (managed as a dictionnary with same key as for item_config) | |
1164 if unrestricted is False, access_model and config will be None | |
1165 """ | |
1166 return self.dbpool.runInteraction(self._getItemsById, authorized_groups, unrestricted, itemIdentifiers) | |
1167 | |
1168 def _getItemsById(self, cursor, authorized_groups, unrestricted, itemIdentifiers): | |
1169 self._checkNodeExists(cursor) | |
1170 ret = [] | |
1171 if unrestricted: #we get everything without checking permissions | |
1172 for itemIdentifier in itemIdentifiers: | |
1173 cursor.execute("""SELECT data::text,items.access_model,item_id,created,updated FROM nodes | |
1174 INNER JOIN items USING (node_id) | |
1175 WHERE node_id=%s AND item=%s""", | |
1176 (self.nodeDbId, | |
1177 itemIdentifier)) | |
1178 result = cursor.fetchone() | |
1179 if not result: | |
1180 raise error.ItemNotFound() | |
1181 | |
1182 item = generic.stripNamespace(parseXml(result[0])) | |
1183 access_model = result[1] | |
1184 item_id = result[2] | |
1185 created= result[3] | |
1186 updated= result[4] | |
1187 access_list = {} | |
1188 if access_model == const.VAL_AMODEL_PUBLISHER_ROSTER: | |
1189 cursor.execute('SELECT groupname FROM item_groups_authorized WHERE item_id=%s', (item_id,)) | |
1190 access_list[const.OPT_ROSTER_GROUPS_ALLOWED] = [r[0] for r in cursor.fetchall()] | |
1191 #TODO: WHITELIST access_model | |
1192 | |
1193 ret.append(container.ItemData(item, access_model, access_list, created=created, updated=updated)) | |
1194 else: #we check permission before returning items | |
1195 for itemIdentifier in itemIdentifiers: | |
1196 args = [self.nodeDbId, itemIdentifier] | |
1197 if authorized_groups: | |
1198 args.append(authorized_groups) | |
1199 cursor.execute("""SELECT data::text, created, updated FROM nodes | |
1200 INNER JOIN items USING (node_id) | |
1201 LEFT JOIN item_groups_authorized USING (item_id) | |
1202 WHERE node_id=%s AND item=%s AND | |
1203 (items.access_model='open' """ + | |
1204 ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') + ")", | |
1205 args) | |
1206 | |
1207 result = cursor.fetchone() | |
1208 if result: | |
1209 ret.append(container.ItemData(generic.stripNamespace(parseXml(result[0])), created=result[1], updated=result[2])) | |
1210 | |
1211 return ret | |
1212 | |
1213 def getItemsCount(self, authorized_groups, unrestricted, ext_data=None): | |
1214 """Count expected number of items in a getItems query | |
1215 | |
1216 @param authorized_groups: we want to get items that these groups can access | |
1217 @param unrestricted: if true, don't check permissions (i.e.: get all items) | |
1218 @param ext_data: options for extra features like RSM and MAM | |
1219 """ | |
1220 if ext_data is None: | |
1221 ext_data = {} | |
1222 return self.dbpool.runInteraction(self._getItemsCount, authorized_groups, unrestricted, ext_data) | |
1223 | |
1224 def _getItemsCount(self, cursor, authorized_groups, unrestricted, ext_data): | |
1225 self._checkNodeExists(cursor) | |
1226 args = [] | |
1227 | |
1228 # SELECT | |
1229 query = ["SELECT count(1)"] | |
1230 | |
1231 self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data) | |
1232 | |
1233 cursor.execute(' '.join(query), args) | |
1234 return cursor.fetchall()[0][0] | |
1235 | |
1236 def getItemsIndex(self, item_id, authorized_groups, unrestricted, ext_data=None): | |
1237 """Get expected index of first item in the window of a getItems query | |
1238 | |
1239 @param item_id: id of the item | |
1240 @param authorized_groups: we want to get items that these groups can access | |
1241 @param unrestricted: if true, don't check permissions (i.e.: get all items) | |
1242 @param ext_data: options for extra features like RSM and MAM | |
1243 """ | |
1244 if ext_data is None: | |
1245 ext_data = {} | |
1246 return self.dbpool.runInteraction(self._getItemsIndex, item_id, authorized_groups, unrestricted, ext_data) | |
1247 | |
1248 def _getItemsIndex(self, cursor, item_id, authorized_groups, unrestricted, ext_data): | |
1249 self._checkNodeExists(cursor) | |
1250 args = [] | |
1251 | |
1252 # SELECT | |
1253 query = [] | |
1254 | |
1255 query_order = self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data) | |
1256 | |
1257 query_select = "SELECT row_number from (SELECT row_number() OVER ({}), item".format(query_order) | |
1258 query.insert(0, query_select) | |
1259 query.append(") as x WHERE item=%s") | |
1260 args.append(item_id) | |
1261 | |
1262 cursor.execute(' '.join(query), args) | |
1263 # XXX: row_number start at 1, but we want that index start at 0 | |
1264 try: | |
1265 return cursor.fetchall()[0][0] - 1 | |
1266 except IndexError: | |
1267 raise error.NodeNotFound() | |
1268 | |
1269 def getItemsPublishers(self, itemIdentifiers): | |
1270 """Get the publishers for all given identifiers | |
1271 | |
1272 @return (dict[unicode, jid.JID]): map of itemIdentifiers to publisher | |
1273 if item is not found, key is skipped in resulting dict | |
1274 """ | |
1275 return self.dbpool.runInteraction(self._getItemsPublishers, itemIdentifiers) | |
1276 | |
1277 def _getItemsPublishers(self, cursor, itemIdentifiers): | |
1278 self._checkNodeExists(cursor) | |
1279 ret = {} | |
1280 for itemIdentifier in itemIdentifiers: | |
1281 cursor.execute("""SELECT publisher FROM items | |
1282 WHERE node_id=%s AND item=%s""", | |
1283 (self.nodeDbId, itemIdentifier,)) | |
1284 result = cursor.fetchone() | |
1285 if result: | |
1286 ret[itemIdentifier] = jid.JID(result[0]) | |
1287 return ret | |
1288 | |
1289 def purge(self): | |
1290 return self.dbpool.runInteraction(self._purge) | |
1291 | |
1292 def _purge(self, cursor): | |
1293 self._checkNodeExists(cursor) | |
1294 | |
1295 cursor.execute("""DELETE FROM items WHERE | |
1296 node_id=%s""", | |
1297 (self.nodeDbId,)) | |
1298 | |
1299 | |
1300 class CollectionNode(Node): | |
1301 | |
1302 nodeType = 'collection' | |
1303 | |
1304 | |
1305 | |
1306 class GatewayStorage(object): | |
1307 """ | |
1308 Memory based storage facility for the XMPP-HTTP gateway. | |
1309 """ | |
1310 | |
1311 def __init__(self, dbpool): | |
1312 self.dbpool = dbpool | |
1313 | |
1314 def _countCallbacks(self, cursor, service, nodeIdentifier): | |
1315 """ | |
1316 Count number of callbacks registered for a node. | |
1317 """ | |
1318 cursor.execute("""SELECT count(*) FROM callbacks | |
1319 WHERE service=%s and node=%s""", | |
1320 (service.full(), | |
1321 nodeIdentifier)) | |
1322 results = cursor.fetchall() | |
1323 return results[0][0] | |
1324 | |
1325 def addCallback(self, service, nodeIdentifier, callback): | |
1326 def interaction(cursor): | |
1327 cursor.execute("""SELECT 1 as bool FROM callbacks | |
1328 WHERE service=%s and node=%s and uri=%s""", | |
1329 (service.full(), | |
1330 nodeIdentifier, | |
1331 callback)) | |
1332 if cursor.fetchall(): | |
1333 return | |
1334 | |
1335 cursor.execute("""INSERT INTO callbacks | |
1336 (service, node, uri) VALUES | |
1337 (%s, %s, %s)""", | |
1338 (service.full(), | |
1339 nodeIdentifier, | |
1340 callback)) | |
1341 | |
1342 return self.dbpool.runInteraction(interaction) | |
1343 | |
1344 def removeCallback(self, service, nodeIdentifier, callback): | |
1345 def interaction(cursor): | |
1346 cursor.execute("""DELETE FROM callbacks | |
1347 WHERE service=%s and node=%s and uri=%s""", | |
1348 (service.full(), | |
1349 nodeIdentifier, | |
1350 callback)) | |
1351 | |
1352 if cursor.rowcount != 1: | |
1353 raise error.NotSubscribed() | |
1354 | |
1355 last = not self._countCallbacks(cursor, service, nodeIdentifier) | |
1356 return last | |
1357 | |
1358 return self.dbpool.runInteraction(interaction) | |
1359 | |
1360 def getCallbacks(self, service, nodeIdentifier): | |
1361 def interaction(cursor): | |
1362 cursor.execute("""SELECT uri FROM callbacks | |
1363 WHERE service=%s and node=%s""", | |
1364 (service.full(), | |
1365 nodeIdentifier)) | |
1366 results = cursor.fetchall() | |
1367 | |
1368 if not results: | |
1369 raise error.NoCallbacks() | |
1370 | |
1371 return [result[0] for result in results] | |
1372 | |
1373 return self.dbpool.runInteraction(interaction) | |
1374 | |
1375 def hasCallbacks(self, service, nodeIdentifier): | |
1376 def interaction(cursor): | |
1377 return bool(self._countCallbacks(cursor, service, nodeIdentifier)) | |
1378 | |
1379 return self.dbpool.runInteraction(interaction) |