view mod_firewall/mod_firewall.lua @ 5193:2bb29ece216b

mod_http_oauth2: Implement stateless dynamic client registration Replaces previous explicit registration that required either the additional module mod_adhoc_oauth2_client or manually editing the database. That method was enough to have something to test with, but would not probably not scale easily. Dynamic client registration allows creating clients on the fly, which may be even easier in theory. In order to not allow basically unauthenticated writes to the database, we implement a stateless model here. per_host_key := HMAC(config -> oauth2_registration_key, hostname) client_id := JWT { client metadata } signed with per_host_key client_secret := HMAC(per_host_key, client_id) This should ensure everything we need to know is part of the client_id, allowing redirects etc to be validated, and the client_secret can be validated with only the client_id and the per_host_key. A nonce injected into the client_id JWT should ensure nobody can submit the same client metadata and retrieve the same client_secret
author Kim Alvefur <zash@zash.se>
date Fri, 03 Mar 2023 21:14:19 +0100
parents 8474a3b80200
children 32a9817c7516
line wrap: on
line source


local lfs = require "lfs";
local resolve_relative_path = require "core.configmanager".resolve_relative_path;
local envload = require "util.envload".envload;
local logger = require "util.logger".init;
local it = require "util.iterators";
local set = require "util.set";

local have_features, features = pcall(require, "core.features");
features = have_features and features.available or set.new();

-- [definition_type] = definition_factory(param)
local definitions = module:shared("definitions");

-- When a definition instance has been instantiated, it lives here
-- [definition_type][definition_name] = definition_object
local active_definitions = {
	ZONE = {
		-- Default zone that includes all local hosts
		["$local"] = setmetatable({}, { __index = prosody.hosts });
	};
};

local default_chains = {
	preroute = {
		type = "event";
		priority = 0.1;
		"pre-message/bare", "pre-message/full", "pre-message/host";
		"pre-presence/bare", "pre-presence/full", "pre-presence/host";
		"pre-iq/bare", "pre-iq/full", "pre-iq/host";
	};
	deliver = {
		type = "event";
		priority = 0.1;
		"message/bare", "message/full", "message/host";
		"presence/bare", "presence/full", "presence/host";
		"iq/bare", "iq/full", "iq/host";
	};
	deliver_remote = {
		type = "event"; "route/remote";
		priority = 0.1;
	};
};

local extra_chains = module:get_option("firewall_extra_chains", {});

local chains = {};
for k,v in pairs(default_chains) do
	chains[k] = v;
end
for k,v in pairs(extra_chains) do
	chains[k] = v;
end

-- Returns the input if it is safe to be used as a variable name, otherwise nil
function idsafe(name)
	return name:match("^%a[%w_]*$");
end

local meta_funcs = {
	bare = function (code)
		return "jid_bare("..code..")", {"jid_bare"};
	end;
	node = function (code)
		return "(jid_split("..code.."))", {"jid_split"};
	end;
	host = function (code)
		return "(select(2, jid_split("..code..")))", {"jid_split"};
	end;
	resource = function (code)
		return "(select(3, jid_split("..code..")))", {"jid_split"};
	end;
};

