view libervia/backend/plugins/plugin_xep_0447.py @ 4341:e9971a4b0627

remove uses of twisted.internet.defer.returnValue This function has been deprecated in Twisted 24.7.0.
author Povilas Kanapickas <povilas@radix.lt>
date Wed, 11 Dec 2024 01:17:09 +0200
parents 430d5d99a740
children
line wrap: on
line source

#!/usr/bin/env python3

# Copyright (C) 2009-2022 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from abc import ABC, abstractmethod
from collections import namedtuple
from functools import partial
import mimetypes
from pathlib import Path
from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    Final,
    List,
    Literal,
    NamedTuple,
    Optional,
    Self,
    Tuple,
    Union,
    cast,
)

from pydantic import BaseModel, Field
import treq
from twisted.internet import defer
from twisted.words.xish import domish

from libervia.backend.core import exceptions
from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import SatXMPPEntity
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.plugins.plugin_xep_0103 import URLData, XEP_0103
from libervia.backend.plugins.plugin_xep_0358 import JinglePub, XEP_0358
from libervia.backend.plugins.plugin_xep_0446 import FileMetadata, XEP_0446
from libervia.backend.tools import stream
from libervia.backend.tools.web import treq_client_no_ssl

log = getLogger(__name__)


PLUGIN_INFO = {
    C.PI_NAME: "Stateless File Sharing",
    C.PI_IMPORT_NAME: "XEP-0447",
    C.PI_TYPE: "XEP",
    C.PI_MODES: C.PLUG_MODE_BOTH,
    C.PI_PROTOCOLS: ["XEP-0447"],
    C.PI_DEPENDENCIES: [
        "XEP-0103",
        "XEP-0334",
        "XEP-0358",
        "XEP-0446",
        "ATTACH",
        "DOWNLOAD",
    ],
    C.PI_RECOMMENDATIONS: ["XEP-0363"],
    C.PI_MAIN: "XEP_0447",
    C.PI_HANDLER: "no",
    C.PI_DESCRIPTION: _("""Implementation of XEP-0447 (Stateless File Sharing)"""),
}

NS_SFS = "urn:xmpp:sfs:0"


class Source(ABC, BaseModel):

    type: ClassVar[str]
    encrypted: ClassVar[bool] = False

    def __init_subclass__(cls) -> None:
        super().__init_subclass__()
        if not hasattr(cls, "type"):
            raise TypeError(
                f'Can\'t instantiate {cls.__name__} without "type" class attribute.'
            )

    @classmethod
    @abstractmethod
    def from_element(cls, element: domish.Element) -> Self:
        """Parse an element and return corresponding model

        @param element: element to parse
        @raise exceptions.DataError: the element is invalid
        """

    @abstractmethod
    def to_element(self) -> domish.Element:
        """Convert model to an element

        @return: domish.Element representing the model
        """


class FileSharing(BaseModel):
    """
    Model for handling XEP-0447 <file-sharing> element.
    """

    file: FileMetadata
    sources: list[Source]
    disposition: str | None = Field(
        default=None,
        description="Disposition of the file, either 'attachment' or 'inline'.",
    )
    id: str | None = Field(
        default=None, description="Unique identifier for the file-sharing element."
    )
    _sfs: "XEP_0447 | None" = None

    def to_element(self) -> domish.Element:
        """Build the <file-sharing> element from this instance's data.

        @return: <file-sharing> element.
        """
        file_sharing_elt = domish.Element((NS_SFS, "file-sharing"))

        if self.disposition:
            file_sharing_elt["disposition"] = self.disposition

        if self.id:
            file_sharing_elt["id"] = self.id

        file_sharing_elt.addChild(self.file.to_element())

        sources_elt = file_sharing_elt.addElement("sources")
        for source in self.sources:
            sources_elt.addChild(source.to_element())

        return file_sharing_elt

    @classmethod
    def from_element(cls, file_sharing_elt: domish.Element) -> Self:
        """Create a FileSharing instance from a <file-sharing> element or its parent.

        @param file_sharing_elt: The <file-sharing> element or a parent element.
        @return: FileSharing instance.
        @raise exceptions.NotFound: If the <file-sharing> element is not found.
        """
        assert cls._sfs is not None
        if file_sharing_elt.uri != NS_SFS or file_sharing_elt.name != "file-sharing":
            child_file_sharing_elt = next(
                file_sharing_elt.elements(NS_SFS, "file-sharing"), None
            )
            if child_file_sharing_elt is None:
                raise exceptions.NotFound("<file-sharing> element not found")
            else:
                file_sharing_elt = child_file_sharing_elt

        kwargs = {}
        disposition = file_sharing_elt.getAttribute("disposition")
        if disposition:
            kwargs["disposition"] = disposition

        file_id = file_sharing_elt.getAttribute("id")
        if file_id:
            kwargs["id"] = file_id

        kwargs["file"] = FileMetadata.from_element(file_sharing_elt)
        kwargs["sources"] = cls._sfs.parse_sources_elt(file_sharing_elt)

        return cls(**kwargs)


