view mod_auth_token/mock.lua @ 5185:09d6bbd6c8a4

mod_http_oauth2: Fix treatment of 'redirect_uri' parameter in code flow It's optional and the one stored in the client registration should really be used instead. RFC 6749 says an URI provided as parameter MUST be validated against the stored one but does not say how. Given that the client needs their secret to proceed, it seems fine to leave this for later.
author Kim Alvefur <zash@zash.se>
date Thu, 02 Mar 2023 22:00:42 +0100
parents d0ca211e1b0e
children
line wrap: on
line source

-- Source code taken from https://github.com/britzl/deftest
-- Released under the MIT License. Copyright (c) 2009-2012 Norman Clarke.

--- Provides the ability to mock any module.

-- @usage
--
-- mock.mock(sys)
--
-- -- specifying return values
-- sys.get_sys_info.returns({my_data})
-- ...
-- local sys_info = sys.get_sys_info() -- will be my_data
-- assert(sys.get_sys_info.calls == 1) -- call counting
-- ...
-- local sys_info = sys.get_sys_info() -- original response as we are now out of mocked answers
-- assert(sys.get_sys_info.calls == 2) -- call counting
-- ...
--
-- -- specifying a replacement function
-- sys.get_sys_info.replace(function () return my_data end)
--
-- ...
-- local sys_info = sys.get_sys_info() -- will be my_data
-- assert(sys.get_sys_info.calls == 3) -- call counting
-- ...
-- local sys_info = sys.get_sys_info() -- will still be my_data
-- assert(sys.get_sys_info.calls == 4) -- call counting
-- ...
--
-- -- cleaning up
-- mock.unmock(sys) -- restore the sys library again

local mock = {}

--- Mock the specified module.
-- Mocking the module extends the functions it contains with the ability to have their logic overridden.
-- @param module module to mock
-- @usage
--
-- -- mock module x
-- mock.mock(x)
--
-- -- make x.f return 1, 2 then the original value
-- x.f.returns({1, 2})
-- print(x.f()) -- prints 1
--
-- -- make x.f return 1 forever
-- x.f.replace(function () return 1 end)
-- while true do print(x.f()) end -- prints 1 forever
--
-- -- counting calls
-- assert(x.f.calls > 0)
--
-- -- return to original state of module x
-- mock.unmock(x)
--
function mock.mock(module)
	assert(module, "You must provide a module to mock")
	for k,v in pairs(module) do
		if type(v) == "function" then
			local mock_fn = {
				calls = 0,
				answers = {},
				repl_fn = nil,
				orig_fn = v,
				params = {}
			}
			function mock_fn.returns(...)
				local arg_length = select("#", ...)
				assert(arg_length > 0, "You must provide some answers")
				local args = { ... }
				if arg_length == 1 then
					mock_fn.answers = args[1]
				else
					mock_fn.answers = args
				end
			end
			function mock_fn.always_returns(answer)
				mock_fn.repl_fn = function()
					return answer
				end
			end
			function mock_fn.replace(repl_fn)
				mock_fn.repl_fn = repl_fn
			end
			function mock_fn.original(...)
				return mock_fn.orig_fn(...)
			end
			function mock_fn.restore()
				mock_fn.repl_fn = nil
			end
			local mt = {
				__call = function (mock_fn, ...)
					mock_fn.calls = mock_fn.calls + 1
					local arg = {...}

					if #arg > 0 then
						for i=1,#arg do
							mock_fn.params[i] = arg[i]
						end
					end

					if mock_fn.answers[1] then
						local result = mock_fn.answers[1]
						table.remove(mock_fn.answers, 1)
						return result
					elseif mock_fn.repl_fn then
						return mock_fn.repl_fn(...)
					else
						return v(...)
					end
				end
			}
			setmetatable(mock_fn, mt)
			module[k] = mock_fn
		end
	end
end

--- Remove the mocking capabilities from a module.
-- @param module module to remove mocking from
function mock.unmock(module)
	assert(module, "You must provide a module to unmock")
	for k,v in pairs(module) do
		if type(v) == "table" then
			if v.orig_fn then
				module[k] = v.orig_fn
			end
		end
	end
end

return mock