view libervia/backend/plugins/plugin_xep_0167/mapping.py @ 4318:27bb22eace65

tests (unit/email gateway): add test for XEP-0131 handling: rel 451
author Goffi <goffi@goffi.org>
date Sat, 28 Sep 2024 15:59:48 +0200
parents 33ecebb2c02f
children
line wrap: on
line source

#!/usr/bin/env python3

# Libervia: an XMPP client
# Copyright (C) 2009-2023 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import base64
from typing import Any, Dict, Optional

from twisted.words.xish import domish

from libervia.backend.core.constants import Const as C
from libervia.backend.core.log import getLogger

from .constants import NS_JINGLE_RTP

log = getLogger(__name__)

host = None


def senders_to_sdp(senders: str, session: dict) -> str:
    """Returns appropriate SDP attribute corresponding to Jingle senders attribute"""
    if senders == "both":
        return "a=sendrecv"
    elif senders == "none":
        return "a=inactive"
    elif session["role"] != senders:
        return "a=sendonly"
    else:
        return "a=recvonly"


def generate_sdp_from_session(session: dict, local: bool = False, port: int = 9) -> str:
    """Generate an SDP string from session data.

    @param session: A dictionary containing the session data. It should have the
        following structure:

        {
            "contents": {
                "<content_id>": {
                    "application_data": {
                        "media": <str: "audio" or "video">,
                        "local_data": <media_data dict>,
                        "peer_data": <media_data dict>,
                        ...
                    },
                    "transport_data": {
                        "local_ice_data": <ice_data dict>,
                        "peer_ice_data": <ice_data dict>,
                        ...
                    },
                    ...
                },
                ...
            }
        }
    @param local: A boolean value indicating whether to generate SDP for the local or
        peer entity. If True, the method will generate SDP for the local entity,
        otherwise for the peer entity. Generally the local SDP is received from frontends
        and not needed in backend, except for debugging purpose.
    @param port: The preferred port for communications.

    @return: The generated SDP string.
    """
    contents = session["contents"]
    sdp_lines = ["v=0"]

    # Add originator (o=) line after the version (v=) line
    username = base64.b64encode(session["local_jid"].full().encode()).decode()
    session_id = "1"  # Increment this for each session
    session_version = "1"  # Increment this when the session is updated
    network_type = "IN"
    address_type = "IP4"
    connection_address = "0.0.0.0"
    o_line = (
        f"o={username} {session_id} {session_version} {network_type} {address_type} "
        f"{connection_address}"
    )
    sdp_lines.append(o_line)

    # Add the mandatory "s=" and t=" lines
    sdp_lines.append("s=-")
    sdp_lines.append("t=0 0")

    # stream direction
    all_senders = {c["senders"] for c in session["contents"].values()}
    # if we don't have a common senders for all contents, we set them at media level
    senders = all_senders.pop() if len(all_senders) == 1 else None

    sdp_lines.append("a=msid-semantic:WMS *")
    sdp_lines.append("a=ice-options:trickle")

    host.trigger.point(
        "XEP-0167_generate_sdp_session",
        session,
        local,
        sdp_lines,
        triggers_no_cancel=True,
    )
    content_names = sorted(contents)

    for content_name, content_data in [
        (n, contents[n]) for n in content_names
    ]:  # contents.items():
        app_data_key = "local_data" if local else "peer_data"
        application_data = content_data["application_data"]
        media_data = application_data[app_data_key]
        media = application_data["media"]
        payload_types = media_data.get("payload_types", {})

        # Generate m= line
        if media == "application":
            transport = "UDP/DTLS/SCTP"
            m_line = f"m={media} {port} {transport} webrtc-datachannel"
        else:
            transport = "UDP/TLS/RTP/SAVPF"
            payload_type_ids = [str(pt_id) for pt_id in payload_types]
            m_line = f"m={media} {port} {transport} {' '.join(payload_type_ids)}"
        sdp_lines.append(m_line)

        sdp_lines.append(f"c={network_type} {address_type} {connection_address}")

        sdp_lines.append(f"a=mid:{content_name}")
        if senders is not None:
            sdp_lines.append(senders_to_sdp(senders, session))

        # stream direction
        if senders is None:
            sdp_lines.append(senders_to_sdp(content_data["senders"], session))

        # Generate a= lines for rtpmap and fmtp
        for pt_id, pt in payload_types.items():
            name = pt["name"]
            clockrate = pt.get("clockrate", "")

            # Check if "channels" is in pt and append it to the line
            channels = pt.get("channels")
            if channels:
                sdp_lines.append(f"a=rtpmap:{pt_id} {name}/{clockrate}/{channels}")
            else:
                sdp_lines.append(f"a=rtpmap:{pt_id} {name}/{clockrate}")

            if "ptime" in pt:
                sdp_lines.append(f"a=ptime:{pt['ptime']}")

            if "parameters" in pt:
                fmtp_params = ";".join([f"{k}={v}" for k, v in pt["parameters"].items()])
                sdp_lines.append(f"a=fmtp:{pt_id} {fmtp_params}")

        if "bandwidth" in media_data:
            sdp_lines.append(f"a=b:{media_data['bandwidth']}")

        if media_data.get("rtcp-mux"):
            sdp_lines.append("a=rtcp-mux")

        # Generate a= lines for fingerprint, ICE ufrag, pwd and candidates
        ice_data_key = "local_ice_data" if local else "peer_ice_data"
        ice_data = content_data["transport_data"][ice_data_key]

        if "fingerprint" in ice_data:
            fingerprint_data = ice_data["fingerprint"]
            sdp_lines.append(
                f"a=fingerprint:{fingerprint_data['hash']} "
                f"{fingerprint_data['fingerprint']}"
            )
            sdp_lines.append(f"a=setup:{fingerprint_data['setup']}")

        sdp_lines.append(f"a=ice-ufrag:{ice_data['ufrag']}")
        sdp_lines.append(f"a=ice-pwd:{ice_data['pwd']}")

        for candidate in ice_data["candidates"]:
            candidate_line = generate_candidate_line(candidate)
            sdp_lines.append(f"a={candidate_line}")

        # Generate a= lines for encryption
        if "encryption" in media_data:
            for enc_data in media_data["encryption"]:
                crypto_suite = enc_data["crypto-suite"]
                key_params = enc_data["key-params"]
                session_params = enc_data.get("session-params", "")
                tag = enc_data["tag"]

                crypto_line = f"a=crypto:{tag} {crypto_suite} {key_params}"
                if session_params:
                    crypto_line += f" {session_params}"
                sdp_lines.append(crypto_line)

        host.trigger.point(
            "XEP-0167_generate_sdp_content",
            session,
            local,
            content_name,
            content_data,
            sdp_lines,
            application_data,
            app_data_key,
            media_data,
            media,
            triggers_no_cancel=True,
        )

    # Combine SDP lines and return the result
    return "\r\n".join(sdp_lines) + "\r\n"


