changeset 2578:6dbd07f9a868

mod_firewall: Various improvements allowing dynamic load/reload/unload of scripts
author Matthew Wild <mwild1@gmail.com>
date Sat, 25 Feb 2017 16:54:52 +0000
parents 00cef058df8d
children 5e948d1392a5
files mod_firewall/mod_firewall.lua
diffstat 1 files changed, 87 insertions(+), 20 deletions(-) [+]
line wrap: on
line diff
--- a/mod_firewall/mod_firewall.lua	Sat Feb 25 16:53:45 2017 +0000
+++ b/mod_firewall/mod_firewall.lua	Sat Feb 25 16:54:52 2017 +0000
@@ -1,7 +1,9 @@
 
+local lfs = require "lfs";
 local resolve_relative_path = require "core.configmanager".resolve_relative_path;
 local logger = require "util.logger".init;
 local it = require "util.iterators";
+local set = require "util.set";
 
 local definitions = module:shared("definitions");
 local active_definitions = {
@@ -549,45 +551,110 @@
 	return resolve_relative_path(relative_to, script_path);
 end
 
+-- [filename] = { last_modified = ..., events_hooked = { [name] = handler } }
+local loaded_scripts = {};
+
 function load_script(script)
 	script = resolve_script_path(script);
-	local chain_functions, err = compile_firewall_rules(script)
+	local last_modified = (lfs.attributes(script) or {}).modification or os.time();
+	if loaded_scripts[script] then
+		if loaded_scripts[script].last_modified == last_modified then
+			return; -- Already loaded, and source file hasn't changed
+		end
+		module:log("debug", "Reloading %s", script);
+		-- Already loaded, but the source file has changed
+		-- unload it now, and we'll load the new version below
+		unload_script(script, true);
+	end
+	local chain_functions, err = compile_firewall_rules(script);
 	
 	if not chain_functions then
 		module:log("error", "Error compiling %s: %s", script, err or "unknown error");
-	else
-		for chain, handler_code in pairs(chain_functions) do
-			local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain);
-			if not new_handler then
-				module:log("error", "Compilation error for %s: %s", script, err);
-			else
-				local chain_definition = chains[chain];
-				if chain_definition and chain_definition.type == "event" then
-					local handler = new_handler(chain_definition.pass_return);
-					for _, event_name in ipairs(chain_definition) do
-						module:hook(event_name, handler, chain_definition.priority);
-					end
-				elseif not chain:sub(1, 5) == "user/" then
-					module:log("warn", "Unknown chain %q", chain);
+		return;
+	end
+
+	-- Loop through the chains in the script, and for each chain attach the compiled code to the
+	-- relevant events, keeping track in events_hooked so we can cleanly unload later
+	local events_hooked = {};
+	for chain, handler_code in pairs(chain_functions) do
+		local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain);
+		if not new_handler then
+			module:log("error", "Compilation error for %s: %s", script, err);
+		else
+			local chain_definition = chains[chain];
+			if chain_definition and chain_definition.type == "event" then
+				local handler = new_handler(chain_definition.pass_return);
+				for _, event_name in ipairs(chain_definition) do
+					events_hooked[event_name] = handler;
+					module:hook(event_name, handler, chain_definition.priority);
 				end
-				module:hook("firewall/chains/"..chain, new_handler(false));
+			elseif not chain:sub(1, 5) == "user/" then
+				module:log("warn", "Unknown chain %q", chain);
 			end
+			local event_name, handler = "firewall/chains/"..chain, new_handler(false);
+			events_hooked[event_name] = handler;
+			module:hook(event_name, handler);
 		end
 	end
+	loaded_scripts[script] = { last_modified = last_modified, events_hooked = events_hooked };
+	module:log("debug", "Loaded %s", script);
+end
+
+function unload_script(script, is_reload)
+	script = resolve_script_path(script);
+	local script_info = loaded_scripts[script];
+	if not script_info then
+		return; -- Script not loaded
+	end
+	local events_hooked = script_info.events_hooked;
+	for event_name, event_handler in pairs(events_hooked) do
+		module:unhook(event_name, event_handler);
+		events_hooked[event_name] = nil;
+	end
+	loaded_scripts[script] = nil;
+	if not is_reload then
+		module:log("debug", "Unloaded %s", script);
+	end
+end
+
+-- Given a set of scripts (e.g. from config) figure out which ones need to
+-- be loaded, which are already loaded but need unloading, and which to reload
+function load_unload_scripts(script_list)
+	local wanted_scripts = script_list / resolve_script_path;
+	local currently_loaded = set.new(it.to_array(it.keys(loaded_scripts)));
+	local scripts_to_unload = currently_loaded - wanted_scripts;
+	for script in wanted_scripts do
+		-- If the script is already loaded, this is fine - it will
+		-- reload the script for us if the file has changed
+		load_script(script);
+	end
+	for script in scripts_to_unload do
+		unload_script(script);
+	end
 end
 
 function module.load()
 	if not prosody.arg then return end -- Don't run in prosodyctl
-	active_definitions = {};
 	local firewall_scripts = module:get_option_set("firewall_scripts", {});
-	for script in firewall_scripts do
-		load_script(script);
-	end
+	load_unload_scripts(firewall_scripts);
 	-- Replace contents of definitions table (shared) with active definitions
 	for k in it.keys(definitions) do definitions[k] = nil; end
 	for k,v in pairs(active_definitions) do definitions[k] = v; end
 end
 
+function module.save()
+	return { active_definitions = active_definitions, loaded_scripts = loaded_scripts };
+end
+
+function module.restore(state)
+	active_definitions = state.active_definitions;
+	loaded_scripts = state.loaded_scripts;
+end
+
+module:hook_global("config-reloaded", function ()
+	load_unload_scripts(module:get_option_set("firewall_scripts", {}));
+end);
+
 function module.command(arg)
 	if not arg[1] or arg[1] == "--help" then
 		require"util.prosodyctl".show_usage([[mod_firewall <firewall.pfw>]], [[Compile files with firewall rules to Lua code]]);