changeset 971:53e158e44a44

mod_firewall: Add rate limiting capabilities, and keep zones and throttle objects in shared tables
author Matthew Wild <mwild1@gmail.com>
date Sat, 06 Apr 2013 22:20:59 +0100
parents adcb751f22f3
children 61b63affd402
files mod_firewall/conditions.lib.lua mod_firewall/mod_firewall.lua
diffstat 2 files changed, 39 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/mod_firewall/conditions.lib.lua	Sat Apr 06 21:47:46 2013 +0200
+++ b/mod_firewall/conditions.lib.lua	Sat Apr 06 22:20:59 2013 +0100
@@ -170,4 +170,8 @@
 	return table.concat(conditions, " or "), { "time:hour,min" };
 end
 
+function condition_handlers.LIMIT(name)
+	return ("not throttle_%s:poll(1)"):format(name), { "throttle:"..name };
+end
+
 return condition_handlers;
--- a/mod_firewall/mod_firewall.lua	Sat Apr 06 21:47:46 2013 +0200
+++ b/mod_firewall/mod_firewall.lua	Sat Apr 06 22:20:59 2013 +0100
@@ -2,17 +2,12 @@
 local resolve_relative_path = require "core.configmanager".resolve_relative_path;
 local logger = require "util.logger".init;
 local set = require "util.set";
+local it = require "util.iterators";
 local add_filter = require "util.filters".add_filter;
+local new_throttle = require "util.throttle".create;
 
-zones = {};
-local zones = zones;
-setmetatable(zones, {
-	__index = function (zones, zone)
-		local t = { [zone] = true };
-		rawset(zones, zone, t);
-		return t;
-	end;
-});
+local zones, throttles = module:shared("zones", "throttles");
+local active_zones, active_throttles = {}, {};
 
 local chains = {
 	preroute = {
@@ -35,6 +30,10 @@
 	};
 };
 
+local function idsafe(name)
+	return not not name:match("^%a[%w_]*$")
+end
+
 -- Dependency locations:
 -- <type lib>
 -- <type global>
@@ -73,7 +72,7 @@
 	is_admin = { global_code = [[local is_admin = require "core.usermanager".is_admin]]};
 	core_post_stanza = { global_code = [[local core_post_stanza = prosody.core_post_stanza]] };
 	zone = { global_code = function (zone)
-		assert(zone:match("^%a[%w_]*$"), "Invalid zone name: "..zone);
+		assert(idsafe(zone), "Invalid zone name: "..zone);
 		return ("local zone_%s = zones[%q] or {};"):format(zone, zone);
 	end };
 	date_time = { global_code = [[local os_date = os.date]]; local_code = [[local current_date_time = os_date("*t");]] };
@@ -84,6 +83,13 @@
 		end
 		return table.concat(defs, " ");
 	end, depends = { "date_time" }; };
+	throttle = {
+		global_code = function (throttle)
+			assert(idsafe(throttle), "Invalid rate limit name: "..throttle);
+			assert(throttles[throttle], "Unknown rate limit: "..throttle);
+			return ("local throttle_%s = throttles.%s;"):format(throttle, throttle);
+		end;
+	};
 };
 
 local function include_dep(dep, code)
@@ -188,6 +194,14 @@
 				zone_member_list[#zone_member_list+1] = member;
 			end
 			zones[zone_name] = set.new(zone_member_list)._items;
+			table.insert(active_zones, zone_name);
+		elseif not(state) and line:match("^RATE ") then
+			local name = line:match("^RATE ([^:]+)");
+			assert(idsafe(name), "Invalid rate limit name: "..name);
+			local rate = assert(tonumber(line:match(":%s*([%d.]+)")), "Unable to parse rate");
+			local burst = tonumber(line:match("%(%s*burst%s+([%d.]+)%s*%)")) or 1;
+			throttles[name] = new_throttle(rate*burst, burst);
+			table.insert(active_throttles, name);
 		elseif line:match("^[^%s:]+[%.=]") then
 			-- Action
 			if state == nil then
@@ -265,7 +279,7 @@
 			table.insert(code, rule_code);
 		end
 
-		local code_string = [[return function (zones, fire_event, log)
+		local code_string = [[return function (zones, throttles, fire_event, log)
 			]]..table.concat(code.global_header, "\n")..[[
 			local db = require 'util.debug'
 			return function (event)
@@ -291,11 +305,17 @@
 	local function fire_event(name, data)
 		return module:fire_event(name, data);
 	end
-	chunk = chunk()(zones, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue.
+	chunk = chunk()(zones, throttles, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue.
 	return chunk;
 end
 
+local function cleanup(t, active_list)
+	local unused = set.new(it.to_array(it.keys(t))) - set.new(active_list);
+	for k in unused do t[k] = nil; end
+end
+
 function module.load()
+	active_zones, active_throttles = {}, {};
 	local firewall_scripts = module:get_option_set("firewall_scripts", {});
 	for script in firewall_scripts do
 		script = resolve_relative_path(prosody.paths.config, script);
@@ -322,4 +342,7 @@
 			end
 		end
 	end
+	-- Remove entries from tables that are no longer in use
+	cleanup(zones, active_zones);
+	cleanup(throttles, active_throttles);
 end