def generate_candidate_line(candidate: dict) -> str:
    """Generate a ``candidate:`` attribute line from candidate data.

    @param candidate: ICE candidate data.
    @return ICE candidate attribute line.
    """
    foundation = candidate["foundation"]
    component_id = candidate["component_id"]
    transport = candidate["transport"]
    priority = candidate["priority"]
    address = candidate["address"]
    candidate_port = candidate["port"]
    candidate_type = candidate["type"]

    candidate_line = (
        f"candidate:{foundation} {component_id} {transport} {priority} "
        f"{address} {candidate_port} typ {candidate_type}"
    )

    if "rel_addr" in candidate and "rel_port" in candidate:
        candidate_line += f" raddr {candidate['rel_addr']} rport {candidate['rel_port']}"

    if "generation" in candidate:
        candidate_line += f" generation {candidate['generation']}"

    if "network" in candidate:
        candidate_line += f" network {candidate['network']}"

    return candidate_line


def parse_candidate(parts: list[str]) -> dict:
    """Parse parts of a ICE candidate

    @param parts: Parts of the candidate line.
    @return: candidate data
    """
    candidate = {
        "foundation": parts[0],
        "component_id": int(parts[1]),
        "transport": parts[2],
        "priority": int(parts[3]),
        "address": parts[4],
        "port": int(parts[5]),
        "type": parts[7],
    }

    for part in parts[8:]:
        if part == "raddr":
            candidate["rel_addr"] = parts[parts.index(part) + 1]
        elif part == "rport":
            candidate["rel_port"] = int(parts[parts.index(part) + 1])
        elif part == "generation":
            candidate["generation"] = parts[parts.index(part) + 1]
        elif part == "network":
            candidate["network"] = parts[parts.index(part) + 1]

    return candidate


