diff libervia/backend/core/xmpp.py @ 4151:18026ce0819c

core (xmpp): message reception workflow refactoring: - Call methods from a root async one instead of using Deferred callbacks chain. - Use a queue to be sure to process messages in order.
author Goffi <goffi@goffi.org>
date Wed, 22 Nov 2023 14:50:35 +0100
parents bc7d45dedeb0
children d67eaa684484
line wrap: on
line diff
--- a/libervia/backend/core/xmpp.py	Wed Nov 22 14:45:26 2023 +0100
+++ b/libervia/backend/core/xmpp.py	Wed Nov 22 14:50:35 2023 +0100
@@ -49,6 +49,7 @@
 from libervia.backend.memory import cache
 from libervia.backend.memory import encryption
 from libervia.backend.memory import persistent
+from libervia.backend.models.core import MessageData
 from libervia.backend.tools import xml_tools
 from libervia.backend.tools import utils
 from libervia.backend.tools.common import data_format
@@ -1224,6 +1225,11 @@
     def __init__(self, host):
         xmppim.MessageProtocol.__init__(self)
         self.host = host
+        self.messages_queue  = defer.DeferredQueue()
+
+    def setHandlerParent(self, parent):
+        super().setHandlerParent(parent)
+        defer.ensureDeferred(self.process_messages())
 
     @property
     def client(self):
@@ -1235,7 +1241,7 @@
         for child in elt.elements():
             self.normalize_ns(child, namespace)
 
-    def parse_message(self, message_elt):
+    def parse_message(self, message_elt: domish.Element) -> MessageData:
         """Parse a message XML and return message_data
 
         @param message_elt(domish.Element): raw <message> xml
@@ -1265,7 +1271,7 @@
         message = {}
         subject = {}
         extra = {}
-        data = {
+        data: MessageData = {
             "from": jid.JID(message_elt["from"]),
             "to": jid.JID(message_elt["to"]),
             "uid": message_elt.getAttribute(
@@ -1316,33 +1322,34 @@
         self.host.trigger.point("message_parse", client,  message_elt, data)
         return data
 
-    def _on_message_start_workflow(self, cont, client, message_elt, post_treat):
-        """Parse message and do post treatments
+
+    def onMessage(self, message_elt: domish.Element) -> None:
+        message_elt._received_timestamp = time.time()
+        self.messages_queue.put(message_elt)
 
-        It is the first callback called after message_received trigger
-        @param cont(bool): workflow will continue only if this is True
-        @param message_elt(domish.Element): message stanza
-            may have be modified by triggers
-        @param post_treat(defer.Deferred): post parsing treatments
+    async def process_messages(self) -> None:
+        """Process message in order
+
+        Messages are processed in a queue to avoid race conditions and ensure orderly
+        processing.
         """
-        if not cont:
-            return
-        data = self.parse_message(message_elt)
-        post_treat.addCallback(self.complete_attachments)
-        post_treat.addCallback(self.skip_empty_message)
-        if not client.is_component or client.receiveHistory:
-            post_treat.addCallback(
-                lambda ret: defer.ensureDeferred(self.add_to_history(ret))
-            )
-        if not client.is_component:
-            post_treat.addCallback(self.bridge_signal, data)
-        post_treat.addErrback(self.cancel_error_trap)
-        post_treat.callback(data)
+        client = self.parent
+        if client is None:
+            log.error("client should not be None!")
+            raise exceptions.InternalError()
+        while True:
+            message_elt = await self.messages_queue.get()
+            try:
+                await self.process_message(client, message_elt)
+            except Exception:
+                log.exception(f"Can't process message {message_elt.toXml()}")
 
-    def onMessage(self, message_elt):
+    async def process_message(
+        self,
+        client: SatXMPPEntity,
+        message_elt: domish.Element
+    ) -> None:
         # TODO: handle threads
-        message_elt._received_timestamp = time.time()
-        client = self.parent
         if not "from" in message_elt.attributes:
             message_elt["from"] = client.jid.host
         log.debug(_("got message from: {from_}").format(from_=message_elt["from"]))
@@ -1352,14 +1359,24 @@
 
         # plugin can add their treatments to this deferred
         post_treat = defer.Deferred()
-
-        d = self.host.trigger.async_point(
+        if not await self.host.trigger.async_point(
             "message_received", client, message_elt, post_treat
-        )
+        ):
+            return
+        try:
+            data = self.parse_message(message_elt)
 
-        d.addCallback(self._on_message_start_workflow, client, message_elt, post_treat)
+            self.complete_attachments(data)
+            if not data["message"] and not data["extra"] and not data["subject"]:
+                raise exceptions.CancelError("Cancelled empty message")
+            if not client.is_component or client.receiveHistory:
+                await self.add_to_history(data)
+            if not client.is_component:
+                self.bridge_signal(data)
+        except exceptions.CancelError:
+            pass
 
-    def complete_attachments(self, data):
+    def complete_attachments(self, data: MessageData) -> MessageData:
         """Complete missing metadata of attachments"""
         for attachment in data['extra'].get(C.KEY_ATTACHMENTS, []):
             if "name" not in attachment and "url" in attachment:
@@ -1371,15 +1388,9 @@
                 media_type = mimetypes.guess_type(attachment['name'], strict=False)[0]
                 if media_type:
                     attachment[C.KEY_ATTACHMENTS_MEDIA_TYPE] = media_type
-
         return data
 
-    def skip_empty_message(self, data):
-        if not data["message"] and not data["extra"] and not data["subject"]:
-            raise failure.Failure(exceptions.CancelError("Cancelled empty message"))
-        return data
-
-    async def add_to_history(self, data):
+    async def add_to_history(self, data: MessageData) -> MessageData:
         if data.pop("history", None) == C.HISTORY_SKIP:
             log.debug("history is skipped as requested")
             data["extra"]["history"] = C.HISTORY_SKIP
@@ -1390,8 +1401,10 @@
             else:
                 log.debug("not storing empty message to history: {data}"
                     .format(data=data))
+        return data
 
-    def bridge_signal(self, __, data):
+    def bridge_signal(self, data: MessageData) -> MessageData:
+        """Send signal to frontends for the given message"""
         try:
             data["extra"]["received_timestamp"] = str(data["received_timestamp"])
             data["extra"]["delay_sender"] = data["delay_sender"]
@@ -1417,10 +1430,6 @@
                     data=data))
         return data
 
-    def cancel_error_trap(self, failure_):
-        """A message sending can be cancelled by a plugin treatment"""
-        failure_.trap(exceptions.CancelError)
-
 
 class LiberviaRosterProtocol(xmppim.RosterClientProtocol):