Mercurial > libervia-web
view libervia/web/server/websockets.py @ 1530:b338c31d5251
browser: integrate the `jid` module
author | Goffi <goffi@goffi.org> |
---|---|
date | Thu, 08 Jun 2023 23:32:47 +0200 |
parents | eb00d593801d |
children |
line wrap: on
line source
#!/usr/bin/env python3 # Libervia: a Salut à Toi frontend # Copyright (C) 2011-2021 Jérôme Poisson <goffi@goffi.org> # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import json from typing import Optional from autobahn.twisted import websocket from autobahn.twisted import resource as resource from autobahn.websocket import types from libervia.backend.core import exceptions from libervia.backend.core.i18n import _ from libervia.backend.core.log import getLogger from . import session_iface from .constants import Const as C log = getLogger(__name__) host = None class LiberviaPageWSProtocol(websocket.WebSocketServerProtocol): def __init__(self): super().__init__() self._init_ok: bool = False self.__profile: Optional[str] = None self.__session: Optional[session_iface.WebSession] = None @property def init_ok(self): return self._init_ok def send(self, data_type: str, data: dict) -> None: """Send data to frontend""" if not self._init_ok and data_type != "error": raise exceptions.InternalError( "send called when not initialized, this should not happend! Please use " "WebSession.send which takes care of sending correctly the data to all " "sessions." ) data_root = { "type": data_type, "data": data } self.sendMessage(json.dumps(data_root, ensure_ascii=False).encode()) def close(self) -> None: log.debug(f"closing websocket for profile {self.__profile}") def error(self, error_type: str, msg: str) -> None: """Send an error message to frontend and log it locally""" log.warning( f"websocket error {error_type}: {msg}" ) self.send("error", { "type": error_type, "msg": msg, }) def onConnect(self, request): if "libervia-page" not in request.protocols: raise types.ConnectionDeny( types.ConnectionDeny.NOT_IMPLEMENTED, "No supported protocol" ) self._init_ok = False cookies = {} for cookie in request.headers.get("cookie", "").split(";"): k, __, v = cookie.partition("=") cookies[k.strip()] = v.strip() session_uid = ( cookies.get("TWISTED_SECURE_SESSION") or cookies.get("TWISTED_SESSION") or "" ) if not session_uid: raise types.ConnectionDeny( types.ConnectionDeny.FORBIDDEN, "No session set" ) try: session = host.site.getSession(session_uid.encode()) except KeyError: raise types.ConnectionDeny( types.ConnectionDeny.FORBIDDEN, "Invalid session" ) session.touch() session_data = session.getComponent(session_iface.IWebSession) if session_data.ws_socket is not None: log.warning(f"Session socket is already set {session_data.ws_socket=} {self=}], force closing it") try: session_data.ws_socket.send( "force_close", {"reason": "duplicate connection detected"} ) except Exception as e: log.warning(f"Can't force close old connection: {e}") session_data.ws_socket = self self.__session = session_data self.__profile = session_data.profile or C.SERVICE_PROFILE log.debug(f"websocket connection connected for profile {self.__profile}") return "libervia-page" def on_open(self): log.debug("websocket connection opened") def onMessage(self, payload: bytes, isBinary: bool) -> None: if self.__session is None: raise exceptions.InternalError("empty session, this should never happen") try: data_full = json.loads(payload.decode()) data_type = data_full["type"] data = data_full["data"] except ValueError as e: self.error( "bad_request", f"Not valid JSON, ignoring data ({e}): {payload!r}" ) return except KeyError: self.error( "bad_request", 'Invalid request (missing "type" or "data")' ) return if data_type == "init": if self._init_ok: self.error( "bad_request", "double init" ) self.sendClose(4400, "Bad Request") return try: profile = data["profile"] or C.SERVICE_PROFILE token = data["token"] except KeyError: self.error( "bad_request", "Invalid init data (missing profile or token)" ) self.sendClose(4400, "Bad Request") return if (( profile != self.__profile or (token != self.__session.ws_token and profile != C.SERVICE_PROFILE) )): log.debug( f"profile got {profile}, was expecting {self.__profile}, " f"token got {token}, was expecting {self.__session.ws_token}, " ) self.error( "Unauthorized", "Invalid profile or token" ) self.sendClose(4401, "Unauthorized") return else: log.debug(f"websocket connection initialized for {profile}") self._init_ok = True # we now send all cached data, if any while True: try: session_kw = self.__session.ws_buffer.popleft() except IndexError: break else: self.send(**session_kw) if not self._init_ok: self.error( "Unauthorized", "session not authorized" ) self.sendClose(4401, "Unauthorized") return def on_close(self, wasClean, code, reason): log.debug(f"closing websocket (profile: {self.__profile}, reason: {reason})") if self.__profile is None: log.error("self.__profile should not be None") self.__profile = C.SERVICE_PROFILE if self.__session is None: log.warning("closing a socket without attached session") elif self.__session.ws_socket != self: log.error("session socket is not linked to our instance") else: log.debug(f"reseting websocket session for {self.__profile}") self.__session.ws_socket = None sessions = session_iface.WebSession.get_profile_sessions(self.__profile) log.debug(f"websocket connection for profile {self.__profile} closed") self.__profile = None @classmethod def get_base_url(cls, secure): return "ws{sec}://localhost:{port}".format( sec="s" if secure else "", port=host.options["port_https" if secure else "port"], ) @classmethod def get_resource(cls, secure): factory = websocket.WebSocketServerFactory(cls.get_base_url(secure)) factory.protocol = cls return resource.WebSocketResource(factory)