diff sat/memory/encryption.py @ 2651:ebcff5423465

core (memory/encryption): start improvments, stop and getSession: - "start" won't do anything if session is already encrypted with requested algorithm, and will raise a ConflictError if it's already encrypted but with an other algorithm - implemented "stop", with an optional namespace to check we are stopping the expected algorithm - "getSession" retrieve the current encryption session of a jid, if any
author Goffi <goffi@goffi.org>
date Sat, 11 Aug 2018 18:24:55 +0200
parents 712cb4ff3e13
children 4e130cc9bfc0
line wrap: on
line diff
--- a/sat/memory/encryption.py	Sat Aug 11 18:24:52 2018 +0200
+++ b/sat/memory/encryption.py	Sat Aug 11 18:24:55 2018 +0200
@@ -61,6 +61,7 @@
             priority=priority)
         cls.plugins.append(plg)
         cls.plugins.sort(key=lambda p: p.priority)
+        log.info(_(u"Encryption plugin registered: {name}").format(name=name))
 
     def start(self, entity, namespace=None):
         """Start an encrypted session with an entity
@@ -74,16 +75,6 @@
             raise exceptions.NotFound(_(u"No encryption plugin is registered, "
                                         u"an encryption session can't be started"))
 
-        bare_jid = entity.userhostJID()
-        if bare_jid in self._sessions:
-            plg = self._sessions[bare_jid]['plugin']
-
-            msg = (_(u"Session with {bare_jid} is already encrypted with {name}."
-                     u"Please stop encryption session before changing algorithm.")
-                   .format(bare_jid=bare_jid, name=plg.name))
-            log.warning(msg)
-            raise exceptions.ConflictError(msg)
-
         if namespace is None:
             plg = self.plugins[0]
         else:
@@ -94,16 +85,81 @@
                     u"Can't find requested encryption plugin: {namespace}").format(
                         namespace=namespace))
 
+        bare_jid = entity.userhostJID()
+        if bare_jid in self._sessions:
+            plg = self._sessions[bare_jid]['plugin']
+            if plg.namespace == namespace:
+                log.info(_(u"Session with {bare_jid} is already encrypted with {name}."
+                     u"Nothing to do.")
+                   .format(bare_jid=bare_jid, name=plg.name))
+                return
+
+            msg = (_(u"Session with {bare_jid} is already encrypted with {name}. "
+                     u"Please stop encryption session before changing algorithm.")
+                   .format(bare_jid=bare_jid, name=plg.name))
+            log.warning(msg)
+            raise exceptions.ConflictError(msg)
+
         data = {"plugin": plg}
         if entity.resource:
             # indicate that we encrypt only for some devices
             data['directed_devices'] = [entity.resource]
 
         self._sessions[entity.userhostJID()] = data
-        log.info(_(u"Encryption session as been set for {bare_jid} with "
+        log.info(_(u"Encryption session has been set for {bare_jid} with "
                    u"{encryption_name}").format(
                    bare_jid=bare_jid.userhost(), encryption_name=plg.name))
 
+    def stop(self, entity, namespace=None):
+        """Stop an encrypted session with an entity
+
+        @param entity(jid.JID): entity with who the encrypted session must be stopped
+            must be bare jid is the algorithm encrypt for all devices
+        @param namespace(unicode): namespace of the session to stop
+            when specified, used to check we stop the right encryption session
+        """
+        session = self.getSession(entity.userhostJID())
+        if not session:
+            raise exceptions.NotFound(_(u"There is no encrypted session with this "
+                                        u"entity."))
+        if namespace is not None and session[u'plugin'].namespace != namespace:
+            raise exceptions.InternalError(_(
+                u"The encrypted session is not run with the expected plugin: encrypted "
+                u"with {current_name} and was expecting {expected_name}").format(
+                current_name=session[u'plugin'].namespace,
+                expected_name=namespace))
+        if entity.resource:
+            try:
+                directed_devices = session[u'directed_devices']
+            except KeyError:
+                raise exceptions.NotFound(_(
+                    u"There is a session for the whole entity (i.e. all devices of the "
+                    u"entity), not a directed one. Please use bare jid if you want to "
+                    u"stop the whole encryption with this entity."))
+
+            try:
+                directed_devices.remove(entity.resource)
+            except ValueError:
+                raise exceptions.NotFound(_(u"There is no directed session with this "
+                                            u"entity."))
+        else:
+            del self._sessions[entity]
+
+        log.info(_(u"Encrypted session stopped with entity {entity}").format(
+            entity=entity.full()))
+
+    def getSession(self, entity):
+        """Get encryption session for this contact
+
+        @param entity(jid.JID): get the session for this entity
+            must be a bare jid
+        @return (dict, None): encrypted session data
+            None if there is not encryption for this session with this jid
+        """
+        if entity.resource:
+            raise exceptions.InternalError(u"Full jid given when expecting bare jid")
+        return self._sessions.get(entity)
+
     ## Triggers ##
 
     def setEncryptionFlag(self, mess_data):