Mercurial > libervia-backend
comparison libervia/backend/tools/web.py @ 4071:4b842c1fb686
refactoring: renamed `sat` package to `libervia.backend`
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 02 Jun 2023 11:49:51 +0200 |
parents | sat/tools/web.py@524856bd7b19 |
children | 0d7bb4df2343 |
comparison
equal
deleted
inserted
replaced
4070:d10748475025 | 4071:4b842c1fb686 |
---|---|
1 #!/usr/bin/env python3 | |
2 | |
3 # Libervia: an XMPP client | |
4 # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) | |
5 | |
6 # This program is free software: you can redistribute it and/or modify | |
7 # it under the terms of the GNU Affero General Public License as published by | |
8 # the Free Software Foundation, either version 3 of the License, or | |
9 # (at your option) any later version. | |
10 | |
11 # This program is distributed in the hope that it will be useful, | |
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
14 # GNU Affero General Public License for more details. | |
15 | |
16 # You should have received a copy of the GNU Affero General Public License | |
17 # along with this program. If not, see <http://www.gnu.org/licenses/>. | |
18 | |
19 from typing import Optional, Union | |
20 from pathlib import Path | |
21 from io import BufferedIOBase | |
22 | |
23 from OpenSSL import SSL | |
24 import treq | |
25 from treq.client import HTTPClient | |
26 from twisted.internet import reactor, ssl | |
27 from twisted.internet.interfaces import IOpenSSLClientConnectionCreator | |
28 from twisted.web import iweb | |
29 from twisted.web import client as http_client | |
30 from zope.interface import implementer | |
31 | |
32 from libervia.backend.core import exceptions | |
33 from libervia.backend.core.log import getLogger | |
34 | |
35 | |
36 log = getLogger(__name__) | |
37 | |
38 | |
39 SSLError = SSL.Error | |
40 | |
41 | |
42 @implementer(IOpenSSLClientConnectionCreator) | |
43 class NoCheckConnectionCreator(object): | |
44 def __init__(self, hostname, ctx): | |
45 self._ctx = ctx | |
46 | |
47 def clientConnectionForTLS(self, tlsProtocol): | |
48 context = self._ctx | |
49 connection = SSL.Connection(context, None) | |
50 connection.set_app_data(tlsProtocol) | |
51 return connection | |
52 | |
53 | |
54 @implementer(iweb.IPolicyForHTTPS) | |
55 class NoCheckContextFactory: | |
56 """Context factory which doesn't do TLS certificate check | |
57 | |
58 /!\\ it's obvisously a security flaw to use this class, | |
59 and it should be used only with explicit agreement from the end used | |
60 """ | |
61 | |
62 def creatorForNetloc(self, hostname, port): | |
63 log.warning( | |
64 "TLS check disabled for {host} on port {port}".format( | |
65 host=hostname, port=port | |
66 ) | |
67 ) | |
68 certificateOptions = ssl.CertificateOptions(trustRoot=None) | |
69 return NoCheckConnectionCreator(hostname, certificateOptions.getContext()) | |
70 | |
71 | |
72 #: following treq doesn't check TLS, obviously it is unsecure and should not be used | |
73 #: without explicit warning | |
74 treq_client_no_ssl = HTTPClient(http_client.Agent(reactor, NoCheckContextFactory())) | |
75 | |
76 | |
77 async def download_file( | |
78 url: str, | |
79 dest: Union[str, Path, BufferedIOBase], | |
80 max_size: Optional[int] = None | |
81 ) -> None: | |
82 """Helper method to download a file | |
83 | |
84 This is for internal download, for high level download with progression, use | |
85 ``plugin_misc_download``. | |
86 | |
87 Inspired from | |
88 https://treq.readthedocs.io/en/latest/howto.html#handling-streaming-responses | |
89 | |
90 @param dest: destination filename or file-like object | |
91 of it's a file-like object, you'll have to close it yourself | |
92 @param max_size: if set, an exceptions.DataError will be raised if the downloaded file | |
93 is bigger that given value (in bytes). | |
94 """ | |
95 if isinstance(dest, BufferedIOBase): | |
96 f = dest | |
97 must_close = False | |
98 else: | |
99 dest = Path(dest) | |
100 f = dest.open("wb") | |
101 must_close = True | |
102 d = treq.get(url, unbuffered=True) | |
103 written = 0 | |
104 | |
105 def write(data: bytes): | |
106 if max_size is not None: | |
107 nonlocal written | |
108 written += len(data) | |
109 if written > max_size: | |
110 raise exceptions.DataError( | |
111 "downloaded file is bigger than expected ({max_size})" | |
112 ) | |
113 f.write(data) | |
114 | |
115 d.addCallback(treq.collect, f.write) | |
116 try: | |
117 await d | |
118 except exceptions.DataError as e: | |
119 log.warning("download cancelled due to file oversized") | |
120 raise e | |
121 except Exception as e: | |
122 log.error(f"Can't write file {dest}: {e}") | |
123 raise e | |
124 finally: | |
125 if must_close: | |
126 f.close() |