class URLDataSource(URLData, Source):
    type = "url"

    @classmethod
    def from_element(cls, element: domish.Element) -> Self:
        return super().from_element(element)

    def to_element(self) -> domish.Element:
        return super().to_element()


class JinglePubSource(JinglePub, Source):
    type = "jingle"

    @classmethod
    def from_element(cls, element: domish.Element) -> Self:
        return super().from_element(element)

    def to_element(self) -> domish.Element:
        return super().to_element()


class XEP_0447:
    namespace = NS_SFS

    def __init__(self, host):
        self.host = host
        log.info(_("XEP-0447 (Stateless File Sharing) plugin initialization"))
        host.register_namespace("sfs", NS_SFS)
        FileSharing._sfs = self
        self._sources_handlers: dict[tuple[str, str], type[Source]] = {}
        self._u = cast(XEP_0103, host.plugins["XEP-0103"])
        self._jp = cast(XEP_0358, host.plugins["XEP-0358"])
        self._hints = host.plugins["XEP-0334"]
        self._m = cast(XEP_0446, host.plugins["XEP-0446"])
        self._http_upload = host.plugins.get("XEP-0363")
        self._attach = host.plugins["ATTACH"]
        self._attach.register(self.can_handle_attachment, self.attach, priority=1000)
        self.register_source(
            self._u.namespace,
            "url-data",
            URLDataSource,
        )
        self.register_source(
            self._jp.namespace,
            "jinglepub",
            JinglePubSource,
        )
        host.plugins["DOWNLOAD"].register_download_handler(
            self._u.namespace, self.download
        )
        host.trigger.add("message_received", self._message_received_trigger)

    def register_source(
        self,
        namespace: str,
        element_name: str,
        source: type[Source],
    ) -> None:
        """Register a handler for file source

        @param namespace: namespace of the element supported
        @param element_name: name of the element supported
        @param callback: method to call to parse the element
            get the matching element as argument, must return the parsed data
        @param encrypted: if True, the source is encrypted (the transmitting channel
            should then be end2end encrypted to avoir leaking decrypting data to servers).
        """
        key = (namespace, element_name)
        if key in self._sources_handlers:
            raise exceptions.ConflictError(
                f"There is already a resource handler for namespace {namespace!r} and "
                f"name {element_name!r}"
            )
        self._sources_handlers[key] = source

    async def download(
        self,
        client: SatXMPPEntity,
        attachment: Dict[str, Any],
        source: Dict[str, Any],
        dest_path: Union[Path, str],
        extra: Optional[Dict[str, Any]] = None,
    ) -> Tuple[str, defer.Deferred]:
        # TODO: handle url-data headers
        if extra is None:
            extra = {}
        try:
            download_url = source["url"]
        except KeyError:
            raise ValueError(f"{source} has missing URL")

        if extra.get("ignore_tls_errors", False):
            log.warning("TLS certificate check disabled, this is highly insecure")
            treq_client = treq_client_no_ssl
        else:
            treq_client = treq

        try:
            file_size = int(attachment["size"])
        except (KeyError, ValueError):
            head_data = await treq_client.head(download_url)
            file_size = int(head_data.headers.getRawHeaders("content-length")[0])

        file_obj = stream.SatFile(
            self.host,
            client,
            dest_path,
            mode="wb",
            size=file_size,
        )

        progress_id = file_obj.uid

        resp = await treq_client.get(download_url, unbuffered=True)
        if resp.code == 200:
            d = treq.collect(resp, file_obj.write)
            d.addCallback(lambda __: file_obj.close())
        else:
            d = defer.Deferred()
            self.host.plugins["DOWNLOAD"].errback_download(file_obj, d, resp)
        return progress_id, d

    async def can_handle_attachment(self, client, data):
        if self._http_upload is None:
            return False
        try:
            await self._http_upload.get_http_upload_entity(client)
        except exceptions.NotFound:
            return False
        else:
            return True

    def get_sources_elt(
        self, children: Optional[List[domish.Element]] = None
    ) -> domish.Element:
        """Generate <sources> element"""
        sources_elt = domish.Element((NS_SFS, "sources"))
        if children:
            for child in children:
                sources_elt.addChild(child)
        return sources_elt

    def get_file_sharing_elt(
        self,
        sources: List[Dict[str, Any]],
        disposition: Optional[str] = None,
        name: Optional[str] = None,
        media_type: Optional[str] = None,
        desc: Optional[str] = None,
        size: Optional[int] = None,
        file_hash: Optional[Tuple[str, str]] = None,
        date: Optional[Union[float, int]] = None,
        width: Optional[int] = None,
        height: Optional[int] = None,
        length: Optional[int] = None,
        thumbnail: Optional[str] = None,
        **kwargs,
    ) -> domish.Element:
        """Generate the <file-sharing/> element

        @param extra: extra metadata describing how to access the URL
        @return: ``<sfs/>`` element
        """
        file_sharing_elt = domish.Element((NS_SFS, "file-sharing"))
        if disposition is not None:
            file_sharing_elt["disposition"] = disposition
        if media_type is None and name:
            media_type = mimetypes.guess_type(name, strict=False)[0]
        file_sharing_elt.addChild(
            self._m.generate_file_metadata(
                name=name,
                media_type=media_type,
                desc=desc,
                size=size,
                file_hash=file_hash,
                date=date,
                width=width,
                height=height,
                length=length,
                thumbnail=thumbnail,
            ).to_element()
        )
        sources_elt = self.get_sources_elt()
        file_sharing_elt.addChild(sources_elt)
        for source_data in sources:
            if "url" in source_data:
                sources_elt.addChild(
                    self._u.generate_url_data(**source_data).to_element()
                )
            else:
                raise NotImplementedError(f"source data not implemented: {source_data}")

        return file_sharing_elt

    def parse_sources_elt(self, sources_elt: domish.Element) -> List[Source]:
        """Parse <sources/> element

        @param sources_elt: <sources/> element, or a direct parent element
        @return: list of found sources data
        @raise: exceptions.NotFound: Can't find <sources/> element
        """
        if sources_elt.name != "sources" or sources_elt.uri != NS_SFS:
            try:
                sources_elt = next(sources_elt.elements(NS_SFS, "sources"))
            except StopIteration:
                raise exceptions.NotFound(
                    f"<sources/> element is missing: {sources_elt.toXml()}"
                )
        sources = []
        for elt in sources_elt.elements():
            if not elt.uri:
                log.warning("ignoring source element {elt.toXml()}")
                continue
            key = (elt.uri, elt.name)
            try:
                source_handler = self._sources_handlers[key]
            except KeyError:
                log.warning(f"unmanaged file sharing element: {elt.toXml}")
                continue
            else:
                source = source_handler.from_element(elt)
                sources.append(source)
        return sources

    def parse_file_sharing_elt(self, file_sharing_elt: domish.Element) -> Dict[str, Any]:
        """Parse <file-sharing/> element and return file-sharing data

        @param file_sharing_elt: <file-sharing/> element
        @return: file-sharing data. It a dict whose keys correspond to
            [get_file_sharing_elt] parameters
        """
        if file_sharing_elt.name != "file-sharing" or file_sharing_elt.uri != NS_SFS:
            try:
                file_sharing_elt = next(file_sharing_elt.elements(NS_SFS, "file-sharing"))
            except StopIteration:
                raise exceptions.NotFound
        try:
            data = self._m.parse_file_metadata_elt(file_sharing_elt).model_dump()
        except exceptions.NotFound:
            data = {}
        disposition = file_sharing_elt.getAttribute("disposition")
        if disposition is not None:
            data["disposition"] = disposition
        try:
            data["sources"] = self.parse_sources_elt(file_sharing_elt)
        except exceptions.NotFound as e:
            raise ValueError(str(e))

        return data

    def _add_file_sharing_attachments(
        self, client: SatXMPPEntity, message_elt: domish.Element, data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Check <message> for a shared file, and add it as an attachment"""
        # XXX: XEP-0447 doesn't support several attachments in a single message, for now
        #   however that should be fixed in future version, and so we accept several
        #   <file-sharing> element in a message.
        for file_sharing_elt in message_elt.elements(NS_SFS, "file-sharing"):
            attachment = self.parse_file_sharing_elt(message_elt)

            if any(
                s.get(C.MESS_KEY_ENCRYPTED, False) for s in attachment["sources"]
            ) and client.encryption.isEncrypted(data):
                # we don't add the encrypted flag if the message itself is not encrypted,
                # because the decryption key is part of the link, so sending it over
                # unencrypted channel is like having no encryption at all.
                attachment[C.MESS_KEY_ENCRYPTED] = True

            attachments = data["extra"].setdefault(C.KEY_ATTACHMENTS, [])
            attachments.append(attachment)

        return data

    async def attach(self, client, data):
        # XXX: for now, XEP-0447 only allow to send one file per <message/>, thus we need
        #   to send each file in a separate message
        attachments = data["extra"][C.KEY_ATTACHMENTS]
        if not data["message"] or data["message"] == {"": ""}:
            extra_attachments = attachments[1:]
            del attachments[1:]
        else:
            # we have a message, we must send first attachment separately
            extra_attachments = attachments[:]
            attachments.clear()
            del data["extra"][C.KEY_ATTACHMENTS]

        if attachments:
            if len(attachments) > 1:
                raise exceptions.InternalError(
                    "There should not be more that one attachment at this point"
                )
            await self._attach.upload_files(client, data)
            self._hints.add_hint_elements(data["xml"], [self._hints.HINT_STORE])
            for attachment in attachments:
                try:
                    file_hash = (attachment["hash_algo"], attachment["hash"])
                except KeyError:
                    file_hash = None
                file_sharing_elt = self.get_file_sharing_elt(
                    [{"url": attachment["url"]}],
                    name=attachment.get("name"),
                    size=attachment.get("size"),
                    desc=attachment.get("desc"),
                    media_type=attachment.get("media_type"),
                    file_hash=file_hash,
                )
                data["xml"].addChild(file_sharing_elt)

        for attachment in extra_attachments:
            # we send all remaining attachment in a separate message
            await client.sendMessage(
                to_jid=data["to"],
                message={"": ""},
                subject=data["subject"],
                mess_type=data["type"],
                extra={C.KEY_ATTACHMENTS: [attachment]},
            )

        if (
            not data["extra"]
            and (not data["message"] or data["message"] == {"": ""})
            and not data["subject"]
        ):
            # nothing left to send, we can cancel the message
            raise exceptions.CancelError("Cancelled by XEP_0447 attachment handling")

    def _message_received_trigger(self, client, message_elt, post_treat):
        # we use a post_treat callback instead of "message_parse" trigger because we need
        # to check if the "encrypted" flag is set to decide if we add the same flag to the
        # attachment
        post_treat.addCallback(
            partial(self._add_file_sharing_attachments, client, message_elt)
        )
        return True