-- Run quoted (%q) strings through this to allow them to contain code. e.g.: LOG=Received: $(stanza:top_tag())
function meta(s, deps, extra)
	return (s:gsub("$(%b())", function (expr)
			expr = expr:gsub("\\(.)", "%1");
			return [["..tostring(]]..expr..[[).."]];
		end)
		:gsub("$(%b<>)", function (expr)
			expr = expr:sub(2,-2);
			local default = "<undefined>";
			expr = expr:gsub("||(%b\"\")$", function (default_string)
				default = stripslashes(default_string:sub(2,-2));
				return "";
			end);
			local func_chain = expr:match("|[%w|]+$");
			if func_chain then
				expr = expr:sub(1, -1-#func_chain);
			end
			local code;
			if expr:match("^@") then
				-- Skip stanza:find() for simple attribute lookup
				local attr_name = expr:sub(2);
				if deps and (attr_name == "to" or attr_name == "from" or attr_name == "type") then
					-- These attributes may be cached in locals
					code = attr_name;
					table.insert(deps, attr_name);
				else
					code = "stanza.attr["..("%q"):format(attr_name).."]";
				end
			elseif expr:match("^%w+#$") then
				code = ("stanza:get_child_text(%q)"):format(expr:sub(1, -2));
			else
				code = ("stanza:find(%q)"):format(expr);
			end
			if func_chain then
				for func_name in func_chain:gmatch("|(%w+)") do
					-- to/from are already available in local variables, use those if possible
					if (code == "to" or code == "from") and func_name == "bare" then
						code = "bare_"..code;
						table.insert(deps, code);
					elseif (code == "to" or code == "from") and (func_name == "node" or func_name == "host" or func_name == "resource") then
						table.insert(deps, "split_"..code);
						code = code.."_"..func_name;
					else
						assert(meta_funcs[func_name], "unknown function: "..func_name);
						local new_code, new_deps = meta_funcs[func_name](code);
						code = new_code;
						if new_deps and #new_deps > 0 then
							assert(deps, "function not supported here: "..func_name);
							for _, dep in ipairs(new_deps) do
								table.insert(deps, dep);
							end
						end
					end
				end
			end
			return "\"..tostring("..code.." or "..("%q"):format(default)..")..\"";
		end)
		:gsub("$$(%a+)", extra or {})
		:gsub([[^""%.%.]], "")
		:gsub([[%.%.""$]], ""));
end

function metaq(s, ...)
	return meta(("%q"):format(s), ...);
end

local escape_chars = {
	a = "\a", b = "\b", f = "\f", n = "\n", r = "\r", t = "\t",
	v = "\v", ["\\"] = "\\", ["\""] = "\"", ["\'"] = "\'"
};
function stripslashes(s)
	return (s:gsub("\\(.)", escape_chars));
end

-- Dependency locations:
-- <type lib>
-- <type global>
-- function handler()
--   <local deps>
--   if <conditions> then
--     <actions>
--   end
-- end

local available_deps = {
	st = { global_code = [[local st = require "util.stanza";]]};
	it = { global_code = [[local it = require "util.iterators";]]};
	it_count = { global_code = [[local it_count = it.count;]], depends = { "it" } };
	current_host = { global_code = [[local current_host = module.host;]] };
	jid_split = {
		global_code = [[local jid_split = require "util.jid".split;]];
	};
	jid_bare = {
		global_code = [[local jid_bare = require "util.jid".bare;]];
	};
	to = { local_code = [[local to = stanza.attr.to or jid_bare(session.full_jid);]]; depends = { "jid_bare" } };
	from = { local_code = [[local from = stanza.attr.from;]] };
	type = { local_code = [[local type = stanza.attr.type;]] };
	name = { local_code = [[local name = stanza.name;]] };
	split_to = { -- The stanza's split to address
		depends = { "jid_split", "to" };
		local_code = [[local to_node, to_host, to_resource = jid_split(to);]];
	};
	split_from = { -- The stanza's split from address
		depends = { "jid_split", "from" };
		local_code = [[local from_node, from_host, from_resource = jid_split(from);]];
	};
	bare_to = { depends = { "jid_bare", "to" }, local_code = "local bare_to = jid_bare(to)"};
	bare_from = { depends = { "jid_bare", "from" }, local_code = "local bare_from = jid_bare(from)"};
	group_contains = {
		global_code = [[local group_contains = module:depends("groups").group_contains]];
	};
	is_admin = require"core.usermanager".is_admin and { global_code = [[local is_admin = require "core.usermanager".is_admin;]]} or nil;
	get_jid_role = require "core.usermanager".get_jid_role and { global_code = [[local get_jid_role = require "core.usermanager".get_jid_role;]] } or nil;
	core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza;]] };
	zone = { global_code = function (zone)
		local var = zone;
		if var == "$local" then
			var = "_local"; -- See #1090
		else
			assert(idsafe(var), "Invalid zone name: "..zone);
		end
		return ("local zone_%s = zones[%q] or {};"):format(var, zone);
	end };
	date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] };
	time = { local_code = function (what)
		local defs = {};
		for field in what:gmatch("%a+") do
			table.insert(defs, ("local current_%s = current_date_time.%s;"):format(field, field));
		end
		return table.concat(defs, " ");
	end, depends = { "date_time" }; };
	timestamp = { global_code = [[local get_time = require "socket".gettime;]]; local_code = [[local current_timestamp = get_time();]]; };
	globalthrottle = {
		global_code = function (throttle)
			assert(idsafe(throttle), "Invalid rate limit name: "..throttle);
			assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle);
			return ("local global_throttle_%s = rates.%s:single();"):format(throttle, throttle);
		end;
	};
	multithrottle = {
		global_code = function (throttle)
			assert(pcall(require, "util.cache"), "Using LIMIT with 'on' requires Prosody 0.10 or higher");
			assert(idsafe(throttle), "Invalid rate limit name: "..throttle);
			assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle);
			return ("local multi_throttle_%s = rates.%s:multi();"):format(throttle, throttle);
		end;
	};
	full_sessions = {
		global_code = [[local full_sessions = prosody.full_sessions;]];
	};
	rostermanager = {
		global_code = [[local rostermanager = require "core.rostermanager";]];
	};
	roster_entry = {
		local_code = [[local roster_entry = (to_node and rostermanager.load_roster(to_node, to_host) or {})[bare_from];]];
		depends = { "rostermanager", "split_to", "bare_from" };
	};
	list = { global_code = function (list)
			assert(idsafe(list), "Invalid list name: "..list);
			assert(active_definitions.LIST[list], "Unknown list: "..list);
			return ("local list_%s = lists[%q];"):format(list, list);
		end
	};
	search = {
		local_code = function (search_name)
			local search_path = assert(active_definitions.SEARCH[search_name], "Undefined search path: "..search_name);
			return ("local search_%s = tostring(stanza:find(%q) or \"\")"):format(search_name, search_path);
		end;
	};
	pattern = {
		local_code = function (pattern_name)
			local pattern = assert(active_definitions.PATTERN[pattern_name], "Undefined pattern: "..pattern_name);
			return ("local pattern_%s = %q"):format(pattern_name, pattern);
		end;
	};
	tokens = {
		local_code = function (search_and_pattern)
			local search_name, pattern_name = search_and_pattern:match("^([^%-]+)-(.+)$");
			local code = ([[local tokens_%s_%s = {};
			if search_%s then
				for s in search_%s:gmatch(pattern_%s) do
					tokens_%s_%s[s] = true;
				end
			end
			]]):format(search_name, pattern_name, search_name, search_name, pattern_name, search_name, pattern_name);
			return code, { "search:"..search_name, "pattern:"..pattern_name };
		end;
	};
	scan_list = {
		global_code = [[local function scan_list(list, items) for item in pairs(items) do if list:contains(item) then return true; end end end]];
	}
};

