view mod_rest/mod_rest.lua @ 3909:eb27e51cf2c9

mod_rest: Handle uncaught native errors correctly Previously they would slip through and be returned over HTTP with 200 OK.
author Matthew Wild <mwild1@gmail.com>
date Wed, 26 Feb 2020 18:04:17 +0000
parents 3d0e8e32453c
children 49efd1323a1b
line wrap: on
line source

-- RESTful API
--
-- Copyright (c) 2019-2020 Kim Alvefur
--
-- This file is MIT/X11 licensed.

local errors = require "util.error";
local http = require "net.http";
local id = require "util.id";
local jid = require "util.jid";
local json = require "util.json";
local st = require "util.stanza";
local xml = require "util.xml";

local allow_any_source = module:get_host_type() == "component";
local validate_from_addresses = module:get_option_boolean("validate_from_addresses", true);
local secret = assert(module:get_option_string("rest_credentials"), "rest_credentials is a required setting");
local auth_type = assert(secret:match("^%S+"), "Format of rest_credentials MUST be like 'Bearer secret'");
assert(auth_type == "Bearer" or auth_type == "Basic", "Only 'Bearer' and 'Basic' are supported in rest_credentials");

local jsonmap = module:require"jsonmap";
-- Bearer token
local function check_credentials(request)
	return request.headers.authorization == secret;
end
if secret == "Basic" and module:get_host_type() == "local" then
	local um = require "core.usermanager";
	local encodings = require "util.encodings";
	local base64 = encodings.base64;

	function check_credentials(request)
		local creds = string.match(request.headers.authorization, "^Basic%s+([A-Za-z0-9+/]+=?=?)%s*$");
		if not creds then return false; end
		creds = base64.decode(creds);
		if not creds then return false; end
		local username, password = string.match(creds, "^([^:]+):(.*)$");
		if not username then return false; end
		username, password = encodings.stringprep.nodeprep(username), encodings.stringprep.saslprep(password);
		if not username then return false; end
		if not um.test_password(username, module.host, password) then
			return false;
		end
		return jid.join(username, module.host);
	end
end

local function parse(mimetype, data)
	mimetype = mimetype and mimetype:match("^[^; ]*");
	if mimetype == "application/xmpp+xml" then
		return xml.parse(data);
	elseif mimetype == "application/json" then
		local parsed, err = json.decode(data);
		if not parsed then
			return parsed, err;
		end
		return jsonmap.json2st(parsed);
	elseif mimetype == "text/plain" then
		return st.message({ type = "chat" }, data);
	end
	return nil, "unknown-payload-type";
end

local supported_types = { "application/xmpp+xml", "application/json" };

local function decide_type(accept)
	-- assumes the accept header is sorted
	local ret = supported_types[1];
	for i = 2, #supported_types do
		if (accept:find(supported_types[i], 1, true) or 1000) < (accept:find(ret, 1, true) or 1000) then
			ret = supported_types[i];
		end
	end
	return ret;
end

local function encode(type, s)
	if type == "application/json" then
		return json.encode(jsonmap.st2json(s));
	elseif type == "text/plain" then
		return s:get_child_text("body") or "";
	end
	return tostring(s);
end

