changeset 5695:b4632d5f840b

mod_storage_s3: Move request signing into a net.http hook
author Kim Alvefur <zash@zash.se>
date Sat, 11 Nov 2023 17:01:29 +0100
parents 8afa0fb8a73e
children 66986f5271c3
files mod_storage_s3/mod_storage_s3.lua
diffstat 1 files changed, 34 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/mod_storage_s3/mod_storage_s3.lua	Fri Nov 10 00:26:17 2023 +0100
+++ b/mod_storage_s3/mod_storage_s3.lua	Sat Nov 11 17:01:29 2023 +0100
@@ -25,22 +25,13 @@
 local access_key = module:get_option_string("s3_access_key");
 local secret_key = module:get_option_string("s3_secret_key");
 
-function driver:open(store, typ)
-	local mt = self[typ or "keyval"]
-	if not mt then
-		return nil, "unsupported-store";
-	end
-	return setmetatable({ store = store; bucket = bucket; type = typ }, mt);
-end
-
-local keyval = { };
-driver.keyval = { __index = keyval; __name = module.name .. " keyval store" };
-
 local aws4_format = "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s";
 
-local function new_request(method, path, query, payload)
-	local request = url.parse(base_uri);
-	request.path = path;
+local function aws_auth(event)
+	local request, options = event.request, event.options;
+	local method = options.method or "GET";
+	local query = options.query;
+	local payload = options.body;
 
 	local payload_type = nil;
 	if st.is_stanza(payload) then
@@ -50,6 +41,7 @@
 		payload_type = "application/json";
 		payload = json.encode(payload);
 	end
+	options.body = payload;
 
 	local payload_hash = sha256(payload or "", true);
 
@@ -112,7 +104,27 @@
 
 	headers["Authorization"] = string.format(aws4_format, access_key, scope, signed_headers, signature);
 
-	return http.request(url.build(request), { method = method; headers = headers; body = payload });
+	options.headers = headers;
+end
+
+function driver:open(store, typ)
+	local mt = self[typ or "keyval"]
+	if not mt then
+		return nil, "unsupported-store";
+	end
+	local httpclient = http.new({});
+	httpclient.events.add_handler("pre-request", aws_auth);
+	return setmetatable({ store = store; bucket = bucket; type = typ; http = httpclient }, mt);
+end
+
+local keyval = { };
+driver.keyval = { __index = keyval; __name = module.name .. " keyval store" };
+
+local function new_request(self, method, path, query, payload)
+	local request = url.parse(base_uri);
+	request.path = path;
+
+	return self.http:request(url.build(request), { method = method; body = payload; query = query });
 end
 
 -- coerce result back into Prosody data type
@@ -147,22 +159,22 @@
 end
 
 function keyval:get(user)
-	return async.wait_for(new_request("GET", self:_path(user)):next(on_result));
+	return async.wait_for(new_request(self, "GET", self:_path(user)):next(on_result));
 end
 
 function keyval:set(user, data)
 
 	if data == nil or (type(data) == "table" and next(data) == nil) then
-		return async.wait_for(new_request("DELETE", self:_path(user)));
+		return async.wait_for(new_request(self, "DELETE", self:_path(user)));
 	end
 
-	return async.wait_for(new_request("PUT", self:_path(user), nil, data));
+	return async.wait_for(new_request(self, "PUT", self:_path(user), nil, data));
 end
 
 function keyval:users()
 	local bucket_path = url.build_path({ is_absolute = true; bucket; is_directory = true });
 	local prefix = url.build_path({ jid.escape(module.host); jid.escape(self.store); is_directory = true });
-	local list_result, err = async.wait_for(new_request("GET", bucket_path, { prefix = prefix }))
+	local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, { prefix = prefix }))
 	if err or list_result.code ~= 200 then
 		return nil, err;
 	end
@@ -208,7 +220,7 @@
 	wrapper:tag("delay", { xmlns = "urn:xmpp:delay"; stamp = dt.datetime(when) }):up();
 	wrapper:add_direct_child(value);
 	key = key or new_uuid();
-	return async.wait_for(new_request("PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r)
+	return async.wait_for(new_request(self, "PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r)
 		if r.code == 200 then
 			return key;
 		else
@@ -232,7 +244,7 @@
 	end
 
 	prefix = url.build_path(prefix);
-	local list_result, err = async.wait_for(new_request("GET", bucket_path, {
+	local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, {
 		prefix = prefix;
 		["max-keys"] = query["max"] and tostring(query["max"]);
 	}));
@@ -276,7 +288,7 @@
 			return nil;
 		end
 		-- luacheck: ignore 431/err
-		local value, err = async.wait_for(new_request("GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result));
+		local value, err = async.wait_for(new_request(self, "GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result));
 		if not value then
 			module:log("error", "%s", err);
 			return nil;