diff sat/memory/encryption.py @ 3226:2f406b762788

core (memory/encryption): encryption session are now restored on client connection
author Goffi <goffi@goffi.org>
date Sun, 22 Mar 2020 18:39:12 +0100
parents 0469c53ed5dd
children cc3fea71c365
line wrap: on
line diff
--- a/sat/memory/encryption.py	Sun Mar 22 18:35:22 2020 +0100
+++ b/sat/memory/encryption.py	Sun Mar 22 18:39:12 2020 +0100
@@ -17,18 +17,20 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+import copy
 from functools import partial
+from collections import namedtuple
+from twisted.words.protocols.jabber import jid
+from twisted.internet import defer
+from twisted.python import failure
 from sat.core.i18n import D_, _
 from sat.core.constants import Const as C
 from sat.core import exceptions
-from collections import namedtuple
 from sat.core.log import getLogger
 from sat.tools.common import data_format
-from twisted.words.protocols.jabber import jid
-from twisted.internet import defer
-from twisted.python import failure
-import copy
-log = getLogger(__name__)
+from sat.tools import utils
+from sat.memory import persistent
+
 
 log = getLogger(__name__)
 
@@ -39,18 +41,38 @@
                                                    "directed"))
 
 
-class EncryptionHandler(object):
+class EncryptionHandler:
     """Class to handle encryption sessions for a client"""
     plugins = []  # plugin able to encrypt messages
 
     def __init__(self, client):
         self.client = client
         self._sessions = {}  # bare_jid ==> encryption_data
+        self._stored_session = persistent.PersistentDict(
+            "core:encryption", profile=client.profile)
 
     @property
     def host(self):
         return self.client.host_app
 
+    async def loadSessions(self):
+        """Load persistent sessions"""
+        await self._stored_session.load()
+        start_d_list = []
+        for entity_jid_s, namespace in self._stored_session.items():
+            entity = jid.JID(entity_jid_s)
+            start_d_list.append(defer.ensureDeferred(self.start(entity, namespace)))
+
+        if start_d_list:
+            result = await defer.DeferredList(start_d_list)
+            for idx, (success, err) in enumerate(result):
+                if not success:
+                    entity_jid_s, namespace = list(self._stored_session.items())[idx]
+                    log.warning(_(
+                        "Could not restart {namespace!r} encryption with {entity}: {err}"
+                        ).format(namespace=namespace, entity=entity_jid_s, err=err))
+            log.info(_("encryption sessions restored"))
+
     @classmethod
     def registerPlugin(cls, plg_instance, name, namespace, priority=0, directed=False):
         """Register a plugin handling an encryption algorithm
@@ -142,40 +164,41 @@
 
         return data_format.serialise(bridge_data)
 
-    def _startEncryption(self, plugin, entity):
+    async def _startEncryption(self, plugin, entity):
         """Start encryption with a plugin
 
         This method must be called just before adding a plugin session.
         StartEncryptionn method of plugin will be called if it exists.
         """
+        if not plugin.directed:
+            await self._stored_session.aset(entity.userhost(), plugin.namespace)
         try:
             start_encryption = plugin.instance.startEncryption
         except AttributeError:
-            log.debug("No startEncryption method found for {plugin}".format(
-                plugin = plugin.namespace))
-            return defer.succeed(None)
+            log.debug(f"No startEncryption method found for {plugin.namespace}")
         else:
             # we copy entity to avoid having the resource changed by stop_encryption
-            return defer.maybeDeferred(start_encryption, self.client, copy.copy(entity))
+            await utils.asDeferred(start_encryption, self.client, copy.copy(entity))
 
-    def _stopEncryption(self, plugin, entity):
+    async def _stopEncryption(self, plugin, entity):
         """Stop encryption with a plugin
 
         This method must be called just before removing a plugin session.
         StopEncryptionn method of plugin will be called if it exists.
         """
         try:
+            await self._stored_session.adel(entity.userhost())
+        except KeyError:
+            pass
+        try:
             stop_encryption = plugin.instance.stopEncryption
         except AttributeError:
-            log.debug("No stopEncryption method found for {plugin}".format(
-                plugin = plugin.namespace))
-            return defer.succeed(None)
+            log.debug(f"No stopEncryption method found for {plugin.namespace}")
         else:
             # we copy entity to avoid having the resource changed by stop_encryption
-            return defer.maybeDeferred(stop_encryption, self.client, copy.copy(entity))
+            return utils.asDeferred(stop_encryption, self.client, copy.copy(entity))
 
-    @defer.inlineCallbacks
-    def start(self, entity, namespace=None, replace=False):
+    async def start(self, entity, namespace=None, replace=False):
         """Start an encryption session with an entity
 
         @param entity(jid.JID): entity to start an encryption session with
@@ -209,7 +232,7 @@
                 # there is a conflict, but replacement is requested
                 # so we stop previous encryption to use new one
                 del self._sessions[bare_jid]
-                yield self._stopEncryption(former_plugin, entity)
+                await self._stopEncryption(former_plugin, entity)
             else:
                 msg = (_("Session with {bare_jid} is already encrypted with {name}. "
                          "Please stop encryption session before changing algorithm.")
@@ -233,7 +256,7 @@
         elif entity.resource:
             raise ValueError(_("{name} encryption must be used with bare jids."))
 
-        yield self._startEncryption(plugin, entity)
+        await self._startEncryption(plugin, entity)
         self._sessions[entity.userhostJID()] = data
         log.info(_("Encryption session has been set for {entity_jid} with "
                    "{encryption_name}").format(
@@ -254,14 +277,13 @@
 
         self.client.feedback(bare_jid, msg)
 
-    @defer.inlineCallbacks
-    def stop(self, entity, namespace=None):
+    async def stop(self, entity, namespace=None):
         """Stop an encryption session with an entity
 
         @param entity(jid.JID): entity with who the encryption session must be stopped
             must be bare jid if the algorithm encrypt for all devices
         @param namespace(unicode): namespace of the session to stop
-            when specified, used to check we stop the right encryption session
+            when specified, used to check that we stop the right encryption session
         """
         session = self.getSession(entity.userhostJID())
         if not session:
@@ -295,12 +317,12 @@
                     # we stop the whole session
                     # see comment below for deleting session before stopping encryption
                     del self._sessions[entity.userhostJID()]
-                    yield self._stopEncryption(plugin, entity)
+                    await self._stopEncryption(plugin, entity)
         else:
             # plugin's stopEncryption may call stop again (that's the case with OTR)
             # so we need to remove plugin from session before calling self._stopEncryption
             del self._sessions[entity.userhostJID()]
-            yield self._stopEncryption(plugin, entity)
+            await self._stopEncryption(plugin, entity)
 
         log.info(_("encryption session stopped with entity {entity}").format(
             entity=entity.full()))
@@ -390,7 +412,7 @@
     def _onMenuUnencrypted(cls, data, host, profile):
         client = host.getClient(profile)
         peer_jid = jid.JID(data['jid']).userhostJID()
-        d = client.encryption.stop(peer_jid)
+        d = defer.ensureDeferred(client.encryption.stop(peer_jid))
         d.addCallback(lambda __: {})
         return d
 
@@ -400,7 +422,8 @@
         peer_jid = jid.JID(data['jid'])
         if not plg.directed:
             peer_jid = peer_jid.userhostJID()
-        d = client.encryption.start(peer_jid, plg.namespace, replace=True)
+        d = defer.ensureDeferred(
+            client.encryption.start(peer_jid, plg.namespace, replace=True))
         d.addCallback(lambda __: {})
         return d