Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move most of TestUtils into TestExt #723

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
Expand All @@ -31,6 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
Expand All @@ -39,6 +39,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLTestExt = ["Test"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
Expand Down Expand Up @@ -67,11 +68,3 @@ ReverseDiff = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.10"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
11 changes: 11 additions & 0 deletions ext/DynamicPPLTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module DynamicPPLTestExt

using DynamicPPL
using AbstractMCMC
using Test

include("DynamicPPLTestExt/contexts.jl")
include("DynamicPPLTestExt/varinfo.jl")
include("DynamicPPLTestExt/sampler.jl")

end
42 changes: 23 additions & 19 deletions src/test_utils/contexts.jl → ext/DynamicPPLTestExt/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

Test that `context` implements the `AbstractContext` interface.
"""
function test_context_interface(context)
function DynamicPPL.TestUtils.test_context_interface(context)
# Is a subtype of `AbstractContext`.
@test context isa DynamicPPL.AbstractContext
# Should implement `NodeTrait.`
Expand All @@ -21,41 +21,45 @@ function test_context_interface(context)
end
end

"""
Context that multiplies each log-prior by mod
used to test whether varwise_logpriors respects child-context.
"""
struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
mod::T
context::Ctx
end
function TestLogModifyingChildContext(
function DynamicPPL.TestUtils.TestLogModifyingChildContext(
mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext()
)
return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context)
return DynamicPPL.TestUtils.TestLogModifyingChildContext{typeof(mod),typeof(context)}(
mod, context
)
end

DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
return TestLogModifyingChildContext(context.mod, child)
function DynamicPPL.NodeTrait(::DynamicPPL.TestUtils.TestLogModifyingChildContext)
return DynamicPPL.IsParent()
end
function DynamicPPL.childcontext(context::DynamicPPL.TestUtils.TestLogModifyingChildContext)
return context.context
end
function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi)
function DynamicPPL.setchildcontext(
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, child
)
return DynamicPPL.TestUtils.TestLogModifyingChildContext(context.mod, child)
end
function DynamicPPL.tilde_assume(
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, vn, vi
)
value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
return value, logp * context.mod, vi
end
function DynamicPPL.dot_tilde_assume(
context::TestLogModifyingChildContext, right, left, vn, vi
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vn, vi
)
value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi)
return value, logp * context.mod, vi
end
function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
function DynamicPPL.tilde_observe(
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vi
)
logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
return logp * context.mod, vi
end
function DynamicPPL.dot_tilde_observe(
context::TestLogModifyingChildContext, right, left, vi
context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vi
)
logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi)
return logp * context.mod, vi
Expand Down
15 changes: 9 additions & 6 deletions src/test_utils/sampler.jl → ext/DynamicPPLTestExt/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

Return the mean of variable represented by `varname` in `chain`.
"""
marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)]))
DynamicPPL.TestUtils.marginal_mean_of_samples(chain, varname) =
mean(Array(chain[Symbol(varname)]))

"""
test_sampler(models, sampler, args...; kwargs...)
Expand All @@ -35,7 +36,7 @@ To change how comparison is done for a particular `chain` type, one can overload
- `rtol=1e-3`: Relative tolerance used in `@test`.
- `kwargs...`: Keyword arguments forwarded to `sample`.
"""
function test_sampler(
function DynamicPPL.TestUtils.test_sampler(
models,
sampler::AbstractMCMC.AbstractSampler,
args...;
Expand All @@ -51,7 +52,7 @@ function test_sampler(
for vn in filter(varnames_filter, varnames(model))
# We want to compare elementwise which can be achieved by
# extracting the leaves of the `VarName` and the corresponding value.
for vn_leaf in varname_leaves(vn, get(target_values, vn))
for vn_leaf in DynamicPPL.varname_leaves(vn, get(target_values, vn))
target_value = get(target_values, vn_leaf)
chain_mean_value = marginal_mean_of_samples(chain, vn_leaf)
@test chain_mean_value ≈ target_value atol = atol rtol = rtol
Expand All @@ -67,10 +68,10 @@ Test `sampler` on every model in [`DEMO_MODELS`](@ref).

This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`.
"""
function test_sampler_on_demo_models(
function DynamicPPL.TestUtils.test_sampler_on_demo_models(
sampler::AbstractMCMC.AbstractSampler, args...; kwargs...
)
return test_sampler(DEMO_MODELS, sampler, args...; kwargs...)
return test_sampler(DynamicPPL.TestUtils.DEMO_MODELS, sampler, args...; kwargs...)
end

"""
Expand All @@ -80,6 +81,8 @@ Test that `sampler` produces the correct marginal posterior means on all models

As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref).
"""
function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...)
function DynamicPPL.TestUtils.test_sampler_continuous(
sampler::AbstractMCMC.AbstractSampler, args...; kwargs...
)
return test_sampler_on_demo_models(sampler, args...; kwargs...)
end
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`.
"""
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...)
function DynamicPPL.TestUtils.test_values(
vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...
)
for vn in vns
@test compare(vi[vn], get(vals, vn); kwargs...)
end
Expand All @@ -23,7 +25,7 @@ each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` i
If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
of the varinfo instances.
"""
function setup_varinfos(
function DynamicPPL.TestUtils.setup_varinfos(
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
)
# VarInfo
Expand Down Expand Up @@ -58,7 +60,7 @@ function setup_varinfos(
svi_vnv_ref,
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
DynamicPPL.setlogp!!(DynamicPPL.update_values!!(vi, example_values, varnames), lp)
end

if include_threadsafe
Expand Down
32 changes: 23 additions & 9 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@
module TestUtils

using AbstractMCMC
using DynamicPPL
using LinearAlgebra
using Distributions
using Test

using Random: Random
using Bijectors: Bijectors
using Accessors: Accessors

# For backwards compat.
using DynamicPPL: varname_leaves, update_values!!

include("test_utils/model_interface.jl")
include("test_utils/models.jl")
include("test_utils/contexts.jl")
include("test_utils/varinfo.jl")
include("test_utils/sampler.jl")

##############################################################
# The remainder of this file contains skeleton implementations for
# DynamicPPLTestExt
##############################################################

function test_context_interface end

"""
Context that multiplies each log-prior by mod
used to test whether varwise_logpriors respects child-context.
"""
struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
mod::T
context::Ctx
end

function marginal_mean_of_samples end
function test_sampler end
function test_sampler_on_demo_models end
function test_sampler_continuous end
function test_values end
function setup_varinfos end

end
Loading