Skip to content

Commit

Permalink
Merge pull request #124 from o-lim/ref-matcher
Browse files Browse the repository at this point in the history
Add ref matcher to match against object references
  • Loading branch information
DorianGray committed Jun 19, 2015
2 parents 7b2aa81 + 4a360a9 commit 4bee87f
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 15 deletions.
24 changes: 24 additions & 0 deletions spec/matchers_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,30 @@ describe("Test Matchers", function()
assert.is_false(match.is_same(nil)("a string"))
end)

it("Checks ref() matcher", function()
local t = {}
local func = function() end
local mythread = coroutine.create(func)
assert.is.error(function() match.is_ref() end) -- minimum 1 arguments
assert.is.error(function() match.is_ref(0) end) -- arg1 must be an object
assert.is.error(function() match.is_ref('') end) -- arg1 must be an object
assert.is.error(function() match.is_ref(nil) end) -- arg1 must be an object
assert.is.error(function() match.is_ref(true) end) -- arg1 must be an object
assert.is.error(function() match.is_ref(false) end) -- arg1 must be an object
assert.is_true(match.is_ref(t)(t))
assert.is_true(match.is_ref(func)(func))
assert.is_true(match.is_ref(mythread)(mythread))
assert.is_false(match.is_ref(t)(func))
assert.is_false(match.is_ref(t)(mythread))
assert.is_false(match.is_ref(t)(nil))
assert.is_false(match.is_ref(t)(true))
assert.is_false(match.is_ref(t)(false))
assert.is_false(match.is_ref(t)(123))
assert.is_false(match.is_ref(t)(""))
assert.is_false(match.is_ref(t)({}))
assert.is_false(match.is_ref(t)(function() end))
end)