local function handle_post(event)
	local request, response = event.request, event.response;
	local from = module.host;
	if not request.headers.authorization then
		response.headers.www_authenticate = ("%s realm=%q"):format(auth_type, module.host.."/"..module.name);
		return 401;
	else
		local authz = check_credentials(request);
		if not authz then
			return 401;
		end
		if type(authz) == "string" then
			from = authz;
		end
	end
	local payload, err = parse(request.headers.content_type, request.body);
	if not payload then
		-- parse fail
		return errors.new({ code = 400, text = "Failed to parse payload" }, { error = err, type = request.headers.content_type, data = request.body });
	end
	if payload.attr.xmlns then
		return errors.new({ code = 422, text = "'xmlns' attribute must be empty" });
	elseif payload.name ~= "message" and payload.name ~= "presence" and payload.name ~= "iq" then
		return errors.new({ code = 422, text = "Invalid stanza, must be 'message', 'presence' or 'iq'." });
	end
	local to = jid.prep(payload.attr.to);
	if not to then
		return errors.new({ code = 422, text = "Invalid destination JID" });
	end
	if allow_any_source and payload.attr.from then
		from = jid.prep(payload.attr.from);
		if not from then
			return errors.new({ code = 422, text = "Invalid source JID" });
		end
		if validate_from_addresses and not jid.compare(from, module.host) then
			return errors.new({ code = 403, text = "Source JID must belong to current host" });
		end
	end
	payload.attr = {
		from = from,
		to = to,
		id = payload.attr.id or id.medium(),
		type = payload.attr.type,
		["xml:lang"] = payload.attr["xml:lang"],
	};
	module:log("debug", "Received[rest]: %s", payload:top_tag());
	local send_type = decide_type((request.headers.accept or "") ..",".. request.headers.content_type)
	if payload.name == "iq" then
		if payload.attr.type ~= "get" and payload.attr.type ~= "set" then
			return errors.new({ code = 422, text = "'iq' stanza must be of type 'get' or 'set'" });
		elseif #payload.tags ~= 1 then
			return errors.new({ code = 422, text = "'iq' stanza must have exactly one child tag" });
		end
		return module:send_iq(payload):next(
			function (result)
				module:log("debug", "Sending[rest]: %s", result.stanza:top_tag());
				response.headers.content_type = send_type;
				return encode(send_type, result.stanza);
			end,
			function (error)
				if not errors.is_err(error) then
					module:log("error", "Uncaught native error: %s", error);
					return select(2, errors.coerce(nil, error));
				elseif error.context and error.context.stanza then
					response.headers.content_type = send_type;
					module:log("debug", "Sending[rest]: %s", error.context.stanza:top_tag());
					return encode(send_type, error.context.stanza);
				else
					return error;
				end
			end);
	else
		local origin = {};
		function origin.send(stanza)
			module:log("debug", "Sending[rest]: %s", stanza:top_tag());
			response.headers.content_type = send_type;
			response:send(encode(send_type, stanza));
			return true;
		end
		module:send(payload, origin);
		return 202;
	end
end

-- Handle stanzas submitted via HTTP
module:depends("http");
module:provides("http", {
		route = {
			POST = handle_post;
		};
	});

