# HG changeset patch # User Matthew Wild # Date 1365001880 -3600 # Node ID c91cac3b823fb0c25e9da4969f28e13f63ad5da2 # Parent 2c5430ff1c11ee2691ebbdb71be6417a1043696c mod_firewall: General stanza filtering plugin with a declarative rule-based syntax diff -r 2c5430ff1c11 -r c91cac3b823f mod_firewall/actions.lib.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/mod_firewall/actions.lib.lua Wed Apr 03 16:11:20 2013 +0100 @@ -0,0 +1,158 @@ +local action_handlers = {}; + +-- Takes an XML string and returns a code string that builds that stanza +-- using st.stanza() +local function compile_xml(data) + local code = {}; + local first, short_close = true, nil; + for tagline, text in data:gmatch("<([^>]+)>([^<]*)") do + if tagline:sub(-1,-1) == "/" then + tagline = tagline:sub(1, -2); + short_close = true; + end + if tagline:sub(1,1) == "/" then + code[#code+1] = (":up()"); + else + local name, attr = tagline:match("^(%S*)%s*(.*)$"); + local attr_str = {}; + for k, _, v in attr:gmatch("(%S+)=([\"'])([^%2]-)%2") do + if #attr_str == 0 then + table.insert(attr_str, ", { "); + else + table.insert(attr_str, ", "); + end + if k:match("^%a%w*$") then + table.insert(attr_str, string.format("%s = %q", k, v)); + else + table.insert(attr_str, string.format("[%q] = %q", k, v)); + end + end + if #attr_str > 0 then + table.insert(attr_str, " }"); + end + if first then + code[#code+1] = (string.format("st.stanza(%q %s)", name, #attr_str>0 and table.concat(attr_str) or ", nil")); + first = nil; + else + code[#code+1] = (string.format(":tag(%q%s)", name, table.concat(attr_str))); + end + end + if text and text:match("%S") then + code[#code+1] = (string.format(":text(%q)", text)); + elseif short_close then + short_close = nil; + code[#code+1] = (":up()"); + end + end + return table.concat(code, ""); +end + + +function action_handlers.DROP() + return "log('debug', 'Firewall dropping stanza: %s', tostring(stanza)); return true;"; +end + +function action_handlers.STRIP(tag_desc) + local code = {}; + local name, xmlns = tag_desc:match("^(%S+) (.+)$"); + if not name then + name, xmlns = tag_desc, nil; + end + if name == "*" then + name = nil; + end + code[#code+1] = ("local stanza_xmlns = stanza.attr.xmlns; "); + code[#code+1] = "stanza:maptags(function (tag) if "; + if name then + code[#code+1] = ("tag.name == %q and "):format(name); + end + if xmlns then + code[#code+1] = ("(tag.attr.xmlns or stanza_xmlns) == %q "):format(xmlns); + else + code[#code+1] = ("tag.attr.xmlns == stanza_xmlns "); + end + code[#code+1] = "then return nil; end return tag; end );"; + return table.concat(code); +end + +function action_handlers.INJECT(tag) + return "stanza:add_child("..compile_xml(tag)..")", { "st" }; +end + +local error_types = { + ["bad-request"] = "modify"; + ["conflict"] = "cancel"; + ["feature-not-implemented"] = "cancel"; + ["forbidden"] = "auth"; + ["gone"] = "cancel"; + ["internal-server-error"] = "cancel"; + ["item-not-found"] = "cancel"; + ["jid-malformed"] = "modify"; + ["not-acceptable"] = "modify"; + ["not-allowed"] = "cancel"; + ["not-authorized"] = "auth"; + ["payment-required"] = "auth"; + ["policy-violation"] = "modify"; + ["recipient-unavailable"] = "wait"; + ["redirect"] = "modify"; + ["registration-required"] = "auth"; + ["remote-server-not-found"] = "cancel"; + ["remote-server-timeout"] = "wait"; + ["resource-constraint"] = "wait"; + ["service-unavailable"] = "cancel"; + ["subscription-required"] = "auth"; + ["undefined-condition"] = "cancel"; + ["unexpected-request"] = "wait"; +}; + + +local function route_modify(make_new, to, drop) + local reroute, deps = "session.send(newstanza)", { "st" }; + if to then + reroute = ("newstanza.attr.to = %q; core_post_stanza(session, newstanza)"):format(to); + deps[#deps+1] = "core_post_stanza"; + end + return ([[local newstanza = st.%s; %s; %s; ]]) + :format(make_new, reroute, drop and "return true" or ""), deps; +end + +function action_handlers.BOUNCE(with) + local error = with and with:match("^%S+") or "service-unavailable"; + local error_type = error:match(":(%S+)"); + if not error_type then + error_type = error_types[error] or "cancel"; + else + error = error:match("^[^:]+"); + end + error, error_type = string.format("%q", error), string.format("%q", error_type); + local text = with and with:match(" %((.+)%)$"); + if text then + text = string.format("%q", text); + else + text = "nil"; + end + return route_modify(("error_reply(stanza, %s, %s, %s)"):format(error_type, error, text), nil, true); +end + +function action_handlers.REDIRECT(where) + return route_modify("clone(stanza)", where, true, true); +end + +function action_handlers.COPY(where) + return route_modify("clone(stanza)", where, true, false); +end + +function action_handlers.LOG(string) + local level = string:match("^%[(%a+)%]") or "info"; + string = string:gsub("^%[%a+%] ?", ""); + return (("log(%q, %q)"):format(level, string) + :gsub("$top", [["..stanza:top_tag().."]]) + :gsub("$stanza", [["..stanza.."]]) + :gsub("$(%b())", [["..%1.."]])); +end + +function action_handlers.RULEDEP(dep) + return "", { dep }; +end + +return action_handlers; diff -r 2c5430ff1c11 -r c91cac3b823f mod_firewall/conditions.lib.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/mod_firewall/conditions.lib.lua Wed Apr 03 16:11:20 2013 +0100 @@ -0,0 +1,94 @@ +local condition_handlers = {}; + +local jid = require "util.jid"; + +-- Return a code string for a condition that checks whether the contents +-- of variable with the name 'name' matches any of the values in the +-- comma/space/pipe delimited list 'values'. +local function compile_comparison_list(name, values) + local conditions = {}; + for value in values:gmatch("[^%s,|]+") do + table.insert(conditions, ("%s == %q"):format(name, value)); + end + return table.concat(conditions, " or "); +end + +function condition_handlers.KIND(kind) + return compile_comparison_list("name", kind), { "name" }; +end + +local wildcard_equivs = { ["*"] = ".*", ["?"] = "." }; + +local function compile_jid_match_part(part, match) + if not match then + return part.." == nil" + end + local pattern = match:match("<(.*)>"); + -- TODO: Support Lua pattern matching (main issue syntax... << >>?) + if pattern then + if pattern ~= "*" then + return ("%s:match(%q)"):format(part, pattern:gsub(".", wildcard_equivs)); + end + else + return ("%s == %q"):format(part, match); + end +end + +local function compile_jid_match(which, match_jid) + local match_node, match_host, match_resource = jid.split(match_jid); + local conditions = { + compile_jid_match_part(which.."_node", match_node); + compile_jid_match_part(which.."_host", match_host); + match_resource and compile_jid_match_part(which.."_resource", match_resource) or nil; + }; + return table.concat(conditions, " and "); +end + +function condition_handlers.TO(to) + return compile_jid_match("to", to), { "split_to" }; +end + +function condition_handlers.FROM(from) + return compile_jid_match("from", from), { "split_from" }; +end + +function condition_handlers.TYPE(type) + return compile_comparison_list("type", type), { "type" }; +end + +function condition_handlers.ENTERING(zone) + return ("(zones[%q] and (zones[%q][to_host] or " + .."zones[%q][to] or " + .."zones[%q][bare_to]))" + ) + :format(zone, zone, zone, zone), { "split_to", "bare_to" }; +end + +function condition_handlers.LEAVING(zone) + return ("zones[%q] and (zones[%q][from_host] or " + .."(zones[%q][from] or " + .."zones[%q][bare_from]))") + :format(zone, zone, zone, zone), { "split_from", "bare_from" }; +end + +function condition_handlers.PAYLOAD(payload_ns) + return ("stanza:get_child(nil, %q)"):format(payload_ns); +end + +function condition_handlers.FROM_GROUP(group_name) + return ("group_contains(%q, bare_from)"):format(group_name), { "group_contains", "bare_from" }; +end + +function condition_handlers.TO_GROUP(group_name) + return ("group_contains(%q, bare_to)"):format(group_name), { "group_contains", "bare_to" }; +end + +function condition_handlers.FROM_ADMIN_OF(host) + return ("is_admin(bare_from, %s)"):format(host ~= "*" and host or nil), { "is_admin", "bare_from" }; +end + +function condition_handlers.TO_ADMIN_OF(host) + return ("is_admin(bare_to, %s)"):format(host ~= "*" and host or nil), { "is_admin", "bare_to" }; +end + +return condition_handlers; diff -r 2c5430ff1c11 -r c91cac3b823f mod_firewall/mod_firewall.lua --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/mod_firewall/mod_firewall.lua Wed Apr 03 16:11:20 2013 +0100 @@ -0,0 +1,271 @@ + +local resolve_relative_path = require "core.configmanager".resolve_relative_path; +local logger = require "util.logger".init; +local set = require "util.set"; +local add_filter = require "util.filters".add_filter; + + +zones = {}; +local zones = zones; +setmetatable(zones, { + __index = function (zones, zone) + local t = { [zone] = true }; + rawset(zones, zone, t); + return t; + end; +}); + +local chains = { + preroute = { + type = "event"; + priority = 0.1; + "pre-message/bare", "pre-message/full", "pre-message/host"; + "pre-presence/bare", "pre-presence/full", "pre-presence/host"; + "pre-iq/bare", "pre-iq/full", "pre-iq/host"; + }; + deliver = { + type = "event"; + priority = 0.1; + "message/bare", "message/full", "message/host"; + "presence/bare", "presence/full", "presence/host"; + "iq/bare", "iq/full", "iq/host"; + }; + deliver_remote = { + type = "event"; "route/remote"; + priority = 0.1; + }; +}; + +-- Dependency locations: +-- +-- +-- function handler() +-- +-- if then +-- +-- end +-- end + +local available_deps = { + st = { global_code = [[local st = require "util.stanza"]]}; + jid_split = { + global_code = [[local jid_split = require "util.jid".split;]]; + }; + jid_bare = { + global_code = [[local jid_bare = require "util.jid".bare;]]; + }; + to = { local_code = [[local to = stanza.attr.to;]] }; + from = { local_code = [[local from = stanza.attr.from;]] }; + type = { local_code = [[local type = stanza.attr.type;]] }; + name = { local_code = [[local name = stanza.name]] }; + split_to = { -- The stanza's split to address + depends = { "jid_split", "to" }; + local_code = [[local to_node, to_host, to_resource = jid_split(to);]]; + }; + split_from = { -- The stanza's split from address + depends = { "jid_split", "from" }; + local_code = [[local from_node, from_host, from_resource = jid_split(from);]]; + }; + bare_to = { depends = { "jid_bare", "to" }, local_code = "local bare_to = jid_bare(to)"}; + bare_from = { depends = { "jid_bare", "from" }, local_code = "local bare_from = jid_bare(from)"}; + group_contains = { + global_code = [[local group_contains = module:depends("groups").group_contains]]; + }; + is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin]]}; + core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza]] }; +}; + +local function include_dep(dep, code) + local dep_info = available_deps[dep]; + if not dep_info then + module:log("error", "Dependency not found: %s", dep); + return; + end + if code.included_deps[dep] then + if code.included_deps[dep] ~= true then + module:log("error", "Circular dependency on %s", dep); + end + return; + end + code.included_deps[dep] = false; -- Pending flag (used to detect circular references) + for _, dep_dep in ipairs(dep_info.depends or {}) do + include_dep(dep_dep, code); + end + if dep_info.global_code then + table.insert(code.global_header, dep_info.global_code); + end + if dep_info.local_code then + table.insert(code, "\n\t-- "..dep.."\n\t"..dep_info.local_code.."\n\n\t"); + end + code.included_deps[dep] = true; +end + +local condition_handlers = module:require("conditions"); +local action_handlers = module:require("actions"); + +local function new_rule(ruleset, chain) + assert(chain, "no chain specified"); + local rule = { conditions = {}, actions = {}, deps = {} }; + table.insert(ruleset[chain], rule); + return rule; +end + +local function compile_firewall_rules(filename) + local line_no = 0; + + local ruleset = { + deliver = {}; + }; + + local chain = "deliver"; -- Default chain + local rule; + + local file, err = io.open(filename); + if not file then return nil, err; end + + local state; -- nil -> "rules" -> "actions" -> nil -> ... + + local line_hold; + for line in file:lines() do + line = line:match("^%s*(.-)%s*$"); + if line_hold and line:sub(-1,-1) ~= "\\" then + line = line_hold..line; + line_hold = nil; + elseif line:sub(-1,-1) == "\\" then + line_hold = (line_hold or "")..line:sub(1,-2); + end + line_no = line_no + 1; + + if line_hold or line:match("^[#;]") then + -- No action; comment or partial line + elseif line == "" then + if state == "rules" then + return nil, ("Expected an action on line %d for preceding criteria") + :format(line_no); + end + state = nil; + elseif not(state) and line:match("^::") then + chain = line:gsub("^::%s*", ""); + ruleset[chain] = ruleset[chain] or {}; + elseif not(state) and line:match("^ZONE ") then + local zone_name = line:match("^ZONE ([^:]+)"); + local zone_members = line:match("^ZONE .-: ?(.*)"); + local zone_member_list = {}; + for member in zone_members:gmatch("[^, ]+") do + zone_member_list[#zone_member_list+1] = member; + end + zones[zone_name] = set.new(zone_member_list)._items; + elseif line:match("^[^%s:]+[%.=]") then + -- Action + if state == nil then + -- This is a standalone action with no conditions + rule = new_rule(ruleset, chain); + end + state = "actions"; + -- Action handlers? + local action = line:match("^%P+"); + if not action_handlers[action] then + return nil, ("Unknown action on line %d: %s"):format(line_no, action or ""); + end + table.insert(rule.actions, "-- "..line) + local action_string, action_deps = action_handlers[action](line:match("=(.+)$")); + table.insert(rule.actions, action_string); + for _, dep in ipairs(action_deps or {}) do + table.insert(rule.deps, dep); + end + elseif state == "actions" then -- state is actions but action pattern did not match + state = nil; -- Awaiting next rule, etc. + table.insert(ruleset[chain], rule); + rule = nil; + else + if not state then + state = "rules"; + rule = new_rule(ruleset, chain); + end + -- Check standard modifiers for the condition (e.g. NOT) + local negated; + local condition = line:match("^[^:=%.]*"); + if condition:match("%f[%w]NOT%f[^%w]") then + local s, e = condition:match("%f[%w]()NOT()%f[^%w]"); + condition = (condition:sub(1,s-1)..condition:sub(e+1, -1)):match("^%s*(.-)%s*$"); + negated = true; + end + condition = condition:gsub(" ", ""); + if not condition_handlers[condition] then + return nil, ("Unknown condition on line %d: %s"):format(line_no, condition); + end + -- Get the code for this condition + local condition_code, condition_deps = condition_handlers[condition](line:match(":%s?(.+)$")); + if negated then condition_code = "not("..condition_code..")"; end + table.insert(rule.conditions, condition_code); + for _, dep in ipairs(condition_deps or {}) do + table.insert(rule.deps, dep); + end + end + end + + -- Compile ruleset and return complete code + + local chain_handlers = {}; + + -- Loop through the chains in the parsed ruleset (e.g. incoming, outgoing) + for chain_name, rules in pairs(ruleset) do + local code = { included_deps = {}, global_header = {} }; + -- This inner loop assumes chain is an event-based, not a filter-based + -- chain (filter-based will be added later) + for _, rule in ipairs(rules) do + for _, dep in ipairs(rule.deps) do + include_dep(dep, code); + end + local rule_code = "if ("..table.concat(rule.conditions, ") and (")..") then\n\t" + ..table.concat(rule.actions, "\n\t") + .."\n end\n"; + table.insert(code, rule_code); + end + + assert(chains[chain_name].type == "event", "Only event chains supported at the moment") + + local code_string = [[return function (zones, log) + ]]..table.concat(code.global_header, "\n")..[[ + local db = require 'util.debug' + return function (event) + local stanza, session = event.stanza, event.origin; + + ]]..table.concat(code, " ")..[[ + end; + end]]; + + print(code_string) + + -- Prepare event handler function + local chunk, err = loadstring(code_string, "="..filename); + if not chunk then + return nil, "Error compiling (probably a compiler bug, please report): "..err; + end + chunk = chunk()(zones, logger(filename)); -- Returns event handler with 'zones' upvalue. + chain_handlers[chain_name] = chunk; + end + + return chain_handlers; +end + +function module.load() + local firewall_scripts = module:get_option_set("firewall_scripts", {}); + for script in firewall_scripts do + script = resolve_relative_path(script) or script; + local chain_functions, err = compile_firewall_rules(script) + + if not chain_functions then + module:log("error", "Error compiling %s: %s", script, err or "unknown error"); + else + for chain, handler in pairs(chain_functions) do + local chain_definition = chains[chain]; + if chain_definition.type == "event" then + for _, event_name in ipairs(chain_definition) do + module:hook(event_name, handler, chain_definition.priority); + end + end + end + end + end +end