Skip to content

Commit

Permalink
Move most of test_utils into TestExt
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Nov 27, 2024
1 parent 5bc980a commit dea4835
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 46 deletions.
11 changes: 2 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
Expand All @@ -31,6 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
Expand All @@ -39,6 +39,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLTestExt = ["Test"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
Expand Down Expand Up @@ -67,11 +68,3 @@ ReverseDiff = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.10"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
11 changes: 11 additions & 0 deletions ext/DynamicPPLTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module DynamicPPLTestExt

using DynamicPPL
using AbstractMCMC
using Test

include("DynamicPPLTestExt/contexts.jl")
include("DynamicPPLTestExt/varinfo.jl")
include("DynamicPPLTestExt/sampler.jl")

end
42 changes: 23 additions & 19 deletions src/test_utils/contexts.jl → ext/DynamicPPLTestExt/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Test that `context` implements the `AbstractContext` interface.
"""
function test_context_interface(context)
function DynamicPPL.TestUtils.test_context_interface(context)
# Is a subtype of `AbstractContext`.
@test context isa DynamicPPL.AbstractContext
# Should implement `NodeTrait.`
Expand All @@ -21,41 +21,45 @@ function test_context_interface(context)
end
end

"""
Context that multiplies each log-prior by mod
used to test whether varwise_logpriors respects child-context.
"""
struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
mod::T
context::Ctx
end
function TestLogModifyingChildContext(
function DynamicPPL.TestUtils.TestLogModifyingChildContext(
mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext()
)
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context)
return DynamicPPL.TestUtils.TestLogModifyingChildContext{typeof(mod),typeof(context)}(
mod, context
)
end

DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
return TestLogModifyingChildContext(context.mod, child)
function DynamicPPL.NodeTrait(::DynamicPPL.TestUtils.TestLogModifyingChildContext)
return DynamicPPL.IsParent()
end
function DynamicPPL.childcontext(context::DynamicPPL.TestUtils.TestLogModifyingChildContext)
return context.context
end
function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi)
function DynamicPPL.setchildcontext(
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, child
)
return DynamicPPL.TestUtils.TestLogModifyingChildContext(context.mod, child)
end
function DynamicPPL.tilde_assume(
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, vn, vi
)
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
return value, logp * context.mod, vi
end
function DynamicPPL.dot_tilde_assume(
context::TestLogModifyingChildContext, right, left, vn, vi
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vn, vi
)
value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi)
return value, logp * context.mod, vi
end
function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
function DynamicPPL.tilde_observe(
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vi
)
logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
return logp * context.mod, vi
end
function DynamicPPL.dot_tilde_observe(
context::TestLogModifyingChildContext, right, left, vi
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vi
)
logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi)
return logp * context.mod, vi
Expand Down
15 changes: 9 additions & 6 deletions src/test_utils/sampler.jl → ext/DynamicPPLTestExt/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
Return the mean of variable represented by `varname` in `chain`.
"""
marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)]))
DynamicPPL.TestUtils.marginal_mean_of_samples(chain, varname) =
mean(Array(chain[Symbol(varname)]))

"""
test_sampler(models, sampler, args...; kwargs...)
Expand All @@ -35,7 +36,7 @@ To change how comparison is done for a particular `chain` type, one can overload
- `rtol=1e-3`: Relative tolerance used in `@test`.
- `kwargs...`: Keyword arguments forwarded to `sample`.
"""
function test_sampler(
function DynamicPPL.TestUtils.test_sampler(
models,
sampler::AbstractMCMC.AbstractSampler,
args...;
Expand All @@ -51,7 +52,7 @@ function test_sampler(
for vn in filter(varnames_filter, varnames(model))
# We want to compare elementwise which can be achieved by
# extracting the leaves of the `VarName` and the corresponding value.
for vn_leaf in varname_leaves(vn, get(target_values, vn))
for vn_leaf in DynamicPPL.varname_leaves(vn, get(target_values, vn))
target_value = get(target_values, vn_leaf)
chain_mean_value = marginal_mean_of_samples(chain, vn_leaf)
@test chain_mean_value target_value atol = atol rtol = rtol
Expand All @@ -67,10 +68,10 @@ Test `sampler` on every model in [`DEMO_MODELS`](@ref).
This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`.
"""
function test_sampler_on_demo_models(
function DynamicPPL.TestUtils.test_sampler_on_demo_models(
sampler::AbstractMCMC.AbstractSampler, args...; kwargs...
)
return test_sampler(DEMO_MODELS, sampler, args...; kwargs...)
return test_sampler(DynamicPPL.TestUtils.DEMO_MODELS, sampler, args...; kwargs...)
end

"""
Expand All @@ -80,6 +81,8 @@ Test that `sampler` produces the correct marginal posterior means on all models
As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref).
"""
function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...)
function DynamicPPL.TestUtils.test_sampler_continuous(
sampler::AbstractMCMC.AbstractSampler, args...; kwargs...
)
return test_sampler_on_demo_models(sampler, args...; kwargs...)
end
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`.
"""
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...)
function DynamicPPL.TestUtils.test_values(
vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...
)
for vn in vns
@test compare(vi[vn], get(vals, vn); kwargs...)
end
Expand All @@ -23,7 +25,7 @@ each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` i
If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
of the varinfo instances.
"""
function setup_varinfos(
function DynamicPPL.TestUtils.setup_varinfos(
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
)
# VarInfo
Expand Down Expand Up @@ -58,7 +60,7 @@ function setup_varinfos(
svi_vnv_ref,
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
DynamicPPL.setlogp!!(DynamicPPL.update_values!!(vi, example_values, varnames), lp)
end

if include_threadsafe
Expand Down
32 changes: 23 additions & 9 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@
module TestUtils

using AbstractMCMC
using DynamicPPL
using LinearAlgebra
using Distributions
using Test

using Random: Random
using Bijectors: Bijectors
using Accessors: Accessors

# For backwards compat.
using DynamicPPL: varname_leaves, update_values!!

include("test_utils/model_interface.jl")
include("test_utils/models.jl")
include("test_utils/contexts.jl")
include("test_utils/varinfo.jl")
include("test_utils/sampler.jl")

##############################################################
# The remainder of this file contains skeleton implementations for
# DynamicPPLTestExt
##############################################################

function test_context_interface end

"""
Context that multiplies each log-prior by mod
used to test whether varwise_logpriors respects child-context.
"""
struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
mod::T
context::Ctx
end

function marginal_mean_of_samples end
function test_sampler end
function test_sampler_on_demo_models end
function test_sampler_continuous end
function test_values end
function setup_varinfos end

end

0 comments on commit dea4835

Please sign in to comment.