Skip to content

Commit

Permalink
Move src/test_utils and test/test_util to DynamicPPLTestExt
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Nov 20, 2024
1 parent f5890a1 commit dcd24e7
Show file tree
Hide file tree
Showing 21 changed files with 287 additions and 273 deletions.
10 changes: 2 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
Expand All @@ -39,6 +40,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLTestExt = ["Test"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
Expand Down Expand Up @@ -67,11 +69,3 @@ ReverseDiff = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.10"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
8 changes: 8 additions & 0 deletions ext/DynamicPPLTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module DynamicPPLTestExt

using DynamicPPL: DynamicPPL
using Test: @test, @testset, @test_throws, @test_broken

include("DynamicPPLTestExt/utils.jl")

end
125 changes: 124 additions & 1 deletion src/test_utils.jl → ext/DynamicPPLTestExt/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
module TestUtils
module TestExtUtils

###################################################
# These used to be in DPPL/src/test_utils.jl ######
###################################################

using AbstractMCMC
using DynamicPPL
Expand Down Expand Up @@ -1097,4 +1101,123 @@ function DynamicPPL.dot_tilde_observe(
return logp * context.mod, vi
end



###################################################
# These used to be in DPPL/test/test_util.jl ######
###################################################

# default model
@model function gdemo_d()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
1.5 ~ Normal(m, sqrt(s))
2.0 ~ Normal(m, sqrt(s))
return s, m
end
const gdemo_default = gdemo_d()

function test_model_ad(model, logp_manual)
vi = VarInfo(model)
x = DynamicPPL.getall(vi)

# Log probabilities using the model.
= DynamicPPL.LogDensityFunction(model, vi)
logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ)

# Check that both functions return the same values.
lp = logp_manual(x)
@test logp_model(x) lp

# Gradients based on the manual implementation.
grad = ForwardDiff.gradient(logp_manual, x)

y, back = Tracker.forward(logp_manual, x)
@test Tracker.data(y) lp
@test Tracker.data(back(1)[1]) grad

y, back = Zygote.pullback(logp_manual, x)
@test y lp
@test back(1)[1] grad

# Gradients based on the model.
@test ForwardDiff.gradient(logp_model, x) grad

y, back = Tracker.forward(logp_model, x)
@test Tracker.data(y) lp
@test Tracker.data(back(1)[1]) grad

y, back = Zygote.pullback(logp_model, x)
@test y lp
@test back(1)[1] grad
end

"""
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
Test `setval!` on `model` and `chain`.
Worth noting that this only supports models containing symbols of the forms
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
"""
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
var_info = VarInfo(model)
spl = SampleFromPrior()
θ_old = var_info[spl]
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
θ_new = var_info[spl]
@test θ_old != θ_new
vals = DynamicPPL.values_as(var_info, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
for (n, v) in mapreduce(collect, vcat, iters)
n = string(n)
if Symbol(n) keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
end
end

"""
short_varinfo_name(vi::AbstractVarInfo)
Return string representing a short description of `vi`.
"""
short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) =
"threadsafe($(short_varinfo_name(vi.varinfo)))"
function short_varinfo_name(vi::TypedVarInfo)
DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector"
return "TypedVarInfo"
end
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo"
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"
function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector})
return "SimpleVarInfo{<:VarNamedVector}"
end

# convenient functions for testing model.jl
# function to modify the representation of values based on their length
function modify_value_representation(nt::NamedTuple)
modified_nt = NamedTuple()
for (key, value) in zip(keys(nt), values(nt))
if length(value) == 1 # Scalar value
modified_value = value[1]
else # Non-scalar value
modified_value = value
end
modified_nt = merge(modified_nt, (key => modified_value,))
end
return modified_nt
end

end # module TestExtUtils
10 changes: 5 additions & 5 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
@testset "AD: ForwardDiff and ReverseDiff" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(m.f)" for m in TU.DEMO_MODELS
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
vns = DynamicPPL.TestUtils.varnames(m)
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
rand_param_values = TU.rand_prior_true(m)
vns = TU.varnames(m)
varinfos = TU.setup_varinfos(m, rand_param_values, vns)

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
@testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos
f = DynamicPPL.LogDensityFunction(m, varinfo)

# use ForwardDiff result as reference
Expand Down
4 changes: 2 additions & 2 deletions test/compat/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logpdf(dist, 2.0)
end

test_model_ad(gdemo_default, logp_gdemo_default)
TU.test_model_ad(TU.gdemo_default, logp_gdemo_default)

@model function wishart_ad()
return v ~ Wishart(7, [1 0.5; 0.5 1])
Expand All @@ -24,7 +24,7 @@
return logpdf(dist, reshape(x, 2, 2))
end

