diff sat/plugins/plugin_xep_0384.py @ 3541:888109774673

core: various changes and fixes to work with new storage and D-Bus bridge: - fixes coroutines handling in various places - fixes types which are not serialised by Tx DBus - XEP-0384: call storage methods in main thread in XEP: Python OMEMO's Promise use thread which prevent the use of AsyncIO loop. To work around that, callLater is used to launch storage method in main thread. This is a temporary workaround, as Python OMEMO should get rid of Promise implementation and threads soon.
author Goffi <goffi@goffi.org>
date Thu, 03 Jun 2021 15:21:43 +0200
parents be6d91572633
children edc79cefe968
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0384.py	Thu Jun 03 15:21:43 2021 +0200
+++ b/sat/plugins/plugin_xep_0384.py	Thu Jun 03 15:21:43 2021 +0200
@@ -114,7 +114,10 @@
     @return (defer.Deferred): deferred instance linked to the promise
     """
     d = defer.Deferred()
-    promise_.then(d.callback, d.errback)
+    promise_.then(
+        lambda result: reactor.callLater(0, d.callback, result),
+        lambda exc: reactor.callLater(0, d.errback, exc)
+    )
     return d
 
 
@@ -141,6 +144,26 @@
         deferred.addCallback(partial(callback, True))
         deferred.addErrback(partial(callback, False))
 
+    def _callMainThread(self, callback, method, *args, check_jid=None):
+        d = method(*args)
+        if check_jid is not None:
+            check_jid_d = self._checkJid(check_jid)
+            check_jid_d.addCallback(lambda __: d)
+            d = check_jid_d
+        if callback is not None:
+            d.addCallback(partial(callback, True))
+            d.addErrback(partial(callback, False))
+
+    def _call(self, callback, method, *args, check_jid=None):
+        """Create Deferred and add Promise callback to it
+
+        This method use reactor.callLater to launch Deferred in main thread
+        @param check_jid: run self._checkJid before method
+        """
+        reactor.callLater(
+            0, self._callMainThread, callback, method, *args, check_jid=check_jid
+        )
+
     def _checkJid(self, bare_jid):
         """Check if jid is known, and store it if not
 
@@ -164,71 +187,50 @@
         callback(True, None)
 
     def loadState(self, callback):
-        d = self.data.get(KEY_STATE)
-        self.setCb(d, callback)
+        self._call(callback, self.data.get, KEY_STATE)
 
     def storeState(self, callback, state):
-        d = self.data.force(KEY_STATE, state)
-        self.setCb(d, callback)
+        self._call(callback, self.data.force, KEY_STATE, state)
 
     def loadSession(self, callback, bare_jid, device_id):
         key = '\n'.join([KEY_SESSION, bare_jid, str(device_id)])
-        d = self.data.get(key)
-        self.setCb(d, callback)
+        self._call(callback, self.data.get, key)
 
     def storeSession(self, callback, bare_jid, device_id, session):
         key = '\n'.join([KEY_SESSION, bare_jid, str(device_id)])
-        d = self.data.force(key, session)
-        self.setCb(d, callback)
+        self._call(callback, self._data.force, key, session)
 
     def deleteSession(self, callback, bare_jid, device_id):
         key = '\n'.join([KEY_SESSION, bare_jid, str(device_id)])
-        d = self.data.remove(key)
-        self.setCb(d, callback)
+        self._call(callback, self.data.remove, key)
 
     def loadActiveDevices(self, callback, bare_jid):
         key = '\n'.join([KEY_ACTIVE_DEVICES, bare_jid])
-        d = self.data.get(key, {})
-        if callback is not None:
-            self.setCb(d, callback)
-        return d
+        self._call(callback, self.data.get, key, {})
 
     def loadInactiveDevices(self, callback, bare_jid):
         key = '\n'.join([KEY_INACTIVE_DEVICES, bare_jid])
-        d = self.data.get(key, {})
-        if callback is not None:
-            self.setCb(d, callback)
-        return d
+        self._call(callback, self.data.get, key, {})
 
     def storeActiveDevices(self, callback, bare_jid, devices):
         key = '\n'.join([KEY_ACTIVE_DEVICES, bare_jid])
-        d = self._checkJid(bare_jid)
-        d.addCallback(lambda _: self.data.force(key, devices))
-        self.setCb(d, callback)
+        self._call(callback, self.data.force, key, devices, check_jid=bare_jid)
 
     def storeInactiveDevices(self, callback, bare_jid, devices):
         key = '\n'.join([KEY_INACTIVE_DEVICES, bare_jid])
-        d = self._checkJid(bare_jid)
-        d.addCallback(lambda _: self.data.force(key, devices))
-        self.setCb(d, callback)
+        self._call(callback, self.data.force, key, devices, check_jid=bare_jid)
 
     def storeTrust(self, callback, bare_jid, device_id, trust):
         key = '\n'.join([KEY_TRUST, bare_jid, str(device_id)])
-        d = self.data.force(key, trust)
-        self.setCb(d, callback)
+        self._call(callback, self.data.force, key, trust)
 
     def loadTrust(self, callback, bare_jid, device_id):
         key = '\n'.join([KEY_TRUST, bare_jid, str(device_id)])
-        d = self.data.get(key)
-        if callback is not None:
-            self.setCb(d, callback)
-        return d
+        self._call(callback, self.data.get, key)
 
     def listJIDs(self, callback):
-        d = defer.succeed(self.all_jids)
         if callback is not None:
-            self.setCb(d, callback)
-        return d
+            callback(True, self.all_jids)
 
     def _deleteJID_logResults(self, results):
         failed = [success for success, __ in results if not success]
@@ -266,8 +268,7 @@
         d.addCallback(self._deleteJID_logResults)
         return d
 
-    def deleteJID(self, callback, bare_jid):
-        """Retrieve all (in)actives devices of bare_jid, and delete all related keys"""
+    def _deleteJID(self, callback, bare_jid):
         d_list = []
 
         key = '\n'.join([KEY_ACTIVE_DEVICES, bare_jid])
@@ -284,7 +285,10 @@
         d.addCallback(self._deleteJID_gotDevices, bare_jid)
         if callback is not None:
             self.setCb(d, callback)
-        return d
+
+    def deleteJID(self, callback, bare_jid):
+        """Retrieve all (in)actives devices of bare_jid, and delete all related keys"""
+        reactor.callLater(0, self._deleteJID, callback, bare_jid)
 
 
 class SatOTPKPolicy(omemo.DefaultOTPKPolicy):
@@ -728,7 +732,7 @@
             while device_id in devices:
                 device_id = random.randint(1, 2**31-1)
             # and we save it
-            persistent_dict[KEY_DEVICE_ID] = device_id
+            await persistent_dict.aset(KEY_DEVICE_ID, device_id)
 
         log.debug(f"our OMEMO device id is {device_id}")