Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Add TestUtils submodule #313

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.14.1"
version = "0.14.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,6 @@ include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")

include("test_utils.jl")

end # module
208 changes: 208 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
module TestUtils

using AbstractMCMC
using DynamicPPL
using Distributions
using Test

# A collection of models for which the mean-of-means for the posterior should
# be same.
@model function demo_dot_assume_dot_observe(
x=10 * ones(2), ::Type{TV}=Vector{Float64}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
) where {TV}
# `dot_assume` and `observe`
m = TV(undef, length(x))
m .~ Normal()
x ~ MvNormal(m, 0.5 * ones(length(x)))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_assume_index_observe(
x=10 * ones(2), ::Type{TV}=Vector{Float64}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
) where {TV}
# `assume` with indexing and `observe`
m = TV(undef, length(x))
for i in eachindex(m)
m[i] ~ Normal()
end
x ~ MvNormal(m, 0.5 * ones(length(x)))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_assume_multivariate_observe_index(x=10 * ones(2))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# Multivariate `assume` and `observe`
m ~ MvNormal(length(x), 1.0)
x ~ MvNormal(m, 0.5 * ones(length(x)))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_dot_assume_observe_index(
x=10 * ones(2), ::Type{TV}=Vector{Float64}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
) where {TV}
# `dot_assume` and `observe` with indexing
m = TV(undef, length(x))
m .~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m[i], 0.5)
end

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

# Using vector of `length` 1 here so the posterior of `m` is the same
# as the others.
@model function demo_assume_dot_observe(x=10 * ones(1))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# `assume` and `dot_observe`
m ~ Normal()
x .~ Normal(m, 0.5)

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_assume_observe_literal()
# `assume` and literal `observe`
m ~ MvNormal(2, 1.0)
[10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__))
end

@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and literal `observe` with indexing
m = TV(undef, 2)
m .~ Normal()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end

return (; m=m, x=10 * ones(length(m)), logp=getlogp(__varinfo__))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

@model function demo_assume_literal_dot_observe()
# `assume` and literal `dot_observe`
m ~ Normal()
[10.0] .~ Normal(m, 0.5)

return (; m=m, x=[10.0], logp=getlogp(__varinfo__))
end

@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, 2)
m .~ Normal()

return m
end

@model function demo_assume_submodel_observe_index_literal()
# Submodel prior
m = @submodel _prior_dot_assume()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end

return (; m=m, x=[10.0], logp=getlogp(__varinfo__))
end

@model function _likelihood_dot_observe(m, x)
return x ~ MvNormal(m, 0.5 * ones(length(m)))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

