changeset 353:7c5d85c6fb3a

backend: enforce schema on get/publish and notifications
author Goffi <goffi@goffi.org>
date Fri, 08 Sep 2017 08:02:05 +0200
parents efbdca10f0fb
children 18b983fe9e1b
files sat_pubsub/backend.py
diffstat 1 files changed, 100 insertions(+), 60 deletions(-) [+]
line wrap: on
line diff
--- a/sat_pubsub/backend.py	Fri Sep 08 08:02:05 2017 +0200
+++ b/sat_pubsub/backend.py	Fri Sep 08 08:02:05 2017 +0200
@@ -163,27 +163,21 @@
         self.storage = storage
         self._callbackList = []
 
-
     def supportsPublisherAffiliation(self):
         return True
 
-
     def supportsGroupBlog(self):
         return True
 
-
     def supportsOutcastAffiliation(self):
         return True
 
-
     def supportsPersistentItems(self):
         return True
 
-
     def supportsPublishModel(self):
         return True
 
-
     def getNodeType(self, nodeIdentifier, pep, recipient=None):
         # FIXME: manage pep and recipient
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
@@ -206,7 +200,6 @@
             return d
         return self.storage.getNodeIds(pep, recipient)
 
-
     def getNodeMetaData(self, nodeIdentifier, pep, recipient=None):
         # FIXME: manage pep and recipient
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
@@ -214,7 +207,6 @@
         d.addCallback(self._makeMetaData)
         return d
 
-
     def _makeMetaData(self, metaData):
         options = []
         for key, value in metaData.iteritems():
@@ -226,7 +218,6 @@
 
         return options
 
-
     def _checkAuth(self, node, requestor):
         """ Check authorisation of publishing in node for requestor """
 
@@ -298,6 +289,52 @@
 
         return categories
 
+    def enforceSchema(self, item_elt, schema, affiliation):
+        """modifify item according to element, or refuse publishing
+
+        @param item_elt(domish.Element): item to check/modify
+        @param schema(domish.Eement): schema to enfore
+        @param affiliation(unicode): affiliation of the publisher
+        """
+        try:
+            x_elt = next(item_elt.elements(data_form.NS_X_DATA, 'x'))
+            item_form = data_form.Form.fromElement(x_elt)
+        except (StopIteration, data_form.Error):
+            raise pubsub.BadRequest(text="node has a schema but item has no form")
+        else:
+            item_elt.children.remove(x_elt)
+
+        schema_form = data_form.Form.fromElement(schema)
+
+        # we enforce restrictions
+        for field_elt in schema.elements(data_form.NS_X_DATA, 'field'):
+            var = field_elt['var']
+            for restrict_elt in field_elt.elements(const.NS_SCHEMA_RESTRICT, 'restrict'):
+                write_restriction = restrict_elt.attributes.get('write')
+                if write_restriction is not None:
+                    if write_restriction == 'owner':
+                        if affiliation != 'owner':
+                            # write is not allowed on this field, we use default value
+                            # we can safely use Field from schema_form because
+                            # we have created this instance only for this method
+                            try:
+                                item_form.removeField(item_form.fields[var])
+                            except KeyError:
+                                pass
+                            item_form.addField(schema_form.fields[var])
+                    else:
+                        raise StanzaError('feature-not-implemented', text='unknown write restriction {}'.format(write_restriction))
+
+        # we now remove every field which is not in data schema
+        to_remove = set()
+        for item_var, item_field in item_form.fields.iteritems():
+            if item_var not in schema_form.fields:
+                to_remove.add(item_field)
+
+        for field in to_remove:
+            item_form.removeField(field)
+        item_elt.addChild(item_form.toElement())
+
     def _checkOverwrite(self, node, itemIdentifiers, publisher):
         """Check that the itemIdentifiers correspond to items published
         by the current publisher"""
@@ -310,7 +347,6 @@
         d.addCallback(doCheck)
         return d
 
-
     def publish(self, nodeIdentifier, items, requestor, pep, recipient):
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
         d.addCallback(self._checkAuth, requestor)
@@ -319,7 +355,6 @@
         d.addCallback(self._doPublish, items, requestor, pep, recipient)
         return d
 
-
     def _doPublish(self, result, items, requestor, pep, recipient):
         affiliation, node = result
         if node.nodeType == 'collection':
@@ -348,6 +383,9 @@
                     check_overwrite = True
             access_model, item_config = self.parseItemConfig(item)
             categories = self.parseCategories(item)
+            schema = node.getSchema()
+            if schema is not None:
+                self.enforceSchema(item, schema, affiliation)
             items_data.append(container.ItemData(item, access_model, item_config, categories))
 
         if persistItems:
