comparison mod_firewall/mod_firewall.lua @ 999:197af8440ffb

mod_firewall: Make defining objects generic (currently zones and rate limits), so more can easily be added. Also a syntax change... definition lines must begin with %
author Matthew Wild <mwild1@gmail.com>
date Tue, 07 May 2013 10:33:49 +0100
parents 6fdcebbd2284
children c0850793b716
comparison
equal deleted inserted replaced
998:6fdcebbd2284 999:197af8440ffb
2 local resolve_relative_path = require "core.configmanager".resolve_relative_path; 2 local resolve_relative_path = require "core.configmanager".resolve_relative_path;
3 local logger = require "util.logger".init; 3 local logger = require "util.logger".init;
4 local set = require "util.set"; 4 local set = require "util.set";
5 local it = require "util.iterators"; 5 local it = require "util.iterators";
6 local add_filter = require "util.filters".add_filter; 6 local add_filter = require "util.filters".add_filter;
7 local new_throttle = require "util.throttle".create; 7
8 8 local definitions = module:shared("definitions");
9 local zones, throttles = module:shared("zones", "throttles"); 9 local active_definitions = {};
10 local active_zones, active_throttles = {}, {};
11 10
12 local chains = { 11 local chains = {
13 preroute = { 12 preroute = {
14 type = "event"; 13 type = "event";
15 priority = 0.1; 14 priority = 0.1;
84 return table.concat(defs, " "); 83 return table.concat(defs, " ");
85 end, depends = { "date_time" }; }; 84 end, depends = { "date_time" }; };
86 throttle = { 85 throttle = {
87 global_code = function (throttle) 86 global_code = function (throttle)
88 assert(idsafe(throttle), "Invalid rate limit name: "..throttle); 87 assert(idsafe(throttle), "Invalid rate limit name: "..throttle);
89 assert(throttles[throttle], "Unknown rate limit: "..throttle); 88 assert(active_definitions.RATE[throttle], "Unknown rate limit: "..throttle);
90 return ("local throttle_%s = throttles.%s;"):format(throttle, throttle); 89 return ("local throttle_%s = rates.%s;"):format(throttle, throttle);
91 end; 90 end;
92 }; 91 };
93 }; 92 };
94 93
95 local function include_dep(dep, code) 94 local function include_dep(dep, code)
124 end 123 end
125 end 124 end
126 code.included_deps[dep] = true; 125 code.included_deps[dep] = true;
127 end 126 end
128 127
128 local definition_handlers = module:require("definitions");
129 local condition_handlers = module:require("conditions"); 129 local condition_handlers = module:require("conditions");
130 local action_handlers = module:require("actions"); 130 local action_handlers = module:require("actions");
131 131
132 local function new_rule(ruleset, chain) 132 local function new_rule(ruleset, chain)
133 assert(chain, "no chain specified"); 133 assert(chain, "no chain specified");
181 return nil, errmsg("Unknown chain: "..chain); 181 return nil, errmsg("Unknown chain: "..chain);
182 elseif chain_info.type ~= "event" then 182 elseif chain_info.type ~= "event" then
183 return nil, errmsg("Only event chains supported at the moment"); 183 return nil, errmsg("Only event chains supported at the moment");
184 end 184 end
185 ruleset[chain] = ruleset[chain] or {}; 185 ruleset[chain] = ruleset[chain] or {};
186 elseif not(state) and line:match("^ZONE ") then 186 elseif not(state) and line:match("^%%") then -- Definition (zone, limit, etc.)
187 local zone_name = line:match("^ZONE ([^:]+)"); 187 local what, name = line:match("^%%%s*(%w+) +([^ :]+)");
188 if not zone_name:match("^%a[%w_]*$") then 188 if not definition_handlers[what] then
189 return nil, errmsg("Invalid character(s) in zone name: "..zone_name); 189 return nil, errmsg("Definition of unknown object: "..what);
190 end 190 elseif not name or not idsafe(name) then
191 local zone_members = line:match("^ZONE .-: ?(.*)"); 191 return nil, errmsg("Invalid "..what.." name");
192 local zone_member_list = {}; 192 end
193 for member in zone_members:gmatch("[^, ]+") do 193
194 zone_member_list[#zone_member_list+1] = member; 194 local val = line:match(": ?(.*)$");
195 end 195 if not val and line:match(":<") then -- Read from file
196 zones[zone_name] = set.new(zone_member_list)._items; 196 local fn = line:match(":< ?(.-)%s*$");
197 table.insert(active_zones, zone_name); 197 if not fn then
198 elseif not(state) and line:match("^RATE ") then 198 return nil, errmsg("Unable to parse filename");
199 local name = line:match("^RATE ([^:]+)"); 199 end
200 assert(idsafe(name), "Invalid rate limit name: "..name); 200 local f, err = io.open(fn);
201 local rate = assert(tonumber(line:match(":%s*([%d.]+)")), "Unable to parse rate"); 201 if not f then return nil, errmsg(err); end
202 local burst = tonumber(line:match("%(%s*burst%s+([%d.]+)%s*%)")) or 1; 202 val = f:read("*a"):gsub("\r?\n", " "):gsub("%s+5", "");
203 throttles[name] = new_throttle(rate*burst, burst); 203 end
204 table.insert(active_throttles, name); 204 if not val then
205 return nil, errmsg("No value given for definition");
206 end
207
208 local ok, ret = pcall(definition_handlers[what], name, val);
209 if not ok then
210 return nil, errmsg(ret);
211 end
212
213 if not active_definitions[what] then
214 active_definitions[what] = {};
215 end
216 active_definitions[what][name] = ret;
205 elseif line:match("^[^%s:]+[%.=]") then 217 elseif line:match("^[^%s:]+[%.=]") then
206 -- Action 218 -- Action
207 if state == nil then 219 if state == nil then
208 -- This is a standalone action with no conditions 220 -- This is a standalone action with no conditions
209 rule = new_rule(ruleset, chain); 221 rule = new_rule(ruleset, chain);
293 .."\n end\n"; 305 .."\n end\n";
294 end 306 end
295 table.insert(code, rule_code); 307 table.insert(code, rule_code);
296 end 308 end
297 309
298 local code_string = [[return function (zones, throttles, fire_event, log) 310 for name in pairs(definition_handlers) do
311 table.insert(code.global_header, 1, "local "..name:lower().."s = definitions."..name..";");
312 end
313
314 local code_string = [[return function (definitions, fire_event, log)
299 ]]..table.concat(code.global_header, "\n")..[[ 315 ]]..table.concat(code.global_header, "\n")..[[
300 local db = require 'util.debug' 316 local db = require 'util.debug'
301 return function (event) 317 return function (event)
302 local stanza, session = event.stanza, event.origin; 318 local stanza, session = event.stanza, event.origin;
303 319
319 return nil, "Error compiling (probably a compiler bug, please report): "..err; 335 return nil, "Error compiling (probably a compiler bug, please report): "..err;
320 end 336 end
321 local function fire_event(name, data) 337 local function fire_event(name, data)
322 return module:fire_event(name, data); 338 return module:fire_event(name, data);
323 end 339 end
324 chunk = chunk()(zones, throttles, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue. 340 chunk = chunk()(active_definitions, fire_event, logger(filename)); -- Returns event handler with 'zones' upvalue.
325 return chunk; 341 return chunk;
326 end 342 end
327 343
328 local function cleanup(t, active_list)
329 local unused = set.new(it.to_array(it.keys(t))) - set.new(active_list);
330 for k in unused do t[k] = nil; end
331 end
332
333 function module.load() 344 function module.load()
334 active_zones, active_throttles = {}, {}; 345 active_definitions = {};
335 local firewall_scripts = module:get_option_set("firewall_scripts", {}); 346 local firewall_scripts = module:get_option_set("firewall_scripts", {});
336 for script in firewall_scripts do 347 for script in firewall_scripts do
337 script = resolve_relative_path(prosody.paths.config, script); 348 script = resolve_relative_path(prosody.paths.config, script);
338 local chain_functions, err = compile_firewall_rules(script) 349 local chain_functions, err = compile_firewall_rules(script)
339 350
356 module:hook("firewall/chains/"..chain, handler); 367 module:hook("firewall/chains/"..chain, handler);
357 end 368 end
358 end 369 end
359 end 370 end
360 end 371 end
361 -- Remove entries from tables that are no longer in use 372 -- Replace contents of definitions table (shared) with active definitions
362 cleanup(zones, active_zones); 373 for k in it.keys(definitions) do definitions[k] = nil; end
363 cleanup(throttles, active_throttles); 374 for k,v in pairs(active_definitions) do definitions[k] = v; end
364 end 375 end