changeset 3226:2f406b762788

core (memory/encryption): encryption session are now restored on client connection
author Goffi <goffi@goffi.org>
date Sun, 22 Mar 2020 18:39:12 +0100
parents 843a9279fb5a
children 6d19a99172d7
files sat/core/sat_main.py sat/core/xmpp.py sat/memory/encryption.py sat/plugins/plugin_sec_otr.py
diffstat 4 files changed, 61 insertions(+), 34 deletions(-) [+]
line wrap: on
line diff
--- a/sat/core/sat_main.py	Sun Mar 22 18:35:22 2020 +0100
+++ b/sat/core/sat_main.py	Sun Mar 22 18:39:12 2020 +0100
@@ -849,12 +849,14 @@
                                 profile_key=C.PROF_KEY_NONE):
         client = self.getClient(profile_key)
         to_jid = jid.JID(to_jid_s)
-        return client.encryption.start(to_jid, namespace or None, replace)
+        return defer.ensureDeferred(
+            client.encryption.start(to_jid, namespace or None, replace))
 
     def _messageEncryptionStop(self, to_jid_s, profile_key=C.PROF_KEY_NONE):
         client = self.getClient(profile_key)
         to_jid = jid.JID(to_jid_s)
-        return client.encryption.stop(to_jid)
+        return defer.ensureDeferred(
+            client.encryption.stop(to_jid))
 
     def _messageEncryptionGet(self, to_jid_s, profile_key=C.PROF_KEY_NONE):
         client = self.getClient(profile_key)
--- a/sat/core/xmpp.py	Sun Mar 22 18:35:22 2020 +0100
+++ b/sat/core/xmpp.py	Sun Mar 22 18:39:12 2020 +0100
@@ -209,6 +209,8 @@
             port, max_retries,
             )
 
+        await entity.encryption.loadSessions()
+
         entity._createSubProtocols()
 
         entity.fallBack = SatFallbackHandler(host)
--- a/sat/memory/encryption.py	Sun Mar 22 18:35:22 2020 +0100
+++ b/sat/memory/encryption.py	Sun Mar 22 18:39:12 2020 +0100
@@ -17,18 +17,20 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+import copy
 from functools import partial
+from collections import namedtuple
+from twisted.words.protocols.jabber import jid
+from twisted.internet import defer
+from twisted.python import failure
 from sat.core.i18n import D_, _
 from sat.core.constants import Const as C
 from sat.core import exceptions
-from collections import namedtuple
 from sat.core.log import getLogger
 from sat.tools.common import data_format
-from twisted.words.protocols.jabber import jid
-from twisted.internet import defer
-from twisted.python import failure
-import copy
-log = getLogger(__name__)
+from sat.tools import utils
+from sat.memory import persistent
+
 
 log = getLogger(__name__)
 
@@ -39,18 +41,38 @@
                                                    "directed"))
 
 
-class EncryptionHandler(object):
+class EncryptionHandler:
     """Class to handle encryption sessions for a client"""
     plugins = []  # plugin able to encrypt messages
 
     def __init__(self, client):
         self.client = client
         self._sessions = {}  # bare_jid ==> encryption_data
+        self._stored_session = persistent.PersistentDict(
+            "core:encryption", profile=client.profile)
 
     @property
     def host(self):
         return self.client.host_app
 
+    async def loadSessions(self):
+        """Load persistent sessions"""
+        await self._stored_session.load()
+        start_d_list = []
+        for entity_jid_s, namespace in self._stored_session.items():
+            entity = jid.JID(entity_jid_s)
+            start_d_list.append(defer.ensureDeferred(self.start(entity, namespace)))
+
+        if start_d_list:
+            result = await defer.DeferredList(start_d_list)
+            for idx, (success, err) in enumerate(result):
+                if not success:
+                    entity_jid_s, namespace = list(self._stored_session.items())[idx]
+                    log.warning(_(
+                        "Could not restart {namespace!r} encryption with {entity}: {err}"
+                        ).format(namespace=namespace, entity=entity_jid_s, err=err))
+            log.info(_("encryption sessions restored"))
+
     @classmethod
     def registerPlugin(cls, plg_instance, name, namespace, priority=0, directed=False):
         """Register a plugin handling an encryption algorithm
@@ -142,40 +164,41 @@
 
         return data_format.serialise(bridge_data)
 
-    def _startEncryption(self, plugin, entity):
+    async def _startEncryption(self, plugin, entity):
         """Start encryption with a plugin
 
         This method must be called just before adding a plugin session.
         StartEncryptionn method of plugin will be called if it exists.
         """
+        if not plugin.directed:
+            await self._stored_session.aset(entity.userhost(), plugin.namespace)
         try:
             start_encryption = plugin.instance.startEncryption
         except AttributeError:
-            log.debug("No startEncryption method found for {plugin}".format(
-                plugin = plugin.namespace))
-            return defer.succeed(None)
+            log.debug(f"No startEncryption method found for {plugin.namespace}")
         else:
             # we copy entity to avoid having the resource changed by stop_encryption
-            return defer.maybeDeferred(start_encryption, self.client, copy.copy(entity))
+            await utils.asDeferred(start_encryption, self.client, copy.copy(entity))
 
-    def _stopEncryption(self, plugin, entity):
+    async def _stopEncryption(self, plugin, entity):
         """Stop encryption with a plugin
 
         This method must be called just before removing a plugin session.
         StopEncryptionn method of plugin will be called if it exists.
         """
         try:
+            await self._stored_session.adel(entity.userhost())
+        except KeyError:
+            pass
+        try:
             stop_encryption = plugin.instance.stopEncryption
         except AttributeError:
-            log.debug("No stopEncryption method found for {plugin}".format(
-                plugin = plugin.namespace))
-            return defer.succeed(None)
+            log.debug(f"No stopEncryption method found for {plugin.namespace}")
         else:
             # we copy entity to avoid having the resource changed by stop_encryption
-            return defer.maybeDeferred(stop_encryption, self.client, copy.copy(entity))
+            return utils.asDeferred(stop_encryption, self.client, copy.copy(entity))
 
-    @defer.inlineCallbacks
-    def start(self, entity, namespace=None, replace=False):
+    async def start(self, entity, namespace=None, replace=False):
         """Start an encryption session with an entity
 
         @param entity(jid.JID): entity to start an encryption session with