@@ -365,7 +403,6 @@
                       deliverPayloads, pep, recipient)
         return d
 
-
     def _doNotify(self, result, node, items_data, deliverPayloads, pep, recipient):
         if items_data and not deliverPayloads:
             for item_data in items_data:
@@ -373,7 +410,6 @@
         self.dispatch({'items_data': items_data, 'node': node, 'pep': pep, 'recipient': recipient},
                       '//event/pubsub/notify')
 
-
     def getNotifications(self, node, items_data):
         """Build a list of subscriber to the node
 
@@ -426,7 +462,6 @@
         d.addCallback(self._doSubscribe, subscriber)
         return d
 
-
     def _doSubscribe(self, result, subscriber):
         # TODO: implement other access models
         node, affiliation = result
@@ -455,7 +490,6 @@
 
         return d
 
-
     def _sendLastPublished(self, subscription, node):
 
         def notifyItem(items):
@@ -478,7 +512,6 @@
 
         return subscription
 
-
     def unsubscribe(self, nodeIdentifier, subscriber, requestor, pep, recipient):
         if subscriber.userhostJID() != requestor.userhostJID():
             return defer.fail(error.Forbidden())
@@ -487,7 +520,6 @@
         d.addCallback(lambda node: node.removeSubscription(subscriber))
         return d
 
-
     def getSubscriptions(self, entity):
         return self.storage.getSubscriptions(entity)
 
@@ -500,7 +532,6 @@
     def supportsInstantNodes(self):
         return True
 
-
     def createNode(self, nodeIdentifier, requestor, options = None, pep=False, recipient=None):
         if not nodeIdentifier:
             nodeIdentifier = 'generic/%s' % uuid.uuid4()
@@ -532,12 +563,10 @@
         d.addCallback(lambda _: nodeIdentifier)
         return d
 
-
     def getDefaultConfiguration(self, nodeType):
         d = defer.succeed(self.storage.getDefaultConfiguration(nodeType))
         return d
 
-
     def getNodeConfiguration(self, nodeIdentifier, pep, recipient):
         if not nodeIdentifier:
             return defer.fail(error.NoRootNode())
@@ -547,7 +576,6 @@
 
         return d
 
-
     def setNodeConfiguration(self, nodeIdentifier, options, requestor, pep, recipient):
         if not nodeIdentifier:
             return defer.fail(error.NoRootNode())
@@ -557,7 +585,6 @@
         d.addCallback(self._doSetNodeConfiguration, options)
         return d
 
-
     def _doSetNodeConfiguration(self, result, options):
         node, affiliation = result
 
@@ -566,7 +593,6 @@
 
         return node.setConfiguration(options)
 
-
     def getNodeSchema(self, nodeIdentifier, pep, recipient):
         if not nodeIdentifier:
             return defer.fail(error.NoRootNode())
@@ -592,7 +618,6 @@
         d.addCallback(self._doSetNodeSchema, schema)
         return d
 
-
     def _doSetNodeSchema(self, result, schema):
         node, affiliation = result
 
@@ -601,18 +626,15 @@
 
         return node.setSchema(schema)
 
-
     def getAffiliations(self, entity, nodeIdentifier, pep, recipient):
         return self.storage.getAffiliations(entity, nodeIdentifier, pep, recipient)
 
-
     def getAffiliationsOwner(self, nodeIdentifier, requestor, pep, recipient):
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
         d.addCallback(_getAffiliation, requestor)
         d.addCallback(self._doGetAffiliationsOwner)
         return d
 
-
     def _doGetAffiliationsOwner(self, result):
         node, affiliation = result
 
@@ -620,14 +642,12 @@
             raise error.Forbidden()
         return node.getAffiliations()
 
-
     def setAffiliationsOwner(self, nodeIdentifier, requestor, affiliations, pep, recipient):
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
         d.addCallback(_getAffiliation, requestor)
         d.addCallback(self._doSetAffiliationsOwner, requestor, affiliations)
         return d
 
-
     def _doSetAffiliationsOwner(self, result, requestor, affiliations):
         # Check that requestor is allowed to set affiliations, and delete entities
         # with "none" affiliation
@@ -658,14 +678,12 @@
 
         return d
 
-
     def getSubscriptionsOwner(self, nodeIdentifier, requestor, pep, recipient):
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
         d.addCallback(_getAffiliation, requestor)
         d.addCallback(self._doGetSubscriptionsOwner)
         return d
 
-
     def _doGetSubscriptionsOwner(self, result):
         node, affiliation = result
 
@@ -673,7 +691,6 @@
             raise error.Forbidden()
         return node.getSubscriptions()
 
-
     def setSubscriptionsOwner(self, nodeIdentifier, requestor, subscriptions, pep, recipient):
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
         d.addCallback(_getAffiliation, requestor)
@@ -709,6 +726,46 @@
         d.addErrback(self.unwrapFirstError)
         return d
 
+    def filterItemsWithSchema(self, items_data, schema, owner):
+        """check schema restriction and remove fields/items if they don't comply
+
+        @param items_data(list[ItemData]): items to filter
+            items in this list will be modified
+        @param schema(domish.Element): node schema
+        @param owner(bool): True is requestor is a owner of the node
+        """
+        fields_to_remove = set()
+        for field_elt in schema.elements(data_form.NS_X_DATA, 'field'):
+            for restrict_elt in field_elt.elements(const.NS_SCHEMA_RESTRICT, 'restrict'):
+                read_restriction = restrict_elt.attributes.get('read')
+                if read_restriction is not None:
+                    if read_restriction == 'owner':
+                        if not owner:
+                            fields_to_remove.add(field_elt['var'])
+                    else:
+                        raise StanzaError('feature-not-implemented', text='unknown read restriction {}'.format(read_restriction))
+        items_to_remove = []
+        for idx, item_data in enumerate(items_data):
+            item_elt = item_data.item
+            try:
+                x_elt = next(item_elt.elements(data_form.NS_X_DATA, 'x'))
+            except StopIteration:
+                log.msg("WARNING, item {id} has a schema but no form, ignoring it")
+                items_to_remove.append(item_data)
+                continue
+            form = data_form.Form.fromElement(x_elt)
+            # we remove fields which are not visible for this user
+            for field in fields_to_remove:
+                try:
+                    form.removeField(form.fields[field])
+                except KeyError:
+                    continue
+            item_elt.children.remove(x_elt)
+            item_elt.addChild(form.toElement())
+
+        for item_data in items_to_remove:
+            items_data.remove(item_data)
+
     @defer.inlineCallbacks
     def checkNodeAccess(self, node, requestor):
         """check if a requestor can access data of a node
