view libervia/backend/plugins/plugin_xep_0448.py @ 4232:0fbe5c605eb6

tests (unit/webrtc,XEP-0176, XEP-0234): Fix tests and add webrtc file transfer tests: fix 441
author Goffi <goffi@goffi.org>
date Sat, 06 Apr 2024 12:59:50 +0200
parents 4b842c1fb686
children 0d7bb4df2343
line wrap: on
line source

#!/usr/bin/env python3

# Libervia plugin for handling stateless file sharing encryption
# Copyright (C) 2009-2022 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import base64
from functools import partial
from pathlib import Path
import secrets
from textwrap import dedent
from typing import Any, Dict, Optional, Tuple, Union

from cryptography.exceptions import AlreadyFinalized
from cryptography.hazmat import backends
from cryptography.hazmat.primitives import ciphers
from cryptography.hazmat.primitives.ciphers import CipherContext, modes
from cryptography.hazmat.primitives.padding import PKCS7, PaddingContext
import treq
from twisted.internet import defer
from twisted.words.protocols.jabber.xmlstream import XMPPHandler
from twisted.words.xish import domish
from wokkel import disco, iwokkel
from zope.interface import implementer

from libervia.backend.core import exceptions
from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import SatXMPPEntity
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.tools import stream
from libervia.backend.tools.web import treq_client_no_ssl

log = getLogger(__name__)

IMPORT_NAME = "XEP-0448"

PLUGIN_INFO = {
    C.PI_NAME: "Encryption for Stateless File Sharing",
    C.PI_IMPORT_NAME: IMPORT_NAME,
    C.PI_TYPE: C.PLUG_TYPE_EXP,
    C.PI_PROTOCOLS: ["XEP-0448"],
    C.PI_DEPENDENCIES: [
        "XEP-0103", "XEP-0300", "XEP-0334", "XEP-0363", "XEP-0384", "XEP-0447",
        "DOWNLOAD", "ATTACH"
    ],
    C.PI_MAIN: "XEP_0448",
    C.PI_HANDLER: "yes",
    C.PI_DESCRIPTION: dedent(_("""\
    Implementation of e2e encryption for media sharing
    """)),
}

NS_ESFS = "urn:xmpp:esfs:0"
NS_AES_128_GCM = "urn:xmpp:ciphers:aes-128-gcm-nopadding:0"
NS_AES_256_GCM = "urn:xmpp:ciphers:aes-256-gcm-nopadding:0"
NS_AES_256_CBC = "urn:xmpp:ciphers:aes-256-cbc-pkcs7:0"


