comparison mod_auth_sql/mod_auth_sql.lua @ 371:c416db434e5b

Do not run in transaction. Code cleanup. Changed logging to module logging. Properly count SQL result rows.
author Tomasz Sterna <tomek@xiaoka.com>
date Tue, 26 Apr 2011 19:28:08 +0200
parents a6dee73a11e7
children fdd4f5ab029a
comparison
equal deleted inserted replaced
370:16da8cd69715 371:c416db434e5b
6 local new_sasl = require "util.sasl".new; 6 local new_sasl = require "util.sasl".new;
7 local nodeprep = require "util.encodings".stringprep.nodeprep; 7 local nodeprep = require "util.encodings".stringprep.nodeprep;
8 8
9 local DBI; 9 local DBI;
10 local connection; 10 local connection;
11 local host,user,store = module.host;
12 local params = module:get_option("sql"); 11 local params = module:get_option("sql");
13 12
14 local resolve_relative_path = require "core.configmanager".resolve_relative_path; 13 local resolve_relative_path = require "core.configmanager".resolve_relative_path;
15 14
16 local function test_connection() 15 local function test_connection()
34 if not dbh then 33 if not dbh then
35 module:log("debug", "Database connection failed: %s", tostring(err)); 34 module:log("debug", "Database connection failed: %s", tostring(err));
36 return nil, err; 35 return nil, err;
37 end 36 end
38 module:log("debug", "Successfully connected to database"); 37 module:log("debug", "Successfully connected to database");
39 dbh:autocommit(false); -- don't commit automatically 38 dbh:autocommit(true); -- don't run in transaction
40 connection = dbh; 39 connection = dbh;
41 return connection; 40 return connection;
42 end 41 end
43 end 42 end
44 43
58 57
59 local function getsql(sql, ...) 58 local function getsql(sql, ...)
60 if params.driver == "PostgreSQL" then 59 if params.driver == "PostgreSQL" then
61 sql = sql:gsub("`", "\""); 60 sql = sql:gsub("`", "\"");
62 end 61 end
63 if not test_connection() then connect() end 62 if not test_connection() then connect(); end
64 -- do prepared statement stuff 63 -- do prepared statement stuff
65 local stmt, err = connection:prepare(sql); 64 local stmt, err = connection:prepare(sql);
66 if not stmt and not test_connection() then error("connection failed"); end 65 if not stmt and not test_connection() then error("connection failed"); end
67 if not stmt then module:log("error", "QUERY FAILED: %s %s", err, debug.traceback()); return nil, err; end 66 if not stmt then module:log("error", "QUERY FAILED: %s %s", err, debug.traceback()); return nil, err; end
68 -- run query 67 -- run query
73 return stmt; 72 return stmt;
74 end 73 end
75 74
76 function new_default_provider(host) 75 function new_default_provider(host)
77 local provider = { name = "sql" }; 76 local provider = { name = "sql" };
78 log("debug", "initializing default authentication provider for host '%s'", host); 77 module:log("debug", "initializing default authentication provider for host '%s'", host);
79 78
80 function provider.test_password(username, password) 79 function provider.test_password(username, password)
81 log("debug", "test password '%s' for user %s at host %s", password, username, module.host); 80 module:log("debug", "test_password '%s' for user %s at host %s", password, username, host);
82 81
83 local stmt, err = getsql("SELECT `username` FROM `authreg` WHERE `username`=? AND `password`=? AND `realm`=?", 82 local stmt, err = getsql("SELECT `username` FROM `authreg` WHERE `username`=? AND `password`=? AND `realm`=?",
84 username, password, module.host); 83 username, password, host);
85 84
86 if stmt ~= nil then 85 if stmt ~= nil then
87 if #stmt:rows(true) > 0 then 86 local count = 0;
87 for row in stmt:rows(true) do
88 count = count + 1;
89 end
90 if count > 0 then
88 return true; 91 return true;
89 end 92 end
90 else 93 else
91 log("error", "QUERY ERROR: %s %s", err, debug.traceback()); 94 module:log("error", "QUERY ERROR: %s %s", err, debug.traceback());
92 return nil, err; 95 return nil, err;
93 end 96 end
94 97
95 return false; 98 return false;
96 end 99 end
97 100
98 function provider.get_password(username) 101 function provider.get_password(username)
99 log("debug", "get_password for username '%s' at host '%s'", username, module.host); 102 module:log("debug", "get_password for username '%s' at host '%s'", username, host);
100 103
101 local stmt, err = getsql("SELECT `password` FROM `authreg` WHERE `username`=? AND `realm`=?", 104 local stmt, err = getsql("SELECT `password` FROM `authreg` WHERE `username`=? AND `realm`=?",
102 username, module.host); 105 username, host);
103 106
104 local password = nil; 107 local password = nil;
105 if stmt ~= nil then 108 if stmt ~= nil then
106 for row in stmt:rows(true) do 109 for row in stmt:rows(true) do
107 password = row.password; 110 password = row.password;
108 end 111 end
109 else 112 else
110 log("error", "QUERY ERROR: %s %s", err, debug.traceback()); 113 module:log("error", "QUERY ERROR: %s %s", err, debug.traceback());
111 return nil; 114 return nil;
112 end 115 end
113 116
114 return password; 117 return password;
115 end 118 end
117 function provider.set_password(username, password) 120 function provider.set_password(username, password)
118 return nil, "Setting password is not supported."; 121 return nil, "Setting password is not supported.";
119 end 122 end
120 123
121 function provider.user_exists(username) 124 function provider.user_exists(username)
122 log("debug", "test user %s existence at host %s", username, module.host); 125 module:log("debug", "test user %s existence at host %s", username, host);
123 126
124 local stmt, err = getsql("SELECT `username` FROM `authreg` WHERE `username`=? AND `realm`=?", 127 local stmt, err = getsql("SELECT `username` FROM `authreg` WHERE `username`=? AND `realm`=?",
125 username, module.host); 128 username, host);
126 129
127 if stmt ~= nil then 130 if stmt ~= nil then
128 if #stmt:rows(true) > 0 then 131 local count = 0;
132 for row in stmt:rows(true) do
133 count = count + 1;
134 end
135 if count > 0 then
129 return true; 136 return true;
130 end 137 end
131 else 138 else
132 log("error", "QUERY ERROR: %s %s", err, debug.traceback()); 139 module:log("error", "QUERY ERROR: %s %s", err, debug.traceback());
133 return nil, err; 140 return nil, err;
134 end 141 end
135 142
136 return false; 143 return false;
137 end 144 end
139 function provider.create_user(username, password) 146 function provider.create_user(username, password)
140 return nil, "Account creation/modification not supported."; 147 return nil, "Account creation/modification not supported.";
141 end 148 end
142 149
143 function provider.get_sasl_handler() 150 function provider.get_sasl_handler()
144 local realm = module:get_option("sasl_realm") or module.host; 151 local realm = module:get_option("sasl_realm") or host;
145 local getpass_authentication_profile = { 152 local getpass_authentication_profile = {
146 plain = function(sasl, username, realm) 153 plain = function(sasl, username, realm)
147 local prepped_username = nodeprep(username); 154 local prepped_username = nodeprep(username);
148 if not prepped_username then 155 if not prepped_username then
149 log("debug", "NODEprep failed on username: %s", username); 156 module:log("debug", "NODEprep failed on username: %s", username);
150 return "", nil; 157 return "", nil;
151 end 158 end
152 local password = usermanager.get_password(prepped_username, realm); 159 local password = usermanager.get_password(prepped_username, realm);
153 if not password then 160 if not password then
154 return "", nil; 161 return "", nil;