local function include_dep(dependency, code)
	local dep, dep_param = dependency:match("^([^:]+):?(.*)$");
	local dep_info = available_deps[dep];
	if not dep_info then
		module:log("error", "Dependency not found: %s", dep);
		return;
	end
	if code.included_deps[dependency] ~= nil then
		if code.included_deps[dependency] ~= true then
			module:log("error", "Circular dependency on %s", dep);
		end
		return;
	end
	code.included_deps[dependency] = false; -- Pending flag (used to detect circular references)
	for _, dep_dep in ipairs(dep_info.depends or {}) do
		include_dep(dep_dep, code);
	end
	if dep_info.global_code then
		if dep_param ~= "" then
			local global_code, deps = dep_info.global_code(dep_param);
			if deps then
				for _, dep_dep in ipairs(deps) do
					include_dep(dep_dep, code);
				end
			end
			table.insert(code.global_header, global_code);
		else
			table.insert(code.global_header, dep_info.global_code);
		end
	end
	if dep_info.local_code then
		if dep_param ~= "" then
			local local_code, deps = dep_info.local_code(dep_param);
			if deps then
				for _, dep_dep in ipairs(deps) do
					include_dep(dep_dep, code);
				end
			end
			table.insert(code, "\n\t\t-- "..dep.."\n\t\t"..local_code.."\n");
		else
			table.insert(code, "\n\t\t-- "..dep.."\n\t\t"..dep_info.local_code.."\n");
		end
	end
	code.included_deps[dependency] = true;
