view libervia/backend/plugins/plugin_xep_0448.py @ 4342:17fa953c8cd7

core (types): improve `SatXMPPEntity` core type and type hints.
author Goffi <goffi@goffi.org>
date Mon, 13 Jan 2025 01:23:10 +0100
parents 111dce64dcb5
children
line wrap: on
line source

#!/usr/bin/env python3

# Libervia plugin for handling stateless file sharing encryption
# Copyright (C) 2009-2022 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import base64
from functools import partial
from pathlib import Path
import secrets
from textwrap import dedent
from typing import Any, Dict, Optional, Self, Tuple, Union, cast

from cryptography.exceptions import AlreadyFinalized
from cryptography.hazmat import backends
from cryptography.hazmat.primitives import ciphers
from cryptography.hazmat.primitives.ciphers import CipherContext, modes
from cryptography.hazmat.primitives.padding import PKCS7, PaddingContext
from pydantic import BaseModel, ValidationError
import treq
from twisted.internet import defer
from twisted.words.protocols.jabber.xmlstream import XMPPHandler
from twisted.words.xish import domish
from wokkel import disco, iwokkel
from zope.interface import implementer

from libervia.backend.core import exceptions
from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import SatXMPPEntity
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.plugins.plugin_misc_download import DownloadPlugin
from libervia.backend.plugins.plugin_xep_0103 import XEP_0103
from libervia.backend.plugins.plugin_xep_0300 import NS_HASHES, XEP_0300, Hash
from libervia.backend.plugins.plugin_xep_0447 import XEP_0447, Source
from libervia.backend.tools import stream
from libervia.backend.tools.web import treq_client_no_ssl

log = getLogger(__name__)

IMPORT_NAME = "XEP-0448"

PLUGIN_INFO = {
    C.PI_NAME: "Encryption for Stateless File Sharing",
    C.PI_IMPORT_NAME: IMPORT_NAME,
    C.PI_TYPE: C.PLUG_TYPE_EXP,
    C.PI_PROTOCOLS: ["XEP-0448"],
    C.PI_DEPENDENCIES: [
        "XEP-0103",
        "XEP-0300",
        "XEP-0334",
        "XEP-0363",
        "XEP-0384",
        "XEP-0447",
        "DOWNLOAD",
        "ATTACH",
    ],
    C.PI_MAIN: "XEP_0448",
    C.PI_HANDLER: "yes",
    C.PI_DESCRIPTION: dedent(
        _(
            """\
    Implementation of e2e encryption for media sharing
    """
        )
    ),
}

NS_ESFS = "urn:xmpp:esfs:0"
NS_AES_128_GCM = "urn:xmpp:ciphers:aes-128-gcm-nopadding:0"
NS_AES_256_GCM = "urn:xmpp:ciphers:aes-256-gcm-nopadding:0"
NS_AES_256_CBC = "urn:xmpp:ciphers:aes-256-cbc-pkcs7:0"


class EncryptedSource(Source):
    type = "encrypted"
    encrypted = True
    cipher: str
    key: str
    iv: str
    hashes: list[Hash]
    sources: list[Source]
    _hash: XEP_0300 | None = None
    _sfs: XEP_0447 | None = None

    @classmethod
    def from_element(cls, element: domish.Element) -> Self:
        """Parse an <encrypted> element and return corresponding EncryptedData model

        @param encrypted_elt: element to parse
        @raise exceptions.DataError: the element is invalid

        """
        assert cls._hash is not None, "_hash attribute is not set"
        assert cls._sfs is not None, "_sfs attribute is not set"
        try:
            cipher = element["cipher"]
            key = str(next(element.elements(NS_ESFS, "key")))
            iv = str(next(element.elements(NS_ESFS, "iv")))
        except (KeyError, StopIteration):
            raise exceptions.DataError(
                "invalid <encrypted/> element: {encrypted_elt.toXml()}"
            )
        sources = cls._sfs.parse_sources_elt(element)
        if not sources:
            raise exceptions.DataError(f"Sources are missing in {element.toXml()}")

        if any(isinstance(source, cls) for source in sources):
            raise exceptions.DataError(
                f"EncryptedData is used as a source of another EncryptedData"
            )

        encrypted_data = {
            "cipher": cipher,
            "key": key,
            "iv": iv,
            "hashes": Hash.from_parent(element),
            "sources": sources,
        }

        return cls(**encrypted_data)

    def to_element(self) -> domish.Element:
        """Convert EncryptedData model to an <encrypted> element

        @return: domish.Element representing the encrypted data

        """
        assert self._hash is not None, "_hash attribute is not set"
        encrypted_elt = domish.Element((NS_ESFS, "encrypted"))
        encrypted_elt["cipher"] = self.cipher
        encrypted_elt.addElement("key").addContent(self.key)
        encrypted_elt.addElement("iv").addContent(self.iv)
        for hash_ in self.hashes:
            encrypted_elt.addChild(hash_.to_element())

        return encrypted_elt


