comparison mod_external_services/mod_external_services.lua @ 4666:dbc7ba3cc27c

mod_external_services: Filter services by requested credentials using a Set Please don't be accidentally quadratic.
author Kim Alvefur <zash@zash.se>
date Mon, 30 Aug 2021 20:19:09 +0200
parents f0ffa8cf3ce6
children 1990611691cf
comparison
equal deleted inserted replaced
4665:f0ffa8cf3ce6 4666:dbc7ba3cc27c
3 local base64 = require "util.encodings".base64; 3 local base64 = require "util.encodings".base64;
4 local hashes = require "util.hashes"; 4 local hashes = require "util.hashes";
5 local st = require "util.stanza"; 5 local st = require "util.stanza";
6 local jid = require "util.jid"; 6 local jid = require "util.jid";
7 local array = require "util.array"; 7 local array = require "util.array";
8 local set = require "util.set";
8 9
9 local default_host = module:get_option_string("external_service_host", module.host); 10 local default_host = module:get_option_string("external_service_host", module.host);
10 local default_port = module:get_option_number("external_service_port"); 11 local default_port = module:get_option_number("external_service_port");
11 local default_secret = module:get_option_string("external_service_secret"); 12 local default_secret = module:get_option_string("external_service_secret");
12 local default_ttl = module:get_option_number("external_service_ttl", 86400); 13 local default_ttl = module:get_option_number("external_service_ttl", 86400);
177 local services = ( configured_services + extras ) / prepare; 178 local services = ( configured_services + extras ) / prepare;
178 services:filter(function (item) 179 services:filter(function (item)
179 return item.restricted; 180 return item.restricted;
180 end) 181 end)
181 182
182 local requested_credentials = {}; 183 local requested_credentials = set.new();
183 for service in action:childtags("service") do 184 for service in action:childtags("service") do
184 if not service.attr.type or not service.attr.host then 185 if not service.attr.type or not service.attr.host then
185 origin.send(st.error_reply(stanza, "modify", "bad-request")); 186 origin.send(st.error_reply(stanza, "modify", "bad-request"));
186 return true; 187 return true;
187 end 188 end
188 189
189 table.insert(requested_credentials, { 190 requested_credentials:add(string.format("%s:%s:%d", service.attr.type, service.attr.host,
190 type = service.attr.type; 191 tonumber(service.attr.port) or 0));
191 host = service.attr.host;
192 port = tonumber(service.attr.port);
193 });
194 end 192 end
195 193
196 setmetatable(services, services_mt); 194 setmetatable(services, services_mt);
197 setmetatable(requested_credentials, services_mt);
198 195
199 module:fire_event("external_service/credentials", { 196 module:fire_event("external_service/credentials", {
200 origin = origin; 197 origin = origin;
201 stanza = stanza; 198 stanza = stanza;
202 reply = reply; 199 reply = reply;
203 requested_credentials = requested_credentials; 200 requested_credentials = requested_credentials;
204 services = services; 201 services = services;
205 }); 202 });
206 203
207 for req_srv in action:childtags("service") do 204 services:filter(function (srv)
208 for _, srv in ipairs(services) do 205 local port_key = string.format("%s:%s:%d", srv.type, srv.host, srv.port or 0);
209 if srv.type == req_srv.attr.type and srv.host == req_srv.attr.host 206 local portless_key = string.format("%s:%s:%d", srv.type, srv.host, 0);
210 and not req_srv.attr.port or srv.port == tonumber(req_srv.attr.port) then 207 return requested_credentials:contains(port_key) or requested_credentials:contains(portless_key);
211 reply:tag("service", { 208 end);
212 type = srv.type; 209
213 transport = srv.transport; 210 for _, srv in ipairs(services) do
214 host = srv.host; 211 reply:tag("service", {
215 port = srv.port and string.format("%d", srv.port) or nil; 212 type = srv.type;
216 username = srv.username; 213 transport = srv.transport;
217 password = srv.password; 214 host = srv.host;
218 expires = srv.expires and dt.datetime(srv.expires) or nil; 215 port = srv.port and string.format("%d", srv.port) or nil;
219 restricted = srv.restricted and "1" or nil; 216 username = srv.username;
220 }):up(); 217 password = srv.password;
221 end 218 expires = srv.expires and dt.datetime(srv.expires) or nil;
222 end 219 restricted = srv.restricted and "1" or nil;
220 }):up();
223 end 221 end
224 222
225 origin.send(reply); 223 origin.send(reply);
226 return true; 224 return true;
227 end 225 end