diff libervia/web/server/websockets.py @ 1518:eb00d593801d

refactoring: rename `libervia` to `libervia.web` + update imports following backend changes
author Goffi <goffi@goffi.org>
date Fri, 02 Jun 2023 16:49:28 +0200
parents libervia/server/websockets.py@ff95501abe74
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/libervia/web/server/websockets.py	Fri Jun 02 16:49:28 2023 +0200
@@ -0,0 +1,224 @@
+#!/usr/bin/env python3
+
+# Libervia: a Salut à Toi frontend
+# Copyright (C) 2011-2021 Jérôme Poisson <goffi@goffi.org>
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+
+import json
+from typing import Optional
+
+from autobahn.twisted import websocket
+from autobahn.twisted import resource as resource
+from autobahn.websocket import types
+from libervia.backend.core import exceptions
+from libervia.backend.core.i18n import _
+from libervia.backend.core.log import getLogger
+
+from . import session_iface
+from .constants import Const as C
+
+log = getLogger(__name__)
+
+host = None
+
+
+class LiberviaPageWSProtocol(websocket.WebSocketServerProtocol):
+
+    def __init__(self):
+        super().__init__()
+        self._init_ok: bool = False
+        self.__profile: Optional[str] = None
+        self.__session: Optional[session_iface.WebSession] = None
+
+    @property
+    def init_ok(self):
+        return self._init_ok
+
+    def send(self, data_type: str, data: dict) -> None:
+        """Send data to frontend"""
+        if not self._init_ok and data_type != "error":
+            raise exceptions.InternalError(
+                "send called when not initialized, this should not happend! Please use "
+                "WebSession.send which takes care of sending correctly the data to all "
+                "sessions."
+            )
+
+        data_root = {
+            "type": data_type,
+            "data": data
+        }
+        self.sendMessage(json.dumps(data_root, ensure_ascii=False).encode())
+
+    def close(self) -> None:
+        log.debug(f"closing websocket for profile {self.__profile}")
+
+    def error(self, error_type: str, msg: str) -> None:
+        """Send an error message to frontend and log it locally"""
+        log.warning(
+            f"websocket error {error_type}: {msg}"
+        )
+        self.send("error", {
+            "type": error_type,
+            "msg": msg,
+        })
+
+    def onConnect(self, request):
+        if "libervia-page" not in request.protocols:
+            raise types.ConnectionDeny(
+                types.ConnectionDeny.NOT_IMPLEMENTED, "No supported protocol"
+            )
+        self._init_ok = False
+        cookies = {}
+        for cookie in request.headers.get("cookie", "").split(";"):
+            k, __, v = cookie.partition("=")
+            cookies[k.strip()] = v.strip()
+        session_uid = (
+            cookies.get("TWISTED_SECURE_SESSION")
+            or cookies.get("TWISTED_SESSION")
+            or ""
+        )
+        if not session_uid:
+            raise types.ConnectionDeny(
+                types.ConnectionDeny.FORBIDDEN, "No session set"
+            )
+        try:
+            session = host.site.getSession(session_uid.encode())
+        except KeyError:
+            raise types.ConnectionDeny(
+                types.ConnectionDeny.FORBIDDEN, "Invalid session"
+            )
+
+        session.touch()
+        session_data = session.getComponent(session_iface.IWebSession)
+        if session_data.ws_socket is not None:
+            log.warning(f"Session socket is already set {session_data.ws_socket=} {self=}], force closing it")
+            try:
+                session_data.ws_socket.send(
+                    "force_close", {"reason": "duplicate connection detected"}
+                )
+            except Exception as e:
+                log.warning(f"Can't force close old connection: {e}")
+        session_data.ws_socket = self
+        self.__session = session_data
+        self.__profile = session_data.profile or C.SERVICE_PROFILE
+        log.debug(f"websocket connection connected for profile {self.__profile}")
+        return "libervia-page"
+
+    def on_open(self):
+        log.debug("websocket connection opened")
+
+    def onMessage(self, payload: bytes, isBinary: bool) -> None:
+        if self.__session is None:
+            raise exceptions.InternalError("empty session, this should never happen")
+        try:
+            data_full = json.loads(payload.decode())
+            data_type = data_full["type"]
+            data = data_full["data"]
+        except ValueError as e:
+            self.error(
+                "bad_request",
+                f"Not valid JSON, ignoring data ({e}): {payload!r}"
+            )
+            return
+        except KeyError:
+            self.error(
+                "bad_request",
+                'Invalid request (missing "type" or "data")'
+            )
+            return
+
+        if data_type == "init":
+            if self._init_ok:
+                self.error(
+                    "bad_request",
+                    "double init"
+                )
+                self.sendClose(4400, "Bad Request")
+                return
+
+            try:
+                profile = data["profile"] or C.SERVICE_PROFILE
+                token = data["token"]
+            except KeyError:
+                self.error(
+                    "bad_request",
+                    "Invalid init data (missing profile or token)"
+                )
+                self.sendClose(4400, "Bad Request")
+                return
+            if ((
+                profile != self.__profile
+                or (token != self.__session.ws_token and profile != C.SERVICE_PROFILE)
+            )):
+                log.debug(
+                    f"profile got {profile}, was expecting {self.__profile}, "
+                    f"token got {token}, was expecting {self.__session.ws_token}, "
+                )
+                self.error(
+                    "Unauthorized",
+                    "Invalid profile or token"
+                )
+                self.sendClose(4401, "Unauthorized")
+                return
+            else:
+                log.debug(f"websocket connection initialized for {profile}")
+                self._init_ok = True
+                # we now send all cached data, if any
+                while True:
+                    try:
+                        session_kw = self.__session.ws_buffer.popleft()
+                    except IndexError:
+                        break
+                    else:
+                        self.send(**session_kw)
+
+        if not self._init_ok:
+            self.error(
+                "Unauthorized",
+                "session not authorized"
+            )
+            self.sendClose(4401, "Unauthorized")
+            return
+
+    def on_close(self, wasClean, code, reason):
+        log.debug(f"closing websocket (profile: {self.__profile}, reason: {reason})")
+        if self.__profile is None:
+            log.error("self.__profile should not be None")
+            self.__profile = C.SERVICE_PROFILE
+
+        if self.__session is None:
+            log.warning("closing a socket without attached session")
+        elif self.__session.ws_socket != self:
+            log.error("session socket is not linked to our instance")
+        else:
+            log.debug(f"reseting websocket session for {self.__profile}")
+            self.__session.ws_socket = None
+        sessions = session_iface.WebSession.get_profile_sessions(self.__profile)
+        log.debug(f"websocket connection for profile {self.__profile} closed")
+        self.__profile = None
+
+    @classmethod
+    def get_base_url(cls, secure):
+        return "ws{sec}://localhost:{port}".format(
+            sec="s" if secure else "",
+            port=host.options["port_https" if secure else "port"],
+        )
+
+    @classmethod
+    def get_resource(cls, secure):
+        factory = websocket.WebSocketServerFactory(cls.get_base_url(secure))
+        factory.protocol = cls
+        return resource.WebSocketResource(factory)