@@ -209,7 +232,7 @@
                 # there is a conflict, but replacement is requested
                 # so we stop previous encryption to use new one
                 del self._sessions[bare_jid]
-                yield self._stopEncryption(former_plugin, entity)
+                await self._stopEncryption(former_plugin, entity)
             else:
                 msg = (_("Session with {bare_jid} is already encrypted with {name}. "
                          "Please stop encryption session before changing algorithm.")
@@ -233,7 +256,7 @@
         elif entity.resource:
             raise ValueError(_("{name} encryption must be used with bare jids."))
 
-        yield self._startEncryption(plugin, entity)
+        await self._startEncryption(plugin, entity)
         self._sessions[entity.userhostJID()] = data
         log.info(_("Encryption session has been set for {entity_jid} with "
                    "{encryption_name}").format(
@@ -254,14 +277,13 @@
 
         self.client.feedback(bare_jid, msg)
 
-    @defer.inlineCallbacks
-    def stop(self, entity, namespace=None):
+    async def stop(self, entity, namespace=None):
         """Stop an encryption session with an entity
 
         @param entity(jid.JID): entity with who the encryption session must be stopped
             must be bare jid if the algorithm encrypt for all devices
         @param namespace(unicode): namespace of the session to stop
-            when specified, used to check we stop the right encryption session
+            when specified, used to check that we stop the right encryption session
         """
         session = self.getSession(entity.userhostJID())
         if not session:
@@ -295,12 +317,12 @@
                     # we stop the whole session
                     # see comment below for deleting session before stopping encryption
                     del self._sessions[entity.userhostJID()]
-                    yield self._stopEncryption(plugin, entity)
+                    await self._stopEncryption(plugin, entity)
         else:
             # plugin's stopEncryption may call stop again (that's the case with OTR)
             # so we need to remove plugin from session before calling self._stopEncryption
             del self._sessions[entity.userhostJID()]
-            yield self._stopEncryption(plugin, entity)
+            await self._stopEncryption(plugin, entity)
 
         log.info(_("encryption session stopped with entity {entity}").format(
             entity=entity.full()))
@@ -390,7 +412,7 @@
     def _onMenuUnencrypted(cls, data, host, profile):
         client = host.getClient(profile)
         peer_jid = jid.JID(data['jid']).userhostJID()
-        d = client.encryption.stop(peer_jid)
+        d = defer.ensureDeferred(client.encryption.stop(peer_jid))
         d.addCallback(lambda __: {})
         return d
 
@@ -400,7 +422,8 @@
         peer_jid = jid.JID(data['jid'])
         if not plg.directed:
             peer_jid = peer_jid.userhostJID()
-        d = client.encryption.start(peer_jid, plg.namespace, replace=True)
+        d = defer.ensureDeferred(
+            client.encryption.start(peer_jid, plg.namespace, replace=True))
         d.addCallback(lambda __: {})
         return d
 
--- a/sat/plugins/plugin_sec_otr.py	Sun Mar 22 18:35:22 2020 +0100
+++ b/sat/plugins/plugin_sec_otr.py	Sun Mar 22 18:39:12 2020 +0100
@@ -169,12 +169,12 @@
             feedback = _("/!\\ conversation with %(other_jid)s is now UNENCRYPTED") % {
                 "other_jid": self.peer.full()
             }
-            d = client.encryption.stop(self.peer, NS_OTR)
+            d = defer.ensureDeferred(client.encryption.stop(self.peer, NS_OTR))
             d.addCallback(self.stopCb, feedback=feedback)
             d.addErrback(self.stopEb)
             return
         elif state == potr.context.STATE_ENCRYPTED:
-            client.encryption.start(self.peer, NS_OTR)
+            defer.ensureDeferred(client.encryption.start(self.peer, NS_OTR))
             try:
                 trusted = self.isTrusted()
             except TypeError:
@@ -201,7 +201,7 @@
             feedback = D_("OTR conversation with {other_jid} is FINISHED").format(
                 other_jid=self.peer.full()
             )
-            d = client.encryption.stop(self.peer, NS_OTR)
+            d = defer.ensureDeferred(client.encryption.stop(self.peer, NS_OTR))
             d.addCallback(self.stopCb, feedback=feedback)
             d.addErrback(self.stopEb)
             return
@@ -808,7 +808,7 @@
         otrctx = client._otr_context_manager.getContextForUser(to_jid)
 
         if otrctx.state != potr.context.STATE_PLAINTEXT:
-            client.encryption.start(to_jid, NS_OTR)
+            defer.ensureDeferred(client.encryption.start(to_jid, NS_OTR))
             client.encryption.setEncryptionFlag(mess_data)
             if not mess_data["to"].resource:
                 # if not resource was given, we force it here