@@ -859,6 +916,10 @@
                 else:
                     raise error.BadAccessTypeError(access_model)
 
+        schema = node.getSchema()
+        if schema is not None:
+            self.filterItemsWithSchema(items_data, schema, owner)
+
         yield self._items_rsm(items_data, node, requestor_groups, owner, itemIdentifiers, ext_data)
         defer.returnValue(items_data)
 
@@ -972,7 +1033,6 @@
             d.addCallback(self._doNotifyRetraction, node, pep, recipient)
         return d
 
-
     def _doNotifyRetraction(self, items_data, node, pep, recipient):
         self.dispatch({'items_data': items_data,
                        'node': node,
@@ -980,14 +1040,12 @@
                        'recipient': recipient},
                       '//event/pubsub/retract')
 
-
     def purgeNode(self, nodeIdentifier, requestor, pep, recipient):
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
         d.addCallback(_getAffiliation, requestor)
         d.addCallback(self._doPurge)
         return d
 
-
     def _doPurge(self, result):
         node, affiliation = result
         persistItems = node.getConfiguration()[const.OPT_PERSIST_ITEMS]
@@ -1002,15 +1060,12 @@
         d.addCallback(self._doNotifyPurge, node.nodeIdentifier)
         return d
 
-
     def _doNotifyPurge(self, result, nodeIdentifier):
         self.dispatch(nodeIdentifier, '//event/pubsub/purge')
 
-
     def registerPreDelete(self, preDeleteFn):
         self._callbackList.append(preDeleteFn)
 
-
     def getSubscribers(self, nodeIdentifier, pep, recipient):
         def cb(subscriptions):
             return [subscription.subscriber for subscription in subscriptions]
@@ -1020,14 +1075,12 @@
         d.addCallback(cb)
         return d
 
-
     def deleteNode(self, nodeIdentifier, requestor, pep, recipient, redirectURI=None):
         d = self.storage.getNode(nodeIdentifier, pep, recipient)
         d.addCallback(_getAffiliation, requestor)
         d.addCallback(self._doPreDelete, redirectURI, pep, recipient)
         return d
 
-
     def _doPreDelete(self, result, redirectURI, pep, recipient):
         node, affiliation = result
 
@@ -1042,7 +1095,6 @@
                                consumeErrors=1)
         d.addCallback(self._doDelete, node.nodeDbId)
 
-
     def _doDelete(self, result, nodeDbId):
         dl = []
         for succeeded, r in result:
