Mercurial > prosody-modules
comparison mod_rest/mod_rest.lua @ 3910:49efd1323a1b
mod_rest: Add support for token authentication
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Wed, 26 Feb 2020 18:36:40 +0000 |
parents | eb27e51cf2c9 |
children | 064c32a5be7c |
comparison
equal
deleted
inserted
replaced
3909:eb27e51cf2c9 | 3910:49efd1323a1b |
---|---|
2 -- | 2 -- |
3 -- Copyright (c) 2019-2020 Kim Alvefur | 3 -- Copyright (c) 2019-2020 Kim Alvefur |
4 -- | 4 -- |
5 -- This file is MIT/X11 licensed. | 5 -- This file is MIT/X11 licensed. |
6 | 6 |
7 local encodings = require "util.encodings"; | |
8 local base64 = encodings.base64; | |
7 local errors = require "util.error"; | 9 local errors = require "util.error"; |
8 local http = require "net.http"; | 10 local http = require "net.http"; |
9 local id = require "util.id"; | 11 local id = require "util.id"; |
10 local jid = require "util.jid"; | 12 local jid = require "util.jid"; |
11 local json = require "util.json"; | 13 local json = require "util.json"; |
12 local st = require "util.stanza"; | 14 local st = require "util.stanza"; |
15 local um = require "core.usermanager"; | |
13 local xml = require "util.xml"; | 16 local xml = require "util.xml"; |
14 | 17 |
15 local allow_any_source = module:get_host_type() == "component"; | |
16 local validate_from_addresses = module:get_option_boolean("validate_from_addresses", true); | |
17 local secret = assert(module:get_option_string("rest_credentials"), "rest_credentials is a required setting"); | |
18 local auth_type = assert(secret:match("^%S+"), "Format of rest_credentials MUST be like 'Bearer secret'"); | |
19 assert(auth_type == "Bearer" or auth_type == "Basic", "Only 'Bearer' and 'Basic' are supported in rest_credentials"); | |
20 | |
21 local jsonmap = module:require"jsonmap"; | 18 local jsonmap = module:require"jsonmap"; |
19 | |
20 local tokens = module:depends("authtokens"); | |
21 | |
22 local auth_mechanisms = module:get_option_set("rest_auth_mechanisms", { "Basic", "Bearer" }); | |
23 | |
24 local www_authenticate_header; | |
25 do | |
26 local header, realm = {}, module.host.."/"..module.name; | |
27 for mech in auth_mechanisms do | |
28 header[#header+1] = ("%s realm=%q"):format(mech, realm); | |
29 end | |
30 www_authenticate_header = table.concat(header, ", "); | |
31 end | |
32 | |
22 -- Bearer token | 33 -- Bearer token |
23 local function check_credentials(request) | 34 local function check_credentials(request) |
24 return request.headers.authorization == secret; | 35 local auth_type, auth_data = string.match(request.headers.authorization, "^(%S+)%s(.+)$"); |
25 end | 36 if not (auth_type and auth_data) or not auth_mechanisms:contains(auth_type) then |
26 if secret == "Basic" and module:get_host_type() == "local" then | 37 return false; |
27 local um = require "core.usermanager"; | 38 end |
28 local encodings = require "util.encodings"; | 39 |
29 local base64 = encodings.base64; | 40 if auth_type == "Basic" then |
30 | 41 local creds = base64.decode(auth_data); |
31 function check_credentials(request) | |
32 local creds = string.match(request.headers.authorization, "^Basic%s+([A-Za-z0-9+/]+=?=?)%s*$"); | |
33 if not creds then return false; end | |
34 creds = base64.decode(creds); | |
35 if not creds then return false; end | 42 if not creds then return false; end |
36 local username, password = string.match(creds, "^([^:]+):(.*)$"); | 43 local username, password = string.match(creds, "^([^:]+):(.*)$"); |
37 if not username then return false; end | 44 if not username then return false; end |
38 username, password = encodings.stringprep.nodeprep(username), encodings.stringprep.saslprep(password); | 45 username, password = encodings.stringprep.nodeprep(username), encodings.stringprep.saslprep(password); |
39 if not username then return false; end | 46 if not username then return false; end |
40 if not um.test_password(username, module.host, password) then | 47 if not um.test_password(username, module.host, password) then |
41 return false; | 48 return false; |
42 end | 49 end |
43 return jid.join(username, module.host); | 50 return { username = username, host = module.host }; |
44 end | 51 elseif auth_type == "Bearer" then |
52 local token_info = tokens.get_token_info(auth_data); | |
53 if not token_info or not token_info.session then | |
54 return false; | |
55 end | |
56 return token_info.session; | |
57 end | |
58 return nil; | |
45 end | 59 end |
46 | 60 |
47 local function parse(mimetype, data) | 61 local function parse(mimetype, data) |
48 mimetype = mimetype and mimetype:match("^[^; ]*"); | 62 mimetype = mimetype and mimetype:match("^[^; ]*"); |
49 if mimetype == "application/xmpp+xml" then | 63 if mimetype == "application/xmpp+xml" then |
82 return tostring(s); | 96 return tostring(s); |
83 end | 97 end |
84 | 98 |
85 local function handle_post(event) | 99 local function handle_post(event) |
86 local request, response = event.request, event.response; | 100 local request, response = event.request, event.response; |
87 local from = module.host; | 101 local from; |
102 local origin; | |
103 | |
88 if not request.headers.authorization then | 104 if not request.headers.authorization then |
89 response.headers.www_authenticate = ("%s realm=%q"):format(auth_type, module.host.."/"..module.name); | 105 response.headers.www_authenticate = www_authenticate_header; |
90 return 401; | 106 return 401; |
91 else | 107 else |
92 local authz = check_credentials(request); | 108 origin = check_credentials(request); |
93 if not authz then | 109 if not origin then |
94 return 401; | 110 return 401; |
95 end | 111 end |
96 if type(authz) == "string" then | 112 from = jid.join(origin.username, origin.host, origin.resource); |
97 from = authz; | |
98 end | |
99 end | 113 end |
100 local payload, err = parse(request.headers.content_type, request.body); | 114 local payload, err = parse(request.headers.content_type, request.body); |
101 if not payload then | 115 if not payload then |
102 -- parse fail | 116 -- parse fail |
103 return errors.new({ code = 400, text = "Failed to parse payload" }, { error = err, type = request.headers.content_type, data = request.body }); | 117 return errors.new({ code = 400, text = "Failed to parse payload" }, { error = err, type = request.headers.content_type, data = request.body }); |
109 end | 123 end |
110 local to = jid.prep(payload.attr.to); | 124 local to = jid.prep(payload.attr.to); |
111 if not to then | 125 if not to then |
112 return errors.new({ code = 422, text = "Invalid destination JID" }); | 126 return errors.new({ code = 422, text = "Invalid destination JID" }); |
113 end | 127 end |
114 if allow_any_source and payload.attr.from then | 128 if payload.attr.from then |
115 from = jid.prep(payload.attr.from); | 129 local requested_from = jid.prep(payload.attr.from); |
116 if not from then | 130 if not requested_from then |
117 return errors.new({ code = 422, text = "Invalid source JID" }); | 131 return errors.new({ code = 422, text = "Invalid source JID" }); |
118 end | 132 end |
119 if validate_from_addresses and not jid.compare(from, module.host) then | 133 if jid.compare(requested_from, from) then |
120 return errors.new({ code = 403, text = "Source JID must belong to current host" }); | 134 from = requested_from; |
135 else | |
136 return errors.new({ code = 403, text = "Not authorized to send from "..requested_from }); | |
121 end | 137 end |
122 end | 138 end |
123 payload.attr = { | 139 payload.attr = { |
124 from = from, | 140 from = from, |
125 to = to, | 141 to = to, |
128 ["xml:lang"] = payload.attr["xml:lang"], | 144 ["xml:lang"] = payload.attr["xml:lang"], |
129 }; | 145 }; |
130 module:log("debug", "Received[rest]: %s", payload:top_tag()); | 146 module:log("debug", "Received[rest]: %s", payload:top_tag()); |
131 local send_type = decide_type((request.headers.accept or "") ..",".. request.headers.content_type) | 147 local send_type = decide_type((request.headers.accept or "") ..",".. request.headers.content_type) |
132 if payload.name == "iq" then | 148 if payload.name == "iq" then |
149 function origin.send(stanza) | |
150 prosody.core_route_stanza(nil, stanza); | |
151 end | |
133 if payload.attr.type ~= "get" and payload.attr.type ~= "set" then | 152 if payload.attr.type ~= "get" and payload.attr.type ~= "set" then |
134 return errors.new({ code = 422, text = "'iq' stanza must be of type 'get' or 'set'" }); | 153 return errors.new({ code = 422, text = "'iq' stanza must be of type 'get' or 'set'" }); |
135 elseif #payload.tags ~= 1 then | 154 elseif #payload.tags ~= 1 then |
136 return errors.new({ code = 422, text = "'iq' stanza must have exactly one child tag" }); | 155 return errors.new({ code = 422, text = "'iq' stanza must have exactly one child tag" }); |
137 end | 156 end |
138 return module:send_iq(payload):next( | 157 return module:send_iq(payload, origin):next( |
139 function (result) | 158 function (result) |
140 module:log("debug", "Sending[rest]: %s", result.stanza:top_tag()); | 159 module:log("debug", "Sending[rest]: %s", result.stanza:top_tag()); |
141 response.headers.content_type = send_type; | 160 response.headers.content_type = send_type; |
142 return encode(send_type, result.stanza); | 161 return encode(send_type, result.stanza); |
143 end, | 162 end, |
152 else | 171 else |
153 return error; | 172 return error; |
154 end | 173 end |
155 end); | 174 end); |
156 else | 175 else |
157 local origin = {}; | |
158 function origin.send(stanza) | 176 function origin.send(stanza) |
159 module:log("debug", "Sending[rest]: %s", stanza:top_tag()); | 177 module:log("debug", "Sending[rest]: %s", stanza:top_tag()); |
160 response.headers.content_type = send_type; | 178 response.headers.content_type = send_type; |
161 response:send(encode(send_type, stanza)); | 179 response:send(encode(send_type, stanza)); |
162 return true; | 180 return true; |