class XEP_0448:

    def __init__(self, host):
        self.host = host
        log.info(_("XEP_0448 plugin initialization"))
        host.register_namespace("esfs", NS_ESFS)
        self._u = host.plugins["XEP-0103"]
        self._h = host.plugins["XEP-0300"]
        self._hints = host.plugins["XEP-0334"]
        self._http_upload = host.plugins["XEP-0363"]
        self._o = host.plugins["XEP-0384"]
        self._sfs = host.plugins["XEP-0447"]
        self._sfs.register_source_handler(
            NS_ESFS, "encrypted", self.parse_encrypted_elt, encrypted=True
        )
        self._attach = host.plugins["ATTACH"]
        self._attach.register(
            self.can_handle_attachment, self.attach, encrypted=True, priority=1000
        )
        host.plugins["DOWNLOAD"].register_download_handler(NS_ESFS, self.download)
        host.trigger.add("XEP-0363_upload_pre_slot", self._upload_pre_slot)
        host.trigger.add("XEP-0363_upload", self._upload_trigger)

    def get_handler(self, client):
        return XEP0448Handler()

    def parse_encrypted_elt(self, encrypted_elt: domish.Element) -> Dict[str, Any]:
        """Parse an <encrypted> element and return corresponding source data

        @param encrypted_elt: element to parse
        @raise exceptions.DataError: the element is invalid

        """
        sources = self._sfs.parse_sources_elt(encrypted_elt)
        if not sources:
            raise exceptions.NotFound("sources are missing in {encrypted_elt.toXml()}")
        if len(sources) > 1:
            log.debug(
                "more that one sources has been found, this is not expected, only the "
                "first one will be used"
            )
        source = sources[0]
        source["type"] = NS_ESFS
        try:
            encrypted_data = source["encrypted_data"] = {
                "cipher": encrypted_elt["cipher"],
                "key": str(next(encrypted_elt.elements(NS_ESFS, "key"))),
                "iv": str(next(encrypted_elt.elements(NS_ESFS, "iv"))),
            }
        except (KeyError, StopIteration):
            raise exceptions.DataError(
                "invalid <encrypted/> element: {encrypted_elt.toXml()}"
            )
        try:
            hash_algo, hash_value = self._h.parse_hash_elt(encrypted_elt)
        except exceptions.NotFound:
            pass
        else:
            encrypted_data["hash_algo"] = hash_algo
            encrypted_data["hash"] = base64.b64encode(hash_value.encode()).decode()
        return source

    async def download(
        self,
        client: SatXMPPEntity,
        attachment: Dict[str, Any],
        source: Dict[str, Any],
        dest_path: Union[Path, str],
        extra: Optional[Dict[str, Any]] = None
    ) -> Tuple[str, defer.Deferred]:
        # TODO: check hash
        if extra is None:
            extra = {}
        try:
            encrypted_data = source["encrypted_data"]
            cipher = encrypted_data["cipher"]
            iv = base64.b64decode(encrypted_data["iv"])
            key = base64.b64decode(encrypted_data["key"])
        except KeyError as e:
            raise ValueError(f"{source} has incomplete encryption data: {e}")
        try:
            download_url = source["url"]
        except KeyError:
            raise ValueError(f"{source} has missing URL")

        if extra.get('ignore_tls_errors', False):
            log.warning(
                "TLS certificate check disabled, this is highly insecure"
            )
            treq_client = treq_client_no_ssl
        else:
            treq_client = treq

        try:
            file_size = int(attachment["size"])
        except (KeyError, ValueError):
            head_data = await treq_client.head(download_url)
            content_length = int(head_data.headers.getRawHeaders('content-length')[0])
            # the 128 bits tag is put at the end
            file_size = content_length - 16

        file_obj = stream.SatFile(
            self.host,
            client,
            dest_path,
            mode="wb",
            size = file_size,
        )

        if cipher in (NS_AES_128_GCM, NS_AES_256_GCM):
            decryptor = ciphers.Cipher(
                ciphers.algorithms.AES(key),
                modes.GCM(iv),
                backend=backends.default_backend(),
            ).decryptor()
            decrypt_cb = partial(
                self.gcm_decrypt,
                client=client,
                file_obj=file_obj,
                decryptor=decryptor,
            )
            finalize_cb = None
        elif cipher == NS_AES_256_CBC:
            cipher_algo = ciphers.algorithms.AES(key)
            decryptor = ciphers.Cipher(
                cipher_algo,
                modes.CBC(iv),
                backend=backends.default_backend(),
            ).decryptor()
            unpadder = PKCS7(cipher_algo.block_size).unpadder()
            decrypt_cb = partial(
                self.cbc_decrypt,
                client=client,
                file_obj=file_obj,
                decryptor=decryptor,
                unpadder=unpadder
            )
            finalize_cb = partial(
                self.cbc_decrypt_finalize,
                file_obj=file_obj,
                decryptor=decryptor,
                unpadder=unpadder
            )
        else:
            msg = f"cipher {cipher!r} is not supported"
            file_obj.close(error=msg)
            log.warning(msg)
            raise exceptions.CancelError(msg)

        progress_id = file_obj.uid

        resp = await treq_client.get(download_url, unbuffered=True)
        if resp.code == 200:
            d = treq.collect(resp, partial(decrypt_cb))
            if finalize_cb is not None:
                d.addCallback(lambda __: finalize_cb())
        else:
            d = defer.Deferred()
            self.host.plugins["DOWNLOAD"].errback_download(file_obj, d, resp)
        return progress_id, d

    async def can_handle_attachment(self, client, data):
        # FIXME: check if SCE is supported without checking which e2ee algo is used
        if client.encryption.get_namespace(data["to"]) != self._o.NS_TWOMEMO:
            # we need SCE, and it is currently supported only by TWOMEMO, thus we can't
            # handle the attachment if it's not activated
            return False
        try:
            await self._http_upload.get_http_upload_entity(client)
        except exceptions.NotFound:
            return False
        else:
            return True

    async def _upload_cb(self, client, filepath, filename, extra):
        attachment = extra["attachment"]
        extra["encryption"] = IMPORT_NAME
        attachment["encryption_data"] = extra["encryption_data"] = {
            "algorithm": C.ENC_AES_GCM,
            "iv": secrets.token_bytes(12),
            "key": secrets.token_bytes(32),
        }
        attachment["filename"] = filename
        return await self._http_upload.file_http_upload(
            client=client,
            filepath=filepath,
            filename="encrypted",
            extra=extra
        )

    async def attach(self, client, data):
        # XXX: for now, XEP-0447/XEP-0448 only allow to send one file per <message/>, thus
        #   we need to send each file in a separate message, in the same way as for
        #   plugin_sec_aesgcm.
        attachments = data["extra"][C.KEY_ATTACHMENTS]
        if not data['message'] or data['message'] == {'': ''}:
            extra_attachments = attachments[1:]
            del attachments[1:]
        else:
            # we have a message, we must send first attachment separately
            extra_attachments = attachments[:]
            attachments.clear()
            del data["extra"][C.KEY_ATTACHMENTS]

        if attachments:
            if len(attachments) > 1:
                raise exceptions.InternalError(
                    "There should not be more that one attachment at this point"
                )
            await self._attach.upload_files(client, data, upload_cb=self._upload_cb)
            self._hints.add_hint_elements(data["xml"], [self._hints.HINT_STORE])
            for attachment in attachments:
                encryption_data = attachment.pop("encryption_data")
                file_hash = (attachment["hash_algo"], attachment["hash"])
                file_sharing_elt = self._sfs.get_file_sharing_elt(
                    [],
                    name=attachment["filename"],
                    size=attachment["size"],
                    file_hash=file_hash
                )
                encrypted_elt = file_sharing_elt.sources.addElement(
                    (NS_ESFS, "encrypted")
                )
                encrypted_elt["cipher"] = NS_AES_256_GCM
                encrypted_elt.addElement(
                    "key",
                    content=base64.b64encode(encryption_data["key"]).decode()
                )
                encrypted_elt.addElement(
                    "iv",
                    content=base64.b64encode(encryption_data["iv"]).decode()
                )
                encrypted_elt.addChild(self._h.build_hash_elt(
                    attachment["encrypted_hash"],
                    attachment["encrypted_hash_algo"]
                ))
                encrypted_elt.addChild(
                    self._sfs.get_sources_elt(
                        [self._u.get_url_data_elt(attachment["url"])]
                    )
                )
                data["xml"].addChild(file_sharing_elt)

        for attachment in extra_attachments:
            # we send all remaining attachment in a separate message
            await client.sendMessage(
                to_jid=data['to'],
                message={'': ''},
                subject=data['subject'],
                mess_type=data['type'],
                extra={C.KEY_ATTACHMENTS: [attachment]},
            )

        if ((not data['extra']
             and (not data['message'] or data['message'] == {'': ''})
             and not data['subject'])):
            # nothing left to send, we can cancel the message
            raise exceptions.CancelError("Cancelled by XEP_0448 attachment handling")

    def gcm_decrypt(
        self,
        data: bytes,
        client: SatXMPPEntity,
        file_obj: stream.SatFile,
        decryptor: CipherContext
    ) -> None:
        if file_obj.tell() + len(data) > file_obj.size:  # type: ignore
            # we're reaching end of file with this bunch of data
            # we may still have a last bunch if the tag is incomplete
            bytes_left = file_obj.size - file_obj.tell()  # type: ignore
            if bytes_left > 0:
                decrypted = decryptor.update(data[:bytes_left])
                file_obj.write(decrypted)
                tag = data[bytes_left:]
            else:
                tag = data
            if len(tag) < 16:
                # the tag is incomplete, either we'll get the rest in next data bunch
                # or we have already the other part from last bunch of data
                try:
                    # we store partial tag in decryptor._sat_tag
                    tag = decryptor._sat_tag + tag
                except AttributeError:
                    # no other part, we'll get the rest at next bunch
                    decryptor.sat_tag = tag
                else:
                    # we have the complete tag, it must be 128 bits
                    if len(tag) != 16:
                        raise ValueError(f"Invalid tag: {tag}")
            remain = decryptor.finalize_with_tag(tag)
            file_obj.write(remain)
            file_obj.close()
        else:
            decrypted = decryptor.update(data)
            file_obj.write(decrypted)

    def cbc_decrypt(
        self,
        data: bytes,
        client: SatXMPPEntity,
        file_obj: stream.SatFile,
        decryptor: CipherContext,
        unpadder: PaddingContext
    ) -> None:
        decrypted = decryptor.update(data)
        file_obj.write(unpadder.update(decrypted))

    def cbc_decrypt_finalize(
        self,
        file_obj: stream.SatFile,
        decryptor: CipherContext,
        unpadder: PaddingContext
    ) -> None:
        decrypted = decryptor.finalize()
        file_obj.write(unpadder.update(decrypted))
        file_obj.write(unpadder.finalize())
        file_obj.close()

    def _upload_pre_slot(self, client, extra, file_metadata):
        if extra.get('encryption') != IMPORT_NAME:
            return True
        # the tag is appended to the file
        file_metadata["size"] += 16
        return True

    def _encrypt(self, data: bytes, encryptor: CipherContext, attachment: dict) -> bytes:
        if data:
            attachment["hasher"].update(data)
            ret = encryptor.update(data)
            attachment["encrypted_hasher"].update(ret)
            return ret
        else:
            try:
                # end of file is reached, me must finalize
                fin = encryptor.finalize()
                tag = encryptor.tag
                ret = fin + tag
                hasher = attachment.pop("hasher")
                attachment["hash"] = hasher.hexdigest()
                encrypted_hasher = attachment.pop("encrypted_hasher")
                encrypted_hasher.update(ret)
                attachment["encrypted_hash"] = encrypted_hasher.hexdigest()
                return ret
            except AlreadyFinalized:
                # as we have already finalized, we can now send EOF
                return b''

    def _upload_trigger(self, client, extra, sat_file, file_producer, slot):
        if extra.get('encryption') != IMPORT_NAME:
            return True
        attachment = extra["attachment"]
        encryption_data = extra["encryption_data"]
        log.debug("encrypting file with AES-GCM")
        iv = encryption_data["iv"]
        key = encryption_data["key"]

        # encrypted data size will be bigger than original file size
        # so we need to check with final data length to avoid a warning on close()
        sat_file.check_size_with_read = True

        # file_producer get length directly from file, and this cause trouble as
        # we have to change the size because of encryption. So we adapt it here,
        # else the producer would stop reading prematurely
        file_producer.length = sat_file.size

        encryptor = ciphers.Cipher(
            ciphers.algorithms.AES(key),
            modes.GCM(iv),
            backend=backends.default_backend(),
        ).encryptor()

        if sat_file.data_cb is not None:
            raise exceptions.InternalError(
                f"data_cb was expected to be None, it is set to {sat_file.data_cb}")

        attachment.update({
            "hash_algo": self._h.ALGO_DEFAULT,
            "hasher": self._h.get_hasher(),
            "encrypted_hash_algo": self._h.ALGO_DEFAULT,
            "encrypted_hasher": self._h.get_hasher(),
        })

        # with data_cb we encrypt the file on the fly
        sat_file.data_cb = partial(
            self._encrypt, encryptor=encryptor, attachment=attachment
        )
        return True


@implementer(iwokkel.IDisco)
class XEP0448Handler(XMPPHandler):

    def getDiscoInfo(self, requestor, target, nodeIdentifier=""):
        return [disco.DiscoFeature(NS_ESFS)]

    def getDiscoItems(self, requestor, target, nodeIdentifier=""):
        return []