end

local definition_handlers = module:require("definitions");
local condition_handlers = module:require("conditions");
local action_handlers = module:require("actions");

if module:get_option_boolean("firewall_experimental_user_marks", false) then
	module:require"marks";
end

local function new_rule(ruleset, chain)
	assert(chain, "no chain specified");
	local rule = { conditions = {}, actions = {}, deps = {} };
	table.insert(ruleset[chain], rule);
	return rule;
end

local function parse_firewall_rules(filename)
	local line_no = 0;

	local function errmsg(err)
		return "Error compiling "..filename.." on line "..line_no..": "..err;
	end

	local ruleset = {
		deliver = {};
	};

	local chain = "deliver"; -- Default chain
	local rule;

	local file, err = io.open(filename);
	if not file then return nil, err; end

	local state; -- nil -> "rules" -> "actions" -> nil -> ...

	local line_hold;
	for line in file:lines() do
		line = line:match("^%s*(.-)%s*$");
		if line_hold and line:sub(-1,-1) ~= "\\" then
			line = line_hold..line;
			line_hold = nil;
		elseif line:sub(-1,-1) == "\\" then
			line_hold = (line_hold or "")..line:sub(1,-2);
		end
		line_no = line_no + 1;

		if line_hold or line:find("^[#;]") then -- luacheck: ignore 542
			-- No action; comment or partial line
		elseif line == "" then
			if state == "rules" then
				return nil, ("Expected an action on line %d for preceding criteria")
					:format(line_no);
			end
			state = nil;
		elseif not(state) and line:sub(1, 2) == "::" then
			chain = line:gsub("^::%s*", "");
			local chain_info = chains[chain];
			if not chain_info then
				if chain:match("^user/") then
					chains[chain] = { type = "event", priority = 1, pass_return = false };
				else
					return nil, errmsg("Unknown chain: "..chain);
				end
			elseif chain_info.type ~= "event" then
				return nil, errmsg("Only event chains supported at the moment");
			end
			ruleset[chain] = ruleset[chain] or {};
		elseif not(state) and line:sub(1,1) == "%" then -- Definition (zone, limit, etc.)
			local what, name = line:match("^%%%s*([%w_]+) +([^ :]+)");
			if not definition_handlers[what] then
				return nil, errmsg("Definition of unknown object: "..what);
			elseif not name or not idsafe(name) then
				return nil, errmsg("Invalid "..what.." name");
			end

			local val = line:match(": ?(.*)$");
			if not val and line:find(":<") then -- Read from file
				local fn = line:match(":< ?(.-)%s*$");
				if not fn then
					return nil, errmsg("Unable to parse filename");
				end
				local f, err = io.open(fn);
				if not f then return nil, errmsg(err); end
				val = f:read("*a"):gsub("\r?\n", " "):gsub("%s+$", "");
			end
			if not val then
				return nil, errmsg("No value given for definition");
			end
			val = stripslashes(val);
			local ok, ret = pcall(definition_handlers[what], name, val);
			if not ok then
				return nil, errmsg(ret);
			end

			if not active_definitions[what] then
				active_definitions[what] = {};
			end
			active_definitions[what][name] = ret;
		elseif line:find("^[%w_ ]+[%.=]") then
			-- Action
			if state == nil then
				-- This is a standalone action with no conditions
				rule = new_rule(ruleset, chain);
			end
			state = "actions";
			-- Action handlers?
			local action = line:match("^[%w_ ]+"):upper():gsub(" ", "_");
			if not action_handlers[action] then
				return nil, ("Unknown action on line %d: %s"):format(line_no, action or "<unknown>");
			end
			table.insert(rule.actions, "-- "..line)
			local ok, action_string, action_deps = pcall(action_handlers[action], line:match("=(.+)$"));
			if not ok then
				return nil, errmsg(action_string);
			end
			table.insert(rule.actions, action_string);
			for _, dep in ipairs(action_deps or {}) do
				table.insert(rule.deps, dep);
			end
		elseif state == "actions" then -- state is actions but action pattern did not match
			state = nil; -- Awaiting next rule, etc.
			table.insert(ruleset[chain], rule);
			rule = nil;
		else
			if not state then
				state = "rules";
				rule = new_rule(ruleset, chain);
			end
			-- Check standard modifiers for the condition (e.g. NOT)
			local negated;
			local condition = line:match("^[^:=%.?]*");
			if condition:find("%f[%w]NOT%f[^%w]") then
				local s, e = condition:match("%f[%w]()NOT()%f[^%w]");
				condition = (condition:sub(1,s-1)..condition:sub(e+1, -1)):match("^%s*(.-)%s*$");
				negated = true;
			end
			condition = condition:gsub(" ", "_");
			if not condition_handlers[condition] then
				return nil, ("Unknown condition on line %d: %s"):format(line_no, (condition:gsub("_", " ")));
			end
			-- Get the code for this condition
			local ok, condition_code, condition_deps = pcall(condition_handlers[condition], line:match(":%s?(.+)$"));
			if not ok then
				return nil, errmsg(condition_code);
			end
			if negated then condition_code = "not("..condition_code..")"; end
			table.insert(rule.conditions, condition_code);
			for _, dep in ipairs(condition_deps or {}) do
				table.insert(rule.deps, dep);
			end
		end
	end
	return ruleset;
