diff sat/plugins/plugin_xep_0045.py @ 4001:32d714a8ea51

plugin XEP-0045: dot not wait for MAM retrieval to be completed: in `_join_MAM`, `room.fully_joined` is called before retrieving the MAM archive, as the process can be very long, and is not necessary to have the room working (message can be received after being in the room, and added out of order). This avoid blocking the `join` workflow for an extended time. Some renaming and coroutine integrations.
author Goffi <goffi@goffi.org>
date Fri, 10 Mar 2023 17:22:41 +0100
parents 8289ac1b34f4
children 524856bd7b19
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0045.py	Fri Mar 10 17:01:09 2023 +0100
+++ b/sat/plugins/plugin_xep_0045.py	Fri Mar 10 17:22:41 2023 +0100
@@ -17,29 +17,26 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
-from sat.core.i18n import _, D_
-from sat.core.constants import Const as C
-from sat.core.log import getLogger
+import time
+from typing import Optional
+import uuid
+
 from twisted.internet import defer
+from twisted.python import failure
 from twisted.words.protocols.jabber import jid
 from twisted.words.protocols.jabber import error as xmpp_error
-from twisted.python import failure
+from wokkel import disco, iwokkel, muc
+from wokkel import rsm
+from wokkel import mam
+from zope.interface import implementer
 
 from sat.core import exceptions
-from sat.core.xmpp import SatXMPPClient
+from sat.core.core_types import SatXMPPEntity
+from sat.core.constants import Const as C
+from sat.core.i18n import D_, _
+from sat.core.log import getLogger
 from sat.memory import memory
-
-import time
-import uuid
-
-from wokkel import muc, disco, iwokkel
-from sat.tools import xml_tools
-
-from zope.interface import implementer
-
-# XXX: mam and rsm come from sat_tmp.wokkel
-from wokkel import rsm
-from wokkel import mam
+from sat.tools import xml_tools, utils
 
 
 log = getLogger(__name__)
@@ -161,10 +158,8 @@
         host.trigger.add("messageReceived", self.messageReceivedTrigger, priority=1000000)
         host.trigger.add("message_parse", self._message_parseTrigger)
 
-    def profileConnected(self, client):
-        def assign_service(service):
-            client.muc_service = service
-        return self.getMUCService(client).addCallback(assign_service)
+    async def profileConnected(self, client):
+        client.muc_service = await self.get_MUC_service(client)
 
     def _message_parseTrigger(self, client, message_elt, data):
         """Add stanza-id from the room if present"""
@@ -205,7 +200,7 @@
                 return False
         return True
 
-    def getRoom(self, client: SatXMPPClient, room_jid: jid.JID) -> muc.Room:
+    def getRoom(self, client: SatXMPPEntity, room_jid: jid.JID) -> muc.Room:
         """Retrieve Room instance from its jid
 
         @param room_jid: jid of the room
@@ -224,7 +219,7 @@
         if room_jid not in client._muc_client.joined_rooms:
             raise exceptions.NotFound(_("This room has not been joined"))
 
-    def isJoinedRoom(self, client: SatXMPPClient, room_jid: jid.JID) -> bool:
+    def isJoinedRoom(self, client: SatXMPPEntity, room_jid: jid.JID) -> bool:
         """Tell if a jid is a known and joined room
 
         @room_jid: jid of the room
@@ -458,16 +453,18 @@
 
     def _getMUCService(self, jid_=None, profile=C.PROF_KEY_NONE):
         client = self.host.getClient(profile)
-        d = self.getMUCService(client, jid_ or None)
+        d = defer.ensureDeferred(self.get_MUC_service(client, jid_ or None))
         d.addCallback(lambda service_jid: service_jid.full() if service_jid is not None else '')
         return d
 
-    @defer.inlineCallbacks
-    def getMUCService(self, client, jid_=None):
+    async def get_MUC_service(
+        self,
+        client: SatXMPPEntity,
+        jid_: Optional[jid.JID] = None) -> Optional[jid.JID]:
         """Return first found MUC service of an entity
 
         @param jid_: entity which may have a MUC service, or None for our own server
-        @return (jid.JID, None): found service jid or None
+        @return: found service jid or None
         """
         if jid_ is None:
             try:
