Mercurial > prosody-modules
diff mod_firewall/mod_firewall.lua @ 956:33d6642f4db7
mod_firewall: Tighten up error handling, and split rules->Lua and Lua->bytecode compilation into separate functions
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Fri, 05 Apr 2013 18:05:21 +0100 |
parents | 97454c088b6c |
children | d773a51af9b1 |
line wrap: on
line diff
--- a/mod_firewall/mod_firewall.lua Thu Apr 04 23:11:36 2013 +0200 +++ b/mod_firewall/mod_firewall.lua Fri Apr 05 18:05:21 2013 +0100 @@ -4,7 +4,6 @@ local set = require "util.set"; local add_filter = require "util.filters".add_filter; - zones = {}; local zones = zones; setmetatable(zones, { @@ -113,6 +112,10 @@ local function compile_firewall_rules(filename) local line_no = 0; + local function errmsg(err) + return "Error compiling "..filename.." on line "..line_no..": "..err; + end + local ruleset = { deliver = {}; }; @@ -146,9 +149,18 @@ state = nil; elseif not(state) and line:match("^::") then chain = line:gsub("^::%s*", ""); + local chain_info = chains[chain_name]; + if not chain_info then + return nil, errmsg("Unknown chain: "..chain_name); + elseif chain_info.type ~= "event" then + return nil, errmsg("Only event chains supported at the moment"); + end ruleset[chain] = ruleset[chain] or {}; elseif not(state) and line:match("^ZONE ") then local zone_name = line:match("^ZONE ([^:]+)"); + if not zone_name:match("^%a[%w_]*$") then + return nil, errmsg("Invalid character(s) in zone name: "..zone_name); + end local zone_members = line:match("^ZONE .-: ?(.*)"); local zone_member_list = {}; for member in zone_members:gmatch("[^, ]+") do @@ -168,7 +180,10 @@ return nil, ("Unknown action on line %d: %s"):format(line_no, action or "<unknown>"); end table.insert(rule.actions, "-- "..line) - local action_string, action_deps = action_handlers[action](line:match("=(.+)$")); + local ok, action_string, action_deps = pcall(action_handlers[action], line:match("=(.+)$")); + if not ok then + return nil, errmsg(action_string); + end table.insert(rule.actions, action_string); for _, dep in ipairs(action_deps or {}) do table.insert(rule.deps, dep); @@ -195,7 +210,10 @@ return nil, ("Unknown condition on line %d: %s"):format(line_no, condition); end -- Get the code for this condition - local condition_code, condition_deps = condition_handlers[condition](line:match(":%s?(.+)$")); + local ok, condition_code, condition_deps = pcall(condition_handlers[condition], line:match(":%s?(.+)$")); + if not ok then + return nil, errmsg(condition_code); + end if negated then condition_code = "not("..condition_code..")"; end table.insert(rule.conditions, condition_code); for _, dep in ipairs(condition_deps or {}) do @@ -235,20 +253,26 @@ end; end]]; - print(code_string) - - -- Prepare event handler function - local chunk, err = loadstring(code_string, "="..filename); - if not chunk then - return nil, "Error compiling (probably a compiler bug, please report): "..err; - end - chunk = chunk()(zones, logger(filename)); -- Returns event handler with 'zones' upvalue. - chain_handlers[chain_name] = chunk; + chain_handlers[chain_name] = code_string; end return chain_handlers; end +local function compile_handler(code_string, filename) + print(code_string) + -- Prepare event handler function + local chunk, err = loadstring(code_string, "="..filename); + if not chunk then + return nil, "Error compiling (probably a compiler bug, please report): "..err; + end + local function fire_event(name, data) + return module:fire_event(name, data); + end + chunk = chunk()(zones, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue. + return chunk; +end + function module.load() local firewall_scripts = module:get_option_set("firewall_scripts", {}); for script in firewall_scripts do @@ -258,12 +282,20 @@ if not chain_functions then module:log("error", "Error compiling %s: %s", script, err or "unknown error"); else - for chain, handler in pairs(chain_functions) do - local chain_definition = chains[chain]; - if chain_definition.type == "event" then - for _, event_name in ipairs(chain_definition) do - module:hook(event_name, handler, chain_definition.priority); + for chain, handler_code in pairs(chain_functions) do + local handler, err = compile_handler(handler_code, "mod_firewall::"..chain); + if not handler then + module:log("error", "Compilation error for %s: %s", script, err); + else + local chain_definition = chains[chain]; + if chain_definition and chain_definition.type == "event" then + for _, event_name in ipairs(chain_definition) do + module:hook(event_name, handler, chain_definition.priority); + end + elseif not chain_name:match("^user/") then + module:log("warn", "Unknown chain %q", chain); end + module:hook("firewall/chains/"..chain, handler); end end end