view mod_storage_s3/mod_storage_s3.lua @ 5715:8488ebde5739

mod_http_oauth2: Skip consent screen if requested by client and same scopes already granted This follows the intent behind the OpenID Connect 'prompt' parameter when it does not include the 'consent' keyword, that is the client wishes to skip the consent screen. If the user has already granted the exact same scopes to the exact same client in the past, then one can assume that they may grant it again.
author Kim Alvefur <zash@zash.se>
date Tue, 14 Nov 2023 23:03:37 +0100
parents 9de7a1b36efb
children 80702e33ba71
line wrap: on
line source

local http = require "prosody.net.http";
local array = require "prosody.util.array";
local async = require "prosody.util.async";
local dt = require "prosody.util.datetime";
local hashes = require "prosody.util.hashes";
local httputil = require "prosody.util.http";
local it = require "prosody.util.iterators";
local jid = require "prosody.util.jid";
local json = require "prosody.util.json";
local promise = require "prosody.util.promise";
local st = require "prosody.util.stanza";
local uuid = require "prosody.util.uuid";
local xml = require "prosody.util.xml";
local url = require "socket.url";

local new_uuid = uuid.v7 or uuid.generate;
local hmac_sha256 = hashes.hmac_sha256;
local sha256 = hashes.sha256;

local driver = {};

local bucket = module:get_option_string("s3_bucket", "prosody");
local base_uri = module:get_option_string("s3_base_uri", "http://localhost:9000");
local region = module:get_option_string("s3_region", "us-east-1");

local access_key = module:get_option_string("s3_access_key");
local secret_key = module:get_option_string("s3_secret_key");

local aws4_format = "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s";

local function aws_auth(event)
	local request, options = event.request, event.options;
	local method = options.method or "GET";
	local query = options.query;
	local payload = options.body;

	local payload_type = nil;
	if st.is_stanza(payload) then
		payload_type = "application/xml";
		payload = tostring(payload);
	elseif payload ~= nil then
		payload_type = "application/json";
		payload = json.encode(payload);
	end
	options.body = payload;

	local payload_hash = sha256(payload or "", true);

	local now = os.time();
	local aws_datetime = os.date("!%Y%m%dT%H%M%SZ", now);
	local aws_date = os.date("!%Y%m%d", now);

	local headers = {
		["Accept"] = "*/*";
		["Authorization"] = nil;
		["Content-Type"] = payload_type;
		["Host"] = request.authority;
		["User-Agent"] = "Prosody XMPP Server";
		["X-Amz-Content-Sha256"] = payload_hash;
		["X-Amz-Date"] = aws_datetime;
	};

	local canonical_uri = url.build({ path = request.path });
	local canonical_query = "";
	local canonical_headers = array();
	local signed_headers = array()

	if query then
		local sorted_query = array();
		for name, value in it.sorted_pairs(query) do
			sorted_query:push({ name = name; value = value });
		end
		sorted_query:sort(function (a,b) return a.name < b.name end)
		canonical_query = httputil.formencode(sorted_query):gsub("%%%x%x", string.upper);
		request.query = canonical_query;
	end

	for header_name, header_value in it.sorted_pairs(headers) do
		header_name = header_name:lower();
		canonical_headers:push(header_name .. ":" .. header_value .. "\n");
		signed_headers:push(header_name);
	end

	canonical_headers = canonical_headers:concat();
	signed_headers = signed_headers:concat(";");

	local scope = aws_date .. "/" .. region .. "/s3/aws4_request";

	local canonical_request = method .. "\n"
		.. canonical_uri .. "\n"
		.. canonical_query .. "\n"
		.. canonical_headers .. "\n"
		.. signed_headers .. "\n"
		.. payload_hash;

	local signature_payload = "AWS4-HMAC-SHA256" .. "\n" .. aws_datetime .. "\n" .. scope .. "\n" .. sha256(canonical_request, true);

	-- This can be cached?
	local date_key = hmac_sha256("AWS4" .. secret_key, aws_date);
	local date_region_key = hmac_sha256(date_key, region);
	local date_region_service_key = hmac_sha256(date_region_key, "s3");
	local signing_key = hmac_sha256(date_region_service_key, "aws4_request");

	local signature = hmac_sha256(signing_key, signature_payload, true);

	headers["Authorization"] = string.format(aws4_format, access_key, scope, signed_headers, signature);

	options.headers = headers;
end

function driver:open(store, typ)
	local mt = self[typ or "keyval"]
	if not mt then
		return nil, "unsupported-store";
	end
	local httpclient = http.new({ connection_pooling = true });
	httpclient.events.add_handler("pre-request", aws_auth);
	return setmetatable({ store = store; bucket = bucket; type = typ; http = httpclient }, mt);
end

local keyval = { };
driver.keyval = { __index = keyval; __name = module.name .. " keyval store" };

local function new_request(self, method, path, query, payload)
	local request = url.parse(base_uri);
	request.path = path;

	return self.http:request(url.build(request), { method = method; body = payload; query = query });
end

-- coerce result back into Prosody data type
local function on_result(response)
	if response.code == 404 and response.request.method == "GET" then
		return nil;
	end
	if response.code >= 400 then
		error(response.body);
	end
	local content_type = response.headers["content-type"];
	if content_type == "application/json" then
		return json.decode(response.body);
	elseif content_type == "application/xml" then
		return xml.parse(response.body);
	elseif content_type == "application/x-www-form-urlencoded" then
		return httputil.formdecode(response.body);
	else
		module:log("warn", "Unknown response data type %s", content_type);
		return response.body;
	end
