diff sat/plugins/plugin_xep_0384.py @ 2648:0f76813afc57

plugin XEP-0384: OMEMO implementation first draft: this is the initial implementation of OMEMO encryption using python omemo module. /!\ This implementation is not yet working /!\
author Goffi <goffi@goffi.org>
date Sun, 29 Jul 2018 19:24:21 +0200
parents
children e7bfbded652a
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sat/plugins/plugin_xep_0384.py	Sun Jul 29 19:24:21 2018 +0200
@@ -0,0 +1,564 @@
+#!/usr/bin/env python2
+# -*- coding: utf-8 -*-
+
+# SAT plugin for OMEMO encryption
+# Copyright (C) 2009-2018 Jérôme Poisson (goffi@goffi.org)
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from sat.core.i18n import _
+from sat.core.constants import Const as C
+from sat.core.log import getLogger
+from sat.core import exceptions
+from twisted.internet import defer
+from twisted.words.xish import domish
+from twisted.words.protocols.jabber import jid
+from twisted.words.protocols.jabber import error
+from sat.memory import persistent
+from functools import partial
+import random
+import base64
+try:
+    import omemo
+    from omemo.extendedpublicbundle import ExtendedPublicBundle
+    from omemo import wireformat
+except ImportError:
+    raise exceptions.MissingModule(
+        u'Missing module omemo, please download/install it. You can use '
+        u'"pip install omemo"'
+    )
+
+log = getLogger(__name__)
+
+
+PLUGIN_INFO = {
+    C.PI_NAME: u"OMEMO",
+    C.PI_IMPORT_NAME: u"OMEMO",
+    C.PI_TYPE: u"SEC",
+    C.PI_PROTOCOLS: [u"XEP-0384"],
+    C.PI_DEPENDENCIES: [u"XEP-0280", u"XEP-0334", u"XEP-0060"],
+    C.PI_MAIN: u"OMEMO",
+    C.PI_HANDLER: u"no",
+    C.PI_DESCRIPTION: _(u"""Implementation of OMEMO"""),
+}
+
+NS_OMEMO = "eu.siacs.conversations.axolotl"
+NS_OMEMO_DEVICES = NS_OMEMO + ".devicelist"
+NS_OMEMO_BUNDLE = NS_OMEMO + ".bundles:{device_id}"
+KEY_STATE = "STATE"
+KEY_DEVICE_ID = "DEVICE_ID"
+KEY_SESSION = "SESSION"
+KEY_ACTIVE_DEVICES = "DEVICES"
+KEY_INACTIVE_DEVICES = "DEVICES"
+
+
+def b64enc(data):
+    return base64.b64encode(bytes(bytearray(data))).decode("ASCII")
+
+
+class OmemoStorage(omemo.Storage):
+
+    def __init__(self, persistent_dict):
+        """
+        @param persistent_dict(persistent.LazyPersistentBinaryDict): object which will
+            store data in SàT database
+        """
+        self.data = persistent_dict
+
+    @property
+    def is_async(self):
+        return True
+
+    def setCb(self, deferred, callback):
+        """Associate Deferred and callback
+
+        callback of omemo.Storage expect a boolean with success state then result
+        Deferred on the other hand use 2 methods for callback and errback
+        This method use partial to call callback with boolean then result when
+        Deferred is called
+        """
+        deferred.addCallback(partial(callback, True))
+        deferred.addErrback(partial(callback, False))
+
+    def loadState(self, callback):
+        d = self.data.get(KEY_STATE)
+        self.setCb(d, callback)
+
+    def storeState(self, callback, state, device_id):
+        d = self.data.force(KEY_STATE, {'state': state, 'device_id': device_id})
+        self.setCb(d, callback)
+
+    def loadSession(self, callback, jid, device_id):
+        key = u'\n'.join([KEY_SESSION, jid, unicode(device_id)])
+        d = self.data.get(key)
+        self.setCb(d, callback)
+
+    def storeSession(self, callback, jid, device_id, session):
+        key = u'\n'.join([KEY_SESSION, jid, unicode(device_id)])
+        d = self.data.force(key, session)
+        self.setCb(d, callback)
+
+    def loadActiveDevices(self, callback, jid):
+        key = u'\n'.join([KEY_ACTIVE_DEVICES, jid])
+        d = self.data.get(key, {})
+        self.setCb(d, callback)
+
+    def loadInactiveDevices(self, callback, jid):
+        key = u'\n'.join([KEY_INACTIVE_DEVICES, jid])
+        d = self.data.get(key, {})
+        self.setCb(d, callback)
+
+    def storeActiveDevices(self, callback, jid, devices):
+        key = u'\n'.join([KEY_ACTIVE_DEVICES, jid])
+        d = self.data.force(key, devices)
+        self.setCb(d, callback)
+
+    def storeInactiveDevices(self, callback, jid, devices):
+        key = u'\n'.join([KEY_INACTIVE_DEVICES, jid])
+        d = self.data.force(key, devices)
+        self.setCb(d, callback)
+
+    def isTrusted(self, callback, jid, device):
+        trusted = True
+        callback(True, trusted)
+
+
+class SatOTPKPolicy(omemo.OTPKPolicy):
+
+    @staticmethod
+    def decideOTPK(preKeyMessages):
+        # Always just delete the OTPK.
+        # This is the behaviour described in the original X3DH specification.
+        return True
+
+
+class OmemoSession(object):
+    """Wrapper to use omemo.OmemoSession with Deferred"""
+
+    def __init__(self, session):
+        self._session = session
+
+    @property
+    def state(self):
+        return self._session.state
+
+    @staticmethod
+    def promise2Deferred(promise_):
+        """Create a Deferred and fire it when promise is resolved
+
+        @param promise_(promise.Promise): promise to convert
+        @return (defer.Deferred): deferred instance linked to the promise
+        """
+        d = defer.Deferred()
+        promise_.then(d.callback, d.errback)
+        return d
+
+    @classmethod
+    def create(cls, client, omemo_storage, device_id):
+        omemo_session_p = client._xep_0384_session = omemo.SessionManager.create(
+            client.jid.userhost(), omemo_storage, SatOTPKPolicy, my_device_id=device_id)
+        d = cls.promise2Deferred(omemo_session_p)
+        d.addCallback(lambda session: cls(session))
+        return d
+
+    def newDeviceList(self, devices, jid=None):
+        if jid is not None:
+            jid = jid.userhost()
+        new_device_p = self._session.newDeviceList(devices, jid)
+        return self.promise2Deferred(new_device_p)
+
+    def getDevices(self, bare_jid=None):
+        get_devices_p = self._session.getDevices(bare_jid=bare_jid)
+        return self.promise2Deferred(get_devices_p)
+
+    def buildSession(self, bare_jid, device, bundle):
+        bare_jid = bare_jid.userhost()
+        build_session_p = self._session.buildSession(bare_jid, device, bundle)
+        return self.promise2Deferred(build_session_p)
+
+    def encryptMessage(self, bare_jids, message, bundles=None, devices=None,
+            always_trust = False):
+        """Encrypt a message
+
+        @param bare_jids(iterable[jid.JID]): destinees of the message
+        @param message(unicode): message to encode
+        @param bundles(dict[jid.JID, dict[int, ExtendedPublicBundle]):
+            entities => devices => bundles map
+        @param devices(iterable[int], None): devices to encode for
+        @param always_trust(bool): TODO
+        @return D(dict): encryption data
+        """
+        if isinstance(bare_jids, jid.JID):
+            bare_jids = bare_jids.userhost()
+        else:
+            bare_jids = [e.userhost() for e in bare_jids]
+        if bundles is not None:
+            bundles = {e.userhost(): v for e, v in bundles.iteritems()}
+        encrypt_mess_p = self._session.encryptMessage(
+            bare_jids=bare_jids,
+            plaintext=message.encode('utf-8'),
+            bundles=bundles,
+            devices=devices,
+            always_trust=always_trust)
+        return self.promise2Deferred(encrypt_mess_p)
+
+    def decryptMessage(self, bare_jid, device, iv, message, is_pre_key_message,
+                       payload=None, from_storage=False):
+        bare_jid = bare_jid.userhost()
+        decrypt_mess_p = self._session.decryptMessage(
+            bare_jid=bare_jid,
+            device=device,
+            iv=iv,
+            message=message,
+            is_pre_key_message=is_pre_key_message,
+            payload=payload,
+            from_storage=from_storage)
+        return self.promise2Deferred(decrypt_mess_p)
+
+
+class OMEMO(object):
+    def __init__(self, host):
+        log.info(_(u"OMEMO plugin initialization"))
+        self.host = host
+        self._p_hints = host.plugins[u"XEP-0334"]
+        self._p_carbons = host.plugins[u"XEP-0280"]
+        self._p = host.plugins[u"XEP-0060"]
+        host.trigger.add("MessageReceived", self._messageReceivedTrigger, priority=100050)
+        host.trigger.add("sendMessageData", self._sendMessageDataTrigger)
+        self.host.registerEncryptionPlugin(self, "OMEMO", NS_OMEMO, 100)
+
+    @defer.inlineCallbacks
+    def profileConnected(self, client):
+        # we first need to get devices ids (including our own)
+        persistent_dict = persistent.LazyPersistentBinaryDict("XEP-0384", client.profile)
+        # all known devices of profile
+        devices = yield self.getDevices(client)
+        # and our own device id
+        device_id = yield persistent_dict.get(KEY_DEVICE_ID)
+        if device_id is None:
+            # we have a new device, we create device_id
+            device_id = random.randint(1, 2**31-1)
+            # we check that it's really unique
+            while device_id in devices:
+                device_id = random.randint(1, 2**31-1)
+            # and we save it
+            persistent_dict[KEY_DEVICE_ID] = device_id
+
+        if device_id not in devices:
+            devices.add(device_id)
+            yield self.setDevices(client, devices)
+
+        omemo_storage = OmemoStorage(persistent_dict)
+        omemo_session = yield OmemoSession.create(client, omemo_storage, device_id)
+        client._xep_0384_session = omemo_session
+        client._xep_0384_device_id = device_id
+        yield omemo_session.newDeviceList(devices)
+        if omemo_session.state.changed:
+            log.info(_(u"Saving public bundle for this device ({device_id})").format(
+                device_id=device_id))
+            bundle = omemo_session.state.getPublicBundle()
+            yield self.setBundle(client, bundle, device_id)
+
+    ## XMPP PEP nodes manipulation
+
+    # devices
+
+    @defer.inlineCallbacks
+    def getDevices(self, client, entity_jid=None):
+        """Retrieve list of registered OMEMO devices
+
+        @param entity_jid(jid.JID, None): get devices from this entity
+            None to get our own devices
+        @return (set(int)): list of devices
+        """
+        if entity_jid is not None:
+            assert not entity_jid.resource
+        devices = set()
+        try:
+            items, metadata = yield self._p.getItems(client, entity_jid, NS_OMEMO_DEVICES)
+        except error.StanzaError as e:
+            if e.condition == 'item-not-found':
+                log.info(_(u"there is no node to handle OMEMO devices"))
+                defer.returnValue(devices)
+
+        if len(items) > 1:
+            log.warning(_(u"OMEMO devices list is stored in more that one items, "
+                          u"this is not expected"))
+        if items:
+            try:
+                list_elt = next(items[0].elements(NS_OMEMO, 'list'))
+            except StopIteration:
+                log.warning(_(u"no list element found in OMEMO devices list"))
+                return
+            for device_elt in list_elt.elements(NS_OMEMO, 'device'):
+                try:
+                    device_id = int(device_elt['id'])
+                except KeyError:
+                    log.warning(_(u'device element is missing "id" attribute: {elt}')
+                                .format(elt=device_elt.toXml()))
+                except ValueError:
+                    log.warning(_(u'invalid device id: {device_id}').format(
+                        device_id=device_elt['id']))
+                else:
+                    devices.add(device_id)
+        defer.returnValue(devices)
+
+    def setDevicesEb(self, failure_):
+        log.warning(_(u"Can't set devices: {reason}").format(reason=failure_))
+
+    def setDevices(self, client, devices):
+        list_elt = domish.Element((NS_OMEMO, 'list'))
+        for device in devices:
+            device_elt = list_elt.addElement('device')
+            device_elt['id'] = unicode(device)
+        d = self._p.sendItem(
+            client, None, NS_OMEMO_DEVICES, list_elt, item_id=self._p.ID_SINGLETON)
+        d.addErrback(self.setDevicesEb)
+        return d
+
+    # bundles
+
+    @defer.inlineCallbacks
+    def getBundles(self, client, entity_jid, devices_ids):
+        """Retrieve public bundles of an entity devices
+
+        @param entity_jid(jid.JID): bare jid of entity
+        @param devices_id(iterable[int]): ids of the devices bundles to retrieve
+        @return (dict[int, ExtendedPublicBundle]): bundles collection
+            key is device_id
+            value is parsed bundle
+        """
+        assert not entity_jid.resource
+        bundles = {}
+        for device_id in devices_ids:
+            node = NS_OMEMO_BUNDLE.format(device_id=device_id)
+            try:
+                items, metadata = yield self._p.getItems(client, entity_jid, node)
+            except Exception as e:
+                log.warning(_(u"Can't get bundle for device {device_id}: {reason}")
+                            .format(device_id=device_id, reason=e))
+                continue
+            if not items:
+                log.warning(_(u"no item found in node {node}, can't get public bundle "
+                              u"for device {device_id}").format(node=node,
+                                                                device_id=device_id))
+                continue
+            if len(items) > 1:
+                log.warning(_(u"more than one item found in {node},"
+                              u"this is not expected").format(node=node))
+            item = items[0]
+            try:
+                bundle_elt = next(item.elements(NS_OMEMO, 'bundle'))
+                signedPreKeyPublic_elt = next(bundle_elt.elements(
+                    NS_OMEMO, 'signedPreKeyPublic'))
+                signedPreKeySignature_elt = next(bundle_elt.elements(
+                    NS_OMEMO, 'signedPreKeySignature'))
+                identityKey_elt = next(bundle_elt.elements(
+                    NS_OMEMO, 'identityKey'))
+                prekeys_elt =  next(bundle_elt.elements(
+                    NS_OMEMO, 'prekeys'))
+            except StopIteration:
+                log.warning(_(u"invalid bundle for device {device_id}, ignoring").format(
+                    device_id=device_id))
+                continue
+
+            try:
+                spkPublic = base64.b64decode(unicode(signedPreKeyPublic_elt))
+                spkSignature = base64.b64decode(
+                    unicode(signedPreKeySignature_elt))
+
+                identityKey = base64.b64decode(unicode(identityKey_elt))
+                spk = {
+                    "key": wireformat.decodePublicKey(spkPublic),
+                    "id": int(signedPreKeyPublic_elt['signedPreKeyId'])
+                }
+                ik = wireformat.decodePublicKey(identityKey)
+                otpks = []
+                for preKeyPublic_elt in prekeys_elt.elements(NS_OMEMO, 'preKeyPublic'):
+                    preKeyPublic = base64.b64decode(unicode(preKeyPublic_elt))
+                    otpk = {
+                        "key": wireformat.decodePublicKey(preKeyPublic),
+                        "id": int(preKeyPublic_elt['preKeyId'])
+                    }
+                    otpks.append(otpk)
+
+            except Exception as e:
+                log.warning(_(u"error while decoding key for device {devide_id}: {msg}")
+                            .format(device_id=device_id, msg=e))
+                continue
+
+            bundles[device_id] = ExtendedPublicBundle(ik, spk, spkSignature, otpks)
+
+        defer.returnValue(bundles)
+
+    def setBundleEb(self, failure_):
+        log.warning(_(u"Can't set bundle: {reason}").format(reason=failure_))
+
+    def setBundle(self, client, bundle, device_id):
+        """Set public bundle for this device.
+
+        @param bundle(ExtendedPublicBundle): bundle to publish
+        """
+        log.debug(_(u"updating bundle for {device_id}").format(device_id=device_id))
+        bundle_elt = domish.Element((NS_OMEMO, 'bundle'))
+        signedPreKeyPublic_elt = bundle_elt.addElement(
+            "signedPreKeyPublic",
+            content=b64enc(wireformat.encodePublicKey(bundle.spk['key'])))
+        signedPreKeyPublic_elt['signedPreKeyId'] = unicode(bundle.spk['id'])
+
+        bundle_elt.addElement(
+            "signedPreKeySignature",
+            content=b64enc(bundle.spk_signature))
+
+        bundle_elt.addElement(
+            "identityKey",
+            content=b64enc(wireformat.encodePublicKey(bundle.ik)))
+
+        prekeys_elt = bundle_elt.addElement('prekeys')
+        for otpk in bundle.otpks:
+            preKeyPublic_elt = prekeys_elt.addElement(
+                'preKeyPublic',
+                content=b64enc(wireformat.encodePublicKey(otpk["key"])))
+            preKeyPublic_elt['preKeyId'] = unicode(otpk['id'])
+
+        node = NS_OMEMO_BUNDLE.format(device_id=device_id)
+        d = self._p.sendItem(client, None, node, bundle_elt, item_id=self._p.ID_SINGLETON)
+        d.addErrback(self.setBundleEb)
+        return d
+
+    ## triggers
+
+    @defer.inlineCallbacks
+    def encryptMessage(self, client, entity_bare_jid, message):
+        omemo_session = client._xep_0384_session
+        devices = yield self.getDevices(client, entity_bare_jid)
+        omemo_session.newDeviceList(devices, entity_bare_jid)
+        bundles = yield self.getBundles(client, entity_bare_jid, devices)
+        encrypted = yield omemo_session.encryptMessage(
+            entity_bare_jid,
+            message,
+            {entity_bare_jid: bundles})
+        defer.returnValue(encrypted)
+
+    @defer.inlineCallbacks
+    def _messageReceivedTrigger(self, client, message_elt, post_treat):
+        if message_elt.getAttribute("type") == C.MESS_TYPE_GROUPCHAT:
+            defer.returnValue(True)
+        try:
+            encrypted_elt = next(message_elt.elements(NS_OMEMO, u"encrypted"))
+        except StopIteration:
+            # no OMEMO message here
+            defer.returnValue(True)
+
+        # we have an encrypted message let's decrypt it
+        # from_jid = jid.JID(message_elt['from'])
+        omemo_session = client._xep_0384_session
+        device_id = client._xep_0384_device_id
+        try:
+            header_elt = next(encrypted_elt.elements(NS_OMEMO, u'header'))
+            iv_elt = next(header_elt.elements(NS_OMEMO, u'iv'))
+        except StopIteration:
+            log.warning(_(u"Invalid OMEMO encrypted stanza, ignoring: {xml}")
+                .format(xml=message_elt.toXml()))
+            defer.returnValue(False)
+        try:
+            key_elt = next((e for e in header_elt.elements(NS_OMEMO, u'key')
+                            if int(e[u'rid']) == device_id))
+        except StopIteration:
+            log.warning(_(u"This OMEMO encrypted stanza has not been encrypted"
+                          u"for our device ({device_id}): {xml}").format(
+                          device_id=device_id, xml=encrypted_elt.toXml()))
+            defer.returnValue(False)
+        except ValueError as e:
+            log.warning(_(u"Invalid recipient ID: {msg}".format(msg=e)))
+            defer.returnValue(False)
+        is_pre_key = C.bool(key_elt.getAttribute('prekey', 'false'))
+        payload_elt = next(encrypted_elt.elements(NS_OMEMO, u'payload'), None)
+
+        try:
+            cipher, plaintext = yield omemo_session.decryptMessage(
+                bare_jid=client.jid.userhostJID(),
+                device=device_id,
+                iv=base64.b64decode(bytes(iv_elt)),
+                message=base64.b64decode(bytes(key_elt)),
+                is_pre_key_message=is_pre_key,
+                payload=base64.b64decode(bytes(payload_elt))
+                    if payload_elt is not None else None,
+                from_storage=False
+            )
+        except Exception as e:
+            log.error(_(u"Can't decrypt message: {reason}\n{xml}").format(
+                reason=e, xml=message_elt.toXml()))
+            defer.returnValue(False)
+        if omemo_session.state.changed:
+            bundle = omemo_session.state.getPublicBundle()
+            # we don't wait for the Deferred (i.e. no yield) on purpose
+            # there is no need to block the whole message workflow while
+            # updating the bundle
+            self.setBundle(client, bundle, device_id)
+
+        message_elt.children.remove(encrypted_elt)
+        if plaintext:
+            message_elt.addElement("body", content=plaintext.decode('utf-8'))
+        defer.returnValue(True)
+
+    @defer.inlineCallbacks
+    def _sendMessageDataTrigger(self, client, mess_data):
+        encryption = mess_data.get(C.MESS_KEY_ENCRYPTION)
+        if encryption is None or encryption['plugin'].namespace != NS_OMEMO:
+            return
+        message_elt = mess_data["xml"]
+        to_jid = mess_data["to"].userhostJID()
+        log.debug(u"encrypting message")
+        body = None
+        for child in list(message_elt.children):
+            if child.name == "body":
+                # we remove all unencrypted body,
+                # and will only encrypt the first one
+                if body is None:
+                    body = child
+                message_elt.children.remove(child)
+            elif child.name == "html":
+                # we don't want any XHTML-IM element
+                message_elt.children.remove(child)
+
+        if body is None:
+            log.warning(u"No message found")
+            return
+
+        encryption_data = yield self.encryptMessage(client, to_jid, unicode(body))
+
+        encrypted_elt = message_elt.addElement((NS_OMEMO, 'encrypted'))
+        header_elt = encrypted_elt.addElement('header')
+        header_elt['sid'] = unicode(encryption_data['sid'])
+        bare_jid_s = to_jid.userhost()
+
+        for message in (m for m in encryption_data['messages']
+                        if m['bare_jid'] == bare_jid_s):
+            key_elt = header_elt.addElement(
+                'key',
+                content=b64enc(message['message']))
+            key_elt['rid'] = unicode(message['rid'])
+            if message['pre_key']:
+                key_elt['prekey'] = 'true'
+
+        header_elt.addElement(
+            'iv',
+            content=b64enc(encryption_data['iv']))
+        try:
+            encrypted_elt.addElement(
+                'payload',
+                content=b64enc(encryption_data['payload']))
+        except KeyError:
+            pass