Mercurial > prosody-modules
diff mod_anti_spam/trie.lib.lua @ 5859:259ffdbf8906
mod_anti_spam: New module for spam filtering (pre-alpha)
author | Matthew Wild <mwild1@gmail.com> |
---|---|
date | Tue, 05 Mar 2024 18:26:29 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/mod_anti_spam/trie.lib.lua Tue Mar 05 18:26:29 2024 +0000 @@ -0,0 +1,168 @@ +local bit = require "prosody.util.bitcompat"; + +local trie_methods = {}; +local trie_mt = { __index = trie_methods }; + +local function new_node() + return {}; +end + +function trie_methods:set(item, value) + local node = self.root; + for i = 1, #item do + local c = item:byte(i); + if not node[c] then + node[c] = new_node(); + end + node = node[c]; + end + node.terminal = true; + node.value = value; +end + +local function _remove(node, item, i) + if i > #item then + if node.terminal then + node.terminal = nil; + node.value = nil; + end + if next(node) ~= nil then + return node; + end + return nil; + end + local c = item:byte(i); + local child = node[c]; + local ret; + if child then + ret = _remove(child, item, i+1); + node[c] = ret; + end + if ret == nil and next(node) == nil then + return nil; + end + return node; +end + +function trie_methods:remove(item) + return _remove(self.root, item, 1); +end + +function trie_methods:get(item, partial) + local value; + local node = self.root; + local len = #item; + for i = 1, len do + if partial and node.terminal then + value = node.value; + end + local c = item:byte(i); + node = node[c]; + if not node then + return value, i - 1; + end + end + return node.value, len; +end + +function trie_methods:add(item) + return self:set(item, true); +end + +function trie_methods:contains(item, partial) + return self:get(item, partial) ~= nil; +end + +function trie_methods:longest_prefix(item) + return select(2, self:get(item)); +end + +function trie_methods:add_subnet(item, bits) + item = item.packed:sub(1, math.ceil(bits/8)); + local existing = self:get(item); + if not existing then + existing = { bits }; + return self:set(item, existing); + end + + -- Simple insertion sort + for i = 1, #existing do + local v = existing[i]; + if v == bits then + return; -- Already in there + elseif v > bits then + table.insert(existing, v, i); + return; + end + end +end + +function trie_methods:remove_subnet(item, bits) + item = item.packed:sub(1, math.ceil(bits/8)); + local existing = self:get(item); + if not existing then + return; + end + + -- Simple insertion sort + for i = 1, #existing do + local v = existing[i]; + if v == bits then + table.remove(existing, i); + break; + elseif v > bits then + return; -- Stop search + end + end + + if #existing == 0 then + self:remove(item); + end +end + +function trie_methods:has_ip(item) + item = item.packed; + local node = self.root; + local len = #item; + for i = 1, len do + if node.terminal then + return true; + end + + local c = item:byte(i); + local child = node[c]; + if not child then + for child_byte, child_node in pairs(node) do + if type(child_byte) == "number" and child_node.terminal then + local bits = child_node.value; + for j = #bits, 1, -1 do + local b = bits[j]-((i-1)*8); + if b ~= 8 then + local mask = bit.bnot(2^b-1); + if bit.band(bit.bxor(c, child_byte), mask) == 0 then + return true; + end + end + end + end + end + return false; + end + node = child; + end +end + +local function new() + return setmetatable({ + root = new_node(); + }, trie_mt); +end + +local function is_trie(o) + return getmetatable(o) == trie_mt; +end + +return { + new = new; + is_trie = is_trie; +};