Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move src/test_utils.jl and test/test_util.jl to DynamicPPLTestExt #718

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
Expand All @@ -39,6 +40,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLTestExt = ["Test"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
Expand Down Expand Up @@ -67,11 +69,3 @@ ReverseDiff = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.10"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
8 changes: 8 additions & 0 deletions ext/DynamicPPLTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module DynamicPPLTestExt

using DynamicPPL: DynamicPPL
using Test: @test, @testset, @test_throws, @test_broken

include("DynamicPPLTestExt/utils.jl")

end
123 changes: 122 additions & 1 deletion src/test_utils.jl → ext/DynamicPPLTestExt/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
module TestUtils
module TestExtUtils

###################################################
# These used to be in DPPL/src/test_utils.jl ######
###################################################

using AbstractMCMC
using DynamicPPL
Expand Down Expand Up @@ -1097,4 +1101,121 @@ function DynamicPPL.dot_tilde_observe(
return logp * context.mod, vi
end

###################################################
# These used to be in DPPL/test/test_util.jl ######
###################################################

# default model
@model function gdemo_d()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
1.5 ~ Normal(m, sqrt(s))
2.0 ~ Normal(m, sqrt(s))
return s, m
end
const gdemo_default = gdemo_d()

function test_model_ad(model, logp_manual)
vi = VarInfo(model)
x = DynamicPPL.getall(vi)

# Log probabilities using the model.
= DynamicPPL.LogDensityFunction(model, vi)
logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ)

# Check that both functions return the same values.
lp = logp_manual(x)
@test logp_model(x) lp

# Gradients based on the manual implementation.
grad = ForwardDiff.gradient(logp_manual, x)

y, back = Tracker.forward(logp_manual, x)
@test Tracker.data(y) lp
@test Tracker.data(back(1)[1]) grad

y, back = Zygote.pullback(logp_manual, x)
@test y lp
@test back(1)[1] grad

# Gradients based on the model.
@test ForwardDiff.gradient(logp_model, x) grad

y, back = Tracker.forward(logp_model, x)
@test Tracker.data(y) lp
@test Tracker.data(back(1)[1]) grad

y, back = Zygote.pullback(logp_model, x)
@test y lp
@test back(1)[1] grad
end

"""
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
Test `setval!` on `model` and `chain`.
Worth noting that this only supports models containing symbols of the forms
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
"""
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
var_info = VarInfo(model)
spl = SampleFromPrior()
θ_old = var_info[spl]
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
θ_new = var_info[spl]
@test θ_old != θ_new
vals = DynamicPPL.values_as(var_info, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
for (n, v) in mapreduce(collect, vcat, iters)
n = string(n)
if Symbol(n) keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
end
end

"""
short_varinfo_name(vi::AbstractVarInfo)
Return string representing a short description of `vi`.
"""
short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) =
"threadsafe($(short_varinfo_name(vi.varinfo)))"
function short_varinfo_name(vi::TypedVarInfo)
DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector"
return "TypedVarInfo"
end
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo"
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"
function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector})
return "SimpleVarInfo{<:VarNamedVector}"
end

# convenient functions for testing model.jl
# function to modify the representation of values based on their length
function modify_value_representation(nt::NamedTuple)
modified_nt = NamedTuple()
for (key, value) in zip(keys(nt), values(nt))
if length(value) == 1 # Scalar value
modified_value = value[1]
else # Non-scalar value
modified_value = value
end
modified_nt = merge(modified_nt, (key => modified_value,))
end
return modified_nt
end

end # module TestExtUtils
28 changes: 0 additions & 28 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ include("context_implementations.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
Expand All @@ -196,33 +195,6 @@ include("values_as_in_model.jl")
include("debug_utils.jl")
using .DebugUtils

if !isdefined(Base, :get_extension)
using Requires
end

@static if !isdefined(Base, :get_extension)
function __init__()
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
"../ext/DynamicPPLChainRulesCoreExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/DynamicPPLForwardDiffExt.jl"
)
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/DynamicPPLReverseDiffExt.jl"
)
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
"../ext/DynamicPPLZygoteRulesExt.jl"
)
end
end

# Standard tag: Improves stacktraces
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct DynamicPPLTag end
Expand Down
10 changes: 5 additions & 5 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
@testset "AD: ForwardDiff and ReverseDiff" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(m.f)" for m in TU.DEMO_MODELS
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
vns = DynamicPPL.TestUtils.varnames(m)
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
rand_param_values = TU.rand_prior_true(m)
vns = TU.varnames(m)
varinfos = TU.setup_varinfos(m, rand_param_values, vns)

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
@testset "$(TU.short_varinfo_name(varinfo))" for varinfo in varinfos
f = DynamicPPL.LogDensityFunction(m, varinfo)

# use ForwardDiff result as reference
Expand Down
4 changes: 2 additions & 2 deletions test/compat/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logpdf(dist, 2.0)
end

test_model_ad(gdemo_default, logp_gdemo_default)
TU.test_model_ad(TU.gdemo_default, logp_gdemo_default)

