Skip to content

Commit

Permalink
Update test/pointwise_logdensities.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde authored Sep 26, 2024
1 parent 073a325 commit 23e1711
Showing 1 changed file with 36 additions and 7 deletions.
43 changes: 36 additions & 7 deletions test/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,50 @@ end
model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe()
# FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced
# an impl of this for containers.
# NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed.
vns = DynamicPPL.TestUtils.varnames(model)
# Get some random `NamedTuple` samples from the prior.
vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ = 1:5]
num_iters = 3
vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters]
# Concatenate the vector representations and create a `Chains` from it.
vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals)
chain = Chains(permutedims(vals_arr), map(Symbol, vns));
chain = Chains(permutedims(vals_arr), map(Symbol, vns))

# Compute the different pointwise logdensities.
logjoints_pointwise = pointwise_logdensities(model, chain)
logpriors_pointwise = pointwise_prior_logdensities(model, chain)
loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain)

# Check that they contain the correct variables.
@test all(string(vn) in keys(logjoints_pointwise) for vn in vns)
@test all(string(vn) in keys(logpriors_pointwise) for vn in vns)
@test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise))
@test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns)
@test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise))

# Get the sum of the logjoints for each of the iterations.
logjoints = [
sum(logjoints_pointwise[string(vn)][idx] for vn in vns)
for idx = 1:5
sum(logjoints_pointwise[vn][idx] for vn in keys(logjoints_pointwise)) for
idx in 1:num_iters
]
logpriors = [
sum(logpriors_pointwise[vn][idx] for vn in keys(logpriors_pointwise)) for
idx in 1:num_iters
]
for (val, logp) in zip(vals, logjoints)
loglikelihoods = [
sum(loglikelihoods_pointwise[vn][idx] for vn in keys(loglikelihoods_pointwise)) for
idx in 1:num_iters
]

for (val, logjoint, logprior, loglikelihood) in
zip(vals, logjoints, logpriors, loglikelihoods)
# Compare true logjoint with the one obtained from `pointwise_logdensities`.
logjoint_true = DynamicPPL.TestUtils.logjoint_true(model, val...)
@test logp logjoint_true
logprior_true = DynamicPPL.TestUtils.logprior_true(model, val...)
loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(model, val...)

@test logjoint logjoint_true
@test logprior logprior_true
@test loglikelihood loglikelihood_true
end
end;
end

0 comments on commit 23e1711

Please sign in to comment.