end

local function process_firewall_rules(ruleset)
	-- Compile ruleset and return complete code

	local chain_handlers = {};

	-- Loop through the chains in the parsed ruleset (e.g. incoming, outgoing)
	for chain_name, rules in pairs(ruleset) do
		local code = { included_deps = {}, global_header = {} };
		local condition_uses = {};
		-- This inner loop assumes chain is an event-based, not a filter-based
		-- chain (filter-based will be added later)
		for _, rule in ipairs(rules) do
			for _, condition in ipairs(rule.conditions) do
				if condition:find("^not%(.+%)$") then
					condition = condition:match("^not%((.+)%)$");
				end
				condition_uses[condition] = (condition_uses[condition] or 0) + 1;
			end
		end

		local condition_cache, n_conditions = {}, 0;
		for _, rule in ipairs(rules) do
			for _, dep in ipairs(rule.deps) do
				include_dep(dep, code);
			end
			table.insert(code, "\n\t\t");
			local rule_code;
			if #rule.conditions > 0 then
				for i, condition in ipairs(rule.conditions) do
					local negated = condition:match("^not%(.+%)$");
					if negated then
						condition = condition:match("^not%((.+)%)$");
					end
					if condition_uses[condition] > 1 then
						local name = condition_cache[condition];
						if not name then
							n_conditions = n_conditions + 1;
							name = "condition"..n_conditions;
							condition_cache[condition] = name;
							table.insert(code, "local "..name.." = "..condition..";\n\t\t");
						end
						rule.conditions[i] = (negated and "not(" or "")..name..(negated and ")" or "");
					else
						rule.conditions[i] = (negated and "not(" or "(")..condition..")";
					end
				end

				rule_code = "if "..table.concat(rule.conditions, " and ").." then\n\t\t\t"
					..table.concat(rule.actions, "\n\t\t\t")
					.."\n\t\tend\n";
			else
				rule_code = table.concat(rule.actions, "\n\t\t");
			end
			table.insert(code, rule_code);
		end

		for name in pairs(definition_handlers) do
			table.insert(code.global_header, 1, "local "..name:lower().."s = definitions."..name..";");
		end

		local code_string = "return function (definitions, fire_event, log, module, pass_return)\n\t"
			..table.concat(code.global_header, "\n\t")
			.."\n\tlocal db = require 'util.debug';\n\n\t"
			.."return function (event)\n\t\t"
			.."local stanza, session = event.stanza, event.origin;\n"
			..table.concat(code, "")
			.."\n\tend;\nend";

		chain_handlers[chain_name] = code_string;
	end

	return chain_handlers;
