# HG changeset patch # User Matthew Wild # Date 1488041692 0 # Node ID 6dbd07f9a86803831f39605b2913d4e4b43bf7a5 # Parent 00cef058df8d7428f3c72701ca496c6716c14b1a mod_firewall: Various improvements allowing dynamic load/reload/unload of scripts diff -r 00cef058df8d -r 6dbd07f9a868 mod_firewall/mod_firewall.lua --- a/mod_firewall/mod_firewall.lua Sat Feb 25 16:53:45 2017 +0000 +++ b/mod_firewall/mod_firewall.lua Sat Feb 25 16:54:52 2017 +0000 @@ -1,7 +1,9 @@ +local lfs = require "lfs"; local resolve_relative_path = require "core.configmanager".resolve_relative_path; local logger = require "util.logger".init; local it = require "util.iterators"; +local set = require "util.set"; local definitions = module:shared("definitions"); local active_definitions = { @@ -549,45 +551,110 @@ return resolve_relative_path(relative_to, script_path); end +-- [filename] = { last_modified = ..., events_hooked = { [name] = handler } } +local loaded_scripts = {}; + function load_script(script) script = resolve_script_path(script); - local chain_functions, err = compile_firewall_rules(script) + local last_modified = (lfs.attributes(script) or {}).modification or os.time(); + if loaded_scripts[script] then + if loaded_scripts[script].last_modified == last_modified then + return; -- Already loaded, and source file hasn't changed + end + module:log("debug", "Reloading %s", script); + -- Already loaded, but the source file has changed + -- unload it now, and we'll load the new version below + unload_script(script, true); + end + local chain_functions, err = compile_firewall_rules(script); if not chain_functions then module:log("error", "Error compiling %s: %s", script, err or "unknown error"); - else - for chain, handler_code in pairs(chain_functions) do - local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain); - if not new_handler then - module:log("error", "Compilation error for %s: %s", script, err); - else - local chain_definition = chains[chain]; - if chain_definition and chain_definition.type == "event" then - local handler = new_handler(chain_definition.pass_return); - for _, event_name in ipairs(chain_definition) do - module:hook(event_name, handler, chain_definition.priority); - end - elseif not chain:sub(1, 5) == "user/" then - module:log("warn", "Unknown chain %q", chain); + return; + end + + -- Loop through the chains in the script, and for each chain attach the compiled code to the + -- relevant events, keeping track in events_hooked so we can cleanly unload later + local events_hooked = {}; + for chain, handler_code in pairs(chain_functions) do + local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain); + if not new_handler then + module:log("error", "Compilation error for %s: %s", script, err); + else + local chain_definition = chains[chain]; + if chain_definition and chain_definition.type == "event" then + local handler = new_handler(chain_definition.pass_return); + for _, event_name in ipairs(chain_definition) do + events_hooked[event_name] = handler; + module:hook(event_name, handler, chain_definition.priority); end - module:hook("firewall/chains/"..chain, new_handler(false)); + elseif not chain:sub(1, 5) == "user/" then + module:log("warn", "Unknown chain %q", chain); end + local event_name, handler = "firewall/chains/"..chain, new_handler(false); + events_hooked[event_name] = handler; + module:hook(event_name, handler); end end + loaded_scripts[script] = { last_modified = last_modified, events_hooked = events_hooked }; + module:log("debug", "Loaded %s", script); +end + +function unload_script(script, is_reload) + script = resolve_script_path(script); + local script_info = loaded_scripts[script]; + if not script_info then + return; -- Script not loaded + end + local events_hooked = script_info.events_hooked; + for event_name, event_handler in pairs(events_hooked) do + module:unhook(event_name, event_handler); + events_hooked[event_name] = nil; + end + loaded_scripts[script] = nil; + if not is_reload then + module:log("debug", "Unloaded %s", script); + end +end + +-- Given a set of scripts (e.g. from config) figure out which ones need to +-- be loaded, which are already loaded but need unloading, and which to reload +function load_unload_scripts(script_list) + local wanted_scripts = script_list / resolve_script_path; + local currently_loaded = set.new(it.to_array(it.keys(loaded_scripts))); + local scripts_to_unload = currently_loaded - wanted_scripts; + for script in wanted_scripts do + -- If the script is already loaded, this is fine - it will + -- reload the script for us if the file has changed + load_script(script); + end + for script in scripts_to_unload do + unload_script(script); + end end function module.load() if not prosody.arg then return end -- Don't run in prosodyctl - active_definitions = {}; local firewall_scripts = module:get_option_set("firewall_scripts", {}); - for script in firewall_scripts do - load_script(script); - end + load_unload_scripts(firewall_scripts); -- Replace contents of definitions table (shared) with active definitions for k in it.keys(definitions) do definitions[k] = nil; end for k,v in pairs(active_definitions) do definitions[k] = v; end end +function module.save() + return { active_definitions = active_definitions, loaded_scripts = loaded_scripts }; +end + +function module.restore(state) + active_definitions = state.active_definitions; + loaded_scripts = state.loaded_scripts; +end + +module:hook_global("config-reloaded", function () + load_unload_scripts(module:get_option_set("firewall_scripts", {})); +end); + function module.command(arg) if not arg[1] or arg[1] == "--help" then require"util.prosodyctl".show_usage([[mod_firewall ]], [[Compile files with firewall rules to Lua code]]);