view libervia/backend/plugins/plugin_xep_0167/mapping.py @ 4217:b53b6dc1f929

plugin XEP-0373: fix serialisation of `public_key_list`
author Goffi <goffi@goffi.org>
date Tue, 05 Mar 2024 17:31:36 +0100
parents b2709504586a
children e11b13418ba6
line wrap: on
line source

#!/usr/bin/env python3

# Libervia: an XMPP client
# Copyright (C) 2009-2023 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import base64
from typing import Any, Dict, Optional

from twisted.words.xish import domish

from libervia.backend.core.constants import Const as C
from libervia.backend.core.log import getLogger

from .constants import NS_JINGLE_RTP

log = getLogger(__name__)

host = None


def senders_to_sdp(senders: str, session: dict) -> str:
    """Returns appropriate SDP attribute corresponding to Jingle senders attribute"""
    if senders == "both":
        return "a=sendrecv"
    elif senders == "none":
        return "a=inactive"
    elif session["role"] == senders:
        return "a=sendonly"
    else:
        return "a=recvonly"


def generate_sdp_from_session(
    session: dict, local: bool = False, port: int = 9
) -> str:
    """Generate an SDP string from session data.

    @param session: A dictionary containing the session data. It should have the
        following structure:

        {
            "contents": {
                "<content_id>": {
                    "application_data": {
                        "media": <str: "audio" or "video">,
                        "local_data": <media_data dict>,
                        "peer_data": <media_data dict>,
                        ...
                    },
                    "transport_data": {
                        "local_ice_data": <ice_data dict>,
                        "peer_ice_data": <ice_data dict>,
                        ...
                    },
                    ...
                },
                ...
            }
        }
    @param local: A boolean value indicating whether to generate SDP for the local or
        peer entity. If True, the method will generate SDP for the local entity,
        otherwise for the peer entity. Generally the local SDP is received from frontends
        and not needed in backend, except for debugging purpose.
    @param port: The preferred port for communications.

    @return: The generated SDP string.
    """
    contents = session["contents"]
    sdp_lines = ["v=0"]

    # Add originator (o=) line after the version (v=) line
    username = base64.b64encode(session["local_jid"].full().encode()).decode()
    session_id = "1"  # Increment this for each session
    session_version = "1"  # Increment this when the session is updated
    network_type = "IN"
    address_type = "IP4"
    connection_address = "0.0.0.0"
    o_line = (
        f"o={username} {session_id} {session_version} {network_type} {address_type} "
        f"{connection_address}"
    )
    sdp_lines.append(o_line)

    # Add the mandatory "s=" and t=" lines
    sdp_lines.append("s=-")
    sdp_lines.append("t=0 0")

    # stream direction
    all_senders = {c["senders"] for c in session["contents"].values()}
    # if we don't have a common senders for all contents, we set them at media level
    senders = all_senders.pop() if len(all_senders) == 1 else None

    sdp_lines.append("a=msid-semantic:WMS *")
    sdp_lines.append("a=ice-options:trickle")

    host.trigger.point(
        "XEP-0167_generate_sdp_session",
        session,
        local,
        sdp_lines,
        triggers_no_cancel=True
    )

    for content_name, content_data in contents.items():
        app_data_key = "local_data" if local else "peer_data"
        application_data = content_data["application_data"]
        media_data = application_data[app_data_key]
        media = application_data["media"]
        payload_types = media_data.get("payload_types", {})

        # Generate m= line
        transport = "UDP/TLS/RTP/SAVPF"
        payload_type_ids = [str(pt_id) for pt_id in payload_types]
        m_line = f"m={media} {port} {transport} {' '.join(payload_type_ids)}"
        sdp_lines.append(m_line)

        sdp_lines.append(f"c={network_type} {address_type} {connection_address}")

        sdp_lines.append(f"a=mid:{content_name}")
        if senders is not None:
            sdp_lines.append(senders_to_sdp(senders, session))

        # stream direction
        if senders is None:
            sdp_lines.append(senders_to_sdp(content_data["senders"], session))


        # Generate a= lines for rtpmap and fmtp
        for pt_id, pt in payload_types.items():
            name = pt["name"]
            clockrate = pt.get("clockrate", "")

            # Check if "channels" is in pt and append it to the line
            channels = pt.get("channels")
            if channels:
                sdp_lines.append(f"a=rtpmap:{pt_id} {name}/{clockrate}/{channels}")
            else:
                sdp_lines.append(f"a=rtpmap:{pt_id} {name}/{clockrate}")

            if "ptime" in pt:
                sdp_lines.append(f"a=ptime:{pt['ptime']}")

            if "parameters" in pt:
                fmtp_params = ";".join([f"{k}={v}" for k, v in pt["parameters"].items()])
                sdp_lines.append(f"a=fmtp:{pt_id} {fmtp_params}")

        if "bandwidth" in media_data:
            sdp_lines.append(f"a=b:{media_data['bandwidth']}")

        if media_data.get("rtcp-mux"):
            sdp_lines.append("a=rtcp-mux")

        # Generate a= lines for fingerprint, ICE ufrag, pwd and candidates
        ice_data_key = "local_ice_data" if local else "peer_ice_data"
        ice_data = content_data["transport_data"][ice_data_key]

        if "fingerprint" in ice_data:
            fingerprint_data = ice_data["fingerprint"]
            sdp_lines.append(
                f"a=fingerprint:{fingerprint_data['hash']} "
                f"{fingerprint_data['fingerprint']}"
            )
            sdp_lines.append(f"a=setup:{fingerprint_data['setup']}")

        sdp_lines.append(f"a=ice-ufrag:{ice_data['ufrag']}")
        sdp_lines.append(f"a=ice-pwd:{ice_data['pwd']}")

        for candidate in ice_data["candidates"]:
            foundation = candidate["foundation"]
            component_id = candidate["component_id"]
            transport = candidate["transport"]
            priority = candidate["priority"]
            address = candidate["address"]
            candidate_port = candidate["port"]
            candidate_type = candidate["type"]

            candidate_line = (
                f"a=candidate:{foundation} {component_id} {transport} {priority} "
                f"{address} {candidate_port} typ {candidate_type}"
            )

            if "rel_addr" in candidate and "rel_port" in candidate:
                candidate_line += (
                    f" raddr {candidate['rel_addr']} rport {candidate['rel_port']}"
                )

            if "generation" in candidate:
                candidate_line += f" generation {candidate['generation']}"

            if "network" in candidate:
                candidate_line += f" network {candidate['network']}"

            sdp_lines.append(candidate_line)

        # Generate a= lines for encryption
        if "encryption" in media_data:
            for enc_data in media_data["encryption"]:
                crypto_suite = enc_data["crypto-suite"]
                key_params = enc_data["key-params"]
                session_params = enc_data.get("session-params", "")
                tag = enc_data["tag"]

                crypto_line = f"a=crypto:{tag} {crypto_suite} {key_params}"
                if session_params:
                    crypto_line += f" {session_params}"
                sdp_lines.append(crypto_line)


        host.trigger.point(
            "XEP-0167_generate_sdp_content",
            session,
            local,
            content_name,
            content_data,
            sdp_lines,
            application_data,
            app_data_key,
            media_data,
            media,
            triggers_no_cancel=True
        )

    # Combine SDP lines and return the result
    return "\r\n".join(sdp_lines) + "\r\n"


