diff mod_firewall/mod_firewall.lua @ 2558:2b533a7b5236

mod_firewall: Make PASS bubble up through all chains, and add DEFAULT and RETURN
author Matthew Wild <mwild1@gmail.com>
date Fri, 24 Feb 2017 09:38:20 +0000
parents 19a182651a9b
children 3da0e3c917cc
line wrap: on
line diff
--- a/mod_firewall/mod_firewall.lua	Thu Feb 23 14:26:19 2017 +0000
+++ b/mod_firewall/mod_firewall.lua	Fri Feb 24 09:38:20 2017 +0000
@@ -348,7 +348,7 @@
 			local chain_info = chains[chain];
 			if not chain_info then
 				if chain:match("^user/") then
-					chains[chain] = { type = "event", priority = 1, "firewall/chains/"..chain };
+					chains[chain] = { type = "event", priority = 1, pass_return = false };
 				else
 					return nil, errmsg("Unknown chain: "..chain);
 				end
@@ -504,7 +504,7 @@
 			table.insert(code.global_header, 1, "local "..name:lower().."s = definitions."..name..";");
 		end
 
-		local code_string = "return function (definitions, fire_event, log, module)\n\t"
+		local code_string = "return function (definitions, fire_event, log, module, pass_return)\n\t"
 			..table.concat(code.global_header, "\n\t")
 			.."\n\tlocal db = require 'util.debug';\n\n\t"
 			.."return function (event)\n\t\t"
@@ -525,6 +525,8 @@
 	return chain_handlers;
 end
 
+-- Compile handler code into a factory that produces a valid event handler. Factory accepts
+-- a value to be returned on PASS
 local function compile_handler(code_string, filename)
 	-- Prepare event handler function
 	local chunk, err = loadstring(code_string, "="..filename);
@@ -534,8 +536,9 @@
 	local function fire_event(name, data)
 		return module:fire_event(name, data);
 	end
-	chunk = chunk()(active_definitions, fire_event, logger(filename), module); -- Returns event handler with 'zones' upvalue.
-	return chunk;
+	return function (pass_return)
+		return chunk()(active_definitions, fire_event, logger(filename), module, pass_return); -- Returns event handler with upvalues
+	end
 end
 
 local function resolve_script_path(script_path)
@@ -559,19 +562,20 @@
 			module:log("error", "Error compiling %s: %s", script, err or "unknown error");
 		else
 			for chain, handler_code in pairs(chain_functions) do
-				local handler, err = compile_handler(handler_code, "mod_firewall::"..chain);
-				if not handler then
+				local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain);
+				if not new_handler then
 					module:log("error", "Compilation error for %s: %s", script, err);
 				else
 					local chain_definition = chains[chain];
 					if chain_definition and chain_definition.type == "event" then
+						local handler = new_handler(chain_definition.pass_return);
 						for _, event_name in ipairs(chain_definition) do
 							module:hook(event_name, handler, chain_definition.priority);
 						end
 					elseif not chain:sub(1, 5) == "user/" then
 						module:log("warn", "Unknown chain %q", chain);
 					end
-					module:hook("firewall/chains/"..chain, handler);
+					module:hook("firewall/chains/"..chain, new_handler(false));
 				end
 			end
 		end