it("Checks matches() matcher does string matching", function()
assert.is.error(function() match.matches() end) -- minimum 1 arguments
assert.is.error(function() match.matches({}) end) -- arg1 must be a string
Expand Down
13 changes: 13 additions & 0 deletions spec/spies_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ describe("Tests dealing with spies", function()
assert.has_error(function() assert.spy(s).was.called_with(5, 6) end)
end)

it("checks called() and called_with() assertions using refs", function()
local s = spy.new(function() end)
local t1 = { foo = { bar = { "test" } } }
local t2 = { foo = { bar = { "test" } } }

s(t1)
t1.foo.bar = "value"

assert.spy(s).was.called_with(t2)
assert.spy(s).was_not.called_with(match.is_ref(t2))
assert.spy(s).was.called_with(match.is_ref(t1))
end)

it("checks called_with(aspy) assertions", function()
local s = spy.new(function() end)

Expand Down
17 changes: 17 additions & 0 deletions spec/stub_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,21 @@ describe("Tests dealing with stubs", function()
assert.is.equal("foo foo", foofoo)
end)

it("on_call_with matches arguments using refs", function()
local t1 = { foo = { bar = { "test" } } }
local t2 = { foo = { bar = { "test" } } }
stub(test, "key").returns("foo foo")
test.key.on_call_with(match.is_ref(t1)).returns("bar")
t1.foo.bar = "value"
t2.foo.bar = "value"

local bar = test.key(t1)
local foo = test.key(t2)
local foofoo = test.key({ foo = { bar = { "test" } } })

assert.is.equal("bar", bar)
assert.is.equal("foo foo", foo)
assert.is.equal("foo foo", foofoo)
end)

end)
6 changes: 6 additions & 0 deletions src/match.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ local state_mt = {
arguments.n = select('#', ...) -- add argument count for trailing nils
local matches = matcher.callback(self, arguments, util.errorlevel())
return setmetatable({
name = matcher.name,
mod = self.mod,
callback = matches,
}, matcher_mt)
Expand Down Expand Up @@ -63,6 +64,11 @@ local match = {
is_matcher = function(object)
return type(object) == "table" and getmetatable(object) == matcher_mt
end,

is_ref_matcher = function(object)
local ismatcher = (type(object) == "table" and getmetatable(object) == matcher_mt)
return ismatcher and object.name == "ref"
end,
}

local mt = {
Expand Down
17 changes: 15 additions & 2 deletions src/matchers/core.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
-- returns; function (or callable object); a function that, given an argument, returns a boolean

local assert = require('luassert.assert')
local astate = require ('luassert.state')
local util = require ('luassert.util')
local astate = require('luassert.state')
local util = require('luassert.util')
local s = require('say')

local function format(val)
Expand Down Expand Up @@ -98,6 +98,18 @@ local function same(state, arguments, level)
end
end

local function ref(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
local argtype = type(arguments[1])
local isobject = (argtype == "table" or argtype == "function" or argtype == "thread" or argtype == "userdata")
assert(argcnt > 0, s("assertion.internal.argtolittle", { "ref", 1, tostring(argcnt) }), level)
assert(isobject, s("assertion.internal.badargtype", { 1, "ref", "object", argtype }), level)
return function(value)
return value == arguments[1]
end
end

local function is_true(state, arguments, level)
return function(value)
return value == true
Expand Down Expand Up @@ -150,6 +162,7 @@ assert:register("matcher", "function", is_function)
assert:register("matcher", "userdata", is_userdata)
assert:register("matcher", "thread", is_thread)

assert:register("matcher", "ref", ref)
assert:register("matcher", "same", same)
assert:register("matcher", "matches", matches)
assert:register("matcher", "match", matches)
Expand Down
4 changes: 2 additions & 2 deletions src/spy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ local spy_mt = {
__call = function(self, ...)
local arguments = {...}
arguments.n = select('#',...) -- add argument count for trailing nils
table.insert(self.calls, util.deepcopy(arguments))
table.insert(self.calls, util.copyargs(arguments))
local function get_returns(...)
local returnvals = {...}
returnvals.n = select('#',...) -- add argument count for trailing nils
table.insert(self.returnvals, util.deepcopy(returnvals))
table.insert(self.returnvals, util.copyargs(returnvals))
return ...
end
return get_returns(self.callback(...))
Expand Down
3 changes: 2 additions & 1 deletion src/stub.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ function stub.new(object, key, ...)
}

s.on_call_with = function(...)
local match_args = util.copyargs({...})
local match_args = {...}
match_args.n = select('#', ...)
match_args = util.copyargs(match_args)
return {
returns = function(...)
local return_args = {...}
Expand Down
24 changes: 14 additions & 10 deletions src/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function util.deepcopy(t, deepmt, cache)
end

-----------------------------------------------
-- Copies arguments in a of arguments
-- Copies arguments as a list of arguments
-- @param args the arguments of which to copy
-- @return the copy of the arguments
function util.copyargs(args)
Expand All @@ -96,7 +96,7 @@ function util.copyargs(args)
for k,v in pairs(args) do
copy[k] = ((match.is_matcher(v) or spy.is_spy(v)) and v or util.deepcopy(v))
end
return copy
return { vals = copy, refs = util.shallowcopy(args) }
end

-----------------------------------------------
Expand All @@ -105,32 +105,36 @@ end
-- @param args the arguments of which to find a match
-- @return the matching arguments if a match is found, otherwise nil
function util.matchargs(argslist, args)
local function matches(t1, t2)
local function matches(t1, t2, t1refs)
local match = require 'luassert.match'
for k1,v1 in pairs(t1) do
local v2 = t2[k1]
if match.is_matcher(v1) then
if not v1(v2) then return false end
elseif match.is_matcher(v2) then
if match.is_ref_matcher(v2) then v1 = t1refs[k1] end
if not v2(v1) then return false end
elseif (v2 == nil or not util.deepcompare(v1,v2)) then
return false
end
end
for k2,v2 in pairs(t2) do
-- only check wether each element has a t1 counterpart, actual comparison
-- has been done in first loop above
local v1 = t1[k2]
if match.is_matcher(v1) then
if not v1(v2) then return false end
elseif match.is_matcher(v2) then
if not v2(v1) then return false end
elseif v1 == nil then
return false
if v1 == nil then
-- no t1 counterpart, so try to compare using matcher
if match.is_matcher(v2) then
if not v2(v1) then return false end
else
return false
end
end
end
return true
end
for k,v in ipairs(argslist) do
if matches(v, args) then
if matches(v.vals, args, v.refs) then
return v
end
end
Expand Down

0 comments on commit 4bee87f

Please sign in to comment.