diff sat/plugins/plugin_xep_0384.py @ 3715:b9718216a1c0 0.9

merge bookmark 0.9
author Goffi <goffi@goffi.org>
date Wed, 01 Dec 2021 16:13:31 +0100
parents 09f5ac48ffe3
children 11f7ca8afd15
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0384.py	Tue Nov 30 23:31:09 2021 +0100
+++ b/sat/plugins/plugin_xep_0384.py	Wed Dec 01 16:13:31 2021 +0100
@@ -114,7 +114,10 @@
     @return (defer.Deferred): deferred instance linked to the promise
     """
     d = defer.Deferred()
-    promise_.then(d.callback, d.errback)
+    promise_.then(
+        lambda result: reactor.callFromThread(d.callback, result),
+        lambda exc: reactor.callFromThread(d.errback, exc)
+    )
     return d
 
 
@@ -141,6 +144,28 @@
         deferred.addCallback(partial(callback, True))
         deferred.addErrback(partial(callback, False))
 
+    def _callMainThread(self, callback, method, *args, check_jid=None):
+        if check_jid is None:
+            d = method(*args)
+        else:
+            check_jid_d = self._checkJid(check_jid)
+            check_jid_d.addCallback(lambda __: method(*args))
+            d = check_jid_d
+
+        if callback is not None:
+            d.addCallback(partial(callback, True))
+            d.addErrback(partial(callback, False))
+
+    def _call(self, callback, method, *args, check_jid=None):
+        """Create Deferred and add Promise callback to it
+
+        This method use reactor.callLater to launch Deferred in main thread
+        @param check_jid: run self._checkJid before method
+        """
+        reactor.callFromThread(
+            self._callMainThread, callback, method, *args, check_jid=check_jid
+        )
+
     def _checkJid(self, bare_jid):
         """Check if jid is known, and store it if not
 
@@ -164,71 +189,50 @@
         callback(True, None)
 
     def loadState(self, callback):
-        d = self.data.get(KEY_STATE)
-        self.setCb(d, callback)
+        self._call(callback, self.data.get, KEY_STATE)
 
     def storeState(self, callback, state):
-        d = self.data.force(KEY_STATE, state)
-        self.setCb(d, callback)
+        self._call(callback, self.data.force, KEY_STATE, state)
 
     def loadSession(self, callback, bare_jid, device_id):
         key = '\n'.join([KEY_SESSION, bare_jid, str(device_id)])
-        d = self.data.get(key)
-        self.setCb(d, callback)
+        self._call(callback, self.data.get, key)
 
     def storeSession(self, callback, bare_jid, device_id, session):
         key = '\n'.join([KEY_SESSION, bare_jid, str(device_id)])
-        d = self.data.force(key, session)
-        self.setCb(d, callback)
+        self._call(callback, self._data.force, key, session)
 
     def deleteSession(self, callback, bare_jid, device_id):
         key = '\n'.join([KEY_SESSION, bare_jid, str(device_id)])
-        d = self.data.remove(key)
-        self.setCb(d, callback)
+        self._call(callback, self.data.remove, key)
 
     def loadActiveDevices(self, callback, bare_jid):
         key = '\n'.join([KEY_ACTIVE_DEVICES, bare_jid])
-        d = self.data.get(key, {})
-        if callback is not None:
-            self.setCb(d, callback)
-        return d
+        self._call(callback, self.data.get, key, {})
 
     def loadInactiveDevices(self, callback, bare_jid):
         key = '\n'.join([KEY_INACTIVE_DEVICES, bare_jid])
-        d = self.data.get(key, {})
-        if callback is not None:
-            self.setCb(d, callback)
-        return d
+        self._call(callback, self.data.get, key, {})
 
     def storeActiveDevices(self, callback, bare_jid, devices):
         key = '\n'.join([KEY_ACTIVE_DEVICES, bare_jid])
-        d = self._checkJid(bare_jid)
-        d.addCallback(lambda _: self.data.force(key, devices))
-        self.setCb(d, callback)
+        self._call(callback, self.data.force, key, devices, check_jid=bare_jid)
 
     def storeInactiveDevices(self, callback, bare_jid, devices):
         key = '\n'.join([KEY_INACTIVE_DEVICES, bare_jid])
