comparison sat_pubsub/pgsql_storage.py @ 405:c56a728412f1

file organisation + setup refactoring: - `/src` has been renamed to `/sat_pubsub`, this is the recommended naming convention - revamped `setup.py` on the basis of SàT's `setup.py` - added a `VERSION` which is the unique place where version number will now be set - use same trick as in SàT to specify dev version (`D` at the end) - use setuptools_scm to retrieve Mercurial hash when in dev version
author Goffi <goffi@goffi.org>
date Fri, 16 Aug 2019 12:00:02 +0200
parents src/pgsql_storage.py@1dc606612405
children a58610ab2983
comparison
equal deleted inserted replaced
404:105a0772eedd 405:c56a728412f1
1 #!/usr/bin/python
2 #-*- coding: utf-8 -*-
3
4 # Copyright (c) 2012-2019 Jérôme Poisson
5 # Copyright (c) 2013-2016 Adrien Cossa
6 # Copyright (c) 2003-2011 Ralph Meijer
7
8
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU Affero General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU Affero General Public License for more details.
18
19 # You should have received a copy of the GNU Affero General Public License
20 # along with this program. If not, see <http://www.gnu.org/licenses/>.
21 # --
22
23 # This program is based on Idavoll (http://idavoll.ik.nu/),
24 # originaly written by Ralph Meijer (http://ralphm.net/blog/)
25 # It is sublicensed under AGPL v3 (or any later version) as allowed by the original
26 # license.
27
28 # --
29
30 # Here is a copy of the original license:
31
32 # Copyright (c) 2003-2011 Ralph Meijer
33
34 # Permission is hereby granted, free of charge, to any person obtaining
35 # a copy of this software and associated documentation files (the
36 # "Software"), to deal in the Software without restriction, including
37 # without limitation the rights to use, copy, modify, merge, publish,
38 # distribute, sublicense, and/or sell copies of the Software, and to
39 # permit persons to whom the Software is furnished to do so, subject to
40 # the following conditions:
41
42 # The above copyright notice and this permission notice shall be
43 # included in all copies or substantial portions of the Software.
44
45 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
46 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
47 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
48 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
49 # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
50 # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
51 # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
52
53
54 import copy, logging
55
56 from zope.interface import implements
57
58 from twisted.internet import reactor
59 from twisted.internet import defer
60 from twisted.words.protocols.jabber import jid
61 from twisted.python import log
62
63 from wokkel import generic
64 from wokkel.pubsub import Subscription
65
66 from sat_pubsub import error
67 from sat_pubsub import iidavoll
68 from sat_pubsub import const
69 from sat_pubsub import container
70 from sat_pubsub import exceptions
71 import uuid
72 import psycopg2
73 import psycopg2.extensions
74 # we wants psycopg2 to return us unicode, not str
75 psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
76 psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
77
78 # parseXml manage str, but we get unicode
79 parseXml = lambda unicode_data: generic.parseXml(unicode_data.encode('utf-8'))
80 ITEMS_SEQ_NAME = u'node_{node_id}_seq'
81 PEP_COL_NAME = 'pep'
82 CURRENT_VERSION = '5'
83 # retrieve the maximum integer item id + 1
84 NEXT_ITEM_ID_QUERY = r"SELECT COALESCE(max(item::integer)+1,1) as val from items where node_id={node_id} and item ~ E'^\\d+$'"
85
86
87 def withPEP(query, values, pep, recipient):
88 """Helper method to facilitate PEP management
89
90 @param query: SQL query basis
91 @param values: current values to replace in query
92 @param pep(bool): True if we are in PEP mode
93 @param recipient(jid.JID): jid of the recipient
94 @return: query + PEP AND check,
95 recipient's bare jid is added to value if needed
96 """
97 if pep:
98 pep_check="AND {}=%s".format(PEP_COL_NAME)
99 values=list(values) + [recipient.userhost()]
100 else:
101 pep_check="AND {} IS NULL".format(PEP_COL_NAME)
102 return "{} {}".format(query, pep_check), values
103
104
105 class Storage:
106
107 implements(iidavoll.IStorage)
108
109 defaultConfig = {
110 'leaf': {
111 const.OPT_PERSIST_ITEMS: True,
112 const.OPT_DELIVER_PAYLOADS: True,
113 const.OPT_SEND_LAST_PUBLISHED_ITEM: 'on_sub',
114 const.OPT_ACCESS_MODEL: const.VAL_AMODEL_DEFAULT,
115 const.OPT_PUBLISH_MODEL: const.VAL_PMODEL_DEFAULT,
116 const.OPT_SERIAL_IDS: False,
117 const.OPT_CONSISTENT_PUBLISHER: False,
118 },
119 'collection': {
120 const.OPT_DELIVER_PAYLOADS: True,
121 const.OPT_SEND_LAST_PUBLISHED_ITEM: 'on_sub',
122 const.OPT_ACCESS_MODEL: const.VAL_AMODEL_DEFAULT,
123 const.OPT_PUBLISH_MODEL: const.VAL_PMODEL_DEFAULT,
124 }
125 }
126
127 def __init__(self, dbpool):
128 self.dbpool = dbpool
129 d = self.dbpool.runQuery("SELECT value FROM metadata WHERE key='version'")
130 d.addCallbacks(self._checkVersion, self._versionEb)
131
132 def _checkVersion(self, row):
133 version = row[0].value
134 if version != CURRENT_VERSION:
135 logging.error("Bad database schema version ({current}), please upgrade to {needed}".format(
136 current=version, needed=CURRENT_VERSION))
137 reactor.stop()
138
139 def _versionEb(self, failure):
140 logging.error("Can't check schema version: {reason}".format(reason=failure))
141 reactor.stop()
142
143 def _buildNode(self, row):
144 """Build a note class from database result row"""
145 configuration = {}
146
147 if not row:
148 raise error.NodeNotFound()
149
150 if row[2] == 'leaf':
151 configuration = {
152 'pubsub#persist_items': row[3],
153 'pubsub#deliver_payloads': row[4],
154 'pubsub#send_last_published_item': row[5],
155 const.OPT_ACCESS_MODEL:row[6],
156 const.OPT_PUBLISH_MODEL:row[7],
157 const.OPT_SERIAL_IDS:row[8],
158 const.OPT_CONSISTENT_PUBLISHER:row[9],
159 }
160 schema = row[10]
161 if schema is not None:
162 schema = parseXml(schema)
163 node = LeafNode(row[0], row[1], configuration, schema)
164 node.dbpool = self.dbpool
165 return node
166 elif row[2] == 'collection':
167 configuration = {
168 'pubsub#deliver_payloads': row[4],
169 'pubsub#send_last_published_item': row[5],
170 const.OPT_ACCESS_MODEL: row[6],
171 const.OPT_PUBLISH_MODEL:row[7],
172 }
173 node = CollectionNode(row[0], row[1], configuration, None)
174 node.dbpool = self.dbpool
175 return node
176 else:
177 raise ValueError("Unknown node type !")
178
179 def getNodeById(self, nodeDbId):
180 """Get node using database ID insted of pubsub identifier
181
182 @param nodeDbId(unicode): database ID
183 """
184 return self.dbpool.runInteraction(self._getNodeById, nodeDbId)
185
186 def _getNodeById(self, cursor, nodeDbId):
187 cursor.execute("""SELECT node_id,
188 node,
189 node_type,
190 persist_items,
191 deliver_payloads,
192 send_last_published_item,
193 access_model,
194 publish_model,
195 serial_ids,
196 consistent_publisher,
197 schema::text,
198 pep
199 FROM nodes
200 WHERE node_id=%s""",
201 (nodeDbId,))
202 row = cursor.fetchone()
203 return self._buildNode(row)
204
205 def getNode(self, nodeIdentifier, pep, recipient=None):
206 return self.dbpool.runInteraction(self._getNode, nodeIdentifier, pep, recipient)
207
208 def _getNode(self, cursor, nodeIdentifier, pep, recipient):
209 cursor.execute(*withPEP("""SELECT node_id,
210 node,
211 node_type,
212 persist_items,
213 deliver_payloads,
214 send_last_published_item,
215 access_model,
216 publish_model,
217 serial_ids,
218 consistent_publisher,
219 schema::text,
220 pep
221 FROM nodes
222 WHERE node=%s""",
223 (nodeIdentifier,), pep, recipient))
224 row = cursor.fetchone()
225 return self._buildNode(row)
226
227 def getNodeIds(self, pep, recipient, allowed_accesses=None):
228 """retrieve ids of existing nodes
229
230 @param pep(bool): True if it's a PEP request
231 @param recipient(jid.JID, None): recipient of the PEP request
232 @param allowed_accesses(None, set): only nodes with access
233 in this set will be returned
234 None to return all nodes
235 @return (list[unicode]): ids of nodes
236 """
237 if not pep:
238 query = "SELECT node from nodes WHERE pep is NULL"
239 values = []
240 else:
241 query = "SELECT node from nodes WHERE pep=%s"
242 values = [recipient.userhost()]
243
244 if allowed_accesses is not None:
245 query += "AND access_model IN %s"
246 values.append(tuple(allowed_accesses))
247
248 d = self.dbpool.runQuery(query, values)
249 d.addCallback(lambda results: [r[0] for r in results])
250 return d
251
252 def createNode(self, nodeIdentifier, owner, config, schema, pep, recipient=None):
253 return self.dbpool.runInteraction(self._createNode, nodeIdentifier,
254 owner, config, schema, pep, recipient)
255
256 def _createNode(self, cursor, nodeIdentifier, owner, config, schema, pep, recipient):
257 if config['pubsub#node_type'] != 'leaf':
258 raise error.NoCollections()
259
260 owner = owner.userhost()
261
262 try:
263 cursor.execute("""INSERT INTO nodes
264 (node,
265 node_type,
266 persist_items,
267 deliver_payloads,
268 send_last_published_item,
269 access_model,
270 publish_model,
271 serial_ids,
272 consistent_publisher,
273 schema,
274 pep)
275 VALUES
276 (%s, 'leaf', %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
277 (nodeIdentifier,
278 config['pubsub#persist_items'],
279 config['pubsub#deliver_payloads'],
280 config['pubsub#send_last_published_item'],
281 config[const.OPT_ACCESS_MODEL],
282 config[const.OPT_PUBLISH_MODEL],
283 config[const.OPT_SERIAL_IDS],
284 config[const.OPT_CONSISTENT_PUBLISHER],
285 schema,
286 recipient.userhost() if pep else None
287 )
288 )
289 except cursor._pool.dbapi.IntegrityError as e:
290 if e.pgcode == "23505":
291 # unique_violation
292 raise error.NodeExists()
293 else:
294 raise error.InvalidConfigurationOption()
295
296 cursor.execute(*withPEP("""SELECT node_id FROM nodes WHERE node=%s""",
297 (nodeIdentifier,), pep, recipient));
298 node_id = cursor.fetchone()[0]
299
300 cursor.execute("""SELECT 1 as bool from entities where jid=%s""",
301 (owner,))
302
303 if not cursor.fetchone():
304 # XXX: we can NOT rely on the previous query! Commit is needed now because
305 # if the entry exists the next query will leave the database in a corrupted
306 # state: the solution is to rollback. I tried with other methods like
307 # "WHERE NOT EXISTS" but none of them worked, so the following solution
308 # looks like the sole - unless you have auto-commit on. More info
309 # about this issue: http://cssmay.com/question/tag/tag-psycopg2
310 cursor.connection.commit()
311 try:
312 cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""",
313 (owner,))
314 except psycopg2.IntegrityError as e:
315 cursor.connection.rollback()
316 logging.warning("during node creation: %s" % e.message)
317
318 cursor.execute("""INSERT INTO affiliations
319 (node_id, entity_id, affiliation)
320 SELECT %s, entity_id, 'owner' FROM
321 (SELECT entity_id FROM entities
322 WHERE jid=%s) as e""",
323 (node_id, owner))
324
325 if config[const.OPT_ACCESS_MODEL] == const.VAL_AMODEL_PUBLISHER_ROSTER:
326 if const.OPT_ROSTER_GROUPS_ALLOWED in config:
327 allowed_groups = config[const.OPT_ROSTER_GROUPS_ALLOWED]
328 else:
329 allowed_groups = []
330 for group in allowed_groups:
331 #TODO: check that group are actually in roster
332 cursor.execute("""INSERT INTO node_groups_authorized (node_id, groupname)
333 VALUES (%s,%s)""" , (node_id, group))
334 # XXX: affiliations can't be set on during node creation (at least not with XEP-0060 alone)
335 # so whitelist affiliations need to be done afterward
336
337 # no we may have to do extra things according to config options
338 default_conf = self.defaultConfig['leaf']
339 # XXX: trigger works on node creation because OPT_SERIAL_IDS is False in defaultConfig
340 # if this value is changed, the _configurationTriggers method should be adapted.
341 Node._configurationTriggers(cursor, node_id, default_conf, config)
342
343 def deleteNodeByDbId(self, db_id):
344 """Delete a node using directly its database id"""
345 return self.dbpool.runInteraction(self._deleteNodeByDbId, db_id)
346
347 def _deleteNodeByDbId(self, cursor, db_id):
348 cursor.execute("""DELETE FROM nodes WHERE node_id=%s""",
349 (db_id,))
350
351 if cursor.rowcount != 1:
352 raise error.NodeNotFound()
353
354 def deleteNode(self, nodeIdentifier, pep, recipient=None):
355 return self.dbpool.runInteraction(self._deleteNode, nodeIdentifier, pep, recipient)
356
357 def _deleteNode(self, cursor, nodeIdentifier, pep, recipient):
358 cursor.execute(*withPEP("""DELETE FROM nodes WHERE node=%s""",
359 (nodeIdentifier,), pep, recipient))
360
361 if cursor.rowcount != 1:
362 raise error.NodeNotFound()
363
364 def getAffiliations(self, entity, nodeIdentifier, pep, recipient=None):
365 return self.dbpool.runInteraction(self._getAffiliations, entity, nodeIdentifier, pep, recipient)
366
367 def _getAffiliations(self, cursor, entity, nodeIdentifier, pep, recipient=None):
368 query = ["""SELECT node, affiliation FROM entities
369 NATURAL JOIN affiliations
370 NATURAL JOIN nodes
371 WHERE jid=%s"""]
372 args = [entity.userhost()]
373
374 if nodeIdentifier is not None:
375 query.append("AND node=%s")
376 args.append(nodeIdentifier)
377
378 cursor.execute(*withPEP(' '.join(query), args, pep, recipient))
379 rows = cursor.fetchall()
380 return [tuple(r) for r in rows]
381
382 def getSubscriptions(self, entity, nodeIdentifier=None, pep=False, recipient=None):
383 """retrieve subscriptions of an entity
384
385 @param entity(jid.JID): entity to check
386 @param nodeIdentifier(unicode, None): node identifier
387 None to retrieve all subscriptions
388 @param pep: True if we are in PEP mode
389 @param recipient: jid of the recipient
390 """
391
392 def toSubscriptions(rows):
393 subscriptions = []
394 for row in rows:
395 subscriber = jid.internJID('%s/%s' % (row.jid,
396 row.resource))
397 subscription = Subscription(row.node, subscriber, row.state)
398 subscriptions.append(subscription)
399 return subscriptions
400
401 query = ["""SELECT node,
402 jid,
403 resource,
404 state
405 FROM entities
406 NATURAL JOIN subscriptions
407 NATURAL JOIN nodes
408 WHERE jid=%s"""]
409
410 args = [entity.userhost()]
411
412 if nodeIdentifier is not None:
413 query.append("AND node=%s")
414 args.append(nodeIdentifier)
415
416 d = self.dbpool.runQuery(*withPEP(' '.join(query), args, pep, recipient))
417 d.addCallback(toSubscriptions)
418 return d
419
420 def getDefaultConfiguration(self, nodeType):
421 return self.defaultConfig[nodeType].copy()
422
423 def formatLastItems(self, result):
424 last_items = []
425 for pep_jid_s, node, data, item_access_model in result:
426 pep_jid = jid.JID(pep_jid_s)
427 item = generic.stripNamespace(parseXml(data))
428 last_items.append((pep_jid, node, item, item_access_model))
429 return last_items
430
431 def getLastItems(self, entities, nodes, node_accesses, item_accesses, pep):
432 """get last item for several nodes and entities in a single request"""
433 if not entities or not nodes or not node_accesses or not item_accesses:
434 raise ValueError("entities, nodes and accesses must not be empty")
435 if node_accesses != ('open',) or item_accesses != ('open',):
436 raise NotImplementedError('only "open" access model is handled for now')
437 if not pep:
438 raise NotImplementedError(u"getLastItems is only implemented for PEP at the moment")
439 d = self.dbpool.runQuery("""SELECT DISTINCT ON (node_id) pep, node, data::text, items.access_model
440 FROM items
441 NATURAL JOIN nodes
442 WHERE nodes.pep IN %s
443 AND node IN %s
444 AND nodes.access_model in %s
445 AND items.access_model in %s
446 ORDER BY node_id DESC, items.updated DESC""",
447 (tuple([e.userhost() for e in entities]),
448 nodes,
449 node_accesses,
450 item_accesses))
451 d.addCallback(self.formatLastItems)
452 return d
453
454
455 class Node:
456
457 implements(iidavoll.INode)
458
459 def __init__(self, nodeDbId, nodeIdentifier, config, schema):
460 self.nodeDbId = nodeDbId
461 self.nodeIdentifier = nodeIdentifier
462 self._config = config
463 self._schema = schema
464
465 def _checkNodeExists(self, cursor):
466 cursor.execute("""SELECT 1 as exist FROM nodes WHERE node_id=%s""",
467 (self.nodeDbId,))
468 if not cursor.fetchone():
469 raise error.NodeNotFound()
470
471 def getType(self):
472 return self.nodeType
473
474 def getOwners(self):
475 d = self.dbpool.runQuery("""SELECT jid FROM nodes NATURAL JOIN affiliations NATURAL JOIN entities WHERE node_id=%s and affiliation='owner'""", (self.nodeDbId,))
476 d.addCallback(lambda rows: [jid.JID(r[0]) for r in rows])
477 return d
478
479 def getConfiguration(self):
480 return self._config
481
482 def getNextId(self):
483 """return XMPP item id usable for next item to publish
484
485 the return value will be next int if serila_ids is set,
486 else an UUID will be returned
487 """
488 if self._config[const.OPT_SERIAL_IDS]:
489 d = self.dbpool.runQuery("SELECT nextval('{seq_name}')".format(
490 seq_name = ITEMS_SEQ_NAME.format(node_id=self.nodeDbId)))
491 d.addCallback(lambda rows: unicode(rows[0][0]))
492 return d
493 else:
494 return defer.succeed(unicode(uuid.uuid4()))
495
496 @staticmethod
497 def _configurationTriggers(cursor, node_id, old_config, new_config):
498 """trigger database relative actions needed when a config is changed
499
500 @param cursor(): current db cursor
501 @param node_id(unicode): database ID of the node
502 @param old_config(dict): config of the node before the change
503 @param new_config(dict): new options that will be changed
504 """
505 serial_ids = new_config[const.OPT_SERIAL_IDS]
506 if serial_ids != old_config[const.OPT_SERIAL_IDS]:
507 # serial_ids option has been modified,
508 # we need to handle corresponding sequence
509
510 # XXX: we use .format in following queries because values
511 # are generated by ourself
512 seq_name = ITEMS_SEQ_NAME.format(node_id=node_id)
513 if serial_ids:
514 # the next query get the max value +1 of all XMPP items ids
515 # which are integers, and default to 1
516 cursor.execute(NEXT_ITEM_ID_QUERY.format(node_id=node_id))
517 next_val = cursor.fetchone()[0]
518 cursor.execute("DROP SEQUENCE IF EXISTS {seq_name}".format(seq_name = seq_name))
519 cursor.execute("CREATE SEQUENCE {seq_name} START {next_val} OWNED BY nodes.node_id".format(
520 seq_name = seq_name,
521 next_val = next_val))
522 else:
523 cursor.execute("DROP SEQUENCE IF EXISTS {seq_name}".format(seq_name = seq_name))
524
525 def setConfiguration(self, options):
526 config = copy.copy(self._config)
527
528 for option in options:
529 if option in config:
530 config[option] = options[option]
531
532 d = self.dbpool.runInteraction(self._setConfiguration, config)
533 d.addCallback(self._setCachedConfiguration, config)
534 return d
535
536 def _setConfiguration(self, cursor, config):
537 self._checkNodeExists(cursor)
538 self._configurationTriggers(cursor, self.nodeDbId, self._config, config)
539 cursor.execute("""UPDATE nodes SET persist_items=%s,
540 deliver_payloads=%s,
541 send_last_published_item=%s,
542 access_model=%s,
543 publish_model=%s,
544 serial_ids=%s,
545 consistent_publisher=%s
546 WHERE node_id=%s""",
547 (config[const.OPT_PERSIST_ITEMS],
548 config[const.OPT_DELIVER_PAYLOADS],
549 config[const.OPT_SEND_LAST_PUBLISHED_ITEM],
550 config[const.OPT_ACCESS_MODEL],
551 config[const.OPT_PUBLISH_MODEL],
552 config[const.OPT_SERIAL_IDS],
553 config[const.OPT_CONSISTENT_PUBLISHER],
554 self.nodeDbId))
555
556 def _setCachedConfiguration(self, void, config):
557 self._config = config
558
559 def getSchema(self):
560 return self._schema
561
562 def setSchema(self, schema):
563 d = self.dbpool.runInteraction(self._setSchema, schema)
564 d.addCallback(self._setCachedSchema, schema)
565 return d
566
567 def _setSchema(self, cursor, schema):
568 self._checkNodeExists(cursor)
569 cursor.execute("""UPDATE nodes SET schema=%s
570 WHERE node_id=%s""",
571 (schema.toXml() if schema else None,
572 self.nodeDbId))
573
574 def _setCachedSchema(self, void, schema):
575 self._schema = schema
576
577 def getMetaData(self):
578 config = copy.copy(self._config)
579 config["pubsub#node_type"] = self.nodeType
580 return config
581
582 def getAffiliation(self, entity):
583 return self.dbpool.runInteraction(self._getAffiliation, entity)
584
585 def _getAffiliation(self, cursor, entity):
586 self._checkNodeExists(cursor)
587 cursor.execute("""SELECT affiliation FROM affiliations
588 NATURAL JOIN nodes
589 NATURAL JOIN entities
590 WHERE node_id=%s AND jid=%s""",
591 (self.nodeDbId,
592 entity.userhost()))
593
594 try:
595 return cursor.fetchone()[0]
596 except TypeError:
597 return None
598
599 def getAccessModel(self):
600 return self._config[const.OPT_ACCESS_MODEL]
601
602 def getSubscription(self, subscriber):
603 return self.dbpool.runInteraction(self._getSubscription, subscriber)
604
605 def _getSubscription(self, cursor, subscriber):
606 self._checkNodeExists(cursor)
607
608 userhost = subscriber.userhost()
609 resource = subscriber.resource or ''
610
611 cursor.execute("""SELECT state FROM subscriptions
612 NATURAL JOIN nodes
613 NATURAL JOIN entities
614 WHERE node_id=%s AND jid=%s AND resource=%s""",
615 (self.nodeDbId,
616 userhost,
617 resource))
618
619 row = cursor.fetchone()
620 if not row:
621 return None
622 else:
623 return Subscription(self.nodeIdentifier, subscriber, row[0])
624
625 def getSubscriptions(self, state=None):
626 return self.dbpool.runInteraction(self._getSubscriptions, state)
627
628 def _getSubscriptions(self, cursor, state):
629 self._checkNodeExists(cursor)
630
631 query = """SELECT node, jid, resource, state,
632 subscription_type, subscription_depth
633 FROM subscriptions
634 NATURAL JOIN nodes
635 NATURAL JOIN entities
636 WHERE node_id=%s"""
637 values = [self.nodeDbId]
638
639 if state:
640 query += " AND state=%s"
641 values.append(state)
642
643 cursor.execute(query, values)
644 rows = cursor.fetchall()
645
646 subscriptions = []
647 for row in rows:
648 subscriber = jid.JID(u'%s/%s' % (row.jid, row.resource))
649
650 options = {}
651 if row.subscription_type:
652 options['pubsub#subscription_type'] = row.subscription_type;
653 if row.subscription_depth:
654 options['pubsub#subscription_depth'] = row.subscription_depth;
655
656 subscriptions.append(Subscription(row.node, subscriber,
657 row.state, options))
658
659 return subscriptions
660
661 def addSubscription(self, subscriber, state, config):
662 return self.dbpool.runInteraction(self._addSubscription, subscriber,
663 state, config)
664
665 def _addSubscription(self, cursor, subscriber, state, config):
666 self._checkNodeExists(cursor)
667
668 userhost = subscriber.userhost()
669 resource = subscriber.resource or ''
670
671 subscription_type = config.get('pubsub#subscription_type')
672 subscription_depth = config.get('pubsub#subscription_depth')
673
674 try:
675 cursor.execute("""INSERT INTO entities (jid) VALUES (%s)""",
676 (userhost,))
677 except cursor._pool.dbapi.IntegrityError:
678 cursor.connection.rollback()
679
680 try:
681 cursor.execute("""INSERT INTO subscriptions
682 (node_id, entity_id, resource, state,
683 subscription_type, subscription_depth)
684 SELECT %s, entity_id, %s, %s, %s, %s FROM
685 (SELECT entity_id FROM entities
686 WHERE jid=%s) AS ent_id""",
687 (self.nodeDbId,
688 resource,
689 state,
690 subscription_type,
691 subscription_depth,
692 userhost))
693 except cursor._pool.dbapi.IntegrityError:
694 raise error.SubscriptionExists()
695
696 def removeSubscription(self, subscriber):
697 return self.dbpool.runInteraction(self._removeSubscription,
698 subscriber)
699
700 def _removeSubscription(self, cursor, subscriber):
701 self._checkNodeExists(cursor)
702
703 userhost = subscriber.userhost()
704 resource = subscriber.resource or ''
705
706 cursor.execute("""DELETE FROM subscriptions WHERE
707 node_id=%s AND
708 entity_id=(SELECT entity_id FROM entities
709 WHERE jid=%s) AND
710 resource=%s""",
711 (self.nodeDbId,
712 userhost,
713 resource))
714 if cursor.rowcount != 1:
715 raise error.NotSubscribed()
716
717 return None
718
719 def setSubscriptions(self, subscriptions):
720 return self.dbpool.runInteraction(self._setSubscriptions, subscriptions)
721
722 def _setSubscriptions(self, cursor, subscriptions):
723 self._checkNodeExists(cursor)
724
725 entities = self.getOrCreateEntities(cursor, [s.subscriber for s in subscriptions])
726 entities_map = {jid.JID(e.jid): e for e in entities}
727
728 # then we construct values for subscriptions update according to entity_id we just got
729 placeholders = ','.join(len(subscriptions) * ["%s"])
730 values = []
731 for subscription in subscriptions:
732 entity_id = entities_map[subscription.subscriber].entity_id
733 resource = subscription.subscriber.resource or u''
734 values.append((self.nodeDbId, entity_id, resource, subscription.state, None, None))
735 # we use upsert so new values are inserted and existing one updated. This feature is only available for PostgreSQL >= 9.5
736 cursor.execute("INSERT INTO subscriptions(node_id, entity_id, resource, state, subscription_type, subscription_depth) VALUES " + placeholders + " ON CONFLICT (entity_id, resource, node_id) DO UPDATE SET state=EXCLUDED.state", [v for v in values])
737
738 def isSubscribed(self, entity):
739 return self.dbpool.runInteraction(self._isSubscribed, entity)
740
741 def _isSubscribed(self, cursor, entity):
742 self._checkNodeExists(cursor)
743
744 cursor.execute("""SELECT 1 as bool FROM entities
745 NATURAL JOIN subscriptions
746 NATURAL JOIN nodes
747 WHERE entities.jid=%s
748 AND node_id=%s AND state='subscribed'""",
749 (entity.userhost(),
750 self.nodeDbId))
751
752 return cursor.fetchone() is not None
753
754 def getAffiliations(self):
755 return self.dbpool.runInteraction(self._getAffiliations)
756
757 def _getAffiliations(self, cursor):
758 self._checkNodeExists(cursor)
759
760 cursor.execute("""SELECT jid, affiliation FROM nodes
761 NATURAL JOIN affiliations
762 NATURAL JOIN entities
763 WHERE node_id=%s""",
764 (self.nodeDbId,))
765 result = cursor.fetchall()
766
767 return {jid.internJID(r[0]): r[1] for r in result}
768
769 def getOrCreateEntities(self, cursor, entities_jids):
770 """Get entity_id from entities in entities table
771
772 Entities will be inserted it they don't exist
773 @param entities_jid(list[jid.JID]): entities to get or create
774 @return list[record(entity_id,jid)]]: list of entity_id and jid (as plain string)
775 both existing and inserted entities are returned
776 """
777 # cf. http://stackoverflow.com/a/35265559
778 placeholders = ','.join(len(entities_jids) * ["(%s)"])
779 query = (
780 """
781 WITH
782 jid_values (jid) AS (
783 VALUES {placeholders}
784 ),
785 inserted (entity_id, jid) AS (
786 INSERT INTO entities (jid)
787 SELECT jid
788 FROM jid_values
789 ON CONFLICT DO NOTHING
790 RETURNING entity_id, jid
791 )
792 SELECT e.entity_id, e.jid
793 FROM entities e JOIN jid_values jv ON jv.jid = e.jid
794 UNION ALL
795 SELECT entity_id, jid
796 FROM inserted""".format(placeholders=placeholders))
797 cursor.execute(query, [j.userhost() for j in entities_jids])
798 return cursor.fetchall()
799
800 def setAffiliations(self, affiliations):
801 return self.dbpool.runInteraction(self._setAffiliations, affiliations)
802
803 def _setAffiliations(self, cursor, affiliations):
804 self._checkNodeExists(cursor)
805
806 entities = self.getOrCreateEntities(cursor, affiliations)
807
808 # then we construct values for affiliations update according to entity_id we just got
809 placeholders = ','.join(len(affiliations) * ["(%s,%s,%s)"])
810 values = []
811 map(values.extend, ((e.entity_id, affiliations[jid.JID(e.jid)], self.nodeDbId) for e in entities))
812
813 # we use upsert so new values are inserted and existing one updated. This feature is only available for PostgreSQL >= 9.5
814 cursor.execute("INSERT INTO affiliations(entity_id,affiliation,node_id) VALUES " + placeholders + " ON CONFLICT (entity_id,node_id) DO UPDATE SET affiliation=EXCLUDED.affiliation", values)
815
816 def deleteAffiliations(self, entities):
817 return self.dbpool.runInteraction(self._deleteAffiliations, entities)
818
819 def _deleteAffiliations(self, cursor, entities):
820 """delete affiliations and subscriptions for this entity"""
821 self._checkNodeExists(cursor)
822 placeholders = ','.join(len(entities) * ["%s"])
823 cursor.execute("DELETE FROM affiliations WHERE node_id=%s AND entity_id in (SELECT entity_id FROM entities WHERE jid IN (" + placeholders + ")) RETURNING entity_id", [self.nodeDbId] + [e.userhost() for e in entities])
824
825 rows = cursor.fetchall()
826 placeholders = ','.join(len(rows) * ["%s"])
827 cursor.execute("DELETE FROM subscriptions WHERE node_id=%s AND entity_id in (" + placeholders + ")", [self.nodeDbId] + [r[0] for r in rows])
828
829 def getAuthorizedGroups(self):
830 return self.dbpool.runInteraction(self._getNodeGroups)
831
832 def _getAuthorizedGroups(self, cursor):
833 cursor.execute("SELECT groupname FROM node_groups_authorized NATURAL JOIN nodes WHERE node=%s",
834 (self.nodeDbId,))
835 rows = cursor.fetchall()
836 return [row[0] for row in rows]
837
838
839 class LeafNode(Node):
840
841 implements(iidavoll.ILeafNode)
842
843 nodeType = 'leaf'
844
845 def getOrderBy(self, ext_data, direction='DESC'):
846 """Return ORDER BY clause corresponding to Order By key in ext_data
847
848 @param ext_data (dict): extra data as used in getItems
849 @param direction (unicode): ORDER BY direction (ASC or DESC)
850 @return (unicode): ORDER BY clause to use
851 """
852 keys = ext_data.get('order_by')
853 if not keys:
854 return u'ORDER BY updated ' + direction
855 cols_statmnt = []
856 for key in keys:
857 if key == 'creation':
858 column = 'item_id' # could work with items.created too
859 elif key == 'modification':
860 column = 'updated'
861 else:
862 log.msg(u"WARNING: Unknown order by key: {key}".format(key=key))
863 column = 'updated'
864 cols_statmnt.append(column + u' ' + direction)
865
866 return u"ORDER BY " + u",".join([col for col in cols_statmnt])
867
868 @defer.inlineCallbacks
869 def storeItems(self, items_data, publisher):
870 # XXX: runInteraction doesn't seem to work when there are several "insert"
871 # or "update".
872 # Before the unpacking was done in _storeItems, but this was causing trouble
873 # in case of multiple items_data. So this has now be moved here.
874 # FIXME: investigate the issue with runInteraction
875 for item_data in items_data:
876 yield self.dbpool.runInteraction(self._storeItems, item_data, publisher)
877
878 def _storeItems(self, cursor, item_data, publisher):
879 self._checkNodeExists(cursor)
880 self._storeItem(cursor, item_data, publisher)
881
882 def _storeItem(self, cursor, item_data, publisher):
883 # first try to insert the item
884 # - if it fails (conflict), and the item is new and we have serial_ids options,
885 # current id will be recomputed using next item id query (note that is not perfect, as
886 # table is not locked and this can fail if two items are added at the same time
887 # but this can only happen with serial_ids and if future ids have been set by a client,
888 # this case should be rare enough to consider this situation acceptable)
889 # - if item insertion fail and the item is not new, we do an update
890 # - in other cases, exception is raised
891 item, access_model, item_config = item_data.item, item_data.access_model, item_data.config
892 data = item.toXml()
893
894 insert_query = """INSERT INTO items (node_id, item, publisher, data, access_model)
895 SELECT %s, %s, %s, %s, %s FROM nodes
896 WHERE node_id=%s
897 RETURNING item_id"""
898 insert_data = [self.nodeDbId,
899 item["id"],
900 publisher.full(),
901 data,
902 access_model,
903 self.nodeDbId]
904
905 try:
906 cursor.execute(insert_query, insert_data)
907 except cursor._pool.dbapi.IntegrityError as e:
908 if e.pgcode != "23505":
909 # we only handle unique_violation, every other exception must be raised
910 raise e
911 cursor.connection.rollback()
912 # the item already exist
913 if item_data.new:
914 # the item is new
915 if self._config[const.OPT_SERIAL_IDS]:
916 # this can happen with serial_ids, if a item has been stored
917 # with a future id (generated by XMPP client)
918 cursor.execute(NEXT_ITEM_ID_QUERY.format(node_id=self.nodeDbId))
919 next_id = cursor.fetchone()[0]
920 # we update the sequence, so we can skip conflicting ids
921 cursor.execute(u"SELECT setval('{seq_name}', %s)".format(
922 seq_name = ITEMS_SEQ_NAME.format(node_id=self.nodeDbId)), [next_id])
923 # and now we can retry the query with the new id
924 item['id'] = insert_data[1] = unicode(next_id)
925 # item saved in DB must also be updated with the new id
926 insert_data[3] = item.toXml()
927 cursor.execute(insert_query, insert_data)
928 else:
929 # but if we have not serial_ids, we have a real problem
930 raise e
931 else:
932 # this is an update
933 cursor.execute("""UPDATE items SET updated=now(), publisher=%s, data=%s
934 FROM nodes
935 WHERE nodes.node_id = items.node_id AND
936 nodes.node_id = %s and items.item=%s
937 RETURNING item_id""",
938 (publisher.full(),
939 data,
940 self.nodeDbId,
941 item["id"]))
942 if cursor.rowcount != 1:
943 raise exceptions.InternalError("item has not been updated correctly")
944 item_id = cursor.fetchone()[0];
945 self._storeCategories(cursor, item_id, item_data.categories, update=True)
946 return
947
948 item_id = cursor.fetchone()[0];
949 self._storeCategories(cursor, item_id, item_data.categories)
950
951 if access_model == const.VAL_AMODEL_PUBLISHER_ROSTER:
952 if const.OPT_ROSTER_GROUPS_ALLOWED in item_config:
953 item_config.fields[const.OPT_ROSTER_GROUPS_ALLOWED].fieldType='list-multi' #XXX: needed to force list if there is only one value
954 allowed_groups = item_config[const.OPT_ROSTER_GROUPS_ALLOWED]
955 else:
956 allowed_groups = []
957 for group in allowed_groups:
958 #TODO: check that group are actually in roster
959 cursor.execute("""INSERT INTO item_groups_authorized (item_id, groupname)
960 VALUES (%s,%s)""" , (item_id, group))
961 # TODO: whitelist access model
962
963 def _storeCategories(self, cursor, item_id, categories, update=False):
964 # TODO: handle canonical form
965 if update:
966 cursor.execute("""DELETE FROM item_categories
967 WHERE item_id=%s""", (item_id,))
968
969 # we use a set to avoid duplicates
970 for category in set(categories):
971 cursor.execute("""INSERT INTO item_categories (item_id, category)
972 VALUES (%s, %s)""", (item_id, category))
973
974 def removeItems(self, itemIdentifiers):
975 return self.dbpool.runInteraction(self._removeItems, itemIdentifiers)
976
977 def _removeItems(self, cursor, itemIdentifiers):
978 self._checkNodeExists(cursor)
979
980 deleted = []
981
982 for itemIdentifier in itemIdentifiers:
983 cursor.execute("""DELETE FROM items WHERE
984 node_id=%s AND
985 item=%s""",
986 (self.nodeDbId,
987 itemIdentifier))
988
989 if cursor.rowcount:
990 deleted.append(itemIdentifier)
991
992 return deleted
993
994 def getItems(self, authorized_groups, unrestricted, maxItems=None, ext_data=None):
995 """ Get all authorised items
996
997 @param authorized_groups: we want to get items that these groups can access
998 @param unrestricted: if true, don't check permissions (i.e.: get all items)
999 @param maxItems: nb of items we want to get
1000 @param ext_data: options for extra features like RSM and MAM
1001
1002 @return: list of container.ItemData
1003 if unrestricted is False, access_model and config will be None
1004 """
1005 if ext_data is None:
1006 ext_data = {}
1007 return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems, ext_data, ids_only=False)
1008
1009 def getItemsIds(self, authorized_groups, unrestricted, maxItems=None, ext_data=None):
1010 """ Get all authorised items ids
1011
1012 @param authorized_groups: we want to get items that these groups can access
1013 @param unrestricted: if true, don't check permissions (i.e.: get all items)
1014 @param maxItems: nb of items we want to get
1015 @param ext_data: options for extra features like RSM and MAM
1016
1017 @return list(unicode): list of ids
1018 """
1019 if ext_data is None:
1020 ext_data = {}
1021 return self.dbpool.runInteraction(self._getItems, authorized_groups, unrestricted, maxItems, ext_data, ids_only=True)
1022
1023 def _appendSourcesAndFilters(self, query, args, authorized_groups, unrestricted, ext_data):
1024 """append sources and filters to sql query requesting items and return ORDER BY
1025
1026 arguments query, args, authorized_groups, unrestricted and ext_data are the same as for
1027 _getItems
1028 """
1029 # SOURCES
1030 query.append("FROM nodes INNER JOIN items USING (node_id)")
1031
1032 if unrestricted:
1033 query_filters = ["WHERE node_id=%s"]
1034 args.append(self.nodeDbId)
1035 else:
1036 query.append("LEFT JOIN item_groups_authorized USING (item_id)")
1037 args.append(self.nodeDbId)
1038 if authorized_groups:
1039 get_groups = " or (items.access_model='roster' and groupname in %s)"
1040 args.append(authorized_groups)
1041 else:
1042 get_groups = ""
1043
1044 query_filters = ["WHERE node_id=%s AND (items.access_model='open'" + get_groups + ")"]
1045
1046 # FILTERS
1047 if 'filters' in ext_data: # MAM filters
1048 for filter_ in ext_data['filters']:
1049 if filter_.var == 'start':
1050 query_filters.append("AND created>=%s")
1051 args.append(filter_.value)
1052 elif filter_.var == 'end':
1053 query_filters.append("AND created<=%s")
1054 args.append(filter_.value)
1055 elif filter_.var == 'with':
1056 jid_s = filter_.value
1057 if '/' in jid_s:
1058 query_filters.append("AND publisher=%s")
1059 args.append(filter_.value)
1060 else:
1061 query_filters.append("AND publisher LIKE %s")
1062 args.append(u"{}%".format(filter_.value))
1063 elif filter_.var == const.MAM_FILTER_CATEGORY:
1064 query.append("LEFT JOIN item_categories USING (item_id)")
1065 query_filters.append("AND category=%s")
1066 args.append(filter_.value)
1067 else:
1068 log.msg("WARNING: unknown filter: {}".format(filter_.encode('utf-8')))
1069
1070 query.extend(query_filters)
1071
1072 return self.getOrderBy(ext_data)
1073
1074 def _getItems(self, cursor, authorized_groups, unrestricted, maxItems, ext_data, ids_only):
1075 self._checkNodeExists(cursor)
1076
1077 if maxItems == 0:
1078 return []
1079
1080 args = []
1081
1082 # SELECT
1083 if ids_only:
1084 query = ["SELECT item"]
1085 else:
1086 query = ["SELECT data::text,items.access_model,item_id,created,updated"]
1087
1088 query_order = self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
1089
1090 if 'rsm' in ext_data:
1091 rsm = ext_data['rsm']
1092 maxItems = rsm.max
1093 if rsm.index is not None:
1094 # We need to know the item_id of corresponding to the index (offset) of the current query
1095 # so we execute the query to look for the item_id
1096 tmp_query = query[:]
1097 tmp_args = args[:]
1098 tmp_query[0] = "SELECT item_id"
1099 tmp_query.append("{} LIMIT 1 OFFSET %s".format(query_order))
1100 tmp_args.append(rsm.index)
1101 cursor.execute(' '.join(query), args)
1102 # FIXME: bad index is not managed yet
1103 item_id = cursor.fetchall()[0][0]
1104
1105 # now that we have the id, we can use it
1106 query.append("AND item_id<=%s")
1107 args.append(item_id)
1108 elif rsm.before is not None:
1109 if rsm.before != '':
1110 query.append("AND item_id>(SELECT item_id FROM items WHERE item=%s LIMIT 1)")
1111 args.append(rsm.before)
1112 if maxItems is not None:
1113 # if we have maxItems (i.e. a limit), we need to reverse order
1114 # in a first query to get the right items
1115 query.insert(0,"SELECT * from (")
1116 query.append(self.getOrderBy(ext_data, direction='ASC'))
1117 query.append("LIMIT %s) as x")
1118 args.append(maxItems)
1119 elif rsm.after:
1120 query.append("AND item_id<(SELECT item_id FROM items WHERE item=%s LIMIT 1)")
1121 args.append(rsm.after)
1122
1123 query.append(query_order)
1124
1125 if maxItems is not None:
1126 query.append("LIMIT %s")
1127 args.append(maxItems)
1128
1129 cursor.execute(' '.join(query), args)
1130
1131 result = cursor.fetchall()
1132 if unrestricted and not ids_only:
1133 # with unrestricted query, we need to fill the access_list for a roster access items
1134 ret = []
1135 for item_data in result:
1136 item = generic.stripNamespace(parseXml(item_data.data))
1137 access_model = item_data.access_model
1138 item_id = item_data.item_id
1139 created = item_data.created
1140 updated = item_data.updated
1141 access_list = {}
1142 if access_model == const.VAL_AMODEL_PUBLISHER_ROSTER:
1143 cursor.execute('SELECT groupname FROM item_groups_authorized WHERE item_id=%s', (item_id,))
1144 access_list[const.OPT_ROSTER_GROUPS_ALLOWED] = [r.groupname for r in cursor.fetchall()]
1145
1146 ret.append(container.ItemData(item, access_model, access_list, created=created, updated=updated))
1147 # TODO: whitelist item access model
1148 return ret
1149
1150 if ids_only:
1151 return [r.item for r in result]
1152 else:
1153 items_data = [container.ItemData(generic.stripNamespace(parseXml(r.data)), r.access_model, created=r.created, updated=r.updated) for r in result]
1154 return items_data
1155
1156 def getItemsById(self, authorized_groups, unrestricted, itemIdentifiers):
1157 """Get items which are in the given list
1158
1159 @param authorized_groups: we want to get items that these groups can access
1160 @param unrestricted: if true, don't check permissions
1161 @param itemIdentifiers: list of ids of the items we want to get
1162 @return: list of container.ItemData
1163 ItemData.config will contains access_list (managed as a dictionnary with same key as for item_config)
1164 if unrestricted is False, access_model and config will be None
1165 """
1166 return self.dbpool.runInteraction(self._getItemsById, authorized_groups, unrestricted, itemIdentifiers)
1167
1168 def _getItemsById(self, cursor, authorized_groups, unrestricted, itemIdentifiers):
1169 self._checkNodeExists(cursor)
1170 ret = []
1171 if unrestricted: #we get everything without checking permissions
1172 for itemIdentifier in itemIdentifiers:
1173 cursor.execute("""SELECT data::text,items.access_model,item_id,created,updated FROM nodes
1174 INNER JOIN items USING (node_id)
1175 WHERE node_id=%s AND item=%s""",
1176 (self.nodeDbId,
1177 itemIdentifier))
1178 result = cursor.fetchone()
1179 if not result:
1180 raise error.ItemNotFound()
1181
1182 item = generic.stripNamespace(parseXml(result[0]))
1183 access_model = result[1]
1184 item_id = result[2]
1185 created= result[3]
1186 updated= result[4]
1187 access_list = {}
1188 if access_model == const.VAL_AMODEL_PUBLISHER_ROSTER:
1189 cursor.execute('SELECT groupname FROM item_groups_authorized WHERE item_id=%s', (item_id,))
1190 access_list[const.OPT_ROSTER_GROUPS_ALLOWED] = [r[0] for r in cursor.fetchall()]
1191 #TODO: WHITELIST access_model
1192
1193 ret.append(container.ItemData(item, access_model, access_list, created=created, updated=updated))
1194 else: #we check permission before returning items
1195 for itemIdentifier in itemIdentifiers:
1196 args = [self.nodeDbId, itemIdentifier]
1197 if authorized_groups:
1198 args.append(authorized_groups)
1199 cursor.execute("""SELECT data::text, created, updated FROM nodes
1200 INNER JOIN items USING (node_id)
1201 LEFT JOIN item_groups_authorized USING (item_id)
1202 WHERE node_id=%s AND item=%s AND
1203 (items.access_model='open' """ +
1204 ("or (items.access_model='roster' and groupname in %s)" if authorized_groups else '') + ")",
1205 args)
1206
1207 result = cursor.fetchone()
1208 if result:
1209 ret.append(container.ItemData(generic.stripNamespace(parseXml(result[0])), created=result[1], updated=result[2]))
1210
1211 return ret
1212
1213 def getItemsCount(self, authorized_groups, unrestricted, ext_data=None):
1214 """Count expected number of items in a getItems query
1215
1216 @param authorized_groups: we want to get items that these groups can access
1217 @param unrestricted: if true, don't check permissions (i.e.: get all items)
1218 @param ext_data: options for extra features like RSM and MAM
1219 """
1220 if ext_data is None:
1221 ext_data = {}
1222 return self.dbpool.runInteraction(self._getItemsCount, authorized_groups, unrestricted, ext_data)
1223
1224 def _getItemsCount(self, cursor, authorized_groups, unrestricted, ext_data):
1225 self._checkNodeExists(cursor)
1226 args = []
1227
1228 # SELECT
1229 query = ["SELECT count(1)"]
1230
1231 self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
1232
1233 cursor.execute(' '.join(query), args)
1234 return cursor.fetchall()[0][0]
1235
1236 def getItemsIndex(self, item_id, authorized_groups, unrestricted, ext_data=None):
1237 """Get expected index of first item in the window of a getItems query
1238
1239 @param item_id: id of the item
1240 @param authorized_groups: we want to get items that these groups can access
1241 @param unrestricted: if true, don't check permissions (i.e.: get all items)
1242 @param ext_data: options for extra features like RSM and MAM
1243 """
1244 if ext_data is None:
1245 ext_data = {}
1246 return self.dbpool.runInteraction(self._getItemsIndex, item_id, authorized_groups, unrestricted, ext_data)
1247
1248 def _getItemsIndex(self, cursor, item_id, authorized_groups, unrestricted, ext_data):
1249 self._checkNodeExists(cursor)
1250 args = []
1251
1252 # SELECT
1253 query = []
1254
1255 query_order = self._appendSourcesAndFilters(query, args, authorized_groups, unrestricted, ext_data)
1256
1257 query_select = "SELECT row_number from (SELECT row_number() OVER ({}), item".format(query_order)
1258 query.insert(0, query_select)
1259 query.append(") as x WHERE item=%s")
1260 args.append(item_id)
1261
1262 cursor.execute(' '.join(query), args)
1263 # XXX: row_number start at 1, but we want that index start at 0
1264 try:
1265 return cursor.fetchall()[0][0] - 1
1266 except IndexError:
1267 raise error.NodeNotFound()
1268
1269 def getItemsPublishers(self, itemIdentifiers):
1270 """Get the publishers for all given identifiers
1271
1272 @return (dict[unicode, jid.JID]): map of itemIdentifiers to publisher
1273 if item is not found, key is skipped in resulting dict
1274 """
1275 return self.dbpool.runInteraction(self._getItemsPublishers, itemIdentifiers)
1276
1277 def _getItemsPublishers(self, cursor, itemIdentifiers):
1278 self._checkNodeExists(cursor)
1279 ret = {}
1280 for itemIdentifier in itemIdentifiers:
1281 cursor.execute("""SELECT publisher FROM items
1282 WHERE node_id=%s AND item=%s""",
1283 (self.nodeDbId, itemIdentifier,))
1284 result = cursor.fetchone()
1285 if result:
1286 ret[itemIdentifier] = jid.JID(result[0])
1287 return ret
1288
1289 def purge(self):
1290 return self.dbpool.runInteraction(self._purge)
1291
1292 def _purge(self, cursor):
1293 self._checkNodeExists(cursor)
1294
1295 cursor.execute("""DELETE FROM items WHERE
1296 node_id=%s""",
1297 (self.nodeDbId,))
1298
1299
1300 class CollectionNode(Node):
1301
1302 nodeType = 'collection'
1303
1304
1305
1306 class GatewayStorage(object):
1307 """
1308 Memory based storage facility for the XMPP-HTTP gateway.
1309 """
1310
1311 def __init__(self, dbpool):
1312 self.dbpool = dbpool
1313
1314 def _countCallbacks(self, cursor, service, nodeIdentifier):
1315 """
1316 Count number of callbacks registered for a node.
1317 """
1318 cursor.execute("""SELECT count(*) FROM callbacks
1319 WHERE service=%s and node=%s""",
1320 (service.full(),
1321 nodeIdentifier))
1322 results = cursor.fetchall()
1323 return results[0][0]
1324
1325 def addCallback(self, service, nodeIdentifier, callback):
1326 def interaction(cursor):
1327 cursor.execute("""SELECT 1 as bool FROM callbacks
1328 WHERE service=%s and node=%s and uri=%s""",
1329 (service.full(),
1330 nodeIdentifier,
1331 callback))
1332 if cursor.fetchall():
1333 return
1334
1335 cursor.execute("""INSERT INTO callbacks
1336 (service, node, uri) VALUES
1337 (%s, %s, %s)""",
1338 (service.full(),
1339 nodeIdentifier,
1340 callback))
1341
1342 return self.dbpool.runInteraction(interaction)
1343
1344 def removeCallback(self, service, nodeIdentifier, callback):
1345 def interaction(cursor):
1346 cursor.execute("""DELETE FROM callbacks
1347 WHERE service=%s and node=%s and uri=%s""",
1348 (service.full(),
1349 nodeIdentifier,
1350 callback))
1351
1352 if cursor.rowcount != 1:
1353 raise error.NotSubscribed()
1354
1355 last = not self._countCallbacks(cursor, service, nodeIdentifier)
1356 return last
1357
1358 return self.dbpool.runInteraction(interaction)
1359
1360 def getCallbacks(self, service, nodeIdentifier):
1361 def interaction(cursor):
1362 cursor.execute("""SELECT uri FROM callbacks
1363 WHERE service=%s and node=%s""",
1364 (service.full(),
1365 nodeIdentifier))
1366 results = cursor.fetchall()
1367
1368 if not results:
1369 raise error.NoCallbacks()
1370
1371 return [result[0] for result in results]
1372
1373 return self.dbpool.runInteraction(interaction)
1374
1375 def hasCallbacks(self, service, nodeIdentifier):
1376 def interaction(cursor):
1377 return bool(self._countCallbacks(cursor, service, nodeIdentifier))
1378
1379 return self.dbpool.runInteraction(interaction)