# HG changeset patch # User Kim Alvefur # Date 1699718489 -3600 # Node ID b4632d5f840b489b6ccf0fd84e9642df1ebf4698 # Parent 8afa0fb8a73e2dd4c4bab0fc0067827b313a91a0 mod_storage_s3: Move request signing into a net.http hook diff -r 8afa0fb8a73e -r b4632d5f840b mod_storage_s3/mod_storage_s3.lua --- a/mod_storage_s3/mod_storage_s3.lua Fri Nov 10 00:26:17 2023 +0100 +++ b/mod_storage_s3/mod_storage_s3.lua Sat Nov 11 17:01:29 2023 +0100 @@ -25,22 +25,13 @@ local access_key = module:get_option_string("s3_access_key"); local secret_key = module:get_option_string("s3_secret_key"); -function driver:open(store, typ) - local mt = self[typ or "keyval"] - if not mt then - return nil, "unsupported-store"; - end - return setmetatable({ store = store; bucket = bucket; type = typ }, mt); -end - -local keyval = { }; -driver.keyval = { __index = keyval; __name = module.name .. " keyval store" }; - local aws4_format = "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s"; -local function new_request(method, path, query, payload) - local request = url.parse(base_uri); - request.path = path; +local function aws_auth(event) + local request, options = event.request, event.options; + local method = options.method or "GET"; + local query = options.query; + local payload = options.body; local payload_type = nil; if st.is_stanza(payload) then @@ -50,6 +41,7 @@ payload_type = "application/json"; payload = json.encode(payload); end + options.body = payload; local payload_hash = sha256(payload or "", true); @@ -112,7 +104,27 @@ headers["Authorization"] = string.format(aws4_format, access_key, scope, signed_headers, signature); - return http.request(url.build(request), { method = method; headers = headers; body = payload }); + options.headers = headers; +end + +function driver:open(store, typ) + local mt = self[typ or "keyval"] + if not mt then + return nil, "unsupported-store"; + end + local httpclient = http.new({}); + httpclient.events.add_handler("pre-request", aws_auth); + return setmetatable({ store = store; bucket = bucket; type = typ; http = httpclient }, mt); +end + +local keyval = { }; +driver.keyval = { __index = keyval; __name = module.name .. " keyval store" }; + +local function new_request(self, method, path, query, payload) + local request = url.parse(base_uri); + request.path = path; + + return self.http:request(url.build(request), { method = method; body = payload; query = query }); end -- coerce result back into Prosody data type @@ -147,22 +159,22 @@ end function keyval:get(user) - return async.wait_for(new_request("GET", self:_path(user)):next(on_result)); + return async.wait_for(new_request(self, "GET", self:_path(user)):next(on_result)); end function keyval:set(user, data) if data == nil or (type(data) == "table" and next(data) == nil) then - return async.wait_for(new_request("DELETE", self:_path(user))); + return async.wait_for(new_request(self, "DELETE", self:_path(user))); end - return async.wait_for(new_request("PUT", self:_path(user), nil, data)); + return async.wait_for(new_request(self, "PUT", self:_path(user), nil, data)); end function keyval:users() local bucket_path = url.build_path({ is_absolute = true; bucket; is_directory = true }); local prefix = url.build_path({ jid.escape(module.host); jid.escape(self.store); is_directory = true }); - local list_result, err = async.wait_for(new_request("GET", bucket_path, { prefix = prefix })) + local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, { prefix = prefix })) if err or list_result.code ~= 200 then return nil, err; end @@ -208,7 +220,7 @@ wrapper:tag("delay", { xmlns = "urn:xmpp:delay"; stamp = dt.datetime(when) }):up(); wrapper:add_direct_child(value); key = key or new_uuid(); - return async.wait_for(new_request("PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r) + return async.wait_for(new_request(self, "PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r) if r.code == 200 then return key; else @@ -232,7 +244,7 @@ end prefix = url.build_path(prefix); - local list_result, err = async.wait_for(new_request("GET", bucket_path, { + local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, { prefix = prefix; ["max-keys"] = query["max"] and tostring(query["max"]); })); @@ -276,7 +288,7 @@ return nil; end -- luacheck: ignore 431/err - local value, err = async.wait_for(new_request("GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result)); + local value, err = async.wait_for(new_request(self, "GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result)); if not value then module:log("error", "%s", err); return nil;