view mod_storage_s3/mod_storage_s3.lua @ 5751:d563a6b0dfb7

mod_http_oauth2: Comment on authorization code storage
author Kim Alvefur <zash@zash.se>
date Fri, 01 Dec 2023 21:35:25 +0100
parents 72b0fa7e36dc
children 32bc648e3892
line wrap: on
line source

local http = require "prosody.net.http";
local array = require "prosody.util.array";
local async = require "prosody.util.async";
local dt = require "prosody.util.datetime";
local hashes = require "prosody.util.hashes";
local httputil = require "prosody.util.http";
local it = require "prosody.util.iterators";
local jid = require "prosody.util.jid";
local json = require "prosody.util.json";
local promise = require "prosody.util.promise";
local set = require "prosody.util.set";
local st = require "prosody.util.stanza";
local uuid = require "prosody.util.uuid";
local xml = require "prosody.util.xml";
local url = require "socket.url";

local new_uuid = uuid.v7 or uuid.generate;
local hmac_sha256 = hashes.hmac_sha256;
local sha256 = hashes.sha256;

local driver = {};

local bucket = module:get_option_string("s3_bucket", "prosody");
local base_uri = module:get_option_string("s3_base_uri", "http://localhost:9000");
local region = module:get_option_string("s3_region", "us-east-1");

local access_key = module:get_option_string("s3_access_key");
local secret_key = module:get_option_string("s3_secret_key");

local aws4_format = "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s";

local function aws_auth(event)
	local request, options = event.request, event.options;
	local method = options.method or "GET";
	local query = options.query;
	local payload = options.body;

	local payload_type = nil;
	if st.is_stanza(payload) then
		payload_type = "application/xml";
		payload = tostring(payload);
	elseif payload ~= nil then
		payload_type = "application/json";
		payload = json.encode(payload);
	end
	options.body = payload;

	local payload_hash = sha256(payload or "", true);

	local now = os.time();
	local aws_datetime = os.date("!%Y%m%dT%H%M%SZ", now);
	local aws_date = os.date("!%Y%m%d", now);

	local headers = {
		["Accept"] = "*/*";
		["Authorization"] = nil;
		["Content-Type"] = payload_type;
		["Host"] = request.authority;
		["User-Agent"] = "Prosody XMPP Server";
		["X-Amz-Content-Sha256"] = payload_hash;
		["X-Amz-Date"] = aws_datetime;
	};

	local canonical_uri = url.build({ path = request.path });
	local canonical_query = "";
	local canonical_headers = array();
	local signed_headers = array()

	if query then
		local sorted_query = array();
		for name, value in it.sorted_pairs(query) do
			sorted_query:push({ name = name; value = value });
		end
		sorted_query:sort(function (a,b) return a.name < b.name end)
		canonical_query = httputil.formencode(sorted_query):gsub("%%%x%x", string.upper);
		request.query = canonical_query;
	end

	for header_name, header_value in it.sorted_pairs(headers) do
		header_name = header_name:lower();
		canonical_headers:push(header_name .. ":" .. header_value .. "\n");
		signed_headers:push(header_name);
	end

	canonical_headers = canonical_headers:concat();
	signed_headers = signed_headers:concat(";");

	local scope = aws_date .. "/" .. region .. "/s3/aws4_request";

	local canonical_request = method .. "\n"
		.. canonical_uri .. "\n"
		.. canonical_query .. "\n"
		.. canonical_headers .. "\n"
		.. signed_headers .. "\n"
		.. payload_hash;

	local signature_payload = "AWS4-HMAC-SHA256" .. "\n" .. aws_datetime .. "\n" .. scope .. "\n" .. sha256(canonical_request, true);

	-- This can be cached?
	local date_key = hmac_sha256("AWS4" .. secret_key, aws_date);
	local date_region_key = hmac_sha256(date_key, region);
	local date_region_service_key = hmac_sha256(date_region_key, "s3");
	local signing_key = hmac_sha256(date_region_service_key, "aws4_request");

	local signature = hmac_sha256(signing_key, signature_payload, true);

	headers["Authorization"] = string.format(aws4_format, access_key, scope, signed_headers, signature);

	options.headers = headers;
end

function driver:open(store, typ)
	local mt = self[typ or "keyval"]
	if not mt then
		return nil, "unsupported-store";
	end
	local httpclient = http.new({ connection_pooling = true });
	httpclient.events.add_handler("pre-request", aws_auth);
	return setmetatable({ store = store; bucket = bucket; type = typ; http = httpclient }, mt);
end

local keyval = { };
driver.keyval = { __index = keyval; __name = module.name .. " keyval store" };

local function new_request(self, method, path, query, payload)
	local request = url.parse(base_uri);
	request.path = path;

	return self.http:request(url.build(request), { method = method; body = payload; query = query });
end

-- coerce result back into Prosody data type
local function on_result(response)
	if response.code == 404 and response.request.method == "GET" then
		return nil;
	end
	if response.code >= 400 then
		error(response.body);
	end
	local content_type = response.headers["content-type"];
	if content_type == "application/json" then
		return json.decode(response.body);
	elseif content_type == "application/xml" then
		return xml.parse(response.body);
	elseif content_type == "application/x-www-form-urlencoded" then
		return httputil.formdecode(response.body);
	else
		module:log("warn", "Unknown response data type %s", content_type);
		return response.body;
	end
end

function keyval:_path(key)
	return url.build_path({
		is_absolute = true;
		bucket;
		jid.escape(module.host);
		jid.escape(key or "@");
		jid.escape(self.store);
	})
end