def parse_sdp(sdp: str, role: str) -> dict:
    """Parse SDP string.

    @param sdp: The SDP string to parse.
    @param role: Role of the entities which produces the SDP.
    @return: A dictionary containing parsed session data.
    """
    assert host is not None
    lines = sdp.strip().split("\r\n")
    metadata: Dict[str, Any] = {}
    call_data = {"metadata": metadata}
    session_attributes: Dict[str, Any] = {}

    media_type = None
    media_data: Optional[Dict[str, Any]] = None
    application_data: Optional[Dict[str, Any]] = None
    transport_data: Optional[Dict[str, Any]] = None
    fingerprint_data: Optional[Dict[str, str]] = None
    ice_pwd: Optional[str] = None
    ice_ufrag: Optional[str] = None
    payload_types: Optional[Dict[int, dict]] = None
    senders: str = "both"

    for line in lines:
        try:
            parts = line.split()
            prefix = parts[0][:2]
            parts[0] = parts[0][2:]

            if prefix == "m=":
                media_type = parts[0]
                port = int(parts[1])
                application_data = {"media": media_type}
                if media_type in ("video", "audio"):
                    payload_types = application_data["payload_types"] = {}
                    for payload_type_id in [int(pt_id) for pt_id in parts[3:]]:
                        payload_type = {"id": payload_type_id}
                        payload_types[payload_type_id] = payload_type
                elif media_type == "application":
                    if parts[3] != "webrtc-datachannel":
                        raise NotImplementedError(
                            f"{media_type!r} application is not supported"
                        )

                transport_data = {"port": port}
                if fingerprint_data is not None:
                    transport_data["fingerprint"] = fingerprint_data
                if ice_pwd is not None:
                    transport_data["pwd"] = ice_pwd
                if ice_ufrag is not None:
                    transport_data["ufrag"] = ice_ufrag
                media_data = call_data[media_type] = {
                    "application_data": application_data,
                    "transport_data": transport_data,
                    "senders": senders,
                }

                # Apply session attributes to the new media
                for attr, value in session_attributes.items():
                    if attr in ("fingerprint", "ice-pwd", "ice-ufrag"):
                        transport_data[attr] = value
                    elif attr in ("sendrecv", "sendonly", "recvonly", "inactive"):
                        application_data["senders"] = value
                    else:
                        application_data[attr] = value

            elif prefix == "a=":
                if ":" in parts[0]:
                    attribute, parts[0] = parts[0].split(":", 1)
                else:
                    attribute = parts[0]

                if media_type is None:
                    # This is a session-level attribute
                    if attribute == "fingerprint":
                        algorithm, fingerprint = parts[0], parts[1]
                        session_attributes["fingerprint"] = {
                            "hash": algorithm,
                            "fingerprint": fingerprint,
                        }
                    elif attribute == "ice-pwd":
                        session_attributes["ice-pwd"] = parts[0]
                    elif attribute == "ice-ufrag":
                        session_attributes["ice-ufrag"] = parts[0]
                    elif attribute in ("sendrecv", "sendonly", "recvonly", "inactive"):
                        if attribute == "sendrecv":
                            value = "both"
                        elif attribute == "sendonly":
                            value = role
                        elif attribute == "recvonly":
                            value = "responder" if role == "initiator" else "initiator"
                        else:
                            value = "none"
                        session_attributes[attribute] = value
                    else:
                        session_attributes[attribute] = parts[0] if parts else True
                else:
                    # This is a media-level attribute
                    if attribute == "mid":
                        assert media_data is not None
                        try:
                            media_data["id"] = parts[0]
                        except IndexError:
                            log.warning(f"invalid media ID: {line}")

                    elif attribute == "rtpmap":
                        assert application_data is not None
                        assert payload_types is not None
                        pt_id = int(parts[0])
                        codec_info = parts[1].split("/")
                        codec = codec_info[0]
                        clockrate = int(codec_info[1])
                        payload_type = {
                            "id": pt_id,
                            "name": codec,
                            "clockrate": clockrate,
                        }
                        if len(codec_info) > 2:
                            channels = int(codec_info[2])
                            payload_type["channels"] = channels

                        payload_types.setdefault(pt_id, {}).update(payload_type)

                    elif attribute == "fmtp":
                        assert payload_types is not None
                        pt_id = int(parts[0])
                        params = parts[1].split(";")
                        try:
                            payload_type = payload_types[pt_id]
                        except KeyError:
                            raise ValueError(
                                f"Can find content type {pt_id}, ignoring: {line}"
                            )

                        try:
                            payload_type["parameters"] = {
                                name: value
                                for name, value in (param.split("=") for param in params)
                            }
                        except ValueError:
                            payload_type.setdefault("exra-parameters", []).extend(params)

                    elif attribute == "candidate":
                        assert transport_data is not None
                        candidate = parse_candidate(parts)

                        transport_data.setdefault("candidates", []).append(candidate)

                    elif attribute == "fingerprint":
                        algorithm, fingerprint = parts[0], parts[1]
                        fingerprint_data = {"hash": algorithm, "fingerprint": fingerprint}
                        if transport_data is not None:
                            transport_data.setdefault("fingerprint", {}).update(
                                fingerprint_data
                            )
                    elif attribute == "setup":
                        assert transport_data is not None
                        setup = parts[0]
                        transport_data.setdefault("fingerprint", {})["setup"] = setup

                    elif attribute == "b":
                        assert application_data is not None
                        bandwidth = int(parts[0])
                        application_data["bandwidth"] = bandwidth

                    elif attribute == "rtcp-mux":
                        assert application_data is not None
                        application_data["rtcp-mux"] = True

                    elif attribute == "ice-ufrag":
                        if transport_data is not None:
                            transport_data["ufrag"] = parts[0]

                    elif attribute == "ice-pwd":
                        if transport_data is not None:
                            transport_data["pwd"] = parts[0]

                    elif attribute in ("sendrecv", "sendonly", "recvonly", "inactive"):
                        if attribute == "sendrecv":
                            value = "both"
                        elif attribute == "sendonly":
                            value = role
                        elif attribute == "recvonly":
                            value = "responder" if role == "initiator" else "initiator"
                        else:
                            value = "none"

                        if application_data is None:
                            senders = value
                        else:
                            application_data["senders"] = value

                host.trigger.point(
                    "XEP-0167_parse_sdp_a",
                    attribute,
                    parts,
                    call_data,
                    metadata,
                    media_type,
                    application_data,
                    transport_data,
                    triggers_no_cancel=True,
                )

        except ValueError as e:
            raise ValueError(f"Could not parse line. Invalid format ({e}): {line}") from e
        except IndexError as e:
            raise IndexError(f"Incomplete line. Missing data: {line}") from e

    # we remove private data (data starting with _, used by some plugins (e.g. XEP-0294)
    # to handle session data at media level))
    for key in [k for k in call_data if k.startswith("_")]:
        log.debug(f"cleaning remaining private data {key!r}")
        del call_data[key]

    all_media = {k: v for k, v in call_data.items() if k in ("audio", "video")}
    if len(all_media) > 1:
        media_with_candidates = [
            m for m in all_media.values() if "candidates" in m["transport_data"]
        ]
        if len(media_with_candidates) == 1:
            log.warning(
                "SDP contains ICE candidates only for one media stream. This is non-standard behavior."
            )

    return call_data


