diff sat/tools/web.py @ 3822:65bac82e4049

core (tools/web): helped method to download files: this method is for internal file download, progression mechanism is not used. rel 368
author Goffi <goffi@goffi.org>
date Wed, 29 Jun 2022 12:07:45 +0200
parents 7550ae9cfbac
children 524856bd7b19
line wrap: on
line diff
--- a/sat/tools/web.py	Wed Jun 29 12:06:21 2022 +0200
+++ b/sat/tools/web.py	Wed Jun 29 12:07:45 2022 +0200
@@ -16,13 +16,20 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+from typing import Optional, Union
+from pathlib import Path
+from io import BufferedIOBase
+
 from OpenSSL import SSL
-from zope.interface import implementer
+import treq
 from treq.client import HTTPClient
+from twisted.internet import reactor, ssl
 from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
-from twisted.internet import reactor, ssl
 from twisted.web import iweb
 from twisted.web import client as http_client
+from zope.interface import implementer
+
+from sat.core import exceptions
 from sat.core.log import getLogger
 
 
@@ -65,3 +72,55 @@
 #: following treq doesn't check TLS, obviously it is unsecure and should not be used
 #: without explicit warning
 treq_client_no_ssl = HTTPClient(http_client.Agent(reactor, NoCheckContextFactory()))
+
+
+async def downloadFile(
+    url: str,
+    dest: Union[str, Path, BufferedIOBase],
+    max_size: Optional[int] = None
+) -> None:
+    """Helper method to download a file
+
+    This is for internal download, for high level download with progression, use
+    ``plugin_misc_download``.
+
+    Inspired from
+    https://treq.readthedocs.io/en/latest/howto.html#handling-streaming-responses
+
+    @param dest: destination filename or file-like object
+        of it's a file-like object, you'll have to close it yourself
+    @param max_size: if set, an exceptions.DataError will be raised if the downloaded file
+        is bigger that given value (in bytes).
+    """
+    if isinstance(dest, BufferedIOBase):
+        f = dest
+        must_close = False
+    else:
+        dest = Path(dest)
+        f = dest.open("wb")
+        must_close = True
+    d = treq.get(url, unbuffered=True)
+    written = 0
+
+    def write(data: bytes):
+        if max_size is not None:
+            nonlocal written
+            written += len(data)
+            if written > max_size:
+                raise exceptions.DataError(
+                    "downloaded file is bigger than expected ({max_size})"
+                )
+        f.write(data)
+
+    d.addCallback(treq.collect, f.write)
+    try:
+        await d
+    except exceptions.DataError as e:
+        log.warning("download cancelled due to file oversized")
+        raise e
+    except Exception as e:
+        log.error(f"Can't write file {dest}: {e}")
+        raise e
+    finally:
+        if must_close:
+            f.close()