@model function wishart_ad()
return v ~ Wishart(7, [1 0.5; 0.5 1])
Expand All @@ -24,7 +24,7 @@
return logpdf(dist, reshape(x, 2, 2))
end

test_model_ad(wishart_ad(), logp_wishart_ad)
TU.test_model_ad(wishart_ad(), logp_wishart_ad)
end

# https://github.com/TuringLang/Turing.jl/issues/1595
Expand Down
10 changes: 4 additions & 6 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ end
vn_without_prefix = remove_prefix(vn)

# Let's check elementwise.
for vn_child in
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
for vn_child in TU.varname_leaves(vn_without_prefix, val)
if getoptic(vn_child)(val) === missing
@test contextual_isassumption(context, vn_child)
else
Expand Down Expand Up @@ -200,8 +199,7 @@ end
# `ConditionContext` with the conditioned variable.
vn_without_prefix = remove_prefix(vn)

for vn_child in
DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val)
for vn_child in TU.varname_leaves(vn_without_prefix, val)
# `vn_child` should be in `context`.
@test hasconditioned_nested(context, vn_child)
# Value should be the same as extracted above.
Expand All @@ -216,7 +214,7 @@ end
@testset "Evaluation" begin
@testset "$context" for context in contexts
# Just making sure that we can actually sample with each of the contexts.
@test (gdemo_default(SamplingContext(context)); true)
@test (TU.gdemo_default(SamplingContext(context)); true)
end
end

Expand Down Expand Up @@ -258,7 +256,7 @@ end
end

@testset "FixedContext" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in TU.DEMO_MODELS
retval = model()
s, m = retval.s, retval.m

Expand Down
10 changes: 5 additions & 5 deletions test/debug_utils.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
@testset "check_model" begin
@testset "context interface" begin
# HACK: Require a model to instantiate it, so let's just grab one.
model = first(DynamicPPL.TestUtils.DEMO_MODELS)
model = first(TU.DEMO_MODELS)
context = DynamicPPL.DebugUtils.DebugContext(model)
DynamicPPL.TestUtils.test_context_interface(context)
TU.test_context_interface(context)
end

@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in TU.DEMO_MODELS
issuccess, trace = check_model_and_trace(model)
# These models should all work.
@test issuccess

# Check that the trace contains all the variables in the model.
varnames_in_trace = DynamicPPL.DebugUtils.varnames_in_trace(trace)
for vn in DynamicPPL.TestUtils.varnames(model)
for vn in TU.varnames(model)
@test vn in varnames_in_trace
end

Expand Down Expand Up @@ -156,7 +156,7 @@
end

@testset "comparing multiple traces" begin
model = DynamicPPL.TestUtils.demo_dynamic_constraint()
model = TU.demo_dynamic_constraint()
issuccess_1, trace_1 = check_model_and_trace(model)
issuccess_2, trace_2 = check_model_and_trace(model)
@test issuccess_1 && issuccess_2
Expand Down
18 changes: 8 additions & 10 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ end
model = demo()

example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),))
@testset "$(short_varinfo_name(vi))" for vi in vis
vis = TU.setup_varinfos(model, example_values, (@varname(m),))
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
# Evaluate once to ensure we have `logp` value.
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
vi_linked = if mutable
Expand Down Expand Up @@ -109,10 +109,8 @@ end
model = demo_lkj(d)
dist = LKJCholesky(d, 1.0, uplo)
values_original = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(
model, values_original, (@varname(x),)
)
@testset "$(short_varinfo_name(vi))" for vi in vis
vis = TU.setup_varinfos(model, values_original, (@varname(x),))
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
val = vi[@varname(x), dist]
# Ensure that `reconstruct` works as intended.
@test val isa Cholesky
Expand Down Expand Up @@ -150,8 +148,8 @@ end
@testset "d=$d" for d in [2, 3, 5]
model = demo_dirichlet(d)
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
lp = logpdf(Dirichlet(d, 1.0), vi[:])
@test length(vi[:]) == d
lp_model = logjoint(model, vi)
Expand Down Expand Up @@ -189,8 +187,8 @@ end
]
model = demo_highdim_dirichlet(ns...)
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
vis = TU.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(TU.short_varinfo_name(vi))" for vi in vis
# Linked.
vi_linked = if mutable
DynamicPPL.link!!(deepcopy(vi), model)
Expand Down
12 changes: 6 additions & 6 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff

@testset "`getmodel` and `setmodel`" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
model = TU.DEMO_MODELS[1]
= DynamicPPL.LogDensityFunction(model)
@test DynamicPPL.getmodel(ℓ) == model
@test DynamicPPL.setmodel(ℓ, model).model == model
Expand All @@ -21,10 +21,10 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, Rever
end

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
vns = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@testset "$(nameof(model))" for model in TU.DEMO_MODELS
example_values = TU.rand_prior_true(model)
vns = TU.varnames(model)
varinfos = TU.setup_varinfos(model, example_values, vns)

@testset "$(varinfo)" for varinfo in varinfos
logdensity = DynamicPPL.LogDensityFunction(model, varinfo)
Expand Down
Loading
Loading