end

local function compile_firewall_rules(filename)
	local ruleset, err = parse_firewall_rules(filename);
	if not ruleset then return nil, err; end
	local chain_handlers = process_firewall_rules(ruleset);
	return chain_handlers;
end

-- Compile handler code into a factory that produces a valid event handler. Factory accepts
-- a value to be returned on PASS
local function compile_handler(code_string, filename)
	-- Prepare event handler function
	local chunk, err = envload(code_string, "="..filename, _G);
	if not chunk then
		return nil, "Error compiling (probably a compiler bug, please report): "..err;
	end
	local function fire_event(name, data)
		return module:fire_event(name, data);
	end
	return function (pass_return)
		return chunk()(active_definitions, fire_event, logger(filename), module, pass_return); -- Returns event handler with upvalues
	end
end

local function resolve_script_path(script_path)
	local relative_to = prosody.paths.config;
	if script_path:match("^module:") then
		relative_to = module.path:sub(1, -#("/mod_"..module.name..".lua"));
		script_path = script_path:match("^module:(.+)$");
	end
	return resolve_relative_path(relative_to, script_path);
end

-- [filename] = { last_modified = ..., events_hooked = { [name] = handler } }
local loaded_scripts = {};

function load_script(script)
	script = resolve_script_path(script);
	local last_modified = (lfs.attributes(script) or {}).modification or os.time();
	if loaded_scripts[script] then
		if loaded_scripts[script].last_modified == last_modified then
			return; -- Already loaded, and source file hasn't changed
		end
		module:log("debug", "Reloading %s", script);
		-- Already loaded, but the source file has changed
		-- unload it now, and we'll load the new version below
		unload_script(script, true);
	end
	local chain_functions, err = compile_firewall_rules(script);

	if not chain_functions then
		module:log("error", "Error compiling %s: %s", script, err or "unknown error");
		return;
	end

	-- Loop through the chains in the script, and for each chain attach the compiled code to the
	-- relevant events, keeping track in events_hooked so we can cleanly unload later
	local events_hooked = {};
	for chain, handler_code in pairs(chain_functions) do
		local new_handler, err = compile_handler(handler_code, "mod_firewall::"..chain);
		if not new_handler then
			module:log("error", "Compilation error for %s: %s", script, err);
		else
			local chain_definition = chains[chain];
			if chain_definition and chain_definition.type == "event" then
				local handler = new_handler(chain_definition.pass_return);
				for _, event_name in ipairs(chain_definition) do
					events_hooked[event_name] = handler;
					module:hook(event_name, handler, chain_definition.priority);
				end
			elseif not chain:sub(1, 5) == "user/" then
				module:log("warn", "Unknown chain %q", chain);
			end
			local event_name, handler = "firewall/chains/"..chain, new_handler(false);
			events_hooked[event_name] = handler;
			module:hook(event_name, handler);
		end
	end
	loaded_scripts[script] = { last_modified = last_modified, events_hooked = events_hooked };
	module:log("debug", "Loaded %s", script);
end

--COMPAT w/0.9 (no module:unhook()!)
local function module_unhook(event, handler)
	return module:unhook_object_event((hosts[module.host] or prosody).events, event, handler);
end

function unload_script(script, is_reload)
	script = resolve_script_path(script);
	local script_info = loaded_scripts[script];
	if not script_info then
		return; -- Script not loaded
	end
	local events_hooked = script_info.events_hooked;
	for event_name, event_handler in pairs(events_hooked) do
		module_unhook(event_name, event_handler);
		events_hooked[event_name] = nil;
	end
	loaded_scripts[script] = nil;
	if not is_reload then
		module:log("debug", "Unloaded %s", script);
	end
end

-- Given a set of scripts (e.g. from config) figure out which ones need to
-- be loaded, which are already loaded but need unloading, and which to reload
function load_unload_scripts(script_list)
	local wanted_scripts = script_list / resolve_script_path;
	local currently_loaded = set.new(it.to_array(it.keys(loaded_scripts)));
	local scripts_to_unload = currently_loaded - wanted_scripts;
	for script in wanted_scripts do
		-- If the script is already loaded, this is fine - it will
		-- reload the script for us if the file has changed
		load_script(script);
	end
	for script in scripts_to_unload do
		unload_script(script);
	end
end

function module.load()
	if not prosody.arg then return end -- Don't run in prosodyctl
	local firewall_scripts = module:get_option_set("firewall_scripts", {});
	load_unload_scripts(firewall_scripts);
	-- Replace contents of definitions table (shared) with active definitions
	for k in it.keys(definitions) do definitions[k] = nil; end
	for k,v in pairs(active_definitions) do definitions[k] = v; end
end

function module.save()
	return { active_definitions = active_definitions, loaded_scripts = loaded_scripts };
end

function module.restore(state)
	active_definitions = state.active_definitions;
	loaded_scripts = state.loaded_scripts;
end

module:hook_global("config-reloaded", function ()
	load_unload_scripts(module:get_option_set("firewall_scripts", {}));
end);

function module.command(arg)
	if not arg[1] or arg[1] == "--help" then
		require"util.prosodyctl".show_usage([[mod_firewall <firewall.pfw>]], [[Compile files with firewall rules to Lua code]]);
		return 1;
	end
	local verbose = arg[1] == "-v";
	if verbose then table.remove(arg, 1); end

	if arg[1] == "test" then
		table.remove(arg, 1);
		return module:require("test")(arg);
	end

	local serialize = require "util.serialization".serialize;
	if verbose then
		print("local logger = require \"util.logger\".init;");
		print();
		print("local function fire_event(name, data)\n\tmodule:fire_event(name, data)\nend");
		print();
	end

	for _, filename in ipairs(arg) do
		filename = resolve_script_path(filename);
		print("do -- File "..filename);
		local chain_functions = assert(compile_firewall_rules(filename));
		if verbose then
			print();
			print("local active_definitions = "..serialize(active_definitions)..";");
			print();
		end
		local c = 0;
		for chain, handler_code in pairs(chain_functions) do
			c = c + 1;
			print("---- Chain "..chain:gsub("_", " "));
			local chain_func_name = "chain_"..tostring(c).."_"..chain:gsub("%p", "_");
			if not verbose then
				print(("%s = %s;"):format(chain_func_name, handler_code:sub(8)));
			else

				print(("local %s = (%s)(active_definitions, fire_event, logger(%q));"):format(chain_func_name, handler_code:sub(8), filename));
				print();

				local chain_definition = chains[chain];
				if chain_definition and chain_definition.type == "event" then
					for _, event_name in ipairs(chain_definition) do
						print(("module:hook(%q, %s, %d);"):format(event_name, chain_func_name, chain_definition.priority or 0));
					end
				end
				print(("module:hook(%q, %s, %d);"):format("firewall/chains/"..chain, chain_func_name, chain_definition.priority or 0));
			end

			print("---- End of chain "..chain);
			print();
		end
		print("end -- End of file "..filename);
	end
end