diff mod_s2s_auth_dane/mod_s2s_auth_dane.lua @ 1351:a052740bbf48

mod_s2s_auth_dane: Back to _port._tcp.srvtarget.example.net
author Kim Alvefur <zash@zash.se>
date Tue, 18 Mar 2014 15:12:11 +0100
parents cda335db2cbb
children b0f780d3a24e
line wrap: on
line diff
--- a/mod_s2s_auth_dane/mod_s2s_auth_dane.lua	Fri Mar 14 14:30:33 2014 +0100
+++ b/mod_s2s_auth_dane/mod_s2s_auth_dane.lua	Tue Mar 18 15:12:11 2014 +0100
@@ -19,6 +19,7 @@
 module:set_global();
 
 local type = type;
+local t_insert = table.insert;
 local set = require"util.set";
 local dns_lookup = require"net.adns".lookup;
 local hashes = require"util.hashes";
@@ -41,34 +42,62 @@
 local configured_uses = module:get_option_set("dane_uses", { "DANE-EE" });
 local enabled_uses = set.intersection(implemented_uses, configured_uses) / function(use) return use_map[use] end;
 
-local function dane_lookup(host_session, name, cb, a,b,c)
-	if host_session.dane ~= nil then return false; end
-	local ascii_host = name and idna_to_ascii(name);
-	if not ascii_host then return false; end
-	host_session.dane = dns_lookup(function(answer)
-		if answer and (answer.secure and #answer > 0) or answer.bogus then
-			host_session.dane = answer;
-		else
-			host_session.dane = false;
-		end
-		if cb then return cb(a,b,c); end
-	end, ("_xmpp-server.%s."):format(ascii_host), "TLSA");
-	host_session.connecting = true;
-	return true;
+local function dane_lookup(host_session, cb, a,b,c,e)
+	if host_session.dane ~= nil then return end
+	if host_session.direction == "incoming" then
+		local name = idna_to_ascii(host_session.from_host);
+		if not name then return end
+		local handle = dns_lookup(function (answer)
+			if not answer.secure then return end
+			if #answer == 1 and answer[1].srv.target == '.' then return end
+			local srv_hosts = { answer = answer };
+			local dane = {};
+			host_session.dane = dane;
+			host_session.srv_hosts = srv_hosts;
+			local n = #answer
+			for _, record in ipairs(answer) do
+				t_insert(srv_hosts, record.srv);
+				dns_lookup(function(dane_answer)
+					n = n - 1;
+					if dane_answer.bogus then
+						t_insert(dane, { bogus = dane_answer.bogus });
+					elseif dane_answer.secure then
+						for _, record in ipairs(dane_answer) do
+							t_insert(dane, record);
+						end
+					end
+					if n == 0 and cb then return cb(a,b,c,e); end
+				end, ("_%d._tcp.%s."):format(record.srv.port, record.srv.target), "TLSA");
+			end
+		end, "_xmpp-server._tcp."..name..".", "SRV");
+		return true;
+	elseif host_session.direction == "outgoing" then
+		local srv_choice = host_session.srv_hosts[host_session.srv_choice];
+		host_session.dane = dns_lookup(function(answer)
+			if answer and (answer.secure and #answer > 0) or answer.bogus then
+				srv_choice.dane = answer;
+			else
+				srv_choice.dane = false;
+			end
+			host_session.dane = srv_choice.dane;
+			if cb then return cb(a,b,c,e); end
+		end, ("_%d._tcp.%s."):format(srv_choice.port, srv_choice.target), "TLSA");
+		return true;
+	end
 end
 
-local _attempt_connection = s2sout.attempt_connection;
-function s2sout.attempt_connection(host_session, err)
-	if not err and dane_lookup(host_session, host_session.to_host, _attempt_connection, host_session, err) then
+local _try_connect = s2sout.try_connect;
+function s2sout.try_connect(host_session, connect_host, connect_port, err)
+	if not err and dane_lookup(host_session, _try_connect, host_session, connect_host, connect_port, err) then
 		return true;
 	end
-	return _attempt_connection(host_session, err);
+	return _try_connect(host_session, connect_host, connect_port, err);
 end
 
 function module.add_host(module)
 	module:hook("s2s-stream-features", function(event)
-		local origin = event.origin;
-		dane_lookup(origin, origin.from_host);
+		-- dane_lookup(origin, origin.from_host);
+		dane_lookup(event.origin);
 	end, 1);
 
 	module:hook("s2s-authenticated", function(event)
@@ -144,7 +173,7 @@
 end);
 
 function module.unload()
-	-- Restore the original attempt_connection function
-	s2sout.attempt_connection = _attempt_connection;
+	-- Restore the original try_connect function
+	s2sout.try_connect = _try_connect;
 end