Mercurial > prosody-modules
view mod_rest/mod_rest.lua @ 4942:e7b9bc629ecc
mod_rest: Add special handling to catch MAM results from remote hosts
Makes MAM queries to remote hosts works.
As the comment says, MAM results from users' local archives or local
MUCs are returned via origin.send() which is provided in the event and
thus already worked. Results from remote hosts go via normal stanza
routing and events, which need this extra handling to catch.
This pattern of iq-set, message+, iq-result is generally limited to MAM.
Closest similar thing might be MUC join, but to really handle that you
would need the webhook callback mechanism.
author | Kim Alvefur <zash@zash.se> |
---|---|
date | Mon, 16 May 2022 19:47:09 +0200 |
parents | c83b009b5bc5 |
children | 83a54f4af94c |
line wrap: on
line source
-- RESTful API -- -- Copyright (c) 2019-2022 Kim Alvefur -- -- This file is MIT/X11 licensed. local encodings = require "util.encodings"; local base64 = encodings.base64; local errors = require "util.error"; local http = require "net.http"; local id = require "util.id"; local jid = require "util.jid"; local json = require "util.json"; local st = require "util.stanza"; local um = require "core.usermanager"; local xml = require "util.xml"; local have_cbor, cbor = pcall(require, "cbor"); local jsonmap = module:require"jsonmap"; local tokens = module:depends("tokenauth"); local auth_mechanisms = module:get_option_set("rest_auth_mechanisms", { "Basic", "Bearer" }); local www_authenticate_header; do local header, realm = {}, module.host.."/"..module.name; for mech in auth_mechanisms do header[#header+1] = ("%s realm=%q"):format(mech, realm); end www_authenticate_header = table.concat(header, ", "); end local function check_credentials(request) local auth_type, auth_data = string.match(request.headers.authorization, "^(%S+)%s(.+)$"); if not (auth_type and auth_data) or not auth_mechanisms:contains(auth_type) then return false; end if auth_type == "Basic" then local creds = base64.decode(auth_data); if not creds then return false; end local username, password = string.match(creds, "^([^:]+):(.*)$"); if not username then return false; end username, password = encodings.stringprep.nodeprep(username), encodings.stringprep.saslprep(password); if not username then return false; end if not um.test_password(username, module.host, password) then return false; end return { username = username, host = module.host }; elseif auth_type == "Bearer" then local token_info = tokens.get_token_info(auth_data); if not token_info or not token_info.session then return false; end return token_info.session; end return nil; end if module:get_option_string("authentication") == "anonymous" and module:get_option_boolean("anonymous_rest") then www_authenticate_header = nil; function check_credentials(request) -- luacheck: ignore 212/request return { username = id.medium():lower(); host = module.host; } end end local function event_suffix(jid_to) local node, _, resource = jid.split(jid_to); if node then if resource then return '/full'; else return '/bare'; end else return '/host'; end end -- TODO This ought to be handled some way other than duplicating this -- core.stanza_router code here. local function compat_preevents(origin, stanza) --> boolean : handled local to = stanza.attr.to; local node, host, resource = jid.split(to); local to_type, to_self; if node then if resource then to_type = '/full'; else to_type = '/bare'; if node == origin.username and host == origin.host then stanza.attr.to = nil; to_self = true; end end else if host then to_type = '/host'; else to_type = '/bare'; to_self = true; end end local event_data = { origin = origin; stanza = stanza; to_self = to_self }; local result = module:fire_event("pre-stanza", event_data); if result ~= nil then return true end if module:fire_event('pre-' .. stanza.name .. to_type, event_data) then return true; end -- do preprocessing return false end -- (table, string) -> table local function amend_from_path(data, path) local st_kind, st_type, st_to = path:match("^([mpi]%w+)/(%w+)/(.*)$"); if not st_kind then return; end if st_kind == "iq" and st_type ~= "get" and st_type ~= "set" then -- GET /iq/disco/jid data = { kind = "iq"; [st_type] = st_type == "ping" or data or {}; }; else data.kind = st_kind; data.type = st_type; end if st_to and st_to ~= "" then data.to = st_to; end return data; end local function parse(mimetype, data, path) --> Stanza, error enum mimetype = mimetype and mimetype:match("^[^; ]*"); if mimetype == "application/xmpp+xml" then return xml.parse(data); elseif mimetype == "application/json" then local parsed, err = json.decode(data); if not parsed then return parsed, err; end if path then parsed = amend_from_path(parsed, path); if not parsed then return nil, "invalid-path"; end end return jsonmap.json2st(parsed); elseif mimetype == "application/cbor" and have_cbor then local parsed, err = cbor.decode(data); if not parsed then return parsed, err; end return jsonmap.json2st(parsed); elseif mimetype == "application/x-www-form-urlencoded"then local parsed = http.formdecode(data); if type(parsed) == "string" then -- This should reject GET /iq/query/to?messagebody if path then return nil, "invalid-query"; end return parse("text/plain", parsed); end for i = #parsed, 1, -1 do parsed[i] = nil; end if path then parsed = amend_from_path(parsed, path); if not parsed then return nil, "invalid-path"; end end return jsonmap.json2st(parsed); elseif mimetype == "text/plain" then if not path then return st.message({ type = "chat" }, data); end local parsed = {}; if path then parsed = amend_from_path(parsed, path); if not parsed then return nil, "invalid-path"; end end if parsed.kind == "message" then parsed.body = data; elseif parsed.kind == "presence" then parsed.show = data; else return nil, "invalid-path"; end return jsonmap.json2st(parsed); elseif not mimetype and path then local parsed = amend_from_path({}, path); if not parsed then return nil, "invalid-path"; end return jsonmap.json2st(parsed); end return nil, "unknown-payload-type"; end local function decide_type(accept, supported_types) -- assumes the accept header is sorted local ret = supported_types[1]; for i = 2, #supported_types do if (accept:find(supported_types[i], 1, true) or 1000) < (accept:find(ret, 1, true) or 1000) then ret = supported_types[i]; end end return ret; end local supported_inputs = { "application/xmpp+xml", "application/json", "application/x-www-form-urlencoded", "text/plain", }; local supported_outputs = { "application/xmpp+xml", "application/json", "application/x-www-form-urlencoded", }; if have_cbor then table.insert(supported_inputs, "application/cbor"); table.insert(supported_outputs, "application/cbor"); end -- Only { string : string } can be form-encoded, discard the rest -- (jsonmap also discards anything unknown or unsupported) local function flatten(t) local form = {}; for k, v in pairs(t) do if type(v) == "string" then form[k] = v; elseif type(v) == "number" then form[k] = tostring(v); elseif v == true then form[k] = ""; end end return form; end local function encode(type, s) if type == "text/plain" then return s:get_child_text("body") or ""; elseif type == "application/xmpp+xml" then return tostring(s); end local mapped, err = jsonmap.st2json(s); if not mapped then return mapped, err; end if type == "application/json" then return json.encode(mapped); elseif type == "application/x-www-form-urlencoded" then return http.formencode(flatten(mapped)); elseif type == "application/cbor" then return cbor.encode(mapped); end error "unsupported encoding"; end local post_errors = errors.init("mod_rest", { noauthz = { code = 401, type = "auth", condition = "not-authorized", text = "No credentials provided" }, unauthz = { code = 403, type = "auth", condition = "not-authorized", text = "Credentials not accepted" }, parse = { code = 400, condition = "not-well-formed", text = "Failed to parse payload", }, xmlns = { code = 422, condition = "invalid-namespace", text = "'xmlns' attribute must be empty", }, name = { code = 422, condition = "unsupported-stanza-type", text = "Invalid stanza, must be 'message', 'presence' or 'iq'.", }, to = { code = 422, condition = "improper-addressing", text = "Invalid destination JID", }, from = { code = 422, condition = "invalid-from", text = "Invalid source JID", }, from_auth = { code = 403, condition = "not-authorized", text = "Not authorized to send stanza with requested 'from'", }, iq_type = { code = 422, condition = "invalid-xml", text = "'iq' stanza must be of type 'get' or 'set'", }, iq_tags = { code = 422, condition = "bad-format", text = "'iq' stanza must have exactly one child tag", }, mediatype = { code = 415, condition = "bad-format", text = "Unsupported media type" }, }); -- GET → iq-get local function parse_request(request, path) if path and request.method == "GET" then -- e.g. /version/{to} if request.url.query then return parse("application/x-www-form-urlencoded", request.url.query, "iq/"..path); end return parse(nil, nil, "iq/"..path); else return parse(request.headers.content_type, request.body, path); end end local function handle_request(event, path) local request, response = event.request, event.response; local from; local origin; local echo = path == "echo"; if echo then path = nil; end if not request.headers.authorization and www_authenticate_header then response.headers.www_authenticate = www_authenticate_header; return post_errors.new("noauthz"); else origin = check_credentials(request); if not origin then return post_errors.new("unauthz"); end from = jid.join(origin.username, origin.host, origin.resource); origin.type = "c2s"; end local payload, err = parse_request(request, path); if not payload then -- parse fail local ctx = { error = err, type = request.headers.content_type, data = request.body, }; if err == "unknown-payload-type" then return post_errors.new("mediatype", ctx); end return post_errors.new("parse", ctx); end if (payload.attr.xmlns or "jabber:client") ~= "jabber:client" then return post_errors.new("xmlns"); elseif payload.name ~= "message" and payload.name ~= "presence" and payload.name ~= "iq" then return post_errors.new("name"); end local to = jid.prep(payload.attr.to); if payload.attr.to and not to then return post_errors.new("to"); end if payload.attr.from then local requested_from = jid.prep(payload.attr.from); if not requested_from then return post_errors.new("from"); end if jid.compare(requested_from, from) then from = requested_from; else return post_errors.new("from_auth"); end end payload.attr = { from = from, to = to, id = payload.attr.id or id.medium(), type = payload.attr.type, ["xml:lang"] = payload.attr["xml:lang"], }; module:log("debug", "Received[rest]: %s", payload:top_tag()); local send_type = decide_type((request.headers.accept or "") ..",".. (request.headers.content_type or ""), supported_outputs) if echo then local ret, err = errors.coerce(encode(send_type, payload)); if not ret then return err; end response.headers.content_type = send_type; return ret; end if payload.name == "iq" then local responses = st.stanza("xmpp"); function origin.send(stanza) responses:add_direct_child(stanza); end if compat_preevents(origin, payload) then return 202; end if payload.attr.type ~= "get" and payload.attr.type ~= "set" then return post_errors.new("iq_type"); elseif #payload.tags ~= 1 then return post_errors.new("iq_tags"); end -- special handling of multiple responses to MAM queries primarily from -- remote hosts, local go directly to origin.send() local archive_event_name = "message"..event_suffix(from); local archive_handler; local archive_query = payload:get_child("query", "urn:xmpp:mam:2"); if archive_query then archive_handler = function(result_event) if result_event.stanza:find("{urn:xmpp:mam:2}result/@queryid") == archive_query.attr.queryid then origin.send(result_event.stanza); return true; end end module:hook(archive_event_name, archive_handler, 1); end local p = module:send_iq(payload, origin):next( function (result) module:log("debug", "Sending[rest]: %s", result.stanza:top_tag()); response.headers.content_type = send_type; if responses[2] then return encode(send_type, responses); end return encode(send_type, result.stanza); end, function (error) if not errors.is_err(error) then module:log("error", "Uncaught native error: %s", error); return select(2, errors.coerce(nil, error)); elseif error.context and error.context.stanza then response.headers.content_type = send_type; module:log("debug", "Sending[rest]: %s", error.context.stanza:top_tag()); return encode(send_type, error.context.stanza); else return error; end end); if archive_handler then p:finally(function () module:unhook(archive_event_name, archive_handler); end) end return p; else function origin.send(stanza) module:log("debug", "Sending[rest]: %s", stanza:top_tag()); response.headers.content_type = send_type; response:send(encode(send_type, stanza)); return true; end if compat_preevents(origin, payload) then return 202; end module:send(payload, origin); return 202; end end module:depends("http"); local demo_handlers = {}; if module:get_option_path("rest_demo_resources", nil) then demo_handlers = module:require"apidemo"; end -- Handle stanzas submitted via HTTP module:provides("http", { route = { POST = handle_request; ["POST /*"] = handle_request; ["GET /*"] = handle_request; -- Only if api_demo_resources are set ["GET /"] = demo_handlers.redirect; ["GET /demo/"] = demo_handlers.main_page; ["GET /demo/openapi.yaml"] = demo_handlers.schema; ["GET /demo/*"] = demo_handlers.resources; }; }); -- Forward stanzas from XMPP to HTTP and return any reply local rest_url = module:get_option_string("rest_callback_url", nil); if rest_url then local function get_url() return rest_url; end if rest_url:find("%b{}") then local httputil = require "util.http"; local render_url = require"util.interpolation".new("%b{}", httputil.urlencode); function get_url(stanza) local at = stanza.attr; return render_url(rest_url, { kind = stanza.name, type = at.type, to = at.to, from = at.from }); end end local send_type = module:get_option_string("rest_callback_content_type", "application/xmpp+xml"); if send_type == "json" then send_type = "application/json"; end module:set_status("info", "Not yet connected"); http.request(get_url(st.stanza("meta", { type = "info", to = module.host, from = module.host })), { method = "OPTIONS", }, function (body, code, response) if code == 0 then module:log_status("error", "Could not connect to callback URL %q: %s", rest_url, body); elseif code == 200 then module:set_status("info", "Connected"); if response.headers.accept then send_type = decide_type(response.headers.accept, supported_outputs); module:log("debug", "Set 'rest_callback_content_type' = %q based on Accept header", send_type); end else module:log_status("warn", "Unexpected response code %d from OPTIONS probe", code); module:log("warn", "Endpoint said: %s", body); end end); local code2err = require "net.http.errors".registry; local function handle_stanza(event) local stanza, origin = event.stanza, event.origin; local reply_allowed = stanza.attr.type ~= "error"; local reply_needed = reply_allowed and stanza.name == "iq"; local receipt; if reply_allowed and stanza.name == "message" and stanza.attr.id and stanza:get_child("urn:xmpp:receipts", "request") then reply_needed = true; receipt = st.stanza("received", { xmlns = "urn:xmpp:receipts", id = stanza.id }); end local request_body = encode(send_type, stanza); -- Keep only the top level element and let the rest be GC'd stanza = st.clone(stanza, true); module:log("debug", "Sending[rest]: %s", stanza:top_tag()); http.request(get_url(stanza), { body = request_body, headers = { ["Content-Type"] = send_type, ["Content-Language"] = stanza.attr["xml:lang"], Accept = table.concat(supported_inputs, ", "); }, }):next(function (response) module:set_status("info", "Connected"); local reply; local code, body = response.code, response.body; if not reply_allowed then return; elseif code == 202 or code == 204 then if not reply_needed then -- Delivered, no reply return; end else local parsed, err = parse(response.headers["content-type"], body); if not parsed then module:log("warn", "Failed parsing data from REST callback: %s, %q", err, body); elseif parsed.name ~= stanza.name then module:log("warn", "REST callback responded with the wrong stanza type, got %s but expected %s", parsed.name, stanza.name); else parsed.attr = { from = stanza.attr.to, to = stanza.attr.from, id = parsed.attr.id or id.medium(); type = parsed.attr.type, ["xml:lang"] = parsed.attr["xml:lang"], }; if parsed.name == "message" and parsed.attr.type == "groupchat" then parsed.attr.to = jid.bare(stanza.attr.from); end if not stanza.attr.type and parsed:get_child("error") then parsed.attr.type = "error"; end if parsed.attr.type == "error" then parsed.attr.id = stanza.attr.id; elseif parsed.name == "iq" then parsed.attr.id = stanza.attr.id; parsed.attr.type = "result"; end reply = parsed; end end if not reply then local code_hundreds = code - (code % 100); if code_hundreds == 200 then reply = st.reply(stanza); if stanza.name ~= "iq" then reply.attr.id = id.medium(); end -- TODO presence/status=body ? elseif code2err[code] then reply = st.error_reply(stanza, errors.new(code, nil, code2err)); elseif code_hundreds == 400 then reply = st.error_reply(stanza, "modify", "bad-request", body); elseif code_hundreds == 500 then reply = st.error_reply(stanza, "cancel", "internal-server-error", body); else reply = st.error_reply(stanza, "cancel", "undefined-condition", body); end end if receipt then reply:add_direct_child(receipt); end module:log("debug", "Received[rest]: %s", reply:top_tag()); origin.send(reply); end, function (err) module:log_status("error", "Could not connect to callback URL %q: %s", rest_url, err); origin.send(st.error_reply(stanza, "wait", "recipient-unavailable", err.text)); end):catch(function (err) module:log("error", "Error[rest]: %s", err); end); return true; end local send_kinds = module:get_option_set("rest_callback_stanzas", { "message", "presence", "iq" }); local event_presets = { -- Don't override everything on normal VirtualHosts by default ["local"] = { "host" }, -- Comonents get to handle all kinds of stanzas ["component"] = { "bare", "full", "host" }, }; local hook_events = module:get_option_set("rest_callback_events", event_presets[module:get_host_type()]); for kind in send_kinds do for event in hook_events do module:hook(kind.."/"..event, handle_stanza, -1); end end end local supported_errors = { "text/html", "application/xmpp+xml", "application/json", }; local http_server = require "net.http.server"; module:hook_object_event(http_server, "http-error", function (event) local request, response = event.request, event.response; local response_as = decide_type(request and request.headers.accept or "", supported_errors); if response_as == "application/xmpp+xml" then if response then response.headers.content_type = "application/xmpp+xml"; end local stream_error = st.stanza("error", { xmlns = "http://etherx.jabber.org/streams" }); if event.error then stream_error:tag(event.error.condition, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' }):up(); if event.error.text then stream_error:text_tag("text", event.error.text, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' }); end end return tostring(stream_error); elseif response_as == "application/json" then if response then response.headers.content_type = "application/json"; end return json.encode({ type = "error", error = event.error, code = event.code, }); end end, 1);