diff sat/plugins/plugin_xep_0384.py @ 2654:e7bfbded652a

plugin XEP-0384, install: adapted plugin to omemo module changes + added omemo module to dependencies: - omemo module logs are now integrated in SàT - OmemoStorage adapated to changes - now really delete OTPK each time, behaviour may be modified in future - fixed bad arguments used during decryption - adapted to other changes OMEMO is now working \o/ It still needs some adjustements though: bundles/devices are for the moment requested on each message encryption, and fingerprints management is not implemented yet.
author Goffi <goffi@goffi.org>
date Sat, 11 Aug 2018 18:24:55 +0200
parents 0f76813afc57
children 0bef44f8e8ca
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0384.py	Sat Aug 11 18:24:55 2018 +0200
+++ b/sat/plugins/plugin_xep_0384.py	Sat Aug 11 18:24:55 2018 +0200
@@ -27,6 +27,7 @@
 from twisted.words.protocols.jabber import error
 from sat.memory import persistent
 from functools import partial
+import logging
 import random
 import base64
 try:
@@ -41,7 +42,6 @@
 
 log = getLogger(__name__)
 
-
 PLUGIN_INFO = {
     C.PI_NAME: u"OMEMO",
     C.PI_IMPORT_NAME: u"OMEMO",
@@ -63,17 +63,36 @@
 KEY_INACTIVE_DEVICES = "DEVICES"
 
 
+# we want to manage log emitted by omemo module ourseves
+
+class SatHandler(logging.Handler):
+
+    def emit(self, record):
+        log.log(record.levelname, record.getMessage())
+
+    @staticmethod
+    def install():
+        omemo_sm_logger = logging.getLogger("omemo.SessionManager")
+        omemo_sm_logger.propagate = False
+        omemo_sm_logger.addHandler(SatHandler())
+
+
+SatHandler.install()
+
+
 def b64enc(data):
     return base64.b64encode(bytes(bytearray(data))).decode("ASCII")
 
 
 class OmemoStorage(omemo.Storage):
 
-    def __init__(self, persistent_dict):
+    def __init__(self, client, device_id, persistent_dict):
         """
         @param persistent_dict(persistent.LazyPersistentBinaryDict): object which will
             store data in SàT database
         """
+        self.own_bare_jid_s = client.jid.userhost()
+        self.device_id = device_id
         self.data = persistent_dict
 
     @property
@@ -91,45 +110,54 @@
         deferred.addCallback(partial(callback, True))
         deferred.addErrback(partial(callback, False))
 
+    def loadOwnData(self, callback):
+        callback(True, {'own_bare_jid': self.own_bare_jid_s,
+                        'own_device_id': self.device_id})
+
+    def storeOwnData(self, callback, own_bare_jid, own_device_id):
+        if own_bare_jid != self.own_bare_jid_s or own_device_id != self.device_id:
+            raise exceptions.InternalError('bare jid or device id inconsistency!')
+        callback(True, None)
+
     def loadState(self, callback):
         d = self.data.get(KEY_STATE)
         self.setCb(d, callback)
 
-    def storeState(self, callback, state, device_id):
-        d = self.data.force(KEY_STATE, {'state': state, 'device_id': device_id})
+    def storeState(self, callback, state):
+        d = self.data.force(KEY_STATE, state)
         self.setCb(d, callback)
 
-    def loadSession(self, callback, jid, device_id):
-        key = u'\n'.join([KEY_SESSION, jid, unicode(device_id)])
+    def loadSession(self, callback, bare_jid, device_id):
+        key = u'\n'.join([KEY_SESSION, bare_jid, unicode(device_id)])
         d = self.data.get(key)
         self.setCb(d, callback)
 
-    def storeSession(self, callback, jid, device_id, session):
-        key = u'\n'.join([KEY_SESSION, jid, unicode(device_id)])
+    def storeSession(self, callback, bare_jid, device_id, session):
+        key = u'\n'.join([KEY_SESSION, bare_jid, unicode(device_id)])
         d = self.data.force(key, session)
         self.setCb(d, callback)
 
-    def loadActiveDevices(self, callback, jid):
-        key = u'\n'.join([KEY_ACTIVE_DEVICES, jid])
+    def loadActiveDevices(self, callback, bare_jid):
+        key = u'\n'.join([KEY_ACTIVE_DEVICES, bare_jid])
         d = self.data.get(key, {})
         self.setCb(d, callback)
 
-    def loadInactiveDevices(self, callback, jid):
-        key = u'\n'.join([KEY_INACTIVE_DEVICES, jid])
+    def loadInactiveDevices(self, callback, bare_jid):
+        key = u'\n'.join([KEY_INACTIVE_DEVICES, bare_jid])
         d = self.data.get(key, {})
         self.setCb(d, callback)
 
-    def storeActiveDevices(self, callback, jid, devices):
-        key = u'\n'.join([KEY_ACTIVE_DEVICES, jid])
+    def storeActiveDevices(self, callback, bare_jid, devices):
+        key = u'\n'.join([KEY_ACTIVE_DEVICES, bare_jid])
         d = self.data.force(key, devices)
         self.setCb(d, callback)
 
-    def storeInactiveDevices(self, callback, jid, devices):
-        key = u'\n'.join([KEY_INACTIVE_DEVICES, jid])
+    def storeInactiveDevices(self, callback, bare_jid, devices):
+        key = u'\n'.join([KEY_INACTIVE_DEVICES, bare_jid])
         d = self.data.force(key, devices)
         self.setCb(d, callback)
 
-    def isTrusted(self, callback, jid, device):
+    def isTrusted(self, callback, bare_jid, device):
         trusted = True
         callback(True, trusted)
 
@@ -138,9 +166,8 @@
 
     @staticmethod
     def decideOTPK(preKeyMessages):
-        # Always just delete the OTPK.
-        # This is the behaviour described in the original X3DH specification.
-        return True
+        # always delete
+        return False
 
 
 class OmemoSession(object):
@@ -165,16 +192,18 @@
         return d
 
     @classmethod
-    def create(cls, client, omemo_storage, device_id):
+    def create(cls, client, storage, my_device_id = None):
         omemo_session_p = client._xep_0384_session = omemo.SessionManager.create(
-            client.jid.userhost(), omemo_storage, SatOTPKPolicy, my_device_id=device_id)
+            storage,
+            SatOTPKPolicy,
+            client.jid.userhost(),
+            my_device_id)
         d = cls.promise2Deferred(omemo_session_p)
         d.addCallback(lambda session: cls(session))
         return d
 
