Skip to content

Commit

Permalink
Add TestUtils submodule (#313)
Browse files Browse the repository at this point in the history
This PR adds a `DynamicPPL.TestUtils` submodule which is meant to include functionality to make it easy to test new samplers, new implementations of `AbstractVarInfo`, etc.

As of right now, this is mainly just a collection of models with equivalent marginal posteriors using the different features of DPPL, e.g. some are using `.~`, some are using `@submodel`, etc.

Eventually this should be expanded to be of more use, but more immediately this will be useful to test functionality in open PRs, e.g. #269, #309, #295, #292.

These models are also already used in Turing.jl's test-suite (https://github.com/TuringLang/Turing.jl/blob/9f52d75c25390b68115624b2e6cf464275a88137/test/test_utils/models.jl#L55-L56), so this PR would avoid the code-duplication + make it easier to keep things up-to-date.
  • Loading branch information
torfjelde committed Sep 8, 2021
1 parent 86afffa commit c5a63b5
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 113 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.15.1"
version = "0.15.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,6 @@ include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")

include("test_utils.jl")

end # module
208 changes: 208 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
module TestUtils

using AbstractMCMC
using DynamicPPL
using Distributions
using Test

# A collection of models for which the mean-of-means for the posterior should
# be same.
@model function demo_dot_assume_dot_observe(
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
) where {TV}
# `dot_assume` and `observe`
m = TV(undef, length(x))
m .~ Normal()
x ~ MvNormal(m, 0.25 * I)
return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_assume_index_observe(
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
) where {TV}
# `assume` with indexing and `observe`
m = TV(undef, length(x))
for i in eachindex(m)
m[i] ~ Normal()
end
x ~ MvNormal(m, 0.25 * I)

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_assume_multivariate_observe_index(x=[10.0, 10.0])
# Multivariate `assume` and `observe`
m ~ MvNormal(zero(x), I)
x ~ MvNormal(m, 0.25 * I)

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_dot_assume_observe_index(
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
) where {TV}
# `dot_assume` and `observe` with indexing
m = TV(undef, length(x))
m .~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m[i], 0.5)
end

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

# Using vector of `length` 1 here so the posterior of `m` is the same
# as the others.
@model function demo_assume_dot_observe(x=[10.0])
# `assume` and `dot_observe`
m ~ Normal()
x .~ Normal(m, 0.5)

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_assume_observe_literal()
# `assume` and literal `observe`
m ~ MvNormal(zeros(2), I)
[10.0, 10.0] ~ MvNormal(m, 0.25 * I)

return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__))
end

@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and literal `observe` with indexing
m = TV(undef, 2)
m .~ Normal()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end

return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__))
end

@model function demo_assume_literal_dot_observe()
# `assume` and literal `dot_observe`
m ~ Normal()
[10.0] .~ Normal(m, 0.5)

return (; m=m, x=[10.0], logp=getlogp(__varinfo__))
end

@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, 2)
m .~ Normal()

return m
end

@model function demo_assume_submodel_observe_index_literal()
# Submodel prior
m = @submodel _prior_dot_assume()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end

return (; m=m, x=[10.0], logp=getlogp(__varinfo__))
end

@model function _likelihood_dot_observe(m, x)
return x ~ MvNormal(m, 0.25 * I)
end

@model function demo_dot_assume_observe_submodel(
x=[10.0, 10.0], ::Type{TV}=Vector{Float64}
) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Submodel likelihood
@submodel _likelihood_dot_observe(m, x)

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_dot_assume_dot_observe_matrix(
x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64}
) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Dotted observe for `Matrix`.
x .~ MvNormal(m, 0.25 * I)

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

const DEMO_MODELS = (
demo_dot_assume_dot_observe(),
demo_assume_index_observe(),
demo_assume_multivariate_observe_index(),
demo_dot_assume_observe_index(),
demo_assume_dot_observe(),
demo_assume_observe_literal(),
demo_dot_assume_observe_index_literal(),
demo_assume_literal_dot_observe(),
demo_assume_submodel_observe_index_literal(),
demo_dot_assume_observe_submodel(),
demo_dot_assume_dot_observe_matrix(),
)

