# HG changeset patch # User Matthew Wild # Date 1582742200 0 # Node ID 49efd1323a1bf12e97d3569e121f6e301ad5252b # Parent eb27e51cf2c95c9810339959f6744c42657fe2b4 mod_rest: Add support for token authentication diff -r eb27e51cf2c9 -r 49efd1323a1b mod_rest/mod_rest.lua --- a/mod_rest/mod_rest.lua Wed Feb 26 18:04:17 2020 +0000 +++ b/mod_rest/mod_rest.lua Wed Feb 26 18:36:40 2020 +0000 @@ -4,34 +4,41 @@ -- -- This file is MIT/X11 licensed. +local encodings = require "util.encodings"; +local base64 = encodings.base64; local errors = require "util.error"; local http = require "net.http"; local id = require "util.id"; local jid = require "util.jid"; local json = require "util.json"; local st = require "util.stanza"; +local um = require "core.usermanager"; local xml = require "util.xml"; -local allow_any_source = module:get_host_type() == "component"; -local validate_from_addresses = module:get_option_boolean("validate_from_addresses", true); -local secret = assert(module:get_option_string("rest_credentials"), "rest_credentials is a required setting"); -local auth_type = assert(secret:match("^%S+"), "Format of rest_credentials MUST be like 'Bearer secret'"); -assert(auth_type == "Bearer" or auth_type == "Basic", "Only 'Bearer' and 'Basic' are supported in rest_credentials"); +local jsonmap = module:require"jsonmap"; + +local tokens = module:depends("authtokens"); + +local auth_mechanisms = module:get_option_set("rest_auth_mechanisms", { "Basic", "Bearer" }); -local jsonmap = module:require"jsonmap"; +local www_authenticate_header; +do + local header, realm = {}, module.host.."/"..module.name; + for mech in auth_mechanisms do + header[#header+1] = ("%s realm=%q"):format(mech, realm); + end + www_authenticate_header = table.concat(header, ", "); +end + -- Bearer token local function check_credentials(request) - return request.headers.authorization == secret; -end -if secret == "Basic" and module:get_host_type() == "local" then - local um = require "core.usermanager"; - local encodings = require "util.encodings"; - local base64 = encodings.base64; + local auth_type, auth_data = string.match(request.headers.authorization, "^(%S+)%s(.+)$"); + if not (auth_type and auth_data) or not auth_mechanisms:contains(auth_type) then + return false; + end - function check_credentials(request) - local creds = string.match(request.headers.authorization, "^Basic%s+([A-Za-z0-9+/]+=?=?)%s*$"); - if not creds then return false; end - creds = base64.decode(creds); + if auth_type == "Basic" then + local creds = base64.decode(auth_data); if not creds then return false; end local username, password = string.match(creds, "^([^:]+):(.*)$"); if not username then return false; end @@ -40,8 +47,15 @@ if not um.test_password(username, module.host, password) then return false; end - return jid.join(username, module.host); + return { username = username, host = module.host }; + elseif auth_type == "Bearer" then + local token_info = tokens.get_token_info(auth_data); + if not token_info or not token_info.session then + return false; + end + return token_info.session; end + return nil; end local function parse(mimetype, data) @@ -84,18 +98,18 @@ local function handle_post(event) local request, response = event.request, event.response; - local from = module.host; + local from; + local origin; + if not request.headers.authorization then - response.headers.www_authenticate = ("%s realm=%q"):format(auth_type, module.host.."/"..module.name); + response.headers.www_authenticate = www_authenticate_header; return 401; else - local authz = check_credentials(request); - if not authz then + origin = check_credentials(request); + if not origin then return 401; end - if type(authz) == "string" then - from = authz; - end + from = jid.join(origin.username, origin.host, origin.resource); end local payload, err = parse(request.headers.content_type, request.body); if not payload then @@ -111,13 +125,15 @@ if not to then return errors.new({ code = 422, text = "Invalid destination JID" }); end - if allow_any_source and payload.attr.from then - from = jid.prep(payload.attr.from); - if not from then + if payload.attr.from then + local requested_from = jid.prep(payload.attr.from); + if not requested_from then return errors.new({ code = 422, text = "Invalid source JID" }); end - if validate_from_addresses and not jid.compare(from, module.host) then - return errors.new({ code = 403, text = "Source JID must belong to current host" }); + if jid.compare(requested_from, from) then + from = requested_from; + else + return errors.new({ code = 403, text = "Not authorized to send from "..requested_from }); end end payload.attr = { @@ -130,12 +146,15 @@ module:log("debug", "Received[rest]: %s", payload:top_tag()); local send_type = decide_type((request.headers.accept or "") ..",".. request.headers.content_type) if payload.name == "iq" then + function origin.send(stanza) + prosody.core_route_stanza(nil, stanza); + end if payload.attr.type ~= "get" and payload.attr.type ~= "set" then return errors.new({ code = 422, text = "'iq' stanza must be of type 'get' or 'set'" }); elseif #payload.tags ~= 1 then return errors.new({ code = 422, text = "'iq' stanza must have exactly one child tag" }); end - return module:send_iq(payload):next( + return module:send_iq(payload, origin):next( function (result) module:log("debug", "Sending[rest]: %s", result.stanza:top_tag()); response.headers.content_type = send_type; @@ -154,7 +173,6 @@ end end); else - local origin = {}; function origin.send(stanza) module:log("debug", "Sending[rest]: %s", stanza:top_tag()); response.headers.content_type = send_type;