@model function demo_dot_assume_observe_submodel(
x=10 * ones(2), ::Type{TV}=Vector{Float64}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Submodel likelihood
@submodel _likelihood_dot_observe(m, x)

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

@model function demo_dot_assume_dot_observe_matrix(
x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Dotted observe for `Matrix`.
x .~ MvNormal(m, 0.5)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

return (; m=m, x=x, logp=getlogp(__varinfo__))
end

const DEMO_MODELS = (
demo_dot_assume_dot_observe(),
demo_assume_index_observe(),
demo_assume_multivariate_observe_index(),
demo_dot_assume_observe_index(),
demo_assume_dot_observe(),
demo_assume_observe_literal(),
demo_dot_assume_observe_index_literal(),
demo_assume_literal_dot_observe(),
demo_assume_submodel_observe_index_literal(),
demo_dot_assume_observe_submodel(),
demo_dot_assume_dot_observe_matrix(),
)

# TODO: Is this really the best/most convenient "default" test method?
"""
test_sampler_demo_models(meanfunction, sampler, args...; kwargs...)

Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`.

In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the
`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target`
provided in `kwargs...`.

# Arguments
- `meanfunction`: A callable which computes the mean of the marginal means from the
chain resulting from the `sample` call.
- `sampler`: The `AbstractMCMC.AbstractSampler` to test.
- `args...`: Arguments forwarded to `sample`.

# Keyword arguments
- `target`: Value to compare result of `meanfunction(chain)` to.
- `atol=1e-1`: Absolute tolerance used in `@test`.
- `rtol=1e-3`: Relative tolerance used in `@test`.
- `kwargs...`: Keyword arguments forwarded to `sample`.
"""
function test_sampler_demo_models(
meanfunction,
sampler::AbstractMCMC.AbstractSampler,
args...;
target=8.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, not entirely certain what you're saying 😅 It's "random" in the sense that it just happened to be the number that you get from the default values of the models, but it's deliberate in the sense that it's the actual true mean of the posterior of the models with the default values:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's really the expected result, but only if you pass meanfunction = mean? I suspected it's either that or a random you put in for testing.

Copy link
Member

@phipsgabler phipsgabler Aug 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if IIUC, maybe that is better?

function test_sampler_demo_models(
    meanfunction,
    sampler::AbstractMCMC.AbstractSampler,
    args...;
    target,
    atol=1e-1,
    rtol=1e-3,
    kwargs...,
)
    ...
end
test_sampler_demo_models(::typeof(mean), args...; kwargs...,) = test_sampler_demo_models(mean, args...; target=8.0, kwargs...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not valid just for mean though. This is because different samplers can have completely different return-values from sample, and so we want to allow different mean functions, while still wanting the target to be 8.0.

atol=1e-1,
rtol=1e-3,
kwargs...,
)
@testset "$(nameof(typeof(sampler))) on $(m.name)" for model in DEMO_MODELS
chain = AbstractMCMC.sample(model, sampler, args...; kwargs...)
μ = meanfunction(chain)
@test μ ≈ target atol = atol rtol = rtol
end
end

"""
test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...)

Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`.

As of right now, this is just an alias for [`test_sampler_demo_models`](@ref).
"""
function test_sampler_continuous(
meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs...
)
return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...)
end

function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...)
# Default for `MCMCChains.Chains`.
return test_sampler_continuous(sampler, args...; kwargs...) do chain
mean(Array(chain))
end
end
phipsgabler marked this conversation as resolved.
Show resolved Hide resolved

end
113 changes: 1 addition & 112 deletions test/loglikelihoods.jl
Original file line number Diff line number Diff line change
@@ -1,116 +1,5 @@
# A collection of models for which the mean-of-means for the posterior should
# be same.
@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and `observe`
m = TV(undef, length(x))
m .~ Normal()
return x ~ MvNormal(m, 0.5)
end

@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
# `assume` with indexing and `observe`
m = TV(undef, length(x))
for i in eachindex(m)
m[i] ~ Normal()
end
return x ~ MvNormal(m, 0.5)
end

@model function gdemo3(x=10 * ones(2))
# Multivariate `assume` and `observe`
m ~ MvNormal(length(x), 1.0)
return x ~ MvNormal(m, 0.5)
end

@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and `observe` with indexing
m = TV(undef, length(x))
m .~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m[i], 0.5)
end
end

# Using vector of `length` 1 here so the posterior of `m` is the same
# as the others.
@model function gdemo5(x=10 * ones(1))
# `assume` and `dot_observe`
m ~ Normal()
return x .~ Normal(m, 0.5)
end

@model function gdemo6(::Type{TV}=Vector{Float64}) where {TV}
# `assume` and literal `observe`
m ~ MvNormal(2, 1.0)
return [10.0, 10.0] ~ MvNormal(m, 0.5)
end

@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and literal `observe` with indexing
m = TV(undef, 2)
m .~ Normal()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end
end

@model function gdemo8(::Type{TV}=Vector{Float64}) where {TV}
# `assume` and literal `dot_observe`
m ~ Normal()
return [10.0] .~ Normal(m, 0.5)
end

@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, 2)
m .~ Normal()

return m
end

@model function gdemo9()
# Submodel prior
m = @submodel _prior_dot_assume()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end
end

@model function _likelihood_dot_observe(m, x)
return x ~ MvNormal(m, 0.5)
end

@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Submodel likelihood
@submodel _likelihood_dot_observe(m, x)
end

@model function gdemo11(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Dotted observe for `Matrix`.
return x .~ MvNormal(m, 0.5)
end

const gdemo_models = (
gdemo1(),
gdemo2(),
gdemo3(),
gdemo4(),
gdemo5(),
gdemo6(),
gdemo7(),
gdemo8(),
gdemo9(),
gdemo10(),
gdemo11(),
)

@testset "loglikelihoods.jl" begin
for m in gdemo_models
for m in DynamicPPL.TestUtils.demo_models
vi = VarInfo(m)

vns = vi.metadata.m.vns
Expand Down