def parse_sdp(sdp: str) -> dict:
    """Parse SDP string.

    @param sdp: The SDP string to parse.

    @return: A dictionary containing parsed session data.
    """
    # FIXME: to be removed once host is accessible from global var
    assert host is not None
    lines = sdp.strip().split("\r\n")
    # session metadata
    metadata: Dict[str, Any] = {}
    call_data = {"metadata": metadata}

    media_type = None
    media_data: Optional[Dict[str, Any]] = None
    application_data: Optional[Dict[str, Any]] = None
    transport_data: Optional[Dict[str, Any]] = None
    fingerprint_data: Optional[Dict[str, str]] = None
    ice_pwd: Optional[str] = None
    ice_ufrag: Optional[str] = None
    payload_types: Optional[Dict[int, dict]] = None

    for line in lines:
        try:
            parts = line.split()
            prefix = parts[0][:2]  # Extract the 'a=', 'm=', etc., prefix
            parts[0] = parts[0][2:]  # Remove the prefix from the first element

            if prefix == "m=":
                media_type = parts[0]
                port = int(parts[1])
                payload_types = {}
                for payload_type_id in [int(pt_id) for pt_id in parts[3:]]:
                    payload_type = {"id": payload_type_id}
                    payload_types[payload_type_id] = payload_type

                application_data = {"media": media_type, "payload_types": payload_types}
                transport_data = {"port": port}
                if fingerprint_data is not None:
                    transport_data["fingerprint"] = fingerprint_data
                if ice_pwd is not None:
                    transport_data["pwd"] = ice_pwd
                if ice_ufrag is not None:
                    transport_data["ufrag"] = ice_ufrag
                media_data = call_data[media_type] = {
                    "application_data": application_data,
                    "transport_data": transport_data,
                }

            elif prefix == "a=":
                if ":" in parts[0]:
                    attribute, parts[0] = parts[0].split(":", 1)
                else:
                    attribute = parts[0]

                if (
                    media_type is None
                    or application_data is None
                    or transport_data is None
                ) and not (
                    attribute
                    in (
                        "sendrecv",
                        "sendonly",
                        "recvonly",
                        "inactive",
                        "fingerprint",
                        "group",
                        "ice-options",
                        "msid-semantic",
                        "ice-pwd",
                        "ice-ufrag",
                    )
                ):
                    log.warning(
                        "Received attribute before media description, this is "
                        f"invalid: {line}"
                    )
                    continue

                if attribute == "mid":
                    assert media_data is not None
                    try:
                        media_data["id"] = parts[0]
                    except IndexError:
                        log.warning(f"invalid media ID: {line}")

                elif attribute == "rtpmap":
                    assert application_data is not None
                    assert payload_types is not None
                    pt_id = int(parts[0])
                    codec_info = parts[1].split("/")
                    codec = codec_info[0]
                    clockrate = int(codec_info[1])
                    payload_type = {
                        "id": pt_id,
                        "name": codec,
                        "clockrate": clockrate,
                    }
                    # Handle optional channel count
                    if len(codec_info) > 2:
                        channels = int(codec_info[2])
                        payload_type["channels"] = channels

                    payload_types.setdefault(pt_id, {}).update(payload_type)

                elif attribute == "fmtp":
                    assert payload_types is not None
                    pt_id = int(parts[0])
                    params = parts[1].split(";")
                    try:
                        payload_type = payload_types[pt_id]
                    except KeyError:
                        raise ValueError(
                            f"Can find content type {pt_id}, ignoring: {line}"
                        )

                    try:
                        payload_type["parameters"] = {
                            name: value
                            for name, value in (param.split("=") for param in params)
                        }
                    except ValueError:
                        payload_type.setdefault("exra-parameters", []).extend(params)

                elif attribute == "candidate":
                    assert transport_data is not None
                    candidate = {
                        "foundation": parts[0],
                        "component_id": int(parts[1]),
                        "transport": parts[2],
                        "priority": int(parts[3]),
                        "address": parts[4],
                        "port": int(parts[5]),
                        "type": parts[7],
                    }

                    for part in parts[8:]:
                        if part == "raddr":
                            candidate["rel_addr"] = parts[parts.index(part) + 1]
                        elif part == "rport":
                            candidate["rel_port"] = int(parts[parts.index(part) + 1])
                        elif part == "generation":
                            candidate["generation"] = parts[parts.index(part) + 1]
                        elif part == "network":
                            candidate["network"] = parts[parts.index(part) + 1]

                    transport_data.setdefault("candidates", []).append(candidate)

                elif attribute == "fingerprint":
                    algorithm, fingerprint = parts[0], parts[1]
                    fingerprint_data = {"hash": algorithm, "fingerprint": fingerprint}
                    if transport_data is not None:
                        transport_data.setdefault("fingerprint", {}).update(
                            fingerprint_data
                        )
                elif attribute == "setup":
                    assert transport_data is not None
                    setup = parts[0]
                    transport_data.setdefault("fingerprint", {})["setup"] = setup

                elif attribute == "b":
                    assert application_data is not None
                    bandwidth = int(parts[0])
                    application_data["bandwidth"] = bandwidth

                elif attribute == "rtcp-mux":
                    assert application_data is not None
                    application_data["rtcp-mux"] = True

                elif attribute == "ice-ufrag":
                    if transport_data is not None:
                        transport_data["ufrag"] = parts[0]

                elif attribute == "ice-pwd":
                    if transport_data is not None:
                        transport_data["pwd"] = parts[0]

                host.trigger.point(
                    "XEP-0167_parse_sdp_a",
                    attribute,
                    parts,
                    call_data,
                    metadata,
                    media_type,
                    application_data,
                    transport_data,
                    triggers_no_cancel=True
                )

        except ValueError as e:
            raise ValueError(f"Could not parse line. Invalid format ({e}): {line}") from e
        except IndexError as e:
            raise IndexError(f"Incomplete line. Missing data: {line}") from e

    # we remove private data (data starting with _, used by some plugins (e.g. XEP-0294)
    # to handle session data at media level))
    for key in [k for k in call_data if k.startswith("_")]:
        log.debug(f"cleaning remaining private data {key!r}")
        del call_data[key]

    # FIXME: is this really useful?
    # ICE candidates may only be specified for the first media, this
    # duplicate the candidate for the other in this case
    all_media = {k:v for k,v in call_data.items() if k in ("audio", "video")}
    if len(all_media) > 1 and not all(
        "candidates" in c["transport_data"] for c in all_media.values()
    ):
        first_content = next(iter(all_media.values()))
        try:
            ice_candidates = first_content["transport_data"]["candidates"]
        except KeyError:
            ice_candidates = []
        for idx, content in enumerate(all_media.values()):
            if idx == 0:
                continue
            content["transport_data"].setdefault("candidates", ice_candidates)

    return call_data


