changeset 1203:251eba911d4d

server (websockets): fixed websocket handling on HTTPS connections: Original request used to retrieve a page was stored on dynamic pages, but after the end of it, the channel was deleted, resulting in a isSecure() always returning False, and troubles in chain leading to the the use of the wrong session object. This patch fixes this by reworking the way original request is used, and creating a new wrapping class allowing to keep an API similar to iweb.IRequest, with data coming from both the original request and the websocket request. fix 327
author Goffi <goffi@goffi.org>
date Sun, 14 Jul 2019 14:45:51 +0200
parents 3f791fbc1643
children a2df53dfbf46
files libervia/pages/chat/page_meta.py libervia/server/pages.py libervia/server/server.py libervia/server/websockets.py
diffstat 4 files changed, 100 insertions(+), 38 deletions(-) [+]
line wrap: on
line diff
--- a/libervia/pages/chat/page_meta.py	Fri Jul 12 14:58:11 2019 +0200
+++ b/libervia/pages/chat/page_meta.py	Sun Jul 14 14:45:51 2019 +0200
@@ -55,7 +55,7 @@
         join_ret = yield self.host.bridgeCall(
             u"mucJoin", target_jid.userhost(), "", "", profile
         )
-        already_joined, room_jid_s, occupants, user_nick, room_subject, dummy = join_ret
+        already_joined, room_jid_s, occupants, user_nick, room_subject, __ = join_ret
         template_data[u"subject"] = room_subject
         own_jid = jid.JID(room_jid_s)
         own_jid.resource = user_nick
@@ -81,7 +81,7 @@
         identities[author] = yield self.host.bridgeCall(u"identityGet", author, profile)
 
     template_data[u"messages"] = data_objects.Messages(history)
-    template_data[u"identities"] = identities
+    rdata[u'identities'] = template_data[u"identities"] = identities
     template_data[u"target_jid"] = target_jid
     template_data[u"chat_type"] = chat_type
 
@@ -116,11 +116,10 @@
 def on_signal(self, request, signal, *args):
     if signal == "messageNew":
         rdata = self.getRData(request)
-        template_data = request.template_data
         template_data_update = {u"msg": data_objects.Message((args))}
         target_jid = rdata["target"]
-        identities = template_data["identities"]
-        uid, timestamp, from_jid_s, to_jid_s, message, subject, mess_type, extra, dummy = (
+        identities = rdata["identities"]
+        uid, timestamp, from_jid_s, to_jid_s, message, subject, mess_type, extra, __ = (
             args
         )
         from_jid = jid.JID(from_jid_s)
--- a/libervia/server/pages.py	Fri Jul 12 14:58:11 2019 +0200
+++ b/libervia/server/pages.py	Sun Jul 14 14:45:51 2019 +0200
@@ -470,6 +470,16 @@
                 .format( *uri_tuple))
         self.uri_callbacks[uri_tuple] = (self, get_uri_cb)
 
+    def getSignalId(self, request):
+        """Retrieve signal_id for a request
+
+        signal_id is used for dynamic page, to associate a initial request with a
+        signal handler. For WebsocketRequest, signal_id attribute is used (which must
+        be orginal request's id)
+        For server.Request it's id(request)
+        """
+        return getattr(request, 'signal_id', id(request))
+
     def registerSignal(self, request, signal, check_profile=True):
         r"""register a signal handler
 
@@ -491,11 +501,12 @@
         if not self.dynamic:
             log.error(_(u"You can't register signal if page is not dynamic"))
             return
-        LiberviaPage.signals_handlers.setdefault(signal, {})[id(request)] = (
+        signal_id = self.getSignalId(request)
+        LiberviaPage.signals_handlers.setdefault(signal, {})[signal_id] = [
             self,
             request,
             check_profile,
-        )
+        ]
         request._signals_registered.append(signal)
 
     def getConfig(self, key, default=None, value_type=None):
@@ -1025,6 +1036,13 @@
         we send all cached signals
         """
         assert request._signals_cache is not None
+        # we need to replace corresponding original requests by this websocket request
+        # in signals_handlers
+        signal_id = request.signal_id
+        for signal_handlers_map in self.__class__.signals_handlers.itervalues():
+            if signal_id in signal_handlers_map:
+                signal_handlers_map[signal_id][1] = request
+
         cache = request._signals_cache
         request._signals_cache = None
         for request, signal, args in cache:
