diff mod_firewall/mod_firewall.lua @ 999:197af8440ffb

mod_firewall: Make defining objects generic (currently zones and rate limits), so more can easily be added. Also a syntax change... definition lines must begin with %
author Matthew Wild <mwild1@gmail.com>
date Tue, 07 May 2013 10:33:49 +0100
parents 6fdcebbd2284
children c0850793b716
line wrap: on
line diff
--- a/mod_firewall/mod_firewall.lua	Tue May 07 10:32:48 2013 +0100
+++ b/mod_firewall/mod_firewall.lua	Tue May 07 10:33:49 2013 +0100
@@ -4,10 +4,9 @@
 local set = require "util.set";
 local it = require "util.iterators";
 local add_filter = require "util.filters".add_filter;
-local new_throttle = require "util.throttle".create;
 
-local zones, throttles = module:shared("zones", "throttles");
-local active_zones, active_throttles = {}, {};
+local definitions = module:shared("definitions");
+local active_definitions = {};
 
 local chains = {
 	preroute = {
@@ -86,8 +85,8 @@
 	throttle = {
 		global_code = function (throttle)
 			assert(idsafe(throttle), "Invalid rate limit name: "..throttle);
-			assert(throttles[throttle], "Unknown rate limit: "..throttle);
-			return ("local throttle_%s = throttles.%s;"):format(throttle, throttle);
+			assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle);
+			return ("local throttle_%s = rates.%s;"):format(throttle, throttle);
 		end;
 	};
 };
@@ -126,6 +125,7 @@
 	code.included_deps[dep] = true;
 end
 
+local definition_handlers = module:require("definitions");
 local condition_handlers = module:require("conditions");
 local action_handlers = module:require("actions");
 
@@ -183,25 +183,37 @@
 				return nil, errmsg("Only event chains supported at the moment");
 			end
 			ruleset[chain] = ruleset[chain] or {};
-		elseif not(state) and line:match("^ZONE ") then
-			local zone_name = line:match("^ZONE ([^:]+)");
-			if not zone_name:match("^%a[%w_]*$") then
-				return nil, errmsg("Invalid character(s) in zone name: "..zone_name);
-			end
-			local zone_members = line:match("^ZONE .-: ?(.*)");
-			local zone_member_list = {};
-			for member in zone_members:gmatch("[^, ]+") do
-				zone_member_list[#zone_member_list+1] = member;
+		elseif not(state) and line:match("^%%") then -- Definition (zone, limit, etc.)
+			local what, name = line:match("^%%%s*(%w+) +([^ :]+)");
+			if not definition_handlers[what] then
+				return nil, errmsg("Definition of unknown object: "..what);
+			elseif not name or not idsafe(name) then
+				return nil, errmsg("Invalid "..what.." name");
 			end
-			zones[zone_name] = set.new(zone_member_list)._items;
-			table.insert(active_zones, zone_name);
-		elseif not(state) and line:match("^RATE ") then
-			local name = line:match("^RATE ([^:]+)");
-			assert(idsafe(name), "Invalid rate limit name: "..name);
-			local rate = assert(tonumber(line:match(":%s*([%d.]+)")), "Unable to parse rate");
-			local burst = tonumber(line:match("%(%s*burst%s+([%d.]+)%s*%)")) or 1;
-			throttles[name] = new_throttle(rate*burst, burst);
-			table.insert(active_throttles, name);
+
+			local val = line:match(": ?(.*)$");
+			if not val and line:match(":<") then -- Read from file
+				local fn = line:match(":< ?(.-)%s*$");
+				if not fn then
+					return nil, errmsg("Unable to parse filename");
+				end
+				local f, err = io.open(fn);
+				if not f then return nil, errmsg(err); end
+				val = f:read("*a"):gsub("\r?\n", " "):gsub("%s+5", "");
+			end
+			if not val then
+				return nil, errmsg("No value given for definition");
+			end
+
+			local ok, ret = pcall(definition_handlers[what], name, val);
+			if not ok then
+				return nil, errmsg(ret);
+			end
+
+			if not active_definitions[what] then
+				active_definitions[what] = {};
+			end
+			active_definitions[what][name] = ret;
 		elseif line:match("^[^%s:]+[%.=]") then
 			-- Action
 			if state == nil then
@@ -295,7 +307,11 @@
 			table.insert(code, rule_code);
 		end
 
-		local code_string = [[return function (zones, throttles, fire_event, log)
+		for name in pairs(definition_handlers) do
+			table.insert(code.global_header, 1, "local "..name:lower().."s = definitions."..name..";");
+		end
+
+		local code_string = [[return function (definitions, fire_event, log)
 			]]..table.concat(code.global_header, "\n")..[[
 			local db = require 'util.debug'
 			return function (event)
@@ -321,17 +337,12 @@
 	local function fire_event(name, data)
 		return module:fire_event(name, data);
 	end
-	chunk = chunk()(zones, throttles, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue.
+	chunk = chunk()(active_definitions, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue.
 	return chunk;
 end
 
-local function cleanup(t, active_list)
-	local unused = set.new(it.to_array(it.keys(t))) - set.new(active_list);
-	for k in unused do t[k] = nil; end
-end
-
 function module.load()
-	active_zones, active_throttles = {}, {};
+	active_definitions = {};
 	local firewall_scripts = module:get_option_set("firewall_scripts", {});
 	for script in firewall_scripts do
 		script = resolve_relative_path(prosody.paths.config, script);
@@ -358,7 +369,7 @@
 			end
 		end
 	end
-	-- Remove entries from tables that are no longer in use
-	cleanup(zones, active_zones);
-	cleanup(throttles, active_throttles);
+	-- Replace contents of definitions table (shared) with active definitions
+	for k in it.keys(definitions) do definitions[k] = nil; end
+	for k,v in pairs(active_definitions) do definitions[k] = v; end
 end