test_model_ad(wishart_ad(), logp_wishart_ad)
TU.test_model_ad(wishart_ad(), logp_wishart_ad)
end

# https://github.com/TuringLang/Turing.jl/issues/1595
Expand Down
8 changes: 4 additions & 4 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ end

# Let's check elementwise.
for vn_child in
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
TU.varname_leaves(vn_without_prefix, val)
if getoptic(vn_child)(val) === missing
@test contextual_isassumption(context, vn_child)
else
Expand Down Expand Up @@ -201,7 +201,7 @@ end
vn_without_prefix = remove_prefix(vn)

for vn_child in
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
TU.varname_leaves(vn_without_prefix, val)
# `vn_child` should be in `context`.
@test hasconditioned_nested(context, vn_child)
# Value should be the same as extracted above.
Expand All @@ -216,7 +216,7 @@ end
@testset "Evaluation" begin
@testset "$context" for context in contexts
# Just making sure that we can actually sample with each of the contexts.
@test (gdemo_default(SamplingContext(context)); true)
@test (TU.gdemo_default(SamplingContext(context)); true)
end
end

Expand Down Expand Up @@ -258,7 +258,7 @@ end
end

@testset "FixedContext" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in TU.DEMO_MODELS
retval = model()
s, m = retval.s, retval.m

Expand Down
10 changes: 5 additions & 5 deletions test/debug_utils.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
@testset "check_model" begin
@testset "context interface" begin
# HACK: Require a model to instantiate it, so let's just grab one.
model = first(DynamicPPL.TestUtils.DEMO_MODELS)
model = first(TU.DEMO_MODELS)
context = DynamicPPL.DebugUtils.DebugContext(model)
DynamicPPL.TestUtils.test_context_interface(context)
TU.test_context_interface(context)
end

@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in TU.DEMO_MODELS
issuccess, trace = check_model_and_trace(model)
# These models should all work.
@test issuccess

# Check that the trace contains all the variables in the model.
varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace)
for vn in DynamicPPL.TestUtils.varnames(model)
for vn in TU.varnames(model)
@test vn in varnames_in_trace
end

Expand Down Expand Up @@ -156,7 +156,7 @@
end

@testset "comparing multiple traces" begin
model = DynamicPPL.TestUtils.demo_dynamic_constraint()
model = TU.demo_dynamic_constraint()
issuccess_1, trace_1 = check_model_and_trace(model)
issuccess_2, trace_2 = check_model_and_trace(model)
@test issuccess_1 && issuccess_2
Expand Down
16 changes: 8 additions & 8 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ end
model = demo()

example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),))
@testset "$(short_varinfo_name(vi))" for vi in vis
vis = TU.setup_varinfos(model, example_values, (@varname(m),))
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
# Evaluate once to ensure we have `logp` value.
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
vi_linked = if mutable
Expand Down Expand Up @@ -109,10 +109,10 @@ end
model = demo_lkj(d)
dist = LKJCholesky(d, 1.0, uplo)
values_original = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(
vis = TU.setup_varinfos(
model, values_original, (@varname(x),)
)
@testset "$(short_varinfo_name(vi))" for vi in vis
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
val = vi[@varname(x), dist]
# Ensure that `reconstruct` works as intended.
@test val isa Cholesky
Expand Down Expand Up @@ -150,8 +150,8 @@ end
@testset "d=$d" for d in [2, 3, 5]
model = demo_dirichlet(d)
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
lp = logpdf(Dirichlet(d, 1.0), vi[:])
@test length(vi[:]) == d
lp_model = logjoint(model, vi)
Expand Down Expand Up @@ -189,8 +189,8 @@ end
]
model = demo_highdim_dirichlet(ns...)
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
# Linked.
vi_linked = if mutable
DynamicPPL.link!!(deepcopy(vi), model)
Expand Down
12 changes: 6 additions & 6 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff

@testset "`getmodel` and `setmodel`" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
model = TU.DEMO_MODELS[1]
= DynamicPPL.LogDensityFunction(model)
@test DynamicPPL.getmodel(ℓ) == model
@test DynamicPPL.setmodel(ℓ, model).model == model
Expand All @@ -21,10 +21,10 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Rever
end

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
vns = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
example_values = TU.rand_prior_true(model)
vns = TU.varnames(model)
varinfos = TU.setup_varinfos(model, example_values, vns)

@testset "$(varinfo)" for varinfo in varinfos
logdensity = DynamicPPL.LogDensityFunction(model, varinfo)
Expand Down
Loading

0 comments on commit dcd24e7

Please sign in to comment.