From e5b7e44959bb81e4bacc778dcc603b3d125edfcd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 5 Nov 2024 08:46:28 +0100 Subject: [PATCH] added proper testing for PrefixContext of all demo models --- test/contexts.jl | 51 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 2767bb1ab..0491fca3e 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -138,24 +138,43 @@ end end @testset "PrefixContext" begin - ctx = @inferred PrefixContext{:f}( - PrefixContext{:e}( - PrefixContext{:d}( - PrefixContext{:c}( - PrefixContext{:b}(PrefixContext{:a}(DefaultContext())) + @testset "prefixing" begin + ctx = @inferred PrefixContext{:f}( + PrefixContext{:e}( + PrefixContext{:d}( + PrefixContext{:c}( + PrefixContext{:b}(PrefixContext{:a}(DefaultContext())) + ), ), ), - ), - ) - vn = VarName{:x}() - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) - - vn = VarName{:x}(((1,),)) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getoptic(vn_prefixed) === getoptic(vn) + ) + vn = VarName{:x}() + vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) + @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") + @test getoptic(vn_prefixed) === getoptic(vn) + + vn = VarName{:x}(((1,),)) + vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) + @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") + @test getoptic(vn_prefixed) === getoptic(vn) + end + + context = DynamicPPL.PrefixContext{:prefix}(SamplingContext()) + @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + # Sample with the context. + varinfo = DynamicPPL.VarInfo() + DynamicPPL.evaluate!!(model, varinfo, context) + # Extract the resulting symbols. + vns_varinfo_syms = Set(map(DynamicPPL.getsym, keys(varinfo))) + + # Extract the ground truth symbols. + vns_syms = Set([ + Symbol("prefix", DynamicPPL.PREFIX_SEPARATOR, DynamicPPL.getsym(vn)) for vn in DynamicPPL.TestUtils.varnames(model) + ]) + + # Check that all variables are prefixed correctly. + @test vns_syms == vns_varinfo_syms + end end @testset "SamplingContext" begin