diff src/server/server.py @ 1006:d0b27d1e2d50

server: moved code to retrieve external server from legacy blog to server.py, and use it to find websocket URL
author Goffi <goffi@goffi.org>
date Sat, 06 Jan 2018 12:37:56 +0100
parents 05cc33d8e328
children 1593e00078d2
line wrap: on
line diff
--- a/src/server/server.py	Fri Jan 05 16:30:05 2018 +0100
+++ b/src/server/server.py	Sat Jan 06 12:37:56 2018 +0100
@@ -1786,6 +1786,69 @@
         # FIXME: check that no information is leaked (c.f. https://twistedmatrix.com/documents/current/web/howto/using-twistedweb.html#request-encoders)
         self.root.putChild(path, web_resource.EncodingResourceWrapper(resource, [server.GzipEncoderFactory()]))
 
+    def getExtBaseURLData(self, request):
+        """Retrieve external base URL Data
+
+        this method tried to retrieve the base URL found by external user
+        It does by checking in this order:
+            - base_url_ext option from configuration
+            - proxy x-forwarder-host headers
+            - URL of the request
+        @return (urlparse.SplitResult): SplitResult instance with only scheme and netloc filled
+        """
+        ext_data = self.base_url_ext_data
+        url_path = request.URLPath()
+        if not ext_data.scheme or not ext_data.netloc:
+            # ext_data is not specified, we check headers
+            if request.requestHeaders.hasHeader('x-forwarded-host'):
+                # we are behing a proxy
+                # we fill proxy_scheme and proxy_netloc value
+                proxy_host = request.requestHeaders.getRawHeaders('x-forwarded-host')[0]
+                try:
+                    proxy_server = request.requestHeaders.getRawHeaders('x-forwarded-server')[0]
+                except TypeError:
+                    # no x-forwarded-server found, we use proxy_host
+                    proxy_netloc = proxy_host
+                else:
+                    # if the proxy host has a port, we use it with server name
+                    proxy_port = urlparse.urlsplit(u'//{}'.format(proxy_host)).port
+                    proxy_netloc = u'{}:{}'.format(proxy_server, proxy_port) if proxy_port is not None else proxy_server
+                proxy_netloc = proxy_netloc.decode('utf-8')
+                try:
+                    proxy_scheme = request.requestHeaders.getRawHeaders('x-forwarded-proto')[0].decode('utf-8')
+                except TypeError:
+                    proxy_scheme = None
+            else:
+                proxy_scheme, proxy_netloc = None, None
+        else:
+            proxy_scheme, proxy_netloc = None, None
+
+        return urlparse.SplitResult(
+            ext_data.scheme or proxy_scheme or url_path.scheme.decode('utf-8'),
+            ext_data.netloc or proxy_netloc or url_path.netloc.decode('utf-8'),
+            ext_data.path or u'/',
+            '', '')
+
+    def getExtBaseURL(self, request, path='', query='', fragment='', scheme=None):
+        """Get external URL according to given elements
+
+        external URL is the URL seen by external user
+        @param path(unicode): same as for urlsplit.urlsplit
+            path will be prefixed to follow found external URL if suitable
+        @param params(unicode): same as for urlsplit.urlsplit
+        @param query(unicode): same as for urlsplit.urlsplit
+        @param fragment(unicode): same as for urlsplit.urlsplit
+        @param scheme(unicode, None): if not None, will override scheme from base URL
+        @return (unicode): external URL
+        """
+        split_result = self.getExtBaseURLData(request)
+        return urlparse.urlunsplit((
+            split_result.scheme.decode('utf-8') if scheme is None else scheme,
+            split_result.netloc.decode('utf-8'),
+            os.path.join(split_result.path, path),
+            query, fragment))
+
+
     ## Sessions ##
 
     def purgeSession(self, request):
@@ -1814,24 +1877,13 @@
     ## Websocket (dynamic pages) ##
 
     def getWebsocketURL(self, request):
-        if request.isSecure():
-            ws = 'wss'
-        else:
-            ws = 'ws'
-
-        if self.base_url_ext:
-            base_url = self.base_url_ext
+        base_url_split = self.getExtBaseURLData(request)
+        if base_url_split.scheme.endswith('s'):
+            scheme = u'wss'
         else:
-            o = self.options
-            if request.isSecure():
-                port = o['port_https_ext'] or o['port_https']
-            else:
-                port = o['port']
-            base_url = request.getRequestHostname().decode('utf-8') + u':' + unicode(port)+ u'/'
+            scheme = u'ws'
 
-        return u'{ws}://{base_url}{ws}'.format(
-            ws = ws,
-            base_url = base_url)
+        return self.getExtBaseURL(request, path=scheme, scheme=scheme)
 
     def registerWSToken(self, token, page, request):
         websockets.LiberviaPageWSProtocol.registerToken(token, page, request)