function keyval:get(user)
	return async.wait_for(new_request(self, "GET", self:_path(user)):next(on_result));
end

function keyval:set(user, data)

	if data == nil or (type(data) == "table" and next(data) == nil) then
		return async.wait_for(new_request(self, "DELETE", self:_path(user)));
	end

	return async.wait_for(new_request(self, "PUT", self:_path(user), nil, data));
end

function keyval:users()
	local bucket_path = url.build_path({ is_absolute = true; bucket; is_directory = true });
	local prefix = url.build_path({ jid.escape(module.host); is_directory = true });
	local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, { prefix = prefix }))
	if err or list_result.code ~= 200 then
		return nil, err;
	end
	local list_bucket_result = xml.parse(list_result.body);
	if list_bucket_result:get_child_text("IsTruncated") == "true" then
		local max_keys = list_bucket_result:get_child_text("MaxKeys");
		module:log("warn", "Paging truncated results not implemented, max %s %s returned", max_keys, self.store);
	end
	local keys = array();
	local store_part = jid.escape(self.store);
	for content in list_bucket_result:childtags("Contents") do
		local key = url.parse_path(content:get_child_text("Key"));
		if key[3] == store_part then
			keys:push(jid.unescape(key[2]));
		end
	end
	return function()
		return keys:pop();
	end
end

local archive = {};
driver.archive = { __index = archive };

archive.caps = {
	full_id_range = true; -- both before and after used
	ids = true;
};

function archive:_path(username, date, when, with, key)
	return url.build_path({
		is_absolute = true;
		bucket;
		jid.escape(module.host);
		jid.escape(username or "@");
		jid.escape(self.store);
		date or dt.date(when);
		jid.escape(with and jid.prep(with) or "@");
		key;
	})
end


-- PUT .../with/when/id
function archive:append(username, key, value, when, with)
	local wrapper = st.stanza("wrapper");
	-- Minio had trouble with timestamps, probably the ':' characters, in paths.
	wrapper:tag("delay", { xmlns = "urn:xmpp:delay"; stamp = dt.datetime(when) }):up();
	wrapper:add_direct_child(value);
	key = key or new_uuid();
	return async.wait_for(new_request(self, "PUT", self:_path(username, nil, when, with, key), nil, wrapper):next(function(r)
		if r.code == 200 then
			return key;
		else
			error(r.body);
		end
	end));
end

function archive:find(username, query)
	local bucket_path = url.build_path({ is_absolute = true; bucket; is_directory = true });
	local prefix = { jid.escape(module.host); jid.escape(username or "@"); jid.escape(self.store); is_directory = true };
	if not query then
		query = {};
	end

	if query["start"] and query["end"] and dt.date(query["start"]) == dt.date(query["end"]) then
		table.insert(prefix, dt.date(query["start"]));
		if query["with"] then
			table.insert(prefix, jid.escape(query["with"]));
		end
	end

	prefix = url.build_path(prefix);
	local list_result, err = async.wait_for(new_request(self, "GET", bucket_path, {
		prefix = prefix;
		["max-keys"] = query["max"] and tostring(query["max"]);
	}));
	if err or list_result.code ~= 200 then
		return nil, err;
	end
	local list_bucket_result = xml.parse(list_result.body);
	if list_bucket_result:get_child_text("IsTruncated") == "true" then
		local max_keys = list_bucket_result:get_child_text("MaxKeys");
		module:log("warn", "Paging truncated results not implemented, max %s %s returned", max_keys, self.store);
	end
	local keys = array();
	local iterwrap = function(...)
		return ...;
	end
	if query["reverse"] then
		query["before"], query["after"] = query["after"], query["before"];
		iterwrap = it.reverse;
	end
	local ids = query["ids"] and set.new(query["ids"]);
	local found = not query["after"];
	for content in iterwrap(list_bucket_result:childtags("Contents")) do
		local date, with, id = table.unpack(url.parse_path(content:get_child_text("Key")), 4);
		local when = dt.parse(content:get_child_text("LastModified"));
		with = jid.unescape(with);
		if found and query["before"] == id then
			break
		end
		if (not query["with"] or query["with"] == with)
		and (not query["start"] or query["start"] <= when)
		and (not query["end"] or query["end"] >= when)
		and (not ids or ids:contains(id))
		and found then
			keys:push({ key = id; date = date; when = when; with = with });
		end
		if not found and id == query["after"] then
			found = not found
		end
	end
	keys:sort(function(a, b)
		if a.date ~= b.date then
			return a.date < b.date
		end
		if a.when ~= b.when then
			return a.when < b.when;
		end
		return a.id < b.id;
	end);
	if query["reverse"] then
		keys:reverse();
	end
	local i = 0;
	local function get_next()
		i = i + 1;
		local item = keys[i];
		if item == nil then
			return nil;
		end
		-- luacheck: ignore 431/err
		local value, err = async.wait_for(new_request(self, "GET", self:_path(username or "@", item.date, nil, item.with, item.key)):next(on_result));
		if not value then
			module:log("error", "%s", err);
			return nil;
		end
		return item.key, value, item.when, item.with;
	end
	return get_next;
end

function archive:users()
	return it.unique(keyval.users(self));
end

local function count(t) local n = 0; for _ in pairs(t) do n = n + 1; end return n; end

function archive:delete(username, query)
	local deletions = {};
	for key, _, when, with in self:find(username, query) do
		deletions[key] = new_request(self, "DELETE", self:_path(username or "@", dt.date(when), nil, with, key));
	end
	return async.wait_for(promise.all(deletions):next(count));
end

module:provides("storage", driver);