end

function keyval:_path(key)
	return url.build_path({
		is_absolute = true;
		bucket;
		jid.escape(module.host);
		jid.escape(self.store);
		jid.escape(key or "@");
	})
end

function keyval:get(user)
	return async.wait_for(new_request(self, "GET", self:_path(user)):next(on_result));
end

function keyval:set(user, data)

	if data == nil or (type(data) == "table" and next(data) == nil) then
		return async.wait_for(new_request(self, "DELETE", self:_path(user)));
	end

	return async.wait_for(new_request(self, "PUT", self:_path(user), nil, data));
end

function keyval:users()
	local bucket_path = url.build_path({ is_absolute = true; bucket; is_directory = true });
	local prefix = url.build_path({ jid.escape(module.host); jid.escape(self.store); is_directory = true });
	local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, { prefix = prefix }))
	if err or list_result.code ~= 200 then
		return nil, err;
	end
	local list_bucket_result = xml.parse(list_result.body);
	if list_bucket_result:get_child_text("IsTruncated") == "true" then
		local max_keys = list_bucket_result:get_child_text("MaxKeys");
		module:log("warn", "Paging truncated results not implemented, max %s %s returned", max_keys, self.store);
	end
	local keys = array();
	for content in list_bucket_result:childtags("Contents") do
		local key = url.parse_path(content:get_child_text("Key"));
		keys:push(jid.unescape(key[3]));
	end
	return function()
		return keys:pop();
	end
end

local archive = {};
driver.archive = { __index = archive };

archive.caps = {
};

function archive:_path(username, date, when, with, key)
	return url.build_path({
		is_absolute = true;
		bucket;
		jid.escape(module.host);
		jid.escape(self.store);
		jid.escape(username);
		jid.escape(jid.prep(with));
		date or dt.date(when);
		key;
	})
end


-- PUT .../with/when/id
function archive:append(username, key, value, when, with)
	local wrapper = st.stanza("wrapper");
	-- Minio had trouble with timestamps, probably the ':' characters, in paths.
	wrapper:tag("delay", { xmlns = "urn:xmpp:delay"; stamp = dt.datetime(when) }):up();
	wrapper:add_direct_child(value);
	key = key or new_uuid();
	return async.wait_for(new_request(self, "PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r)
		if r.code == 200 then
			return key;
		else
			error(r.body);
		end
	end));
end

function archive:find(username, query)
	local bucket_path = url.build_path({ is_absolute = true; bucket; is_directory = true });
	local prefix = { jid.escape(module.host); jid.escape(self.store); is_directory = true };
	table.insert(prefix, jid.escape(username or "@"));
	if not query then
		query = {};
	end
	if query["with"] then
		table.insert(prefix, sha256(jid.prep(query["with"]), true):sub(1,24));
		if query["start"] and query["end"] and dt.date(query["start"]) == dt.date(query["end"]) then
			table.insert(prefix, sha256(jid.prep(query["with"]), true):sub(1,24));
		end
	end

	prefix = url.build_path(prefix);
	local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, {
		prefix = prefix;
		["max-keys"] = query["max"] and tostring(query["max"]);
	}));
	if err or list_result.code ~= 200 then
		return nil, err;
	end
	local list_bucket_result = xml.parse(list_result.body);
	if list_bucket_result:get_child_text("IsTruncated") == "true" then
		local max_keys = list_bucket_result:get_child_text("MaxKeys");
		module:log("warn", "Paging truncated results not implemented, max %s %s returned", max_keys, self.store);
	end
	local keys = array();
	local iterwrap = function(...)
		return ...;
	end
	if query["reverse"] then
		query["before"], query["after"] = query["after"], query["before"];
		iterwrap = it.reverse;
	end
	local found = not query["after"];
	for content in iterwrap(list_bucket_result:childtags("Contents")) do
		local key = url.parse_path(content:get_child_text("Key"));
		if found and query["before"] == key[6] then
			break
		end
		if (not query["with"] or query["with"] == jid.unescape(key[5]))
		and (not query["start"] or dt.date(query["start"]) >= key[6])
		and (not query["end"] or dt.date(query["end"]) <= key[6])
		and found then
			keys:push({ key = key[6]; date = key[5]; with = jid.unescape(key[4]) });
		end
		if not found and key[6] == query["after"] then
			found = not found
		end
	end
	local i = 0;
	local function get_next()
		i = i + 1;
		local item = keys[i];
		if item == nil then
			return nil;
		end
		-- luacheck: ignore 431/err
		local value, err = async.wait_for(new_request(self, "GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result));
		if not value then
			module:log("error", "%s", err);
			return nil;
		end
		local when = dt.parse(value:get_child_attr("delay", "urn:xmpp:delay", "stamp"));

		if (not query["start"] or query["start"] >= when) and (not query["end"] or query["end"] <= when) then
			return item.key, value.tags[2], when, item.with;
		else
			-- date was correct but not the time
			return get_next();
		end
	end
	return get_next;
end

function archive:users()
	return it.unique(keyval.users(self));
end

local function count(t) local n = 0; for _ in pairs(t) do n = n + 1; end return n; end

function archive:delete(username, query)
	local deletions = {};
	for key, _, when, with in self:find(username, query) do
		deletions[key] = new_request(self, "DELETE", self:_path(username or "@", dt.date(when), nil, with, key));
	end
	return async.wait_for(promise.all(deletions):next(count));
end

module:provides("storage", driver);