def build_description(media: str, media_data: dict, session: dict) -> domish.Element:
    """Generate <description> element from media data

    @param media: media type ("audio" or "video")

    @param media_data: A dictionary containing the media description data.
        The keys and values are described below:

        - ssrc (str, optional): The synchronization source identifier.
        - payload_types (list): A list of dictionaries, each representing a payload
          type.
          Each dictionary may contain the following keys:
            - channels (str, optional): Number of audio channels.
            - clockrate (str, optional): Clock rate of the media.
            - id (str): The unique identifier of the payload type.
            - maxptime (str, optional): Maximum packet time.
            - name (str, optional): Name of the codec.
            - ptime (str, optional): Preferred packet time.
            - parameters (dict, optional): A dictionary of codec-specific parameters.
              Key-value pairs represent the parameter name and value, respectively.
        - bandwidth (str, optional): The bandwidth type.
        - rtcp-mux (bool, optional): Indicates whether RTCP multiplexing is enabled or
          not.
        - encryption (list, optional): A list of dictionaries, each representing an
          encryption method.
          Each dictionary may contain the following keys:
            - tag (str): The unique identifier of the encryption method.
            - crypto-suite (str): The encryption suite in use.
            - key-params (str): Key parameters for the encryption suite.
            - session-params (str, optional): Session parameters for the encryption
              suite.

    @return: A <description> element.
    """
    # FIXME: to be removed once host is accessible from global var
    assert host is not None

    desc_elt = domish.Element((NS_JINGLE_RTP, "description"), attribs={"media": media})

    for pt_id, pt_data in media_data.get("payload_types", {}).items():
        payload_type_elt = desc_elt.addElement("payload-type")
        payload_type_elt["id"] = str(pt_id)
        for attr in ["channels", "clockrate", "maxptime", "name", "ptime"]:
            if attr in pt_data:
                payload_type_elt[attr] = str(pt_data[attr])

        if "parameters" in pt_data:
            for param_name, param_value in pt_data["parameters"].items():
                param_elt = payload_type_elt.addElement("parameter")
                param_elt["name"] = param_name
                param_elt["value"] = param_value
        host.trigger.point(
            "XEP-0167_build_description_payload_type",
            desc_elt,
            media_data,
            pt_data,
            payload_type_elt,
            triggers_no_cancel=True,
        )

    if "bandwidth" in media_data:
        bandwidth_elt = desc_elt.addElement("bandwidth")
        bandwidth_elt["type"] = media_data["bandwidth"]

    if media_data.get("rtcp-mux"):
        desc_elt.addElement("rtcp-mux")

    # Add encryption element
    if "encryption" in media_data:
        encryption_elt = desc_elt.addElement("encryption")
        # we always want require encryption if the `encryption` data is present
        encryption_elt["required"] = "1"
        for enc_data in media_data["encryption"]:
            crypto_elt = encryption_elt.addElement("crypto")
            for attr in ["tag", "crypto-suite", "key-params", "session-params"]:
                if attr in enc_data:
                    crypto_elt[attr] = enc_data[attr]

    host.trigger.point(
        "XEP-0167_build_description",
        desc_elt,
        media_data,
        session,
        triggers_no_cancel=True,
    )

    return desc_elt