-- Forward stanzas from XMPP to HTTP and return any reply
local rest_url = module:get_option_string("rest_callback_url", nil);
if rest_url then
	local send_type = module:get_option_string("rest_callback_content_type", "application/xmpp+xml");
	if send_type == "json" then
		send_type = "application/json";
	end

	module:set_status("info", "Not yet connected");
	http.request(rest_url, {
			method = "OPTIONS",
		}, function (body, code, response)
			if code == 0 then
				return module:log_status("error", "Could not connect to callback URL %q: %s", rest_url, body);
			else
				module:set_status("info", "Connected");
			end
			if code == 200 and response.headers.accept then
				send_type = decide_type(response.headers.accept);
				module:log("debug", "Set 'rest_callback_content_type' = %q based on Accept header", send_type);
			end
		end);

	local code2err = {
		[400] = { condition = "bad-request"; type = "modify" };
		[401] = { condition = "not-authorized"; type = "auth" };
		[402] = { condition = "not-authorized"; type = "auth" };
		[403] = { condition = "forbidden"; type = "auth" };
		[404] = { condition = "item-not-found"; type = "cancel" };
		[406] = { condition = "not-acceptable"; type = "modify" };
		[408] = { condition = "remote-server-timeout"; type = "wait" };
		[409] = { condition = "conflict"; type = "cancel" };
		[410] = { condition = "gone"; type = "cancel" };
		[411] = { condition = "bad-request"; type = "modify" };
		[412] = { condition = "bad-request"; type = "modify" };
		[413] = { condition = "resource-constraint"; type = "modify" };
		[414] = { condition = "resource-constraint"; type = "modify" };
		[415] = { condition = "bad-request"; type = "modify" };
		[429] = { condition = "resource-constraint"; type = "wait" };
		[431] = { condition = "resource-constraint"; type = "wait" };

		[500] = { condition = "internal-server-error"; type = "cancel" };
		[501] = { condition = "feature-not-implemented"; type = "modify" };
		[502] = { condition = "remote-server-timeout"; type = "wait" };
		[503] = { condition = "service-unavailable"; type = "cancel" };
		[504] = { condition = "remote-server-timeout"; type = "wait" };
		[507] = { condition = "resource-constraint"; type = "wait" };
	};

	local function handle_stanza(event)
		local stanza, origin = event.stanza, event.origin;
		local reply_needed = stanza.name == "iq";
		local receipt;

		if stanza.attr.type == "error" then
			reply_needed = false;
		end

		if stanza.name == "message" and stanza.attr.id and stanza:get_child("urn:xmpp:receipts", "request") then
			reply_needed = true;
			receipt = st.stanza("received", { xmlns = "urn:xmpp:receipts", id = stanza.id });
		end

		local request_body = encode(send_type, stanza);

		-- Keep only the top level element and let the rest be GC'd
		stanza = st.clone(stanza, true);

		module:log("debug", "Sending[rest]: %s", stanza:top_tag());
		http.request(rest_url, {
				body = request_body,
				headers = {
					["Content-Type"] = send_type,
					["Content-Language"] = stanza.attr["xml:lang"],
					Accept = table.concat(supported_types, ", ");
				},
			}, function (body, code, response)
				if code == 0 then
					module:log_status("error", "Could not connect to callback URL %q: %s", rest_url, body);
					origin.send(st.error_reply(stanza, "wait", "recipient-unavailable", body));
					return;
				else
					module:set_status("info", "Connected");
				end
				local reply;

				if code == 202 or code == 204 then
					if not reply_needed then
						-- Delivered, no reply
						return;
					end
				else
					local parsed, err = parse(response.headers["content-type"], body);
					if not parsed then
						module:log("warn", "Failed parsing data from REST callback: %s, %q", err, body);
					elseif parsed.name ~= stanza.name then
						module:log("warn", "REST callback responded with the wrong stanza type, got %s but expected %s", parsed.name, stanza.name);
					else
						parsed.attr = {
							from = stanza.attr.to,
							to = stanza.attr.from,
							id = parsed.attr.id or id.medium();
							type = parsed.attr.type,
							["xml:lang"] = parsed.attr["xml:lang"],
						};
						if parsed.name == "message" and parsed.attr.type == "groupchat" then
							parsed.attr.to = jid.bare(stanza.attr.from);
						end
						if not stanza.attr.type and parsed:get_child("error") then
							parsed.attr.type = "error";
						end
						if parsed.attr.type == "error" then
							parsed.attr.id = stanza.attr.id;
						elseif parsed.name == "iq" then
							parsed.attr.id = stanza.attr.id;
							parsed.attr.type = "result";
						end
						reply = parsed;
					end
				end

				if not reply then
					local code_hundreds = code - (code % 100);
					if code_hundreds == 200 then
						reply = st.reply(stanza);
						if stanza.name ~= "iq" then
							reply.attr.id = id.medium();
						end
						-- TODO presence/status=body ?
					elseif code2err[code] then
						reply = st.error_reply(stanza, errors.new(code, nil, code2err));
					elseif code_hundreds == 400 then
						reply = st.error_reply(stanza, "modify", "bad-request", body);
					elseif code_hundreds == 500 then
						reply = st.error_reply(stanza, "cancel", "internal-server-error", body);
					else
						reply = st.error_reply(stanza, "cancel", "undefined-condition", body);
					end
				end

				if receipt then
					reply:add_direct_child(receipt);
				end

				module:log("debug", "Received[rest]: %s", reply:top_tag());

				origin.send(reply);
			end);

		return true;
	end

	if module:get_host_type() == "component" then
		module:hook("iq/bare", handle_stanza, -1);
		module:hook("message/bare", handle_stanza, -1);
		module:hook("presence/bare", handle_stanza, -1);
		module:hook("iq/full", handle_stanza, -1);
		module:hook("message/full", handle_stanza, -1);
		module:hook("presence/full", handle_stanza, -1);
		module:hook("iq/host", handle_stanza, -1);
		module:hook("message/host", handle_stanza, -1);
		module:hook("presence/host", handle_stanza, -1);
	else
		-- Don't override everything on normal VirtualHosts
		module:hook("iq/host", handle_stanza, -1);
		module:hook("message/host", handle_stanza, -1);
		module:hook("presence/host", handle_stanza, -1);
	end
end

local http_server = require "net.http.server";
module:hook_object_event(http_server, "http-error", function (event)
	local request, response = event.request, event.response;
	if decide_type(request and request.headers.accept or "") == "application/json" then
		if response then
			response.headers.content_type = "application/json";
		end
		return json.encode({
				type = "error",
				error = event.error,
				code = event.code,
			});
	end
end, 10);