diff --git a/Project.toml b/Project.toml index ebc70b5ab..a0c9bd9cf 100644 --- a/Project.toml +++ b/Project.toml @@ -31,6 +31,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] @@ -39,6 +40,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLReverseDiffExt = ["ReverseDiff"] +DynamicPPLTestExt = ["Test"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -67,11 +69,3 @@ ReverseDiff = "1" Test = "1.6" ZygoteRules = "0.2" julia = "1.10" - -[extras] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/ext/DynamicPPLTestExt.jl b/ext/DynamicPPLTestExt.jl new file mode 100644 index 000000000..dc377687b --- /dev/null +++ b/ext/DynamicPPLTestExt.jl @@ -0,0 +1,8 @@ +module DynamicPPLTestExt + +using DynamicPPL: DynamicPPL +using Test: @test, @testset, @test_throws, @test_broken + +include("DynamicPPLTestExt/utils.jl") + +end diff --git a/src/test_utils.jl b/ext/DynamicPPLTestExt/utils.jl similarity index 90% rename from src/test_utils.jl rename to ext/DynamicPPLTestExt/utils.jl index 6199138aa..6c8a7dd0e 100644 --- a/src/test_utils.jl +++ b/ext/DynamicPPLTestExt/utils.jl @@ -1,4 +1,8 @@ -module TestUtils +module TestExtUtils + +################################################### +# These used to be in DPPL/src/test_utils.jl ###### +################################################### using AbstractMCMC using DynamicPPL @@ -1097,4 +1101,123 @@ function DynamicPPL.dot_tilde_observe( return logp * context.mod, vi end + + +################################################### +# These used to be in DPPL/test/test_util.jl ###### +################################################### + +# default model +@model function gdemo_d() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + 1.5 ~ Normal(m, sqrt(s)) + 2.0 ~ Normal(m, sqrt(s)) + return s, m +end +const gdemo_default = gdemo_d() + +function test_model_ad(model, logp_manual) + vi = VarInfo(model) + x = DynamicPPL.getall(vi) + + # Log probabilities using the model. + ℓ = DynamicPPL.LogDensityFunction(model, vi) + logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ) + + # Check that both functions return the same values. + lp = logp_manual(x) + @test logp_model(x) ≈ lp + + # Gradients based on the manual implementation. + grad = ForwardDiff.gradient(logp_manual, x) + + y, back = Tracker.forward(logp_manual, x) + @test Tracker.data(y) ≈ lp + @test Tracker.data(back(1)[1]) ≈ grad + + y, back = Zygote.pullback(logp_manual, x) + @test y ≈ lp + @test back(1)[1] ≈ grad + + # Gradients based on the model. + @test ForwardDiff.gradient(logp_model, x) ≈ grad + + y, back = Tracker.forward(logp_model, x) + @test Tracker.data(y) ≈ lp + @test Tracker.data(back(1)[1]) ≈ grad + + y, back = Zygote.pullback(logp_model, x) + @test y ≈ lp + @test back(1)[1] ≈ grad +end + +""" + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + +Test `setval!` on `model` and `chain`. + +Worth noting that this only supports models containing symbols of the forms +`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. +""" +function test_setval!(model, chain; sample_idx=1, chain_idx=1) + var_info = VarInfo(model) + spl = SampleFromPrior() + θ_old = var_info[spl] + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) + θ_new = var_info[spl] + @test θ_old != θ_new + vals = DynamicPPL.values_as(var_info, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + for (n, v) in mapreduce(collect, vcat, iters) + n = string(n) + if Symbol(n) ∉ keys(chain) + # Assume it's a group + chain_val = vec( + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] + ) + v_true = vec(v) + else + chain_val = chain[sample_idx, n, chain_idx] + v_true = v + end + + @test v_true == chain_val + end +end + +""" + short_varinfo_name(vi::AbstractVarInfo) + +Return string representing a short description of `vi`. +""" +short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = + "threadsafe($(short_varinfo_name(vi.varinfo)))" +function short_varinfo_name(vi::TypedVarInfo) + DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" + return "TypedVarInfo" +end +short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" +short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" +short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" +function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) + return "SimpleVarInfo{<:VarNamedVector}" +end + +# convenient functions for testing model.jl +# function to modify the representation of values based on their length +function modify_value_representation(nt::NamedTuple) + modified_nt = NamedTuple() + for (key, value) in zip(keys(nt), values(nt)) + if length(value) == 1 # Scalar value + modified_value = value[1] + else # Non-scalar value + modified_value = value + end + modified_nt = merge(modified_nt, (key => modified_value,)) + end + return modified_nt end + +end # module TestExtUtils diff --git a/test/ad.jl b/test/ad.jl index 6046cfda4..1be5d2fd8 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,11 +1,11 @@ @testset "AD: ForwardDiff and ReverseDiff" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(m.f)" for m in TU.DEMO_MODELS f = DynamicPPL.LogDensityFunction(m) - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) + rand_param_values = TU.rand_prior_true(m) + vns = TU.varnames(m) + varinfos = TU.setup_varinfos(m, rand_param_values, vns) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos f = DynamicPPL.LogDensityFunction(m, varinfo) # use ForwardDiff result as reference diff --git a/test/compat/ad.jl b/test/compat/ad.jl index f76ce6f6e..44364edc1 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -12,7 +12,7 @@ logpdf(dist, 2.0) end - test_model_ad(gdemo_default, logp_gdemo_default) + TU.test_model_ad(TU.gdemo_default, logp_gdemo_default) @model function wishart_ad() return v ~ Wishart(7, [1 0.5; 0.5 1]) @@ -24,7 +24,7 @@ return logpdf(dist, reshape(x, 2, 2)) end - test_model_ad(wishart_ad(), logp_wishart_ad) + TU.test_model_ad(wishart_ad(), logp_wishart_ad) end # https://github.com/TuringLang/Turing.jl/issues/1595 diff --git a/test/contexts.jl b/test/contexts.jl index 4ec9ff945..e3f145cd2 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -168,7 +168,7 @@ end # Let's check elementwise. for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) + TU.varname_leaves(vn_without_prefix, val) if getoptic(vn_child)(val) === missing @test contextual_isassumption(context, vn_child) else @@ -201,7 +201,7 @@ end vn_without_prefix = remove_prefix(vn) for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) + TU.varname_leaves(vn_without_prefix, val) # `vn_child` should be in `context`. @test hasconditioned_nested(context, vn_child) # Value should be the same as extracted above. @@ -216,7 +216,7 @@ end @testset "Evaluation" begin @testset "$context" for context in contexts # Just making sure that we can actually sample with each of the contexts. - @test (gdemo_default(SamplingContext(context)); true) + @test (TU.gdemo_default(SamplingContext(context)); true) end end @@ -258,7 +258,7 @@ end end @testset "FixedContext" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in TU.DEMO_MODELS retval = model() s, m = retval.s, retval.m diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 50bb5d4be..ebc7cb272 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,19 +1,19 @@ @testset "check_model" begin @testset "context interface" begin # HACK: Require a model to instantiate it, so let's just grab one. - model = first(DynamicPPL.TestUtils.DEMO_MODELS) + model = first(TU.DEMO_MODELS) context = DynamicPPL.DebugUtils.DebugContext(model) - DynamicPPL.TestUtils.test_context_interface(context) + TU.test_context_interface(context) end - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in TU.DEMO_MODELS issuccess, trace = check_model_and_trace(model) # These models should all work. @test issuccess # Check that the trace contains all the variables in the model. varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace) - for vn in DynamicPPL.TestUtils.varnames(model) + for vn in TU.varnames(model) @test vn in varnames_in_trace end @@ -156,7 +156,7 @@ end @testset "comparing multiple traces" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() + model = TU.demo_dynamic_constraint() issuccess_1, trace_1 = check_model_and_trace(model) issuccess_2, trace_2 = check_model_and_trace(model) @test issuccess_1 && issuccess_2 diff --git a/test/linking.jl b/test/linking.jl index d424a9c2d..383a85a67 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -75,8 +75,8 @@ end model = demo() example_values = rand(NamedTuple, model) - vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),)) - @testset "$(short_varinfo_name(vi))" for vi in vis + vis = TU.setup_varinfos(model, example_values, (@varname(m),)) + @testset "$(TU.short_varinfo_name(vi))" for vi in vis # Evaluate once to ensure we have `logp` value. vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) vi_linked = if mutable @@ -109,10 +109,10 @@ end model = demo_lkj(d) dist = LKJCholesky(d, 1.0, uplo) values_original = rand(NamedTuple, model) - vis = DynamicPPL.TestUtils.setup_varinfos( + vis = TU.setup_varinfos( model, values_original, (@varname(x),) ) - @testset "$(short_varinfo_name(vi))" for vi in vis + @testset "$(TU.short_varinfo_name(vi))" for vi in vis val = vi[@varname(x), dist] # Ensure that `reconstruct` works as intended. @test val isa Cholesky @@ -150,8 +150,8 @@ end @testset "d=$d" for d in [2, 3, 5] model = demo_dirichlet(d) example_values = rand(NamedTuple, model) - vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),)) - @testset "$(short_varinfo_name(vi))" for vi in vis + vis = TU.setup_varinfos(model, example_values, (@varname(x),)) + @testset "$(TU.short_varinfo_name(vi))" for vi in vis lp = logpdf(Dirichlet(d, 1.0), vi[:]) @test length(vi[:]) == d lp_model = logjoint(model, vi) @@ -189,8 +189,8 @@ end ] model = demo_highdim_dirichlet(ns...) example_values = rand(NamedTuple, model) - vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),)) - @testset "$(short_varinfo_name(vi))" for vi in vis + vis = TU.setup_varinfos(model, example_values, (@varname(x),)) + @testset "$(TU.short_varinfo_name(vi))" for vi in vis # Linked. vi_linked = if mutable DynamicPPL.link!!(deepcopy(vi), model) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index beda767e6..72423668b 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,8 +1,8 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff @testset "`getmodel` and `setmodel`" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.DEMO_MODELS[1] + @testset "$(nameof(model))" for model in TU.DEMO_MODELS + model = TU.DEMO_MODELS[1] ℓ = DynamicPPL.LogDensityFunction(model) @test DynamicPPL.getmodel(ℓ) == model @test DynamicPPL.setmodel(ℓ, model).model == model @@ -21,10 +21,10 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Rever end @testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + @testset "$(nameof(model))" for model in TU.DEMO_MODELS + example_values = TU.rand_prior_true(model) + vns = TU.varnames(model) + varinfos = TU.setup_varinfos(model, example_values, vns) @testset "$(varinfo)" for varinfo in varinfos logdensity = DynamicPPL.LogDensityFunction(model, varinfo) diff --git a/test/model.jl b/test/model.jl index d163f55f0..88fb0c085 100644 --- a/test/model.jl +++ b/test/model.jl @@ -31,7 +31,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @testset "model.jl" begin @testset "convenience functions" begin - model = gdemo_default # defined in test/test_util.jl + model = TU.gdemo_default # sample from model and extract variables vi = VarInfo(model) @@ -55,9 +55,9 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test ljoint ≈ lp #### logprior, logjoint, loglikelihood for MCMC chains #### - for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12 + for model in TU.DEMO_MODELS # length(TU.DEMO_MODELS)=12 var_info = VarInfo(model) - vns = DynamicPPL.TestUtils.varnames(model) + vns = TU.varnames(model) syms = unique(DynamicPPL.getsym.(vns)) # generate a chain of sample parameter values. @@ -113,20 +113,20 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true samples_dict[key] = existing_value end samples = (; samples_dict...) - samples = modify_value_representation(samples) # `modify_value_representation` defined in test/test_util.jl + samples = TU.modify_value_representation(samples) @test logpriors[i] ≈ - DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m]) - @test loglikelihoods[i] ≈ DynamicPPL.TestUtils.loglikelihood_true( + TU.logprior_true(model, samples[:s], samples[:m]) + @test loglikelihoods[i] ≈ TU.loglikelihood_true( model, samples[:s], samples[:m] ) @test logjoints[i] ≈ - DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m]) + TU.logjoint_true(model, samples[:s], samples[:m]) end end end @testset "rng" begin - model = gdemo_default + model = TU.gdemo_default for sampler in (SampleFromPrior(), SampleFromUniform()) for i in 1:10 @@ -144,7 +144,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "defaults without VarInfo, Sampler, and Context" begin - model = gdemo_default + model = TU.gdemo_default Random.seed!(100) s, m = model() @@ -184,7 +184,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "Internal methods" begin - model = gdemo_default + model = TU.gdemo_default # sample from model and extract variables vi = VarInfo(model) @@ -200,7 +200,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "Dynamic constraints, Metadata" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() + model = TU.demo_dynamic_constraint() spl = SampleFromPrior() vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata()) link!!(vi, spl, model) @@ -216,7 +216,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "Dynamic constraints, VectorVarInfo" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() + model = TU.demo_dynamic_constraint() for i in 1:10 vi = VarInfo(model) @test vi[@varname(x)] >= vi[@varname(m)] @@ -224,7 +224,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "rand" begin - model = gdemo_default + model = TU.gdemo_default Random.seed!(1776) s, m = model() @@ -256,7 +256,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "extract priors" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in TU.DEMO_MODELS priors = extract_priors(model) # We know that any variable starting with `s` should have `InverseGamma` @@ -274,45 +274,45 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "TestUtils" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - x = DynamicPPL.TestUtils.rand_prior_true(model) + @testset "$(model.f)" for model in TU.DEMO_MODELS + x = TU.rand_prior_true(model) # `rand_prior_true` should return a `NamedTuple`. @test x isa NamedTuple # `rand` with a `AbstractDict` should have `varnames` as keys. x_rand_dict = rand(OrderedDict, model) - for vn in DynamicPPL.TestUtils.varnames(model) + for vn in TU.varnames(model) @test haskey(x_rand_dict, vn) end # `rand` with a `NamedTuple` should have `map(Symbol, varnames)` as keys. x_rand_nt = rand(NamedTuple, model) - for vn in DynamicPPL.TestUtils.varnames(model) + for vn in TU.varnames(model) @test haskey(x_rand_nt, Symbol(vn)) end # Ensure log-probability computations are implemented. - @test logprior(model, x) ≈ DynamicPPL.TestUtils.logprior_true(model, x...) + @test logprior(model, x) ≈ TU.logprior_true(model, x...) @test loglikelihood(model, x) ≈ - DynamicPPL.TestUtils.loglikelihood_true(model, x...) - @test logjoint(model, x) ≈ DynamicPPL.TestUtils.logjoint_true(model, x...) + TU.loglikelihood_true(model, x...) + @test logjoint(model, x) ≈ TU.logjoint_true(model, x...) @test logjoint(model, x) != - DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) + TU.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. vi = last( DynamicPPL.evaluate!!( model, SimpleVarInfo(OrderedDict()), SamplingContext() ), ) - @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) + @test all(collect(keys(vi)) .== TU.varnames(model)) # Ensure `posterior_mean` is implemented. - @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) + @test TU.posterior_mean(model) isa typeof(x) end end @testset "generated_quantities on `LKJCholesky`" begin n = 10 d = 2 - model = DynamicPPL.TestUtils.demo_lkjchol(d) + model = TU.demo_lkjchol(d) xs = [model().x for _ in 1:n] # Extract varnames and values. @@ -361,17 +361,17 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true if VERSION >= v"1.8" @testset "Type stability of models" begin models_to_test = [ - DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) + TU.DEMO_MODELS..., TU.demo_lkjchol(2) ] context = DefaultContext() @testset "$(model.f)" for model in models_to_test - vns = DynamicPPL.TestUtils.varnames(model) - example_values = DynamicPPL.TestUtils.rand_prior_true(model) + vns = TU.varnames(model) + example_values = TU.rand_prior_true(model) varinfos = filter( is_typed_varinfo, - DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), + TU.setup_varinfos(model, example_values, vns), ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos @test begin @inferred(DynamicPPL.evaluate!!(model, varinfo, context)) true @@ -388,11 +388,11 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "values_as_in_model" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @testset "$(model.f)" for model in TU.DEMO_MODELS + vns = TU.varnames(model) + example_values = TU.rand_prior_true(model) + varinfos = TU.setup_varinfos(model, example_values, vns) + @testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos realizations = values_as_in_model(model, varinfo) # Ensure that all variables are found. vns_found = collect(keys(realizations)) @@ -431,7 +431,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true DynamicPPL.typed_simple_varinfo(model), DynamicPPL.untyped_simple_varinfo(model), ] - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos varinfo_linked = DynamicPPL.link(varinfo, model) varinfo_linked_result = last( DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 93b7c59be..b1b9fd37a 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,16 +1,16 @@ @testset "logdensities_likelihoods.jl" begin - mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) - mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) + mod_ctx = TU.TestLogModifyingChildContext(1.2) + mod_ctx2 = TU.TestLogModifyingChildContext(1.4, mod_ctx) + @testset "$(model.f)" for model in TU.DEMO_MODELS + example_values = TU.rand_prior_true(model) # Instantiate a `VarInfo` with the example values. vi = VarInfo(model) - for vn in DynamicPPL.TestUtils.varnames(model) + for vn in TU.varnames(model) vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true( + loglikelihood_true = TU.loglikelihood_true( model, example_values... ) logp_true = logprior(model, vi) @@ -49,14 +49,14 @@ end # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just # to ensure that we don't accidentally break the the version on `Chains`. - model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() + model = TU.demo_dot_assume_dot_observe() # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced # an impl of this for containers. # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. - vns = DynamicPPL.TestUtils.varnames(model) + vns = TU.varnames(model) # Get some random `NamedTuple` samples from the prior. num_iters = 3 - vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters] + vals = [TU.rand_prior_true(model) for _ in 1:num_iters] # Concatenate the vector representations and create a `Chains` from it. vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) chain = Chains(permutedims(vals_arr), map(Symbol, vns)) @@ -90,9 +90,9 @@ end for (val, logjoint, logprior, loglikelihood) in zip(vals, logjoints, logpriors, loglikelihoods) # Compare true logjoint with the one obtained from `pointwise_logdensities`. - logjoint_true = DynamicPPL.TestUtils.logjoint_true(model, val...) - logprior_true = DynamicPPL.TestUtils.logprior_true(model, val...) - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(model, val...) + logjoint_true = TU.logjoint_true(model, val...) + logprior_true = TU.logprior_true(model, val...) + loglikelihood_true = TU.loglikelihood_true(model, val...) @test logjoint ≈ logjoint_true @test logprior ≈ logprior_true diff --git a/test/runtests.jl b/test/runtests.jl index a832a0f08..da444f102 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ using LinearAlgebra # Diagonal using Combinatorics: combinations using DynamicPPL: getargs_dottilde, getargs_tilde, Selector +import DynamicPPLTestExt.TestExtUtils as TU const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing") diff --git a/test/sampler.jl b/test/sampler.jl index 95e838167..8315a2b82 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -33,8 +33,8 @@ end @testset "init" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in TU.DEMO_MODELS + @testset "$(model.f)" for model in TU.DEMO_MODELS N = 1000 chain_init = sample(model, SampleFromUniform(), N; progress=false) diff --git a/test/serialization.jl b/test/serialization.jl index a2d9abb36..98d0ae7e4 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -2,7 +2,7 @@ @testset "saving and loading" begin # Save model. file = joinpath(mktempdir(), "gdemo_default.jls") - serialize(file, gdemo_default) + serialize(file, TU.gdemo_default) # Sample from deserialized model. gdemo_default_copy = deserialize(file) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 4343563eb..9c07b34ae 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -86,15 +86,15 @@ end @testset "link!! & invlink!! on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS - values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) + TU.DEMO_MODELS + values_constrained = TU.rand_prior_true(model) @testset "$(typeof(vi))" for vi in ( SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), SimpleVarInfo(DynamicPPL.VarNamedVector()), VarInfo(model), ) - for vn in DynamicPPL.TestUtils.varnames(model) + for vn in TU.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) @@ -103,7 +103,7 @@ # `link!!` vi_linked = link!!(deepcopy(vi), model) lp_linked = getlogp(vi_linked) - values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + values_unconstrained, lp_linked_true = TU.logjoint_true_with_logabsdet_jacobian( model, values_constrained... ) # Should result in the correct logjoint. @@ -120,7 +120,7 @@ # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) lp_invlinked = getlogp(vi_invlinked) - lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( + lp_invlinked_true = TU.logjoint_true( model, values_constrained... ) # Should result in the correct logjoint. @@ -132,21 +132,21 @@ @test all( DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_invlinked, vn)) ≈ DynamicPPL.tovec(get(values_constrained, vn)) for - vn in DynamicPPL.TestUtils.varnames(model) + vn in TU.varnames(model) ) end end @testset "SimpleVarInfo on $(nameof(model))" for model in - DynamicPPL.TestUtils.DEMO_MODELS - model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix() + TU.DEMO_MODELS + model = TU.demo_dot_assume_matrix_dot_observe_matrix() # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. - svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) + svi_nt = SimpleVarInfo(TU.rand_prior_true(model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) vnv = DynamicPPL.VarNamedVector() - for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) + for (k, v) in pairs(TU.rand_prior_true(model)) vnv = push!!(vnv, VarName{k}() => v) end svi_vnv = SimpleVarInfo(vnv) @@ -168,7 +168,7 @@ _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) # Realization for `m` should be different wp. 1. - for vn in DynamicPPL.TestUtils.varnames(model) + for vn in TU.varnames(model) @test svi_new[vn] != get(retval, vn) end @@ -176,35 +176,35 @@ @test getlogp(svi_new) != 0 ### Evaluation ### - values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) + values_eval_constrained = TU.rand_prior_true(model) if DynamicPPL.istrans(svi) - _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( + _values_prior, logpri_true = TU.logprior_true_with_logabsdet_jacobian( model, values_eval_constrained... ) - values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + values_eval, logπ_true = TU.logjoint_true_with_logabsdet_jacobian( model, values_eval_constrained... ) # Make sure that these two computation paths provide the same # transformed values. @test values_eval == _values_prior else - logpri_true = DynamicPPL.TestUtils.logprior_true( + logpri_true = TU.logprior_true( model, values_eval_constrained... ) - logπ_true = DynamicPPL.TestUtils.logjoint_true( + logπ_true = TU.logjoint_true( model, values_eval_constrained... ) values_eval = values_eval_constrained end # No logabsdet-jacobian correction needed for the likelihood. - loglik_true = DynamicPPL.TestUtils.loglikelihood_true( + loglik_true = TU.loglikelihood_true( model, values_eval_constrained... ) # Update the realizations in `svi_new`. svi_eval = svi_new - for vn in DynamicPPL.TestUtils.varnames(model) + for vn in TU.varnames(model) svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end @@ -217,7 +217,7 @@ loglik = loglikelihood(model, svi_eval) # Values should not have changed. - for vn in DynamicPPL.TestUtils.varnames(model) + for vn in TU.varnames(model) @test svi_eval[vn] == get(values_eval, vn) end @@ -229,7 +229,7 @@ end @testset "Dynamic constraints" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() + model = TU.demo_dynamic_constraint() # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) @@ -247,12 +247,12 @@ @test retval.m == svi[@varname(m)] # `m` is unconstrained @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + retval_unconstrained, lp_true = TU.logjoint_true_with_logabsdet_jacobian( model, retval.m, retval.x ) # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) + for vn in TU.varnames(model) @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 end @@ -264,12 +264,12 @@ end @testset "Static transformation" begin - model = DynamicPPL.TestUtils.demo_static_transformation() + model = TU.demo_static_transformation() - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)] + varinfos = TU.setup_varinfos( + model, TU.rand_prior_true(model), [@varname(s), @varname(m)] ) - @testset "$(short_varinfo_name(vi))" for vi in varinfos + @testset "$(TU.short_varinfo_name(vi))" for vi in varinfos # Initialize varinfo and link. vi_linked = DynamicPPL.link!!(vi, model) @@ -304,7 +304,7 @@ @test vi_linked_result[@varname(m)] == retval.m # Compare to truth. - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + retval_unconstrained, lp_true = TU.logjoint_true_with_logabsdet_jacobian( model, retval.s, retval.m ) diff --git a/test/test_util.jl b/test/test_util.jl deleted file mode 100644 index f1325b729..000000000 --- a/test/test_util.jl +++ /dev/null @@ -1,112 +0,0 @@ -# default model -@model function gdemo_d() - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - 1.5 ~ Normal(m, sqrt(s)) - 2.0 ~ Normal(m, sqrt(s)) - return s, m -end -const gdemo_default = gdemo_d() - -function test_model_ad(model, logp_manual) - vi = VarInfo(model) - x = DynamicPPL.getall(vi) - - # Log probabilities using the model. - ℓ = DynamicPPL.LogDensityFunction(model, vi) - logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ) - - # Check that both functions return the same values. - lp = logp_manual(x) - @test logp_model(x) ≈ lp - - # Gradients based on the manual implementation. - grad = ForwardDiff.gradient(logp_manual, x) - - y, back = Tracker.forward(logp_manual, x) - @test Tracker.data(y) ≈ lp - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(logp_manual, x) - @test y ≈ lp - @test back(1)[1] ≈ grad - - # Gradients based on the model. - @test ForwardDiff.gradient(logp_model, x) ≈ grad - - y, back = Tracker.forward(logp_model, x) - @test Tracker.data(y) ≈ lp - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(logp_model, x) - @test y ≈ lp - @test back(1)[1] ≈ grad -end - -""" - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - -Test `setval!` on `model` and `chain`. - -Worth noting that this only supports models containing symbols of the forms -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. -""" -function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end -end - -""" - short_varinfo_name(vi::AbstractVarInfo) - -Return string representing a short description of `vi`. -""" -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = - "threadsafe($(short_varinfo_name(vi.varinfo)))" -function short_varinfo_name(vi::TypedVarInfo) - DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" - return "TypedVarInfo" -end -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) - return "SimpleVarInfo{<:VarNamedVector}" -end - -# convenient functions for testing model.jl -# function to modify the representation of values based on their length -function modify_value_representation(nt::NamedTuple) - modified_nt = NamedTuple() - for (key, value) in zip(keys(nt), values(nt)) - if length(value) == 1 # Scalar value - modified_value = value[1] - else # Non-scalar value - modified_value = value - end - modified_nt = merge(modified_nt, (key => modified_value,)) - end - return modified_nt -end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 72c439db8..47216c366 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -1,6 +1,6 @@ @testset "threadsafe.jl" begin @testset "constructor" begin - vi = VarInfo(gdemo_default) + vi = VarInfo(TU.gdemo_default) threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) @test threadsafe_vi.varinfo === vi @@ -11,7 +11,7 @@ # TODO: Add more tests of the public API @testset "API" begin - vi = VarInfo(gdemo_default) + vi = VarInfo(TU.gdemo_default) threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) lp = getlogp(vi) diff --git a/test/turing/compiler.jl b/test/turing/compiler.jl index 5c46ab777..3936894be 100644 --- a/test/turing/compiler.jl +++ b/test/turing/compiler.jl @@ -158,7 +158,7 @@ end @testset "sample" begin alg = Gibbs(HMC(0.2, 3, :m), PG(10, :s)) - chn = sample(gdemo_default, alg, 1000) + chn = sample(TU.gdemo_default, alg, 1000) end @testset "vectorization @." begin @model function vdemo1(x) diff --git a/test/turing/model.jl b/test/turing/model.jl index 599fba21b..98c3ee47d 100644 --- a/test/turing/model.jl +++ b/test/turing/model.jl @@ -1,17 +1,17 @@ @testset "model.jl" begin @testset "setval! & generated_quantities" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$model" for model in TU.DEMO_MODELS chain = sample(model, Prior(), 10) # A simple way of checking that the computation is determinstic: run twice and compare. res1 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) res2 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) @test all(res1 .== res2) - test_setval!(model, MCMCChains.get_sections(chain, :parameters)) + TU.test_setval!(model, MCMCChains.get_sections(chain, :parameters)) end end @testset "value_iterator_from_chain" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$model" for model in TU.DEMO_MODELS chain = sample(model, Prior(), 10; progress=false) for (i, d) in enumerate(value_iterator_from_chain(model, chain)) for vn in keys(d) diff --git a/test/varinfo.jl b/test/varinfo.jl index a2425ebc8..3cb24eb99 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -229,7 +229,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) model(vi_vnv, SampleFromPrior()) model_name = model == model_uv ? "univariate" : "multivariate" - @testset "$(model_name), $(short_varinfo_name(vi))" for vi in [ + @testset "$(model_name), $(TU.short_varinfo_name(vi))" for vi in [ vi_untyped, vi_typed, vi_vnv, vi_vnv_typed ] Random.seed!(23) @@ -395,17 +395,17 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end @testset "values_as" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) + @testset "$(nameof(model))" for model in TU.DEMO_MODELS + example_values = TU.rand_prior_true(model) + vns = TU.varnames(model) # Set up the different instances of `AbstractVarInfo` with the desired values. - varinfos = DynamicPPL.TestUtils.setup_varinfos( + varinfos = TU.setup_varinfos( model, example_values, vns; include_threadsafe=true ) - @testset "$(short_varinfo_name(vi))" for vi in varinfos + @testset "$(TU.short_varinfo_name(vi))" for vi in varinfos # Just making sure. - DynamicPPL.TestUtils.test_values(vi, example_values, vns) + TU.test_values(vi, example_values, vns) @testset "NamedTuple" begin vals = values_as(vi, NamedTuple) @@ -439,16 +439,16 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "unflatten + linking" begin @testset "Model: $(model.f)" for model in [ - DynamicPPL.TestUtils.demo_one_variable_multiple_constraints(), - DynamicPPL.TestUtils.demo_lkjchol(), + TU.demo_one_variable_multiple_constraints(), + TU.demo_lkjchol(), ] @testset "mutating=$mutating" for mutating in [false, true] - value_true = DynamicPPL.TestUtils.rand_prior_true(model) - varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( + value_true = TU.rand_prior_true(model) + varnames = TU.varnames(model) + varinfos = TU.setup_varinfos( model, value_true, varnames; include_threadsafe=true ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: this is broken since we'll end up trying to set # @@ -486,8 +486,8 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) ) @test length(varinfo_linked_unflattened[:]) == length(varinfo_linked[:]) - lp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) - value_linked_true, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + lp_true = TU.logjoint_true(model, value_true...) + value_linked_true, lp_linked_true = TU.logjoint_true_with_logabsdet_jacobian( model, value_true... ) @@ -526,7 +526,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] # `VarInfo` supports, effectively, arbitrary subsetting. - varinfos = DynamicPPL.TestUtils.setup_varinfos( + varinfos = TU.setup_varinfos( model, model(), vns; include_threadsafe=true ) varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) @@ -564,7 +564,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # in the model. vns_supported_simple = filter(∈(vns), vns_supported_standard) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos # All variables. check_varinfo_keys(varinfo, vns) @@ -624,7 +624,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # For certain varinfos we should have errors. # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `identity`. varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] - @testset "$(short_varinfo_name(varinfo)): failure cases" begin + @testset "$(TU.short_varinfo_name(varinfo)): failure cases" begin @test_throws ArgumentError subset( varinfo, [@varname(s), @varname(m), @varname(x[1])] ) @@ -632,15 +632,15 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end @testset "merge" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( + @testset "$(model.f)" for model in TU.DEMO_MODELS + vns = TU.varnames(model) + varinfos = TU.setup_varinfos( model, - DynamicPPL.TestUtils.rand_prior_true(model), + TU.rand_prior_true(model), vns; include_threadsafe=true, ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos @testset "with itself" begin # Merging itself should be a no-op. varinfo_merged = merge(varinfo, varinfo) @@ -678,13 +678,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end @testset "with different value" begin - x = DynamicPPL.TestUtils.rand_prior_true(model) - varinfo_changed = DynamicPPL.TestUtils.update_values!!( + x = TU.rand_prior_true(model) + varinfo_changed = TU.update_values!!( deepcopy(varinfo), x, vns ) # After `merge`, we should have the same values as `x`. varinfo_merged = merge(varinfo, varinfo_changed) - DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns) + TU.test_values(varinfo_merged, x, vns) end end end @@ -729,7 +729,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end @testset "VarInfo with selectors" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in TU.DEMO_MODELS varinfo = VarInfo( model, DynamicPPL.SampleFromPrior(), @@ -739,7 +739,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) selector = DynamicPPL.Selector() spl = Sampler(MySAlg(), model, selector) - vns = DynamicPPL.TestUtils.varnames(model) + vns = TU.varnames(model) vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns) for vn in vns_s diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index bd3f5553f..595bb93a8 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -578,29 +578,29 @@ end end @testset "VarInfo + VarNamedVector" begin - models = DynamicPPL.TestUtils.DEMO_MODELS + models = TU.DEMO_MODELS @testset "$(model.f)" for model in models # NOTE: Need to set random seed explicitly to avoid using the same seed # for initialization as for sampling in the inner testset below. Random.seed!(42) - value_true = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( + value_true = TU.rand_prior_true(model) + vns = TU.varnames(model) + varnames = TU.varnames(model) + varinfos = TU.setup_varinfos( model, value_true, varnames; include_threadsafe=false ) # Filter out those which are not based on `VarNamedVector`. varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) # Get the true log joint. - logp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) + logp_true = TU.logjoint_true(model, value_true...) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos # Need to make sure we're using a different random seed from the # one used in the above call to `rand_prior_true`. Random.seed!(43) # Are values correct? - DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) + TU.test_values(varinfo, value_true, vns) # Is evaluation correct? varinfo_eval = last( @@ -609,7 +609,7 @@ end # Log density should be the same. @test getlogp(varinfo_eval) ≈ logp_true # Values should be the same. - DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) + TU.test_values(varinfo_eval, value_true, vns) # Is sampling correct? varinfo_sample = last( @@ -618,7 +618,7 @@ end # Log density should be different. @test getlogp(varinfo_sample) != getlogp(varinfo) # Values should be different. - DynamicPPL.TestUtils.test_values( + TU.test_values( varinfo_sample, value_true, vns; compare=!isequal ) end