def parse_description(desc_elt: domish.Element) -> dict:
    """Parse <desciption> to a dict

    @param desc_elt: <description> element
    @return: media data as in [build_description]
    """
    # FIXME: to be removed once host is accessible from global var
    assert host is not None

    media_data = {}
    if desc_elt.hasAttribute("ssrc"):
        media_data.setdefault("ssrc", {})[desc_elt["ssrc"]] = {}

    payload_types = {}
    for payload_type_elt in desc_elt.elements(NS_JINGLE_RTP, "payload-type"):
        payload_type_data = {
            attr: payload_type_elt[attr]
            for attr in [
                "channels",
                "clockrate",
                "maxptime",
                "name",
                "ptime",
            ]
            if payload_type_elt.hasAttribute(attr)
        }
        try:
            pt_id = int(payload_type_elt["id"])
        except KeyError:
            log.warning(
                f"missing ID in payload type, ignoring: {payload_type_elt.toXml()}"
            )
            continue

        parameters = {}
        for param_elt in payload_type_elt.elements(NS_JINGLE_RTP, "parameter"):
            param_name = param_elt.getAttribute("name")
            param_value = param_elt.getAttribute("value")
            if not param_name or param_value is None:
                log.warning(f"invalid parameter: {param_elt.toXml()}")
                continue
            parameters[param_name] = param_value

        if parameters:
            payload_type_data["parameters"] = parameters

        host.trigger.point(
            "XEP-0167_parse_description_payload_type",
            desc_elt,
            media_data,
            payload_type_elt,
            payload_type_data,
            triggers_no_cancel=True,
        )
        payload_types[pt_id] = payload_type_data

    # bandwidth
    media_data["payload_types"] = payload_types
    try:
        bandwidth_elt = next(desc_elt.elements(NS_JINGLE_RTP, "bandwidth"))
    except StopIteration:
        pass
    else:
        bandwidth = bandwidth_elt.getAttribute("type")
        if not bandwidth:
            log.warning(f"invalid bandwidth: {bandwidth_elt.toXml}")
        else:
            media_data["bandwidth"] = bandwidth

    # rtcp-mux
    rtcp_mux_elt = next(desc_elt.elements(NS_JINGLE_RTP, "rtcp-mux"), None)
    media_data["rtcp-mux"] = rtcp_mux_elt is not None

    # Encryption
    encryption_data = []
    encryption_elt = next(desc_elt.elements(NS_JINGLE_RTP, "encryption"), None)
    if encryption_elt:
        media_data["encryption_required"] = C.bool(
            encryption_elt.getAttribute("required", C.BOOL_FALSE)
        )

        for crypto_elt in encryption_elt.elements(NS_JINGLE_RTP, "crypto"):
            crypto_data = {
                attr: crypto_elt[attr]
                for attr in [
                    "crypto-suite",
                    "key-params",
                    "session-params",
                    "tag",
                ]
                if crypto_elt.hasAttribute(attr)
            }
            encryption_data.append(crypto_data)

    if encryption_data:
        media_data["encryption"] = encryption_data

    host.trigger.point(
        "XEP-0167_parse_description", desc_elt, media_data, triggers_no_cancel=True
    )

    return media_data