diff mod_firewall/definitions.lib.lua @ 2586:d28e434cb5fd

mod_firewall: Support filters for normalizing items before checking for them in lists
author Matthew Wild <mwild1@gmail.com>
date Sun, 26 Feb 2017 11:28:56 +0000
parents 22a271641c29
children 8c879948a2cf
line wrap: on
line diff
--- a/mod_firewall/definitions.lib.lua	Sun Feb 26 09:58:07 2017 +0000
+++ b/mod_firewall/definitions.lib.lua	Sun Feb 26 11:28:56 2017 +0000
@@ -8,6 +8,8 @@
 local timer = require "util.timer";
 local set = require"util.set";
 local new_throttle = require "util.throttle".create;
+local hashes = require "util.hashes";
+local jid = require "util.jid";
 
 local multirate_cache_size = module:get_option_number("firewall_multirate_cache_limit", 1000);
 
@@ -171,6 +173,18 @@
 };
 list_backends.https = list_backends.http;
 
+local normalize_functions = {
+	upper = string.upper, lower = string.lower;
+	md5 = hashes.md5, sha1 = hashes.sha1, sha256 = hashes.sha256;
+	prep = jid.prep, bare = jid.bare;
+};
+
+local function wrap_list_method(list_method, filter)
+	return function (self, item)
+		return list_method(self, filter(item));
+	end
+end
+
 local function create_list(list_backend, list_def, opts)
 	if not list_backends[list_backend] then
 		error("Unknown list type '"..list_backend.."'", 0);
@@ -179,6 +193,38 @@
 	if list.init then
 		list:init(list_def, opts);
 	end
+	if opts.filter then
+		local filters = {};
+		for func_name in opts.filter:gmatch("[%w_]+") do
+			if func_name == "log" then
+				table.insert(filters, function (s)
+					--print("&&&&&", s);
+					module:log("debug", "Checking list <%s> for: %s", list_def, s);
+					return s;
+				end);
+			else
+				assert(normalize_functions[func_name], "Unknown list filter: "..func_name);
+				table.insert(filters, normalize_functions[func_name]);
+			end
+		end
+
+		local filter;
+		local n = #filters;
+		if n == 1 then
+			filter = filters[1];
+		else
+			function filter(s)
+				for i = 1, n do
+					s = filters[i](s or "");
+				end
+				return s;
+			end
+		end
+
+		list.add = wrap_list_method(list.add, filter);
+		list.remove = wrap_list_method(list.remove, filter);
+		list.contains = wrap_list_method(list.contains, filter);
+	end
 	return list;
 end