-    def newDeviceList(self, devices, jid=None):
-        if jid is not None:
-            jid = jid.userhost()
+    def newDeviceList(self, devices, jid):
+        jid = jid.userhost()
         new_device_p = self._session.newDeviceList(devices, jid)
         return self.promise2Deferred(new_device_p)
 
@@ -236,7 +265,7 @@
         self._p = host.plugins[u"XEP-0060"]
         host.trigger.add("MessageReceived", self._messageReceivedTrigger, priority=100050)
         host.trigger.add("sendMessageData", self._sendMessageDataTrigger)
-        self.host.registerEncryptionPlugin(self, "OMEMO", NS_OMEMO, 100)
+        self.host.registerEncryptionPlugin(self, u"OMEMO", NS_OMEMO, 100)
 
     @defer.inlineCallbacks
     def profileConnected(self, client):
@@ -259,11 +288,11 @@
             devices.add(device_id)
             yield self.setDevices(client, devices)
 
-        omemo_storage = OmemoStorage(persistent_dict)
+        omemo_storage = OmemoStorage(client, device_id, persistent_dict)
         omemo_session = yield OmemoSession.create(client, omemo_storage, device_id)
         client._xep_0384_session = omemo_session
         client._xep_0384_device_id = device_id
-        yield omemo_session.newDeviceList(devices)
+        yield omemo_session.newDeviceList(devices, client.jid)
         if omemo_session.state.changed:
             log.info(_(u"Saving public bundle for this device ({device_id})").format(
                 device_id=device_id))
@@ -355,7 +384,7 @@
                                                                 device_id=device_id))
                 continue
             if len(items) > 1:
-                log.warning(_(u"more than one item found in {node},"
+                log.warning(_(u"more than one item found in {node}, "
                               u"this is not expected").format(node=node))
             item = items[0]
             try:
@@ -394,7 +423,7 @@
                     otpks.append(otpk)
 
             except Exception as e:
-                log.warning(_(u"error while decoding key for device {devide_id}: {msg}")
+                log.warning(_(u"error while decoding key for device {device_id}: {msg}")
                             .format(device_id=device_id, msg=e))
                 continue
 
@@ -462,7 +491,7 @@
             defer.returnValue(True)
 
         # we have an encrypted message let's decrypt it
-        # from_jid = jid.JID(message_elt['from'])
+        from_jid = jid.JID(message_elt['from'])
         omemo_session = client._xep_0384_session
         device_id = client._xep_0384_device_id
         try:
@@ -473,6 +502,13 @@
                 .format(xml=message_elt.toXml()))
             defer.returnValue(False)
         try:
+            s_device_id = header_elt['sid']
+        except KeyError:
+            log.warning(_(u"Invalid OMEMO encrypted stanza, missing sender device ID, "
+                          u"ignoring: {xml}")
+                .format(xml=message_elt.toXml()))
+            defer.returnValue(False)
+        try:
             key_elt = next((e for e in header_elt.elements(NS_OMEMO, u'key')
                             if int(e[u'rid']) == device_id))
         except StopIteration:
@@ -488,8 +524,8 @@
 
         try:
             cipher, plaintext = yield omemo_session.decryptMessage(
-                bare_jid=client.jid.userhostJID(),
-                device=device_id,
+                bare_jid=from_jid.userhostJID(),
+                device=s_device_id,
                 iv=base64.b64decode(bytes(iv_elt)),
                 message=base64.b64decode(bytes(key_elt)),
                 is_pre_key_message=is_pre_key,