@@ -476,8 +473,8 @@
                 pass
             else:
                 # we have a cached value, we return it
-                defer.returnValue(muc_service)
-        services = yield self.host.findServiceEntities(client, "conference", "text", jid_)
+                return muc_service
+        services = await self.host.findServiceEntities(client, "conference", "text", jid_)
         for service in services:
             if ".irc." not in service.userhost():
                 # FIXME:
@@ -487,7 +484,7 @@
                 break
         else:
             muc_service = None
-        defer.returnValue(muc_service)
+        return muc_service
 
     def _getUniqueName(self, muc_service="", profile_key=C.PROF_KEY_NONE):
         client = self.host.getClient(profile_key)
@@ -549,7 +546,13 @@
         d.addErrback(self._join_eb, client)
         return d
 
-    def join(self, client, room_jid, nick=None, options=None):
+    async def join(
+        self,
+        client: SatXMPPEntity,
+        room_jid: jid.JID,
+        nick: Optional[str] = None,
+        options: Optional[dict] = None
+    ) -> Optional[muc.Room]:
         if not nick:
             nick = client.jid.user
         if options is None:
@@ -558,18 +561,24 @@
             room = client._muc_client.joined_rooms[room_jid]
             log.info(_('{profile} is already in room {room_jid}').format(
                 profile=client.profile, room_jid = room_jid.userhost()))
-            return defer.fail(AlreadyJoined(room))
+            raise AlreadyJoined(room)
         log.info(_("[{profile}] is joining room {room} with nick {nick}").format(
             profile=client.profile, room=room_jid.userhost(), nick=nick))
         self.host.bridge.mucRoomPrepareJoin(room_jid.userhost(), client.profile)
 
         password = options.get("password")
 