# TODO: Is this really the best/most convenient "default" test method?
"""
test_sampler_demo_models(meanfunction, sampler, args...; kwargs...)
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`.
In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the
`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target`
provided in `kwargs...`.
# Arguments
- `meanfunction`: A callable which computes the mean of the marginal means from the
chain resulting from the `sample` call.
- `sampler`: The `AbstractMCMC.AbstractSampler` to test.
- `args...`: Arguments forwarded to `sample`.
# Keyword arguments
- `target`: Value to compare result of `meanfunction(chain)` to.
- `atol=1e-1`: Absolute tolerance used in `@test`.
- `rtol=1e-3`: Relative tolerance used in `@test`.
- `kwargs...`: Keyword arguments forwarded to `sample`.
"""
function test_sampler_demo_models(
meanfunction,
sampler::AbstractMCMC.AbstractSampler,
args...;
target=8.0,
atol=1e-1,
rtol=1e-3,
kwargs...,
)
@testset "$(nameof(typeof(sampler))) on $(m.name)" for model in DEMO_MODELS
chain = AbstractMCMC.sample(model, sampler, args...; kwargs...)
μ = meanfunction(chain)
@test μ target atol = atol rtol = rtol
end
end

"""
test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...)
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`.
As of right now, this is just an alias for [`test_sampler_demo_models`](@ref).
"""
function test_sampler_continuous(
meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs...
)
return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...)
end

function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...)
# Default for `MCMCChains.Chains`.
return test_sampler_continuous(sampler, args...; kwargs...) do chain
mean(Array(chain))
end
end

end
113 changes: 1 addition & 112 deletions test/loglikelihoods.jl
Original file line number Diff line number Diff line change
@@ -1,116 +1,5 @@
# A collection of models for which the mean-of-means for the posterior should
# be same.
@model function gdemo1(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and `observe`
m = TV(undef, length(x))
m .~ Normal()
return x ~ MvNormal(m, 0.25 * I)
end

@model function gdemo2(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
# `assume` with indexing and `observe`
m = TV(undef, length(x))
for i in eachindex(m)
m[i] ~ Normal()
end
return x ~ MvNormal(m, 0.25 * I)
end

@model function gdemo3(x=[10.0, 10.0])
# Multivariate `assume` and `observe`
m ~ MvNormal(zero(x), I)
return x ~ MvNormal(m, 0.25 * I)
end

@model function gdemo4(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and `observe` with indexing
m = TV(undef, length(x))
m .~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m[i], 0.5)
end
end

# Using vector of `length` 1 here so the posterior of `m` is the same
# as the others.
@model function gdemo5(x=[10.0])
# `assume` and `dot_observe`
m ~ Normal()
return x .~ Normal(m, 0.5)
end

@model function gdemo6(::Type{TV}=Vector{Float64}) where {TV}
# `assume` and literal `observe`
m ~ MvNormal(zeros(2), I)
return [10.0, 10.0] ~ MvNormal(m, 0.25 * I)
end

@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and literal `observe` with indexing
m = TV(undef, 2)
m .~ Normal()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end
end

@model function gdemo8(::Type{TV}=Vector{Float64}) where {TV}
# `assume` and literal `dot_observe`
m ~ Normal()
return [10.0] .~ Normal(m, 0.5)
end

@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, 2)
m .~ Normal()

return m
end

@model function gdemo9()
# Submodel prior
m = @submodel _prior_dot_assume()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end
end

@model function _likelihood_dot_observe(m, x)
return x ~ MvNormal(m, 0.25 * I)
end

@model function gdemo10(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Submodel likelihood
@submodel _likelihood_dot_observe(m, x)
end

@model function gdemo11(x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Dotted observe for `Matrix`.
return x .~ MvNormal(m, 0.25 * I)
end

const gdemo_models = (
gdemo1(),
gdemo2(),
gdemo3(),
gdemo4(),
gdemo5(),
gdemo6(),
gdemo7(),
gdemo8(),
gdemo9(),
gdemo10(),
gdemo11(),
)

@testset "loglikelihoods.jl" begin
for m in gdemo_models
for m in DynamicPPL.TestUtils.demo_models
vi = VarInfo(m)

vns = vi.metadata.m.vns
Expand Down

0 comments on commit c5a63b5

Please sign in to comment.