@@ -1054,13 +1106,11 @@
 
         return d
 
-
     def _doNotifyDelete(self, result, dl):
         for d in dl:
             d.callback(None)
 
 
-
 class PubSubResourceFromBackend(pubsub.PubSubResource):
     """
     Adapts a backend to an xmpp publish-subscribe service.
@@ -1151,7 +1201,6 @@
         # if self.backend.supportsPublishModel():       #XXX: this feature is not really described in XEP-0060, we just can see it in examples
         #     self.features.append("publish_model")     #     but it's necessary for microblogging comments (see XEP-0277)
 
-
     def getFullItem(self, item_data):
         """ Attach item configuration to this item
 
@@ -1233,7 +1282,6 @@
         d.addCallback(afterPrepare)
         return d
 
-
     @defer.inlineCallbacks
     def _prepareNotify(self, items_data, node, subscription=None, pep=None, recipient=None):
         """Do a bunch of permissions check and filter notifications
@@ -1272,6 +1320,7 @@
 
         #we filter items not allowed for the subscribers
         notifications_filtered = []
+        schema = node.getSchema()
 
         for subscriber, subscriptions, items_data in notifications:
             subscriber_bare = subscriber.userhostJID()
@@ -1281,6 +1330,12 @@
                 continue
             allowed_items = [] #we keep only item which subscriber can access
 
+            if schema is not None:
+                # we have to deepcopy items because different subscribers may receive
+                # different items (e.g. read restriction in schema)
+                items_data = deepcopy(items_data)
+                self.backend.filterItemsWithSchema(items_data, schema, False)
+
             for item_data in items_data:
                 item, access_model = item_data.item, item_data.access_model
                 access_list = item_data.config
@@ -1392,7 +1447,6 @@
                                       service)
         return d.addErrback(self._mapErrors)
 
-
     def getConfigurationOptions(self):
         return self.backend.nodeOptions
 
@@ -1430,7 +1484,6 @@
         d.addErrback(self._publish_errb, request)
         return d.addErrback(self._mapErrors)
 
-
     def subscribe(self, request):
         d = self.backend.subscribe(request.nodeIdentifier,
                                    request.subscriber,
@@ -1439,7 +1492,6 @@
                                    request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def unsubscribe(self, request):
         d = self.backend.unsubscribe(request.nodeIdentifier,
                                      request.subscriber,
@@ -1448,13 +1500,11 @@
                                      request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def subscriptions(self, request):
         d = self.backend.getSubscriptions(self._isPep(request),
                                           request.sender)
         return d.addErrback(self._mapErrors)
 
-
     def affiliations(self, request):
         """Retrieve affiliation for normal entity (cf. XEP-0060 §5.7)
 
@@ -1466,7 +1516,6 @@
                                          request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def create(self, request):
         d = self.backend.createNode(request.nodeIdentifier,
                                     request.sender, request.options,
@@ -1474,21 +1523,18 @@
                                     request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def default(self, request):
         d = self.backend.getDefaultConfiguration(request.nodeType,
                                                  self._isPep(request),
                                                  request.sender)
         return d.addErrback(self._mapErrors)
 
-
     def configureGet(self, request):
         d = self.backend.getNodeConfiguration(request.nodeIdentifier,
                                               self._isPep(request),
                                               request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def configureSet(self, request):
         d = self.backend.setNodeConfiguration(request.nodeIdentifier,
                                               request.options,
@@ -1497,7 +1543,6 @@
                                               request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def affiliationsGet(self, request):
         """Retrieve affiliations for owner (cf. XEP-0060 §8.9.1)
 
@@ -1517,7 +1562,6 @@
                                               request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def subscriptionsGet(self, request):
         """Retrieve subscriptions for owner (cf. XEP-0060 §8.8.1)
 
@@ -1529,7 +1573,6 @@
                                                request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def subscriptionsSet(self, request):
         d = self.backend.setSubscriptionsOwner(request.nodeIdentifier,
                                               request.sender,
@@ -1538,7 +1581,6 @@
                                               request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def items(self, request):
         ext_data = {}
         if const.FLAG_ENABLE_RSM and request.rsm is not None:
@@ -1564,7 +1606,6 @@
                                      request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def purge(self, request):
         d = self.backend.purgeNode(request.nodeIdentifier,
                                    request.sender,
@@ -1572,7 +1613,6 @@
                                    request.recipient)
         return d.addErrback(self._mapErrors)
 
-
     def delete(self, request):
         d = self.backend.deleteNode(request.nodeIdentifier,
                                     request.sender,