view libervia/frontends/tools/webrtc_file.py @ 4288:f46891f2c9cb

plugin XEP-0166: handle `content-add` action + expose `get_transport`: - `content-add` is now handled at this plugin level (implementation needs to be done in apps and transports plugins). - `get_transport` is now exposed. rel 447
author Goffi <goffi@goffi.org>
date Mon, 29 Jul 2024 03:30:58 +0200
parents 0d7bb4df2343
children
line wrap: on
line source

#!/usr/bin/env python3

# Libervia WebRTC implementation
# Copyright (C) 2009-2024 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import asyncio
import atexit
from functools import partial
import logging
from pathlib import Path
from typing import Any, Callable, IO

import gi

gi.require_versions({"Gst": "1.0", "GstWebRTC": "1.0"})
from gi.repository import GLib, GstWebRTC

from libervia.backend.core import exceptions
from libervia.backend.core.i18n import _
from libervia.backend.tools.common import data_format, utils
from libervia.frontends.tools import aio, jid, webrtc
from libervia.frontends.tools.webrtc_models import CallData


log = logging.getLogger(__name__)

WEBRTC_CHUNK_SIZE = 64 * 1024


class WebRTCFileSender:

    def __init__(
        self,
        bridge,
        profile: str,
        on_call_start_cb: Callable[[dict], Any] | None = None,
        end_call_cb: Callable[[], Any] | None = None,
    ) -> None:
        """Initializes the File Sender.

        @param bridge: An async Bridge instance.
        @param profile: The profile name to be used.
        @param on_call_start_cb: A blocking or async callable that accepts a dict as its
            only argument.
        @param end_call_cb: A callable to be invoked at the end of a call.
        """

        self.bridge = bridge
        self.profile = profile
        self.on_call_start_cb = on_call_start_cb
        self.end_call_cb = end_call_cb
        self.loop = asyncio.get_event_loop()

    async def _on_webrtc_call_start(
        self,
        file_path: Path,
        file_name: str | None,
        callee: str,
        call_data: dict,
        profile: str,
    ) -> str:
        file_data_s = await self.bridge.file_jingle_send(
            str(callee),
            "",
            file_name or file_path.name,
            "",
            data_format.serialise(
                {
                    "webrtc": True,
                    "call_data": call_data,
                    "size": file_path.stat().st_size,
                }
            ),
            self.profile,
        )
        file_data = data_format.deserialise(file_data_s)

        if self.on_call_start_cb is not None:
            await aio.maybe_async(self.on_call_start_cb(file_data))
        return file_data["session_id"]

    async def _send_file(
        self, file_path: Path, data_channel: GstWebRTC.WebRTCDataChannel
    ) -> None:
        """Send file to Data Channel by chunks"""
        try:
            with file_path.open("rb") as file:
                while True:
                    data = file.read(WEBRTC_CHUNK_SIZE)
                    if not data:
                        break
                    data_channel.send_data(GLib.Bytes(data))
                    # We give control back to the loop to avoid freezing everything.
                    await asyncio.sleep(0)
        finally:
            webrtc_call = self.webrtc_call
            # we connect to the "on-close" signal to wait for the data channel to be
            # actually closed before closing the call and quitting the app.
            data_channel.connect("on-close", partial(self._on_dc_close, webrtc_call))
            data_channel.close()

    def _on_dc_close(self, webrtc_call, data_channel: GstWebRTC.WebRTCDataChannel):
        if webrtc_call is not None:
            aio.run_from_thread(self._end_call_and_quit, webrtc_call, loop=self.loop)

    async def _end_call_and_quit(self, webrtc_call):
        await webrtc_call.webrtc.end_call()
        if self.end_call_cb is not None:
            await aio.maybe_async(self.end_call_cb())

    def _on_dc_open(
        self, file_path: Path, data_channel: GstWebRTC.WebRTCDataChannel
    ) -> None:
        """Called when datachannel is open"""
        aio.run_from_thread(self._send_file, file_path, data_channel, loop=self.loop)

    async def send_file_webrtc(
        self,
        file_path: Path | str,
        callee: jid.JID,
        file_name: str | None = None,
    ) -> None:
        """Send a file using WebRTC to the given callee JID.

        @param file_path: The local path to the file to send.
        @param callee: The JID of the recipient to send the file to.
        @param file_name: Name of the file as sent to the peer.
            If None or empty string, name will be retrieved from file path.
        """
        file_path = Path(file_path)
        call_data = CallData(callee=callee)
        self.webrtc_call = await webrtc.WebRTCCall.make_webrtc_call(
            self.bridge,
            self.profile,
            call_data,
            sources_data=webrtc.SourcesDataChannel(
                dc_open_cb=partial(self._on_dc_open, file_path)
            ),
            call_start_cb=partial(
                self._on_webrtc_call_start,
                file_path,
                file_name,
            ),
        )