-        d = client._muc_client.join(room_jid, nick, password)
-        d.addCallbacks(self._joinCb, self._joinEb,
-                       (client, room_jid, nick),
-                       errbackArgs=(client, room_jid, nick, password))
-        return d
+        try:
+            room = await client._muc_client.join(room_jid, nick, password)
+        except Exception as e:
+            room = await utils.asDeferred(
+                self._joinEb(failure.Failure(e), client, room_jid, nick, password)
+            )
+        else:
+            await defer.ensureDeferred(
+                self._joinCb(room, client, room_jid, nick)
+            )
+        return room
 
     def popRooms(self, client):
         """Remove rooms and return data needed to re-join them
@@ -611,7 +620,7 @@
 
     def getHandler(self, client):
         # create a MUC client and associate it with profile' session
-        muc_client = client._muc_client = SatMUCClient(self)
+        muc_client = client._muc_client = LiberviaMUCClient(self)
         return muc_client
 
     def kick(self, client, nick, room_jid, options=None):
@@ -915,7 +924,7 @@
 
 
 @implementer(iwokkel.IDisco)
-class SatMUCClient(muc.MUCClient):
+class LiberviaMUCClient(muc.MUCClient):
 
     def __init__(self, plugin_parent):
         self.plugin_parent = plugin_parent
@@ -967,7 +976,7 @@
         room.state = new_state
 
     def _addRoom(self, room):
-        super(SatMUCClient, self)._addRoom(room)
+        super(LiberviaMUCClient, self)._addRoom(room)
         room._roster_ok = False  # True when occupants list has been fully received
         room.state = ROOM_STATE_OCCUPANTS
         # FIXME: check if history_d is not redundant with fully_joined
@@ -978,14 +987,21 @@
         # we only need to keep last presence status for each jid, so a dict is suitable
         room._cache_presence = {}
 
-    @defer.inlineCallbacks
-    def _joinLegacy(self, client, room_jid, nick, password):
+    async def _join_legacy(
+        self,
+        client: SatXMPPEntity,
+        room_jid: jid.JID,
+        nick: str,
+        password: Optional[str]
+    ) -> muc.Room:
         """Join room an retrieve history with legacy method"""
-        mess_data_list = yield self.host.memory.historyGet(room_jid,
-                                                           client.jid.userhostJID(),
-                                                           limit=1,
-                                                           between=True,
-                                                           profile=client.profile)
+        mess_data_list = await self.host.memory.historyGet(
+            room_jid,
+            client.jid.userhostJID(),
+            limit=1,
+            between=True,
+            profile=client.profile
+        )
         if mess_data_list:
             timestamp = mess_data_list[0][1]
             # we use seconds since last message to get backlog without duplicates
@@ -994,27 +1010,27 @@
         else:
             seconds = None
 
-        room = yield super(SatMUCClient, self).join(
+        room = await super(LiberviaMUCClient, self).join(
             room_jid, nick, muc.HistoryOptions(seconds=seconds), password)
         # used to send bridge signal once backlog are written in history
         room._history_type = HISTORY_LEGACY
         room._history_d = defer.Deferred()
         room._history_d.callback(None)
-        defer.returnValue(room)
+        return room
 
-    @defer.inlineCallbacks
-    def _joinMAM(self, client, room_jid, nick, password):
-        """Join room and retrieve history using MAM"""
-        room = yield super(SatMUCClient, self).join(
-            # we don't want any history from room as we'll get it with MAM
-            room_jid, nick, muc.HistoryOptions(maxStanzas=0), password=password)
-        room._history_type = HISTORY_MAM
+    async def _get_MAM_history(
+        self,
+        client: SatXMPPEntity,
+        room: muc.Room,
+        room_jid: jid.JID
+    ) -> None:
+        """Retrieve history for rooms handling MAM"""
         history_d = room._history_d = defer.Deferred()
         # we trigger now the deferred so all callback are processed as soon as possible
         # and in order
         history_d.callback(None)
 
-        last_mess = yield self.host.memory.historyGet(
+        last_mess = await self.host.memory.historyGet(
             room_jid,
             None,
             limit=1,
@@ -1040,7 +1056,7 @@
         count = 0
         while not complete:
             try:
-                mam_data = yield self._mam.getArchives(client, mam_req,
+                mam_data = await self._mam.getArchives(client, mam_req,
                                                        service=room_jid)
             except xmpp_error.StanzaError as e:
                 if last_mess and e.condition == 'item-not-found':
@@ -1107,20 +1123,35 @@
                                      errbackArgs=[room])
 
         # we wait for all callbacks to be processed
-        yield history_d
+        await history_d
 
-        defer.returnValue(room)
+    async def _join_MAM(
+        self,
+        client: SatXMPPEntity,
+        room_jid: jid.JID,
+        nick: str,
+        password: Optional[str]
+    ) -> muc.Room:
+        """Join room and retrieve history using MAM"""
+        room = await super(LiberviaMUCClient, self).join(
+            # we don't want any history from room as we'll get it with MAM
+            room_jid, nick, muc.HistoryOptions(maxStanzas=0), password=password
+        )
+        room._history_type = HISTORY_MAM
+        # MAM history retrieval can be very long, and doesn't need to be sync, so we don't
+        # wait for it
+        defer.ensureDeferred(self._get_MAM_history(client, room, room_jid))
+        room.fully_joined.callback(room)
 
-    @defer.inlineCallbacks
-    def join(self, room_jid, nick, password=None):
+        return room
+
+    async def join(self, room_jid, nick, password=None):
         room_service = jid.JID(room_jid.host)
-        has_mam = yield self.host.hasFeature(self.client, mam.NS_MAM, room_service)
+        has_mam = await self.host.hasFeature(self.client, mam.NS_MAM, room_service)
         if not self._mam or not has_mam:
-            room = yield self._joinLegacy(self.client, room_jid, nick, password)
-            defer.returnValue(room)
+            return await self._join_legacy(self.client, room_jid, nick, password)
         else:
-            room = yield self._joinMAM(self.client, room_jid, nick, password)
-            defer.returnValue(room)
+            return await self._join_MAM(self.client, room_jid, nick, password)
 
     ## presence/roster ##
 
@@ -1400,13 +1431,14 @@
 
         this method will finish joining by:
             - sending message to bridge
-            - calling fully_joined deferred
+            - calling fully_joined deferred (for legacy history)
             - sending stanza put in cache
             - cleaning variables not needed anymore
         """
         args = self.plugin_parent._getRoomJoinedArgs(room, self.client.profile)
         self.host.bridge.mucRoomJoined(*args)
-        room.fully_joined.callback(room)
+        if room._history_type == HISTORY_LEGACY:
+            room.fully_joined.callback(room)
         del room._history_d
         del room._history_type
         cache = room._cache