Mercurial > prosody-modules
view mod_firewall/conditions.lib.lua @ 5468:14b5446e22e1
mod_http_oauth2: Fix returning errors from response handlers
This would either redirect the user back to the client along with the
error code, or show the error HTML template.
Previously this would just show some JSON to the user.
author | Kim Alvefur <zash@zash.se> |
---|---|
date | Thu, 18 May 2023 12:57:23 +0200 |
parents | 84997bc3f92e |
children | 8226ac08484e |
line wrap: on
line source
--luacheck: globals meta idsafe local condition_handlers = {}; local jid = require "util.jid"; local unpack = table.unpack or unpack; -- Helper to convert user-input strings (yes/true//no/false) to a bool local function string_to_boolean(s) s = s:lower(); return s == "yes" or s == "true"; end -- Return a code string for a condition that checks whether the contents -- of variable with the name 'name' matches any of the values in the -- comma/space/pipe delimited list 'values'. local function compile_comparison_list(name, values) local conditions = {}; for value in values:gmatch("[^%s,|]+") do table.insert(conditions, ("%s == %q"):format(name, value)); end return table.concat(conditions, " or "); end function condition_handlers.KIND(kind) assert(kind, "Expected stanza kind to match against"); return compile_comparison_list("name", kind), { "name" }; end local wildcard_equivs = { ["*"] = ".*", ["?"] = "." }; local function compile_jid_match_part(part, match) if not match then return part.." == nil"; end local pattern = match:match("^<(.*)>$"); if pattern then if pattern == "*" then return part; end if pattern:find("^<.*>$") then pattern = pattern:match("^<(.*)>$"); else pattern = pattern:gsub("%p", "%%%0"):gsub("%%(%p)", wildcard_equivs); end return ("(%s and %s:find(%q))"):format(part, part, "^"..pattern.."$"); else return ("%s == %q"):format(part, match); end end local function compile_jid_match(which, match_jid) local match_node, match_host, match_resource = jid.split(match_jid); local conditions = {}; conditions[#conditions+1] = compile_jid_match_part(which.."_node", match_node); conditions[#conditions+1] = compile_jid_match_part(which.."_host", match_host); if match_resource then conditions[#conditions+1] = compile_jid_match_part(which.."_resource", match_resource); end return table.concat(conditions, " and "); end function condition_handlers.TO(to) return compile_jid_match("to", to), { "split_to" }; end function condition_handlers.FROM(from) return compile_jid_match("from", from), { "split_from" }; end function condition_handlers.FROM_EXACTLY(from) local metadeps = {}; return ("from == %s"):format(metaq(from, metadeps)), { "from", unpack(metadeps) }; end function condition_handlers.TO_EXACTLY(to) local metadeps = {}; return ("to == %s"):format(metaq(to, metadeps)), { "to", unpack(metadeps) }; end function condition_handlers.TO_SELF() -- Intentionally not using 'to' here, as that defaults to bare JID when nil return ("stanza.attr.to == nil"); end function condition_handlers.TYPE(type) assert(type, "Expected 'type' value to match against"); return compile_comparison_list("(type or (name == 'message' and 'normal') or (name == 'presence' and 'available'))", type), { "type", "name" }; end local function zone_check(zone, which) local zone_var = zone; if zone == "$local" then zone_var = "_local" end local which_not = which == "from" and "to" or "from"; return ("(zone_%s[%s_host] or zone_%s[%s] or zone_%s[bare_%s]) " .."and not(zone_%s[%s_host] or zone_%s[%s] or zone_%s[bare_%s])" ) :format(zone_var, which, zone_var, which, zone_var, which, zone_var, which_not, zone_var, which_not, zone_var, which_not), { "split_to", "split_from", "bare_to", "bare_from", "zone:"..zone }; end function condition_handlers.ENTERING(zone) return zone_check(zone, "to"); end function condition_handlers.LEAVING(zone) return zone_check(zone, "from"); end -- IN ROSTER? (parameter is deprecated) function condition_handlers.IN_ROSTER(yes_no) local in_roster_requirement = string_to_boolean(yes_no or "yes"); -- COMPAT w/ older scripts return "not "..(in_roster_requirement and "not" or "").." roster_entry", { "roster_entry" }; end function condition_handlers.IN_ROSTER_GROUP(group) return ("not not (roster_entry and roster_entry.groups[%q])"):format(group), { "roster_entry" }; end function condition_handlers.SUBSCRIBED() return "(bare_to == bare_from or to_node and rostermanager.is_contact_subscribed(to_node, to_host, bare_from))", { "rostermanager", "split_to", "bare_to", "bare_from" }; end function condition_handlers.PENDING_SUBSCRIPTION_FROM_SENDER() return "(bare_to == bare_from or to_node and rostermanager.is_contact_pending_in(to_node, to_host, bare_from))", { "rostermanager", "split_to", "bare_to", "bare_from" }; end function condition_handlers.PAYLOAD(payload_ns) return ("stanza:get_child(nil, %q)"):format(payload_ns); end function condition_handlers.INSPECT(path) if path:find("=") then local query, match_type, value = path:match("(.-)([~/$]*)=(.*)"); if not(query:match("#$") or query:match("@[^/]+")) then error("Stanza path does not return a string (append # for text content or @name for value of named attribute)", 0); end local meta_deps = {}; local quoted_value = ("%q"):format(value); if match_type:find("$", 1, true) then match_type = match_type:gsub("%$", ""); quoted_value = meta(quoted_value, meta_deps); end if match_type == "~" then -- Lua pattern match return ("(stanza:find(%q) or ''):match(%s)"):format(query, quoted_value), meta_deps; elseif match_type == "/" then -- find literal substring return ("(stanza:find(%q) or ''):find(%s, 1, true)"):format(query, quoted_value), meta_deps; elseif match_type == "" then -- exact match return ("stanza:find(%q) == %s"):format(query, quoted_value), meta_deps; else error("Unrecognised comparison '"..match_type.."='", 0); end end return ("stanza:find(%q)"):format(path); end function condition_handlers.FROM_GROUP(group_name) return ("group_contains(%q, bare_from)"):format(group_name), { "group_contains", "bare_from" }; end function condition_handlers.TO_GROUP(group_name) return ("group_contains(%q, bare_to)"):format(group_name), { "group_contains", "bare_to" }; end function condition_handlers.CROSSING_GROUPS(group_names) local code = {}; for group_name in group_names:gmatch("([^, ][^,]+)") do group_name = group_name:match("^%s*(.-)%s*$"); -- Trim leading/trailing whitespace -- Just check that's it is crossing from outside group to inside group table.insert(code, ("(group_contains(%q, bare_to) and group_contains(%q, bare_from))"):format(group_name, group_name)) end return "not "..table.concat(code, " or "), { "group_contains", "bare_to", "bare_from" }; end -- COMPAT w/0.12: Deprecated function condition_handlers.FROM_ADMIN_OF(host) return ("is_admin(bare_from, %s)"):format(host ~= "*" and metaq(host) or nil), { "is_admin", "bare_from" }; end -- COMPAT w/0.12: Deprecated function condition_handlers.TO_ADMIN_OF(host) return ("is_admin(bare_to, %s)"):format(host ~= "*" and metaq(host) or nil), { "is_admin", "bare_to" }; end -- COMPAT w/0.12: Deprecated function condition_handlers.FROM_ADMIN() return ("is_admin(bare_from, current_host)"), { "is_admin", "bare_from", "current_host" }; end -- COMPAT w/0.12: Deprecated function condition_handlers.TO_ADMIN() return ("is_admin(bare_to, current_host)"), { "is_admin", "bare_to", "current_host" }; end -- MAY: permission_to_check function condition_handlers.MAY(permission_to_check) return ("module:may(%q, event)"):format(permission_to_check); end function condition_handlers.TO_ROLE(role_name) return ("get_jid_role(bare_to, current_host) == %q"):format(role_name), { "get_jid_role", "current_host", "bare_to" }; end function condition_handlers.FROM_ROLE(role_name) return ("get_jid_role(bare_from, current_host) == %q"):format(role_name), { "get_jid_role", "current_host", "bare_from" }; end local day_numbers = { sun = 0, mon = 2, tue = 3, wed = 4, thu = 5, fri = 6, sat = 7 }; local function current_time_check(op, hour, minute) hour, minute = tonumber(hour), tonumber(minute); local adj_op = op == "<" and "<" or ">="; -- Start time inclusive, end time exclusive if minute == 0 then return "(current_hour"..adj_op..hour..")"; else return "((current_hour"..op..hour..") or (current_hour == "..hour.." and current_minute"..adj_op..minute.."))"; end end local function resolve_day_number(day_name) return assert(day_numbers[day_name:sub(1,3):lower()], "Unknown day name: "..day_name); end function condition_handlers.DAY(days) local conditions = {}; for day_range in days:gmatch("[^,]+") do local day_start, day_end = day_range:match("(%a+)%s*%-%s*(%a+)"); if day_start and day_end then local day_start_num, day_end_num = resolve_day_number(day_start), resolve_day_number(day_end); local op = "and"; if day_end_num < day_start_num then op = "or"; end table.insert(conditions, ("current_day >= %d %s current_day <= %d"):format(day_start_num, op, day_end_num)); elseif day_range:find("%a") then local day = resolve_day_number(day_range:match("%a+")); table.insert(conditions, "current_day == "..day); else error("Unable to parse day/day range: "..day_range); end end assert(#conditions>0, "Expected a list of days or day ranges"); return "("..table.concat(conditions, ") or (")..")", { "time:day" }; end function condition_handlers.TIME(ranges) local conditions = {}; for range in ranges:gmatch("([^,]+)") do local clause = {}; range = range:lower() :gsub("(%d+):?(%d*) *am", function (h, m) return tostring(tonumber(h)%12)..":"..(tonumber(m) or "00"); end) :gsub("(%d+):?(%d*) *pm", function (h, m) return tostring(tonumber(h)%12+12)..":"..(tonumber(m) or "00"); end); local start_hour, start_minute = range:match("(%d+):(%d+) *%-"); local end_hour, end_minute = range:match("%- *(%d+):(%d+)"); local op = tonumber(start_hour) > tonumber(end_hour) and " or " or " and "; if start_hour and end_hour then table.insert(clause, current_time_check(">", start_hour, start_minute)); table.insert(clause, current_time_check("<", end_hour, end_minute)); end if #clause == 0 then error("Unable to parse time range: "..range); end table.insert(conditions, "("..table.concat(clause, " "..op.." ")..")"); end return table.concat(conditions, " or "), { "time:hour,min" }; end function condition_handlers.LIMIT(spec) local name, param = spec:match("^(%w+) on (.+)$"); local meta_deps = {}; if not name then name = spec:match("^%w+$"); if not name then error("Unable to parse LIMIT specification"); end else param = meta(("%q"):format(param), meta_deps); end if not param then return ("not global_throttle_%s:poll(1)"):format(name), { "globalthrottle:"..name, unpack(meta_deps) }; end return ("not multi_throttle_%s:poll_on(%s, 1)"):format(name, param), { "multithrottle:"..name, unpack(meta_deps) }; end function condition_handlers.ORIGIN_MARKED(name_and_time) local name, time = name_and_time:match("^%s*([%w_]+)%s+%(([^)]+)s%)%s*$"); if not name then name = name_and_time:match("^%s*([%w_]+)%s*$"); end if not name then error("Error parsing mark name, see documentation for usage examples"); end if time then return ("(current_timestamp - (session.firewall_marked_%s or 0)) < %d"):format(idsafe(name), tonumber(time)), { "timestamp" }; end return ("not not session.firewall_marked_"..idsafe(name)); end function condition_handlers.USER_MARKED(name_and_time) local name, time = name_and_time:match("^%s*([%w_]+)%s+%(([^)]+)s%)%s*$"); if not name then name = name_and_time:match("^%s*([%w_]+)%s*$"); end if not name then error("Error parsing mark name, see documentation for usage examples"); end if time then return ("(current_timestamp - (session.firewall_marks and session.firewall_marks.%s or 0)) < %d"):format(idsafe(name), tonumber(time)), { "timestamp" }; end return ("not not (session.firewall_marks and session.firewall_marks."..idsafe(name)..")"); end function condition_handlers.SENT_DIRECTED_PRESENCE_TO_SENDER() return "not not (session.directed and session.directed[from])", { "from" }; end -- TO FULL JID? function condition_handlers.TO_FULL_JID() return "not not full_sessions[to]", { "to", "full_sessions" }; end -- CHECK LIST: spammers contains $<@from> function condition_handlers.CHECK_LIST(list_condition) local list_name, expr = list_condition:match("(%S+) contains (.+)$"); if not (list_name and expr) then error("Error parsing list check, syntax: LISTNAME contains EXPRESSION"); end local meta_deps = {}; expr = meta(("%q"):format(expr), meta_deps); return ("list_%s:contains(%s) == true"):format(list_name, expr), { "list:"..list_name, unpack(meta_deps) }; end -- SCAN: body for word in badwords function condition_handlers.SCAN(scan_expression) local search_name, pattern_name, list_name = scan_expression:match("(%S+) for (%S+) in (%S+)$"); if not (search_name) then error("Error parsing SCAN expression, syntax: SEARCH for PATTERN in LIST"); end return ("scan_list(list_%s, %s)"):format(list_name, "tokens_"..search_name.."_"..pattern_name), { "scan_list", "tokens:"..search_name.."-"..pattern_name, "list:"..list_name }; end -- COUNT: lines in body < 10 local valid_comp_ops = { [">"] = ">", ["<"] = "<", ["="] = "==", ["=="] = "==", ["<="] = "<=", [">="] = ">=" }; function condition_handlers.COUNT(count_expression) local pattern_name, search_name, comparator_expression = count_expression:match("(%S+) in (%S+) (.+)$"); if not (pattern_name) then error("Error parsing COUNT expression, syntax: PATTERN in SEARCH COMPARATOR"); end local value; comparator_expression = comparator_expression:gsub("%d+", function (value_string) value = tonumber(value_string); return ""; end); if not value then error("Error parsing COUNT expression, expected value"); end local comp_op = comparator_expression:gsub("%s+", ""); assert(valid_comp_ops[comp_op], "Error parsing COUNT expression, unknown comparison operator: "..comp_op); return ("it_count(search_%s:gmatch(pattern_%s)) %s %d"):format(search_name, pattern_name, comp_op, value), { "it_count", "search:"..search_name, "pattern:"..pattern_name }; end return condition_handlers;