class WebRTCFileReceiver:

    def __init__(
        self, bridge, profile: str, on_close_cb: Callable[[], Any] | None = None
    ) -> None:
        """Initializes the File Receiver.

        @param bridge: An async Bridge instance.
        @param profile: The profile name to be used.
        @param on_close_cb: Called when the Data Channel is closed.
        """
        self.bridge = bridge
        self.profile = profile
        self.on_close_cb = on_close_cb
        self.loop = asyncio.get_event_loop()
        self.file_data: dict | None = None
        self.fd: IO[bytes] | None = None

    @staticmethod
    def format_confirm_msg(
        action_data: dict, peer_jid: jid.JID, peer_name: str | None = None
    ) -> str:
        """Format a user-friendly confirmation message.

        File data will be retrieve from ``action_data`` and used to format a user-friendly
        file confirmation message.
        @param action_data: Data as returned by the "FILE" ``action_new`` signal.
        @return: User-friendly confirmation message.
        """
        file_data = action_data.get("file_data", {})

        file_name = file_data.get("name")
        file_size = file_data.get("size")

        if file_name:
            file_name_msg = 'wants to send you the file "{file_name}"'.format(
                file_name=file_name
            )
        else:
            file_name_msg = "wants to send you an unnamed file"

        if file_size is not None:
            file_size_msg = "which has a size of {file_size_human}".format(
                file_size_human=utils.get_human_size(file_size)
            )
        else:
            file_size_msg = "which has an unknown size"

        file_description = file_data.get("desc")
        if file_description:
            description_msg = " Description: {}.".format(file_description)
        else:
            description_msg = ""

        file_data = action_data.get("file_data", {})

        if not peer_name:
            peer_name = str(peer_jid)
        else:
            peer_name = f"{peer_name} ({peer_jid})"

        return _(
            "{peer_name} {file_name_msg} {file_size_msg}.{description_msg} "
            "Do you accept?"
        ).format(
            peer_name=peer_name,
            file_name_msg=file_name_msg,
            file_size_msg=file_size_msg,
            description_msg=description_msg,
        )

    def _on_dc_message_data(self, fd, data_channel, glib_data) -> None:
        """A data chunk of the file has been received."""
        fd.write(glib_data.get_data())

    def _on_dc_close(self, data_channel) -> None:
        """Data channel is closed

        The file download should be complete, we close it.
        """
        aio.run_from_thread(self._on_close, loop=self.loop)

    async def _on_close(self) -> None:
        assert self.fd is not None
        self.fd.close()
        if self.on_close_cb is not None:
            await aio.maybe_async(self.on_close_cb())

    def _on_data_channel(self, webrtcbin, data_channel) -> None:
        """The data channel has been opened."""
        data_channel.connect(
            "on-message-data", partial(self._on_dc_message_data, self.fd)
        )
        data_channel.connect("on-close", self._on_dc_close)

    def _on_fd_clean(self, fd) -> None:
        """Closed opened file object if not already.

        If the file object was not closed, an error message is returned.
        """
        if fd is None:
            return
        if not fd.closed:
            log.error(
                f"The file {fd.name!r} was not closed properly, which might "
                "indicate an incomplete download."
            )
            fd.close()

    async def receive_file_webrtc(
        self,
        from_jid: jid.JID,
        session_id: str,
        file_path: Path,
        file_data: dict,
    ) -> None:
        """Receives a file via WebRTC and saves it to the specified path.

        @param from_jid: The JID of the entity sending the file.
        @param session_id: The Jingle FT Session ID.
        @param file_path: The local path where the received file will be saved.
            If a file already exists at this path, it will be overwritten.
        @param file_data: Additional data about the file being transferred.
        """
        if file_path.exists() and not file_path.is_file():
            raise exceptions.InternalError(
                f"{file_path} is not a valid destination path."
            )
        self.fd = file_path.open("wb")
        atexit.register(self._on_fd_clean, self.fd)
        self.file_data = file_data
        call_data = CallData(callee=from_jid, sid=session_id)
        await webrtc.WebRTCCall.make_webrtc_call(
            self.bridge,
            self.profile,
            call_data,
            sinks_data=webrtc.SinksDataChannel(
                dc_on_data_channel=self._on_data_channel,
            ),
        )