class XEP_0448:

    def __init__(self, host):
        self.host = host
        log.info(_("XEP_0448 plugin initialization"))
        host.register_namespace("esfs", NS_ESFS)
        self._u = cast(XEP_0103, host.plugins["XEP-0103"])
        self._h = cast(XEP_0300, host.plugins["XEP-0300"])
        self._hints = host.plugins["XEP-0334"]
        self._http_upload = host.plugins["XEP-0363"]
        self._o = host.plugins["XEP-0384"]
        self._sfs = cast(XEP_0447, host.plugins["XEP-0447"])
        self._sfs.register_source(NS_ESFS, "encrypted", EncryptedSource)
        self._attach = host.plugins["ATTACH"]
        self._attach.register(
            self.can_handle_attachment, self.attach, encrypted=True, priority=1000
        )
        EncryptedSource._hash = self._h
        EncryptedSource._sfs = self._sfs
        download = cast(DownloadPlugin, host.plugins["DOWNLOAD"])
        download.register_download_handler(NS_ESFS, self.download)
        host.trigger.add("XEP-0363_upload_pre_slot", self._upload_pre_slot)
        host.trigger.add("XEP-0363_upload", self._upload_trigger)

    def get_handler(self, client):
        return XEP0448Handler()

    async def download(
        self,
        client: SatXMPPEntity,
        attachment: dict[str, Any],
        source: dict[str, Any],
        dest_path: Union[Path, str],
        extra: dict[str, Any] | None = None,
    ) -> tuple[str, defer.Deferred]:
        # TODO: check hash
        if extra is None:
            extra = {}
        assert source["type"] == "encrypted"
        try:
            cipher = source["cipher"]
            iv = base64.b64decode(source["iv"])
            key = base64.b64decode(source["key"])
        except KeyError as e:
            raise ValueError(f"{source} has incomplete encryption data: {e}") from e

        try:
            download_url = source["sources"][0]["url"]
        except (IndexError, KeyError) as e:
            raise ValueError(f"{source} has missing URL") from e

        if extra.get("ignore_tls_errors", False):
            log.warning("TLS certificate check disabled, this is highly insecure")
            treq_client = treq_client_no_ssl
        else:
            treq_client = treq

        try:
            file_size = int(attachment["size"])
        except (KeyError, ValueError):
            head_data = await treq_client.head(download_url)
            content_length = int(head_data.headers.getRawHeaders("content-length")[0])
            # the 128 bits tag is put at the end
            file_size = content_length - 16

        file_obj = stream.SatFile(
            self.host,
            client,
            dest_path,
            mode="wb",
            size=file_size,
        )

        if cipher in (NS_AES_128_GCM, NS_AES_256_GCM):
            decryptor = ciphers.Cipher(
                ciphers.algorithms.AES(key),
                modes.GCM(iv),
                backend=backends.default_backend(),
            ).decryptor()
            decrypt_cb = partial(
                self.gcm_decrypt,
                client=client,
                file_obj=file_obj,
                decryptor=decryptor,
            )
            finalize_cb = None
        elif cipher == NS_AES_256_CBC:
            cipher_algo = ciphers.algorithms.AES(key)
            decryptor = ciphers.Cipher(
                cipher_algo,
                modes.CBC(iv),
                backend=backends.default_backend(),
            ).decryptor()
            unpadder = PKCS7(cipher_algo.block_size).unpadder()
            decrypt_cb = partial(
                self.cbc_decrypt,
                client=client,
                file_obj=file_obj,
                decryptor=decryptor,
                unpadder=unpadder,
            )
            finalize_cb = partial(
                self.cbc_decrypt_finalize,
                file_obj=file_obj,
                decryptor=decryptor,
                unpadder=unpadder,
            )
        else:
            msg = f"cipher {cipher!r} is not supported"
            file_obj.close(error=msg)
            log.warning(msg)
            raise exceptions.CancelError(msg)

        progress_id = file_obj.uid

        resp = await treq_client.get(download_url, unbuffered=True)
        if resp.code == 200:
            d = treq.collect(resp, partial(decrypt_cb))
            if finalize_cb is not None:
                d.addCallback(lambda __: finalize_cb())
        else:
            d = defer.Deferred()
            self.host.plugins["DOWNLOAD"].errback_download(file_obj, d, resp)
        return progress_id, d

    async def can_handle_attachment(self, client, data):
        # FIXME: check if SCE is supported without checking which e2ee algo is used
        if client.encryption.get_namespace(data["to"]) != self._o.NS_TWOMEMO:
            # we need SCE, and it is currently supported only by TWOMEMO, thus we can't
            # handle the attachment if it's not activated
            return False
        try:
            await self._http_upload.get_http_upload_entity(client)
        except exceptions.NotFound:
            return False
        else:
            return True

    async def _upload_cb(self, client, filepath, filename, extra):
        attachment = extra["attachment"]
        extra["encryption"] = IMPORT_NAME
        attachment["encryption_data"] = extra["encryption_data"] = {
            "algorithm": C.ENC_AES_GCM,
            "iv": secrets.token_bytes(12),
            "key": secrets.token_bytes(32),
        }
        attachment["filename"] = filename
        return await self._http_upload.file_http_upload(
            client=client, filepath=filepath, filename="encrypted", extra=extra
        )

    async def attach(self, client, data):
        # XXX: for now, XEP-0447/XEP-0448 only allow to send one file per <message/>, thus
        #   we need to send each file in a separate message, in the same way as for
        #   plugin_sec_aesgcm.
        attachments = data["extra"][C.KEY_ATTACHMENTS]
        if not data["message"] or data["message"] == {"": ""}:
            extra_attachments = attachments[1:]
            del attachments[1:]
        else:
            # we have a message, we must send first attachment separately
            extra_attachments = attachments[:]
            attachments.clear()
            del data["extra"][C.KEY_ATTACHMENTS]

        if attachments:
            if len(attachments) > 1:
                raise exceptions.InternalError(
                    "There should not be more that one attachment at this point"
                )
            await self._attach.upload_files(client, data, upload_cb=self._upload_cb)
            self._hints.add_hint_elements(data["xml"], [self._hints.HINT_STORE])
            for attachment in attachments:
                encryption_data = attachment.pop("encryption_data")
                file_hash = (attachment["hash_algo"], attachment["hash"])
                file_sharing_elt = self._sfs.get_file_sharing_elt(
                    [],
                    name=attachment["filename"],
                    size=attachment["size"],
                    file_hash=file_hash,
                )
                sources_elt = file_sharing_elt.sources
                assert sources_elt is not None
                encrypted_elt = sources_elt.addElement((NS_ESFS, "encrypted"))
                encrypted_elt["cipher"] = NS_AES_256_GCM
                encrypted_elt.addElement(
                    "key", content=base64.b64encode(encryption_data["key"]).decode()
                )
                encrypted_elt.addElement(
                    "iv", content=base64.b64encode(encryption_data["iv"]).decode()
                )
                encrypted_elt.addChild(
                    self._h.build_hash_elt(
                        attachment["encrypted_hash"], attachment["encrypted_hash_algo"]
                    )
                )
                encrypted_elt.addChild(
                    self._sfs.get_sources_elt(
                        [self._u.generate_url_data(attachment["url"]).to_element()]
                    )
                )
                data["xml"].addChild(file_sharing_elt)

        for attachment in extra_attachments:
            # we send all remaining attachment in a separate message
            await client.sendMessage(
                to_jid=data["to"],
                message={"": ""},
                subject=data["subject"],
                mess_type=data["type"],
                extra={C.KEY_ATTACHMENTS: [attachment]},
            )

        if (
            not data["extra"]
            and (not data["message"] or data["message"] == {"": ""})
            and not data["subject"]
        ):
            # nothing left to send, we can cancel the message
            raise exceptions.CancelError("Cancelled by XEP_0448 attachment handling")

    def gcm_decrypt(
        self,
        data: bytes,
        client: SatXMPPEntity,
        file_obj: stream.SatFile,
        decryptor: CipherContext,
    ) -> None:
        if file_obj.tell() + len(data) > file_obj.size:  # type: ignore
            # we're reaching end of file with this bunch of data
            # we may still have a last bunch if the tag is incomplete
            bytes_left = file_obj.size - file_obj.tell()  # type: ignore
            if bytes_left > 0:
                decrypted = decryptor.update(data[:bytes_left])
                file_obj.write(decrypted)
                tag = data[bytes_left:]
            else:
                tag = data
            if len(tag) < 16:
                # the tag is incomplete, either we'll get the rest in next data bunch
                # or we have already the other part from last bunch of data
                try:
                    # we store partial tag in decryptor._sat_tag
                    tag = decryptor._sat_tag + tag
                except AttributeError:
                    # no other part, we'll get the rest at next bunch
                    decryptor.sat_tag = tag
                else:
                    # we have the complete tag, it must be 128 bits
                    if len(tag) != 16:
                        raise ValueError(f"Invalid tag: {tag}")
            remain = decryptor.finalize_with_tag(tag)
            file_obj.write(remain)
            file_obj.close()
        else:
            decrypted = decryptor.update(data)
            file_obj.write(decrypted)

    def cbc_decrypt(
        self,
        data: bytes,
        client: SatXMPPEntity,
        file_obj: stream.SatFile,
        decryptor: CipherContext,
        unpadder: PaddingContext,
    ) -> None:
        decrypted = decryptor.update(data)
        file_obj.write(unpadder.update(decrypted))

    def cbc_decrypt_finalize(
        self, file_obj: stream.SatFile, decryptor: CipherContext, unpadder: PaddingContext
    ) -> None:
        decrypted = decryptor.finalize()
        file_obj.write(unpadder.update(decrypted))
        file_obj.write(unpadder.finalize())
        file_obj.close()

    def _upload_pre_slot(self, client, extra, file_metadata):
        if extra.get("encryption") != IMPORT_NAME:
            return True
        # the tag is appended to the file
        file_metadata["size"] += 16
        return True

    def _encrypt(self, data: bytes, encryptor: CipherContext, attachment: dict) -> bytes:
        if data:
            attachment["hasher"].update(data)
            ret = encryptor.update(data)
            attachment["encrypted_hasher"].update(ret)
            return ret
        else:
            try:
                # end of file is reached, me must finalize
                fin = encryptor.finalize()
                tag = encryptor.tag
                ret = fin + tag
                hasher = attachment.pop("hasher")
                attachment["hash"] = hasher.hexdigest()
                encrypted_hasher = attachment.pop("encrypted_hasher")
                encrypted_hasher.update(ret)
                attachment["encrypted_hash"] = encrypted_hasher.hexdigest()
                return ret
            except AlreadyFinalized:
                # as we have already finalized, we can now send EOF
                return b""

    def _upload_trigger(self, client, extra, sat_file, file_producer, slot):
        if extra.get("encryption") != IMPORT_NAME:
            return True
        attachment = extra["attachment"]
        encryption_data = extra["encryption_data"]
        log.debug("encrypting file with AES-GCM")
        iv = encryption_data["iv"]
        key = encryption_data["key"]

        # encrypted data size will be bigger than original file size
        # so we need to check with final data length to avoid a warning on close()
        sat_file.check_size_with_read = True

        # file_producer get length directly from file, and this cause trouble as
        # we have to change the size because of encryption. So we adapt it here,
        # else the producer would stop reading prematurely
        file_producer.length = sat_file.size

        encryptor = ciphers.Cipher(
            ciphers.algorithms.AES(key),
            modes.GCM(iv),
            backend=backends.default_backend(),
        ).encryptor()

        if sat_file.data_cb is not None:
            raise exceptions.InternalError(
                f"data_cb was expected to be None, it is set to {sat_file.data_cb}"
            )

        attachment.update(
            {
                "hash_algo": self._h.ALGO_DEFAULT,
                "hasher": self._h.get_hasher(),
                "encrypted_hash_algo": self._h.ALGO_DEFAULT,
                "encrypted_hasher": self._h.get_hasher(),
            }
        )

        # with data_cb we encrypt the file on the fly
        sat_file.data_cb = partial(
            self._encrypt, encryptor=encryptor, attachment=attachment
        )
        return True


@implementer(iwokkel.IDisco)
class XEP0448Handler(XMPPHandler):

    def getDiscoInfo(self, requestor, target, nodeIdentifier=""):
        return [disco.DiscoFeature(NS_ESFS)]

    def getDiscoItems(self, requestor, target, nodeIdentifier=""):
        return []