diff sat/memory/encryption.py @ 2733:e347e32aa07f

core (memory/encryption): new encryptionNamespaceGet and encryptionTrustUIGet methods: - encryptionNamespaceGet retrieves algorithm namespace from its short name - encryptionTrustUIGet retrieves trust mangement XMLUI from encryption plugin - new markAsUntrusted internal helper method, to add untrusted flag to message data
author Goffi <goffi@goffi.org>
date Wed, 02 Jan 2019 18:22:30 +0100
parents 4e130cc9bfc0
children da59ff099b32
line wrap: on
line diff
--- a/sat/memory/encryption.py	Thu Dec 27 11:40:04 2018 +0100
+++ b/sat/memory/encryption.py	Wed Jan 02 18:22:30 2019 +0100
@@ -51,8 +51,10 @@
 
         @param plg_instance(object): instance of the plugin
             it must have the following methods:
-                - startEncryption(jid.JID): start an encryption session with a bare jid
-                - stopEncryption(jid.JID): stop an encryption session with a bare jid
+                - getTrustUI(entity): return a XMLUI for trust management
+                    entity(jid.JID): entity to manage
+                    The returned XMLUI must be a form
+
         @param name(unicode): human readable name of the encryption algorithm
         @param namespace(unicode): namespace of the encryption algorithm
         @param priority(int): priority of this plugin to encrypt an message when not
@@ -84,6 +86,15 @@
         return cls.plugins
 
     @classmethod
+    def getPlugin(cls, namespace):
+        try:
+            return next(p for p in cls.plugins if p.namespace == namespace)
+        except StopIteration:
+            raise exceptions.NotFound(_(
+                u"Can't find requested encryption plugin: {namespace}").format(
+                    namespace=namespace))
+
+    @classmethod
     def getNSFromName(cls, name):
         """Retrieve plugin namespace from its name
 
@@ -94,7 +105,9 @@
         for p in cls.plugins:
             if p.name.lower() == name.lower():
                 return p.namespace
-        raise exceptions.NotFound
+        raise exceptions.NotFound(_(
+            u"Can't find a plugin with the name \"{name}\".".format(
+                name=name)))
 
     def getBridgeData(self, session):
         """Retrieve session data serialized for bridge.
@@ -117,7 +130,8 @@
 
         @param entity(jid.JID): entity to start an encryption session with
             must be bare jid is the algorithm encrypt for all devices
-        @param namespace(unicode, None): namespace of the encryption algorithm to use
+        @param namespace(unicode, None): namespace of the encryption algorithm
+            to use.
             None to select automatically an algorithm
         @param replace(bool): if True and an encrypted session already exists,
             it will be replaced by the new one
@@ -129,12 +143,7 @@
         if namespace is None:
             plugin = self.plugins[0]
         else:
-            try:
-                plugin = next(p for p in self.plugins if p.namespace == namespace)
-            except StopIteration:
-                raise exceptions.NotFound(_(
-                    u"Can't find requested encryption plugin: {namespace}").format(
-                        namespace=namespace))
+            plugin = self.getPlugin(namespace)
 
         bare_jid = entity.userhostJID()
         if bare_jid in self._sessions:
@@ -142,7 +151,8 @@
             former_plugin = self._sessions[bare_jid]['plugin']
             if former_plugin.namespace == namespace:
                 log.info(_(u"Session with {bare_jid} is already encrypted with {name}. "
-                           u"Nothing to do.").format(bare_jid=bare_jid, name=plugin.name))
+                           u"Nothing to do.").format(
+                               bare_jid=bare_jid, name=former_plugin.name))
                 return
 
             if replace:
@@ -255,9 +265,39 @@
             None if there is not encryption for this session with this jid
         """
         if entity.resource:
-            raise exceptions.InternalError(u"Full jid given when expecting bare jid")
+            raise ValueError(u"Full jid given when expecting bare jid")
         return self._sessions.get(entity)
 
+    def getTrustUI(self, entity_jid, namespace=None):
+        """Retrieve encryption UI
+
+        @param entity_jid(jid.JID): get the UI for this entity
+            must be a bare jid
+        @param namespace(unicode): namespace of the algorithm to manage
+            if None use current algorithm
+        @return D(xmlui): XMLUI for trust management
+            the xmlui is a form
+            None if there is not encryption for this session with this jid
+        @raise exceptions.NotFound: no algorithm/plugin found
+        @raise NotImplementedError: plugin doesn't handle UI management
+        """
+        if namespace is None:
+            session = self.getSession(entity_jid)
+            if not session:
+                raise exceptions.NotFound(
+                    u"No encryption session currently active for {entity_jid}"
+                    .format(entity_jid=entity_jid.full()))
+            plugin = session['plugin']
+        else:
+            plugin = self.getPlugin(namespace)
+        try:
+            get_trust_ui = plugin.instance.getTrustUI
+        except AttributeError:
+            raise NotImplementedError(
+                u"Encryption plugin doesn't handle trust management UI")
+        else:
+            return get_trust_ui(self.client, entity_jid)
+
     ## Triggers ##
 
     def setEncryptionFlag(self, mess_data):
@@ -283,3 +323,13 @@
         """
         mess_data['encrypted'] = True
         return mess_data
+
+    def markAsUntrusted(self, mess_data):
+        """Helper methor to mark a message as sent from an untrusted entity.
+
+        This should be used in the post_treat workflow of MessageReceived trigger of
+        the plugin
+        @param mess_data(dict): message data as used in post treat workflow
+        """
+        mess_data['untrusted'] = True
+        return mess_data