diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index e796c89ef..a422965ef 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -52,21 +52,50 @@ end model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced # an impl of this for containers. + # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. vns = DynamicPPL.TestUtils.varnames(model) # Get some random `NamedTuple` samples from the prior. - vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ = 1:5] + num_iters = 3 + vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters] # Concatenate the vector representations and create a `Chains` from it. vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) - chain = Chains(permutedims(vals_arr), map(Symbol, vns)); + chain = Chains(permutedims(vals_arr), map(Symbol, vns)) + + # Compute the different pointwise logdensities. logjoints_pointwise = pointwise_logdensities(model, chain) + logpriors_pointwise = pointwise_prior_logdensities(model, chain) + loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain) + + # Check that they contain the correct variables. + @test all(string(vn) in keys(logjoints_pointwise) for vn in vns) + @test all(string(vn) in keys(logpriors_pointwise) for vn in vns) + @test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise)) + @test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns) + @test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise)) + # Get the sum of the logjoints for each of the iterations. logjoints = [ - sum(logjoints_pointwise[string(vn)][idx] for vn in vns) - for idx = 1:5 + sum(logjoints_pointwise[vn][idx] for vn in keys(logjoints_pointwise)) for + idx in 1:num_iters + ] + logpriors = [ + sum(logpriors_pointwise[vn][idx] for vn in keys(logpriors_pointwise)) for + idx in 1:num_iters ] - for (val, logp) in zip(vals, logjoints) + loglikelihoods = [ + sum(loglikelihoods_pointwise[vn][idx] for vn in keys(loglikelihoods_pointwise)) for + idx in 1:num_iters + ] + + for (val, logjoint, logprior, loglikelihood) in + zip(vals, logjoints, logpriors, loglikelihoods) # Compare true logjoint with the one obtained from `pointwise_logdensities`. logjoint_true = DynamicPPL.TestUtils.logjoint_true(model, val...) - @test logp ≈ logjoint_true + logprior_true = DynamicPPL.TestUtils.logprior_true(model, val...) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(model, val...) + + @test logjoint ≈ logjoint_true + @test logprior ≈ logprior_true + @test loglikelihood ≈ loglikelihood_true end -end; +end