diff sat/plugins/plugin_xep_0060.py @ 3849:bc24ce903835

plugin XEP-0060: handle priority in `addManagedNode`: Priority lets order the callback list when an event is received. This is important in some use case, notably when a plugin needs to check the former item before it is deleted from cache or updated. rel 370
author Goffi <goffi@goffi.org>
date Thu, 14 Jul 2022 12:55:30 +0200
parents 853cbaf56e9e
children 8a2c46122a11
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0060.py	Thu Jul 14 12:55:30 2022 +0200
+++ b/sat/plugins/plugin_xep_0060.py	Thu Jul 14 12:55:30 2022 +0200
@@ -19,7 +19,7 @@
 
 from collections import namedtuple
 from functools import reduce
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, Callable
 import urllib.error
 import urllib.parse
 import urllib.request
@@ -420,11 +420,18 @@
 
         return Extra(rsm_request, extra)
 
-    def addManagedNode(self, node, **kwargs):
+    def addManagedNode(
+        self,
+        node: str,
+        priority: int = 0,
+        **kwargs: Callable
+    ):
         """Add a handler for a node
 
-        @param node(unicode): node to monitor
+        @param node: node to monitor
             all node *prefixed* with this one will be triggered
+        @param priority: priority of the callback. Callbacks with higher priority will be
+            called first.
         @param **kwargs: method(s) to call when the node is found
             the method must be named after PubSub constants in lower case
             and suffixed with "_cb"
@@ -437,7 +444,9 @@
         for event, cb in kwargs.items():
             event_name = event[:-3]
             assert event_name in C.PS_EVENTS
-            callbacks.setdefault(event_name, []).append(cb)
+            cb_list = callbacks.setdefault(event_name, [])
+            cb_list.append((cb, priority))
+            cb_list.sort(key=lambda c: c[1], reverse=True)
 
     def removeManagedNode(self, node, *args):
         """Add a handler for a node
@@ -451,28 +460,26 @@
         except KeyError:
             pass
         else:
+            removed = False
             for callback in args:
                 for event, cb_list in registred_cb.items():
-                    try:
-                        cb_list.remove(callback)
-                    except ValueError:
-                        pass
-                    else:
-                        log.debug(
-                            "removed callback {cb} for event {event} on node {node}".format(
-                                cb=callback, event=event, node=node
-                            )
-                        )
-                        if not cb_list:
-                            del registred_cb[event]
-                        if not registred_cb:
-                            del self._node_cb[node]
-                        return
-        log.error(
-            "Trying to remove inexistant callback {cb} for node {node}".format(
-                cb=callback, node=node
-            )
-        )
+                    to_remove = []
+                    for cb in cb_list:
+                        if cb[0] == callback:
+                            to_remove.append(cb)
+                            for cb in to_remove:
+                                cb_list.remove(cb)
+                            if not cb_list:
+                                del registred_cb[event]
+                            if not registred_cb:
+                                del self._node_cb[node]
+                            removed = True
+                            break
+
+            if not removed:
+                log.error(
+                    f"Trying to remove inexistant callback {callback} for node {node}"
+                )
 
     # def listNodes(self, service, nodeIdentifier='', profile=C.PROF_KEY_NONE):
     #     """Retrieve the name of the nodes that are accessible on the target service.
@@ -1614,20 +1621,30 @@
             if not node.startswith(registered_node):
                 continue
             try:
-                for callback in callbacks_dict[event]:
-                    yield callback
+                for callback_data in callbacks_dict[event]:
+                    yield callback_data[0]
             except KeyError:
                 continue
 
+    async def _callNodeCallbacks(self, client, event: pubsub.ItemsEvent) -> None:
+        """Call sequencially event callbacks of a node
+
+        Callbacks are called sequencially and not in parallel to be sure to respect
+        priority (notably for plugin needing to get old items before they are modified or
+        deleted from cache).
+        """
+        for callback in self._getNodeCallbacks(event.nodeIdentifier, C.PS_ITEMS):
+            try:
+                await utils.asDeferred(callback, client, event)
+            except Exception as e:
+                log.error(
+                    f"Error while running items event callback {callback}: {e}"
+                )
 
     def itemsReceived(self, event):
         log.debug("Pubsub items received")
-        for callback in self._getNodeCallbacks(event.nodeIdentifier, C.PS_ITEMS):
-            d = utils.asDeferred(callback, self.parent, event)
-            d.addErrback(lambda f: log.error(
-                f"Error while running items event callback {callback}: {f}"
-            ))
         client = self.parent
+        defer.ensureDeferred(self._callNodeCallbacks(client, event))
         if (event.sender, event.nodeIdentifier) in client.pubsub_watching:
             raw_items = [i.toXml() for i in event.items]
             self.host.bridge.psEventRaw(