view libervia/web/server/websockets.py @ 1543:f00497c00e38

pages (chat): fix `own_jid` exposure: `own_jid` is a `JID` instance, and must be casted to str to be exposed.
author Goffi <goffi@goffi.org>
date Thu, 06 Jul 2023 12:05:48 +0200
parents eb00d593801d
children
line wrap: on
line source

#!/usr/bin/env python3

# Libervia: a Salut à Toi frontend
# Copyright (C) 2011-2021 Jérôme Poisson <goffi@goffi.org>

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import json
from typing import Optional

from autobahn.twisted import websocket
from autobahn.twisted import resource as resource
from autobahn.websocket import types
from libervia.backend.core import exceptions
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger

from . import session_iface
from .constants import Const as C

log = getLogger(__name__)

host = None


class LiberviaPageWSProtocol(websocket.WebSocketServerProtocol):

    def __init__(self):
        super().__init__()
        self._init_ok: bool = False
        self.__profile: Optional[str] = None
        self.__session: Optional[session_iface.WebSession] = None

    @property
    def init_ok(self):
        return self._init_ok

    def send(self, data_type: str, data: dict) -> None:
        """Send data to frontend"""
        if not self._init_ok and data_type != "error":
            raise exceptions.InternalError(
                "send called when not initialized, this should not happend! Please use "
                "WebSession.send which takes care of sending correctly the data to all "
                "sessions."
            )

        data_root = {
            "type": data_type,
            "data": data
        }
        self.sendMessage(json.dumps(data_root, ensure_ascii=False).encode())

    def close(self) -> None:
        log.debug(f"closing websocket for profile {self.__profile}")

    def error(self, error_type: str, msg: str) -> None:
        """Send an error message to frontend and log it locally"""
        log.warning(
            f"websocket error {error_type}: {msg}"
        )
        self.send("error", {
            "type": error_type,
            "msg": msg,
        })

    def onConnect(self, request):
        if "libervia-page" not in request.protocols:
            raise types.ConnectionDeny(
                types.ConnectionDeny.NOT_IMPLEMENTED, "No supported protocol"
            )
        self._init_ok = False
        cookies = {}
        for cookie in request.headers.get("cookie", "").split(";"):
            k, __, v = cookie.partition("=")
            cookies[k.strip()] = v.strip()
        session_uid = (
            cookies.get("TWISTED_SECURE_SESSION")
            or cookies.get("TWISTED_SESSION")
            or ""
        )
        if not session_uid:
            raise types.ConnectionDeny(
                types.ConnectionDeny.FORBIDDEN, "No session set"
            )
        try:
            session = host.site.getSession(session_uid.encode())
        except KeyError:
            raise types.ConnectionDeny(
                types.ConnectionDeny.FORBIDDEN, "Invalid session"
            )

        session.touch()
        session_data = session.getComponent(session_iface.IWebSession)
        if session_data.ws_socket is not None:
            log.warning(f"Session socket is already set {session_data.ws_socket=} {self=}], force closing it")
            try:
                session_data.ws_socket.send(
                    "force_close", {"reason": "duplicate connection detected"}
                )
            except Exception as e:
                log.warning(f"Can't force close old connection: {e}")
        session_data.ws_socket = self
        self.__session = session_data
        self.__profile = session_data.profile or C.SERVICE_PROFILE
        log.debug(f"websocket connection connected for profile {self.__profile}")
        return "libervia-page"

    def on_open(self):
        log.debug("websocket connection opened")

    def onMessage(self, payload: bytes, isBinary: bool) -> None:
        if self.__session is None:
            raise exceptions.InternalError("empty session, this should never happen")
        try:
            data_full = json.loads(payload.decode())
            data_type = data_full["type"]
            data = data_full["data"]
        except ValueError as e:
            self.error(
                "bad_request",
                f"Not valid JSON, ignoring data ({e}): {payload!r}"
            )
            return
        except KeyError:
            self.error(
                "bad_request",
                'Invalid request (missing "type" or "data")'
            )
            return

        if data_type == "init":
            if self._init_ok:
                self.error(
                    "bad_request",
                    "double init"
                )
                self.sendClose(4400, "Bad Request")
                return

            try:
                profile = data["profile"] or C.SERVICE_PROFILE
                token = data["token"]
            except KeyError:
                self.error(
                    "bad_request",
                    "Invalid init data (missing profile or token)"
                )
                self.sendClose(4400, "Bad Request")
                return
            if ((
                profile != self.__profile
                or (token != self.__session.ws_token and profile != C.SERVICE_PROFILE)
            )):
                log.debug(
                    f"profile got {profile}, was expecting {self.__profile}, "
                    f"token got {token}, was expecting {self.__session.ws_token}, "
                )
                self.error(
                    "Unauthorized",
                    "Invalid profile or token"
                )
                self.sendClose(4401, "Unauthorized")
                return
            else:
                log.debug(f"websocket connection initialized for {profile}")
                self._init_ok = True
                # we now send all cached data, if any
                while True:
                    try:
                        session_kw = self.__session.ws_buffer.popleft()
                    except IndexError:
                        break
                    else:
                        self.send(**session_kw)

        if not self._init_ok:
            self.error(
                "Unauthorized",
                "session not authorized"
            )
            self.sendClose(4401, "Unauthorized")
            return

    def on_close(self, wasClean, code, reason):
        log.debug(f"closing websocket (profile: {self.__profile}, reason: {reason})")
        if self.__profile is None:
            log.error("self.__profile should not be None")
            self.__profile = C.SERVICE_PROFILE

        if self.__session is None:
            log.warning("closing a socket without attached session")
        elif self.__session.ws_socket != self:
            log.error("session socket is not linked to our instance")
        else:
            log.debug(f"reseting websocket session for {self.__profile}")
            self.__session.ws_socket = None
        sessions = session_iface.WebSession.get_profile_sessions(self.__profile)
        log.debug(f"websocket connection for profile {self.__profile} closed")
        self.__profile = None

    @classmethod
    def get_base_url(cls, secure):
        return "ws{sec}://localhost:{port}".format(
            sec="s" if secure else "",
            port=host.options["port_https" if secure else "port"],
        )

    @classmethod
    def get_resource(cls, secure):
        factory = websocket.WebSocketServerFactory(cls.get_base_url(secure))
        factory.protocol = cls
        return resource.WebSocketResource(factory)