def build_description(media: str, media_data: dict, session: dict) -> domish.Element:
    """Generate <description> element from media data

    @param media: media type ("audio" or "video")

    @param media_data: A dictionary containing the media description data.
        The keys and values are described below:

        - ssrc (str, optional): The synchronization source identifier.
        - payload_types (list): A list of dictionaries, each representing a payload
          type.
          Each dictionary may contain the following keys:
            - channels (str, optional): Number of audio channels.
            - clockrate (str, optional): Clock rate of the media.
            - id (str): The unique identifier of the payload type.
            - maxptime (str, optional): Maximum packet time.
            - name (str, optional): Name of the codec.
            - ptime (str, optional): Preferred packet time.
            - parameters (dict, optional): A dictionary of codec-specific parameters.
              Key-value pairs represent the parameter name and value, respectively.
        - bandwidth (str, optional): The bandwidth type.
        - rtcp-mux (bool, optional): Indicates whether RTCP multiplexing is enabled or
          not.
        - encryption (list, optional): A list of dictionaries, each representing an
          encryption method.
          Each dictionary may contain the following keys:
            - tag (str): The unique identifier of the encryption method.
            - crypto-suite (str): The encryption suite in use.
            - key-params (str): Key parameters for the encryption suite.
            - session-params (str, optional): Session parameters for the encryption
              suite.

    @return: A <description> element.
    """
    # FIXME: to be removed once host is accessible from global var
    assert host is not None

    desc_elt = domish.Element((NS_JINGLE_RTP, "description"), attribs={"media": media})

    for pt_id, pt_data in media_data.get("payload_types", {}).items():
        payload_type_elt = desc_elt.addElement("payload-type")
        payload_type_elt["id"] = str(pt_id)
        for attr in ["channels", "clockrate", "maxptime", "name", "ptime"]:
            if attr in pt_data:
                payload_type_elt[attr] = str(pt_data[attr])

        if "parameters" in pt_data:
            for param_name, param_value in pt_data["parameters"].items():
                param_elt = payload_type_elt.addElement("parameter")
                param_elt["name"] = param_name
                param_elt["value"] = param_value
        host.trigger.point(
            "XEP-0167_build_description_payload_type",
            desc_elt,
            media_data,
            pt_data,
            payload_type_elt,
            triggers_no_cancel=True
        )

    if "bandwidth" in media_data:
        bandwidth_elt = desc_elt.addElement("bandwidth")
        bandwidth_elt["type"] = media_data["bandwidth"]

    if media_data.get("rtcp-mux"):
        desc_elt.addElement("rtcp-mux")

    # Add encryption element
    if "encryption" in media_data:
        encryption_elt = desc_elt.addElement("encryption")
        # we always want require encryption if the `encryption` data is present
        encryption_elt["required"] = "1"
        for enc_data in media_data["encryption"]:
            crypto_elt = encryption_elt.addElement("crypto")
            for attr in ["tag", "crypto-suite", "key-params", "session-params"]:
                if attr in enc_data:
                    crypto_elt[attr] = enc_data[attr]

    host.trigger.point(
        "XEP-0167_build_description",
        desc_elt,
        media_data,
        session,
        triggers_no_cancel=True
    )

    return desc_elt