-        d = self._checkJid(bare_jid)
-        d.addCallback(lambda _: self.data.force(key, devices))
-        self.setCb(d, callback)
+        self._call(callback, self.data.force, key, devices, check_jid=bare_jid)
 
     def storeTrust(self, callback, bare_jid, device_id, trust):
         key = '\n'.join([KEY_TRUST, bare_jid, str(device_id)])
-        d = self.data.force(key, trust)
-        self.setCb(d, callback)
+        self._call(callback, self.data.force, key, trust)
 
     def loadTrust(self, callback, bare_jid, device_id):
         key = '\n'.join([KEY_TRUST, bare_jid, str(device_id)])
-        d = self.data.get(key)
-        if callback is not None:
-            self.setCb(d, callback)
-        return d
+        self._call(callback, self.data.get, key)
 
     def listJIDs(self, callback):
-        d = defer.succeed(self.all_jids)
         if callback is not None:
-            self.setCb(d, callback)
-        return d
+            callback(True, self.all_jids)
 
     def _deleteJID_logResults(self, results):
         failed = [success for success, __ in results if not success]
@@ -266,8 +270,7 @@
         d.addCallback(self._deleteJID_logResults)
         return d
 
-    def deleteJID(self, callback, bare_jid):
-        """Retrieve all (in)actives devices of bare_jid, and delete all related keys"""
+    def _deleteJID(self, callback, bare_jid):
         d_list = []
 
         key = '\n'.join([KEY_ACTIVE_DEVICES, bare_jid])
@@ -284,7 +287,10 @@
         d.addCallback(self._deleteJID_gotDevices, bare_jid)
         if callback is not None:
             self.setCb(d, callback)
-        return d
+
+    def deleteJID(self, callback, bare_jid):
+        """Retrieve all (in)actives devices of bare_jid, and delete all related keys"""
+        reactor.callFromThread(self._deleteJID, callback, bare_jid)
 
 
 class SatOTPKPolicy(omemo.DefaultOTPKPolicy):
@@ -728,7 +734,7 @@
             while device_id in devices:
                 device_id = random.randint(1, 2**31-1)
             # and we save it
-            persistent_dict[KEY_DEVICE_ID] = device_id
+            await persistent_dict.aset(KEY_DEVICE_ID, device_id)
 
         log.debug(f"our OMEMO device id is {device_id}")
 
@@ -788,8 +794,7 @@
                     devices.add(device_id)
         return devices
 
-    @defer.inlineCallbacks
-    def getDevices(self, client, entity_jid=None):
+    async def getDevices(self, client, entity_jid=None):
         """Retrieve list of registered OMEMO devices
 
         @param entity_jid(jid.JID, None): get devices from this entity
@@ -799,13 +804,13 @@
         if entity_jid is not None:
             assert not entity_jid.resource
         try:
-            items, metadata = yield self._p.getItems(client, entity_jid, NS_OMEMO_DEVICES)
+            items, metadata = await self._p.getItems(client, entity_jid, NS_OMEMO_DEVICES)
         except exceptions.NotFound:
             log.info(_("there is no node to handle OMEMO devices"))
-            defer.returnValue(set())
+            return set()
 
         devices = self.parseDevices(items)
-        defer.returnValue(devices)
+        return devices
 
     async def setDevices(self, client, devices):
         log.debug(f"setting devices with {', '.join(str(d) for d in devices)}")
@@ -827,8 +832,7 @@
 
     # bundles
 
-    @defer.inlineCallbacks
-    def getBundles(self, client, entity_jid, devices_ids):
+    async def getBundles(self, client, entity_jid, devices_ids):
         """Retrieve public bundles of an entity devices
 
         @param entity_jid(jid.JID): bare jid of entity
@@ -845,7 +849,7 @@
         for device_id in devices_ids:
             node = NS_OMEMO_BUNDLE.format(device_id=device_id)
             try:
-                items, metadata = yield self._p.getItems(client, entity_jid, node)
+                items, metadata = await self._p.getItems(client, entity_jid, node)
             except exceptions.NotFound:
                 log.warning(_("Bundle missing for device {device_id}")
                     .format(device_id=device_id))
@@ -906,7 +910,7 @@
             bundles[device_id] = ExtendedPublicBundle.parse(omemo_backend, ik, spk,
                                                             spkSignature, otpks)
 
-        defer.returnValue((bundles, missing))
+        return (bundles, missing)
 
     async def setBundle(self, client, bundle, device_id):
         """Set public bundle for this device.