diff mod_firewall/mod_firewall.lua @ 956:33d6642f4db7

mod_firewall: Tighten up error handling, and split rules->Lua and Lua->bytecode compilation into separate functions
author Matthew Wild <mwild1@gmail.com>
date Fri, 05 Apr 2013 18:05:21 +0100
parents 97454c088b6c
children d773a51af9b1
line wrap: on
line diff
--- a/mod_firewall/mod_firewall.lua	Thu Apr 04 23:11:36 2013 +0200
+++ b/mod_firewall/mod_firewall.lua	Fri Apr 05 18:05:21 2013 +0100
@@ -4,7 +4,6 @@
 local set = require "util.set";
 local add_filter = require "util.filters".add_filter;
 
-
 zones = {};
 local zones = zones;
 setmetatable(zones, {
@@ -113,6 +112,10 @@
 local function compile_firewall_rules(filename)
 	local line_no = 0;
 	
+	local function errmsg(err)
+		return "Error compiling "..filename.." on line "..line_no..": "..err;
+	end
+	
 	local ruleset = {
 		deliver = {};
 	};
@@ -146,9 +149,18 @@
 			state = nil;
 		elseif not(state) and line:match("^::") then
 			chain = line:gsub("^::%s*", "");
+			local chain_info = chains[chain_name];
+			if not chain_info then
+				return nil, errmsg("Unknown chain: "..chain_name);
+			elseif chain_info.type ~= "event" then
+				return nil, errmsg("Only event chains supported at the moment");
+			end
 			ruleset[chain] = ruleset[chain] or {};
 		elseif not(state) and line:match("^ZONE ") then
 			local zone_name = line:match("^ZONE ([^:]+)");
+			if not zone_name:match("^%a[%w_]*$") then
+				return nil, errmsg("Invalid character(s) in zone name: "..zone_name);
+			end
 			local zone_members = line:match("^ZONE .-: ?(.*)");
 			local zone_member_list = {};
 			for member in zone_members:gmatch("[^, ]+") do
@@ -168,7 +180,10 @@
 				return nil, ("Unknown action on line %d: %s"):format(line_no, action or "<unknown>");
 			end
 			table.insert(rule.actions, "-- "..line)
-			local action_string, action_deps = action_handlers[action](line:match("=(.+)$"));
+			local ok, action_string, action_deps = pcall(action_handlers[action], line:match("=(.+)$"));
+			if not ok then
+				return nil, errmsg(action_string);
+			end
 			table.insert(rule.actions, action_string);
 			for _, dep in ipairs(action_deps or {}) do
 				table.insert(rule.deps, dep);
@@ -195,7 +210,10 @@
 				return nil, ("Unknown condition on line %d: %s"):format(line_no, condition);
 			end
 			-- Get the code for this condition
-			local condition_code, condition_deps = condition_handlers[condition](line:match(":%s?(.+)$"));
+			local ok, condition_code, condition_deps = pcall(condition_handlers[condition], line:match(":%s?(.+)$"));
+			if not ok then
+				return nil, errmsg(condition_code);
+			end
 			if negated then condition_code = "not("..condition_code..")"; end
 			table.insert(rule.conditions, condition_code);
 			for _, dep in ipairs(condition_deps or {}) do
@@ -235,20 +253,26 @@
 			end;
 		end]];
 
-		print(code_string)
-
-		-- Prepare event handler function
-		local chunk, err = loadstring(code_string, "="..filename);
-		if not chunk then
-			return nil, "Error compiling (probably a compiler bug, please report): "..err;
-		end
-		chunk = chunk()(zones, logger(filename)); -- Returns event handler with 'zones' upvalue.
-		chain_handlers[chain_name] = chunk;
+		chain_handlers[chain_name] = code_string;
 	end
 		
 	return chain_handlers;
 end
 
+local function compile_handler(code_string, filename)
+	print(code_string)
+	-- Prepare event handler function
+	local chunk, err = loadstring(code_string, "="..filename);
+	if not chunk then
+		return nil, "Error compiling (probably a compiler bug, please report): "..err;
+	end
+	local function fire_event(name, data)
+		return module:fire_event(name, data);
+	end
+	chunk = chunk()(zones, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue.
+	return chunk;
+end
+
 function module.load()
 	local firewall_scripts = module:get_option_set("firewall_scripts", {});
 	for script in firewall_scripts do
@@ -258,12 +282,20 @@
 		if not chain_functions then
 			module:log("error", "Error compiling %s: %s", script, err or "unknown error");
 		else
-			for chain, handler in pairs(chain_functions) do
-				local chain_definition = chains[chain];
-				if chain_definition.type == "event" then
-					for _, event_name in ipairs(chain_definition) do
-						module:hook(event_name, handler, chain_definition.priority);
+			for chain, handler_code in pairs(chain_functions) do
+				local handler, err = compile_handler(handler_code, "mod_firewall::"..chain);
+				if not handler then
+					module:log("error", "Compilation error for %s: %s", script, err);
+				else
+					local chain_definition = chains[chain];
+					if chain_definition and chain_definition.type == "event" then
+						for _, event_name in ipairs(chain_definition) do
+							module:hook(event_name, handler, chain_definition.priority);
+						end
+					elseif not chain_name:match("^user/") then
+						module:log("warn", "Unknown chain %q", chain);
 					end
+					module:hook("firewall/chains/"..chain, handler);
 				end
 			end
 		end