diff --git a/test/test_util.jl b/test/test_util.jl index c71d7b486..15701b78a 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual) @test back(1)[1] ≈ grad end -""" - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - -Test `setval!` on `model` and `chain`. - -Worth noting that this only supports models containing symbols of the forms -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. -""" -function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end -end - """ short_varinfo_name(vi::AbstractVarInfo) diff --git a/test/varinfo.jl b/test/varinfo.jl index c45fb47e0..cd4a2a75a 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -130,6 +130,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_base!!(SimpleVarInfo(Dict())) test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! @@ -187,6 +188,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) setgid!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) end + @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) @@ -339,6 +341,52 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test vals_prev == vi.metadata.x.vals end + @testset "setval! on chain" begin + # Define a helper function + """ + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + + Test `setval!` on `model` and `chain`. + + Worth noting that this only supports models containing symbols of the forms + `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. + """ + function test_setval!(model, chain; sample_idx=1, chain_idx=1) + var_info = VarInfo(model) + spl = SampleFromPrior() + θ_old = var_info[spl] + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) + θ_new = var_info[spl] + @test θ_old != θ_new + vals = DynamicPPL.values_as(var_info, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + for (n, v) in mapreduce(collect, vcat, iters) + n = string(n) + if Symbol(n) ∉ keys(chain) + # Assume it's a group + chain_val = vec( + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] + ) + v_true = vec(v) + else + chain_val = chain[sample_idx, n, chain_idx] + v_true = v + end + + @test v_true == chain_val + end + end + + @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = make_chain_from_prior(model, 10) + # A simple way of checking that the computation is determinstic: run twice and compare. + res1 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) + res2 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) + @test all(res1 .== res2) + test_setval!(model, MCMCChains.get_sections(chain, :parameters)) + end + end + @testset "istrans" begin @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) model = demo_constrained()