def parse_description(desc_elt: domish.Element) -> dict:
    """Parse <desciption> to a dict

    @param desc_elt: <description> element
    @return: media data as in [build_description]
    """
    # FIXME: to be removed once host is accessible from global var
    assert host is not None

    media_data = {}
    if desc_elt.hasAttribute("ssrc"):
        media_data.setdefault("ssrc", {})[desc_elt["ssrc"]] = {}

    payload_types = {}
    for payload_type_elt in desc_elt.elements(NS_JINGLE_RTP, "payload-type"):
        payload_type_data = {
            attr: payload_type_elt[attr]
            for attr in [
                "channels",
                "clockrate",
                "maxptime",
                "name",
                "ptime",
            ]
            if payload_type_elt.hasAttribute(attr)
        }
        try:
            pt_id = int(payload_type_elt["id"])
        except KeyError:
            log.warning(
                f"missing ID in payload type, ignoring: {payload_type_elt.toXml()}"
            )
            continue

        parameters = {}
        for param_elt in payload_type_elt.elements(NS_JINGLE_RTP, "parameter"):
            param_name = param_elt.getAttribute("name")
            param_value = param_elt.getAttribute("value")
            if not param_name or param_value is None:
                log.warning(f"invalid parameter: {param_elt.toXml()}")
                continue
            parameters[param_name] = param_value

        if parameters:
            payload_type_data["parameters"] = parameters

        host.trigger.point(
            "XEP-0167_parse_description_payload_type",
            desc_elt,
            media_data,
            payload_type_elt,
            payload_type_data,
            triggers_no_cancel=True
        )
        payload_types[pt_id] = payload_type_data

    # bandwidth
    media_data["payload_types"] = payload_types
    try:
        bandwidth_elt = next(desc_elt.elements(NS_JINGLE_RTP, "bandwidth"))
    except StopIteration:
        pass
    else:
        bandwidth = bandwidth_elt.getAttribute("type")
        if not bandwidth:
            log.warning(f"invalid bandwidth: {bandwidth_elt.toXml}")
        else:
            media_data["bandwidth"] = bandwidth

    # rtcp-mux
    rtcp_mux_elt = next(desc_elt.elements(NS_JINGLE_RTP, "rtcp-mux"), None)
    media_data["rtcp-mux"] = rtcp_mux_elt is not None

    # Encryption
    encryption_data = []
    encryption_elt = next(desc_elt.elements(NS_JINGLE_RTP, "encryption"), None)
    if encryption_elt:
        media_data["encryption_required"] = C.bool(
            encryption_elt.getAttribute("required", C.BOOL_FALSE)
        )

        for crypto_elt in encryption_elt.elements(NS_JINGLE_RTP, "crypto"):
            crypto_data = {
                attr: crypto_elt[attr]
                for attr in [
                    "crypto-suite",
                    "key-params",
                    "session-params",
                    "tag",
                ]
                if crypto_elt.hasAttribute(attr)
            }
            encryption_data.append(crypto_data)

    if encryption_data:
        media_data["encryption"] = encryption_data

    host.trigger.point(
        "XEP-0167_parse_description",
        desc_elt,
        media_data,
        triggers_no_cancel=True
    )

    return media_data