@@ -1036,8 +1054,9 @@
         we remove signal handler
         """
         for signal in request._signals_registered:
+            signal_id = self.getSignalId(request)
             try:
-                del LiberviaPage.signals_handlers[signal][id(request)]
+                del LiberviaPage.signals_handlers[signal][signal_id]
             except KeyError:
                 log.error(_(u"Can't find signal handler for [{signal}], this should not "
                             u"happen").format(signal=signal))
@@ -1208,11 +1227,11 @@
         request.template_data["websocket"] = WebsocketMeta(
             socket_url, socket_token, socket_debug
         )
-        self.host.registerWSToken(socket_token, self, request)
         # we will keep track of handlers to remove
         request._signals_registered = []
         # we will cache registered signals until socket is opened
         request._signals_cache = []
+        self.host.registerWSToken(socket_token, self, request)
 
     def _prepare_render(self, __, request):
         return defer.maybeDeferred(self.prepare_render, self, request)
--- a/libervia/server/server.py	Fri Jul 12 14:58:11 2019 +0200
+++ b/libervia/server/server.py	Sun Jul 14 14:45:51 2019 +0200
@@ -27,6 +27,7 @@
 import urlparse
 import urllib
 import time
+import copy
 from twisted.application import service
 from twisted.internet import reactor, defer, inotify
 from twisted.web import server
@@ -2658,7 +2659,11 @@
         return self.getExtBaseURL(request, path=scheme, scheme=scheme)
 
     def registerWSToken(self, token, page, request):
-        websockets.LiberviaPageWSProtocol.registerToken(token, page, request)
+        # we make a shallow copy of request to avoid losing request.channel when
+        # connection is lost (which would result as request.isSecure() being always
+        # False). See #327
+        request._signal_id = id(request)
+        websockets.LiberviaPageWSProtocol.registerToken(token, page, copy.copy(request))
 
     ## Various utils ##
 
--- a/libervia/server/websockets.py	Fri Jul 12 14:58:11 2019 +0200
+++ b/libervia/server/websockets.py	Sun Jul 14 14:45:51 2019 +0200
@@ -17,21 +17,78 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+
+import json
+from twisted.internet import error
+from autobahn.twisted import websocket
+from autobahn.twisted import resource as resource
+from autobahn.websocket import types
+from sat.core import exceptions
 from sat.core.i18n import _
 from sat.core.log import getLogger
 
 log = getLogger(__name__)
-from sat.core import exceptions
-
-from autobahn.twisted import websocket
-from autobahn.twisted import resource as resource
-from autobahn.websocket import types
-
-import json
 
 LIBERVIA_PROTOCOL = "libervia_page"
 
 
+class WebsocketRequest(object):
+    """Wrapper around autobahn's ConnectionRequest and Twisted's server.Request
+
+    This is used to have a common interface in Libervia page with request object
+    """
+
+    def __init__(self, ws_protocol, connection_request, server_request):
+        """
+        @param connection_request: websocket request
+        @param serveur_request: original request of the page
+        """
+        self.ws_protocol = ws_protocol
+        self.ws_request = connection_request
+        if self.isSecure():
+            cookie_string = "TWISTED_SECURE_SESSION"
+        else:
+            cookie_string = "TWISTED_SESSION"
+        cookie_value = server_request.getCookie(cookie_string)
+        try:
+            raw_cookies = ws_protocol.http_headers['cookie']
+        except KeyError:
+            raise ValueError(u"missing expected cookie header")
+        self.cookies = {k:v for k,v in (c.split('=') for c in raw_cookies.split(';'))}
+        if self.cookies[cookie_string] != cookie_value:
+            raise exceptions.PermissionError(
+                u"Bad cookie value, this should never happen.\n"
+                u"headers: {headers}".format(headers=ws_protocol.http_headers))
+
+        self.template_data = server_request.template_data
+        self.data = server_request.data
+        self.session = server_request.getSession()
+        self._signals_registered = server_request._signals_registered
+        self._signals_cache = server_request._signals_cache
+        # signal id is needed to link original request with signal handler
+        self.signal_id = server_request._signal_id
+
+    def isSecure(self):
+        return self.ws_protocol.factory.isSecure
+
+    def getSession(self, sessionInterface=None):
+        try:
+            self.session.touch()
+        except (error.AlreadyCalled, error.AlreadyCancelled):
+            # Session has already expired.
+            self.session = None
+
+        if sessionInterface:
+            return self.session.getComponent(sessionInterface)
+
+        return self.session
+
+    def sendData(self, type_, **data):
+        assert "type" not in data
+        data["type"] = type_
+        self.ws_protocol.sendMessage(json.dumps(data, ensure_ascii=False).encode("utf8"))
+
+
 class LiberviaPageWSProtocol(websocket.WebSocketServerProtocol):
     host = None
     tokens_map = {}
@@ -54,8 +111,9 @@
                 types.ConnectionDeny.FORBIDDEN, u"Bad token, please reload page"
             )
         self.token = token
-        self.page = self.tokens_map[token]["page"]
-        self.request = self.tokens_map[token]["request"]
+        token_map = self.tokens_map.pop(token)
+        self.page = token_map["page"]
+        self.request = WebsocketRequest(self, request, token_map["request"])
         return protocol
 
     def onOpen(self):
@@ -66,7 +124,6 @@
                 )
             )
         )
-        self.request.sendData = self.sendJSONData
         self.page.onSocketOpen(self.request)
 
     def onMessage(self, payload, isBinary):
@@ -94,21 +151,8 @@
             cb(page, self.request, data_json)
 
     def onClose(self, wasClean, code, reason):
-        try:
-            token = self.token
-        except AttributeError:
-            log.warning(_(u"Websocket closed but no token is associated"))
-            return
-
         self.page.onSocketClose(self.request)
 
-        try:
-            del self.tokens_map[token]
-            del self.request.sendData
-        except (KeyError, AttributeError):
-            raise exceptions.InternalError(
-                _(u"Token or sendData doesn't exist, this should never happen!")
-            )
         log.debug(
             _(
                 u"Websocket closed for {page} (token: {token}). {reason}".format(
@@ -121,11 +165,6 @@
             )
         )
 
-    def sendJSONData(self, type_, **data):
-        assert "type" not in data
-        data["type"] = type_
-        self.sendMessage(json.dumps(data, ensure_ascii=False).encode("utf8"))
-
     @classmethod
     def getBaseURL(cls, host, secure):
         return u"ws{sec}://localhost:{port}".format(