From 124996e530350d00cb93e9592a6419b02d1c8a86 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 27 Nov 2024 11:59:19 +0000 Subject: [PATCH] Move most of test_utils into TestExt --- Project.toml | 11 +----- ext/DynamicPPLTestExt.jl | 11 ++++++ .../DynamicPPLTestExt}/contexts.jl | 38 +++++++++---------- .../DynamicPPLTestExt}/sampler.jl | 12 +++--- .../DynamicPPLTestExt}/varinfo.jl | 6 +-- src/test_utils.jl | 33 +++++++++++----- 6 files changed, 65 insertions(+), 46 deletions(-) create mode 100644 ext/DynamicPPLTestExt.jl rename {src/test_utils => ext/DynamicPPLTestExt}/contexts.jl (55%) rename {src/test_utils => ext/DynamicPPLTestExt}/sampler.jl (85%) rename {src/test_utils => ext/DynamicPPLTestExt}/varinfo.jl (89%) diff --git a/Project.toml b/Project.toml index ebc70b5ab..fade1b2a9 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] @@ -31,6 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] @@ -39,6 +39,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLReverseDiffExt = ["ReverseDiff"] +DynamicPPLTestExt = ["Test"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -67,11 +68,3 @@ ReverseDiff = "1" Test = "1.6" ZygoteRules = "0.2" julia = "1.10" - -[extras] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/ext/DynamicPPLTestExt.jl b/ext/DynamicPPLTestExt.jl new file mode 100644 index 000000000..7a6e176eb --- /dev/null +++ b/ext/DynamicPPLTestExt.jl @@ -0,0 +1,11 @@ +module DynamicPPLTestExt + +using DynamicPPL +using AbstractMCMC +using Test + +include("DynamicPPLTestExt/contexts.jl") +include("DynamicPPLTestExt/varinfo.jl") +include("DynamicPPLTestExt/sampler.jl") + +end diff --git a/src/test_utils/contexts.jl b/ext/DynamicPPLTestExt/contexts.jl similarity index 55% rename from src/test_utils/contexts.jl rename to ext/DynamicPPLTestExt/contexts.jl index 89b0bb0d7..108af567c 100644 --- a/src/test_utils/contexts.jl +++ b/ext/DynamicPPLTestExt/contexts.jl @@ -8,7 +8,7 @@ Test that `context` implements the `AbstractContext` interface. """ -function test_context_interface(context) +function DynamicPPL.TestUtils.test_context_interface(context) # Is a subtype of `AbstractContext`. @test context isa DynamicPPL.AbstractContext # Should implement `NodeTrait.` @@ -21,41 +21,41 @@ function test_context_interface(context) end end -""" -Context that multiplies each log-prior by mod -used to test whether varwise_logpriors respects child-context. -""" -struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext - mod::T - context::Ctx -end -function TestLogModifyingChildContext( +function DynamicPPL.TestUtils.TestLogModifyingChildContext( mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext() ) - return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) + return DynamicPPL.TestUtils.TestLogModifyingChildContext{typeof(mod),typeof(context)}( + mod, context + ) end -DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context -function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) - return TestLogModifyingChildContext(context.mod, child) +function DynamicPPL.NodeTrait(::DynamicPPL.TestUtils.TestLogModifyingChildContext) + return DynamicPPL.IsParent() +end +function DynamicPPL.childcontext(context::DynamicPPL.TestUtils.TestLogModifyingChildContext) + return context.context +end +function DynamicPPL.setchildcontext( + context::DynamicPPL.TestUtils.TestLogModifyingChildContext, child +) + return DynamicPPL.TestUtils.TestLogModifyingChildContext(context.mod, child) end -function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) +function DynamicPPL.tilde_assume(context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, vn, vi) value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) return value, logp * context.mod, vi end function DynamicPPL.dot_tilde_assume( - context::TestLogModifyingChildContext, right, left, vn, vi + context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vn, vi ) value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) return value, logp * context.mod, vi end -function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) +function DynamicPPL.tilde_observe(context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vi) logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) return logp * context.mod, vi end function DynamicPPL.dot_tilde_observe( - context::TestLogModifyingChildContext, right, left, vi + context::DynamicPPL.TestUtils.TestLogModifyingChildContext, right, left, vi ) logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) return logp * context.mod, vi diff --git a/src/test_utils/sampler.jl b/ext/DynamicPPLTestExt/sampler.jl similarity index 85% rename from src/test_utils/sampler.jl rename to ext/DynamicPPLTestExt/sampler.jl index 71cdb1cac..f41b15360 100644 --- a/src/test_utils/sampler.jl +++ b/ext/DynamicPPLTestExt/sampler.jl @@ -8,7 +8,7 @@ Return the mean of variable represented by `varname` in `chain`. """ -marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)])) +DynamicPPL.TestUtils.marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)])) """ test_sampler(models, sampler, args...; kwargs...) @@ -35,7 +35,7 @@ To change how comparison is done for a particular `chain` type, one can overload - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. """ -function test_sampler( +function DynamicPPL.TestUtils.test_sampler( models, sampler::AbstractMCMC.AbstractSampler, args...; @@ -51,7 +51,7 @@ function test_sampler( for vn in filter(varnames_filter, varnames(model)) # We want to compare elementwise which can be achieved by # extracting the leaves of the `VarName` and the corresponding value. - for vn_leaf in varname_leaves(vn, get(target_values, vn)) + for vn_leaf in DynamicPPL.varname_leaves(vn, get(target_values, vn)) target_value = get(target_values, vn_leaf) chain_mean_value = marginal_mean_of_samples(chain, vn_leaf) @test chain_mean_value ≈ target_value atol = atol rtol = rtol @@ -67,10 +67,10 @@ Test `sampler` on every model in [`DEMO_MODELS`](@ref). This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. """ -function test_sampler_on_demo_models( +function DynamicPPL.TestUtils.test_sampler_on_demo_models( sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler(DEMO_MODELS, sampler, args...; kwargs...) + return test_sampler(DynamicPPL.TestUtils.DEMO_MODELS, sampler, args...; kwargs...) end """ @@ -80,6 +80,6 @@ Test that `sampler` produces the correct marginal posterior means on all models As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). """ -function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) +function DynamicPPL.TestUtils.test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) return test_sampler_on_demo_models(sampler, args...; kwargs...) end diff --git a/src/test_utils/varinfo.jl b/ext/DynamicPPLTestExt/varinfo.jl similarity index 89% rename from src/test_utils/varinfo.jl rename to ext/DynamicPPLTestExt/varinfo.jl index 6a655ded4..97c19ac74 100644 --- a/src/test_utils/varinfo.jl +++ b/ext/DynamicPPLTestExt/varinfo.jl @@ -8,7 +8,7 @@ Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`. """ -function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...) +function DynamicPPL.TestUtils.test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...) for vn in vns @test compare(vi[vn], get(vals, vn); kwargs...) end @@ -23,7 +23,7 @@ each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` i If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions of the varinfo instances. """ -function setup_varinfos( +function DynamicPPL.TestUtils.setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) # VarInfo @@ -58,7 +58,7 @@ function setup_varinfos( svi_vnv_ref, )) do vi # Set them all to the same values. - DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) + DynamicPPL.setlogp!!(DynamicPPL.update_values!!(vi, example_values, varnames), lp) end if include_threadsafe diff --git a/src/test_utils.jl b/src/test_utils.jl index c7d12c927..5200b2535 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1,22 +1,37 @@ module TestUtils -using AbstractMCMC using DynamicPPL using LinearAlgebra using Distributions -using Test using Random: Random using Bijectors: Bijectors -using Accessors: Accessors - -# For backwards compat. -using DynamicPPL: varname_leaves, update_values!! include("test_utils/model_interface.jl") include("test_utils/models.jl") -include("test_utils/contexts.jl") -include("test_utils/varinfo.jl") -include("test_utils/sampler.jl") + + +############################################################## +# The remainder of this file contains skeleton implementations for +# DynamicPPLTestExt +############################################################## + +function test_context_interface end + +""" +Context that multiplies each log-prior by mod +used to test whether varwise_logpriors respects child-context. +""" +struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext + mod::T + context::Ctx +end + +function marginal_mean_of_samples end +function test_sampler end +function test_sampler_on_demo_models end +function test_sampler_continuous end +function test_values end +function setup_varinfos end end