-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Merged by Bors] - Add TestUtils submodule #313
Changes from 10 commits
a5280ea
dd7d774
caffa78
c8d6b1f
7b58174
a1a9752
eb92fd1
fc815d0
b32eb59
b36c842
e0adc2b
3b71318
0d3ff26
334e8da
4623798
d86ab1d
0f6f784
857b99b
9199004
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
module TestUtils | ||
|
||
using AbstractMCMC | ||
using DynamicPPL | ||
using Distributions | ||
using Test | ||
|
||
# A collection of models for which the mean-of-means for the posterior should | ||
# be same. | ||
@model function demo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} | ||
# `dot_assume` and `observe` | ||
m = TV(undef, length(x)) | ||
m .~ Normal() | ||
x ~ MvNormal(m, 0.5 * ones(length(x))) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return (; m, x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} | ||
# `assume` with indexing and `observe` | ||
m = TV(undef, length(x)) | ||
for i in eachindex(m) | ||
m[i] ~ Normal() | ||
end | ||
x ~ MvNormal(m, 0.5 * ones(length(x))) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return (; m, x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo3(x=10 * ones(2)) | ||
# Multivariate `assume` and `observe` | ||
m ~ MvNormal(length(x), 1.0) | ||
x ~ MvNormal(m, 0.5 * ones(length(x))) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return (; m, x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} | ||
# `dot_assume` and `observe` with indexing | ||
m = TV(undef, length(x)) | ||
m .~ Normal() | ||
for i in eachindex(x) | ||
x[i] ~ Normal(m[i], 0.5) | ||
end | ||
|
||
return (; m, x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
# Using vector of `length` 1 here so the posterior of `m` is the same | ||
# as the others. | ||
@model function demo5(x=10 * ones(1)) | ||
# `assume` and `dot_observe` | ||
m ~ Normal() | ||
x .~ Normal(m, 0.5) | ||
|
||
return (; m, x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo6() | ||
# `assume` and literal `observe` | ||
m ~ MvNormal(2, 1.0) | ||
[10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return (; m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo7(::Type{TV}=Vector{Float64}) where {TV} | ||
# `dot_assume` and literal `observe` with indexing | ||
m = TV(undef, 2) | ||
m .~ Normal() | ||
for i in eachindex(m) | ||
10.0 ~ Normal(m[i], 0.5) | ||
end | ||
|
||
return (; m, x=10 * ones(length(m)), logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo8() | ||
# `assume` and literal `dot_observe` | ||
m ~ Normal() | ||
[10.0] .~ Normal(m, 0.5) | ||
|
||
return (; m, x=[10.0], logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} | ||
m = TV(undef, 2) | ||
m .~ Normal() | ||
|
||
return m | ||
end | ||
|
||
@model function demo9() | ||
# Submodel prior | ||
m = @submodel _prior_dot_assume() | ||
for i in eachindex(m) | ||
10.0 ~ Normal(m[i], 0.5) | ||
end | ||
|
||
return (; m, x=[10.0], logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function _likelihood_dot_observe(m, x) | ||
return x ~ MvNormal(m, 0.5 * ones(length(m))) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
@model function demo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} | ||
m = TV(undef, length(x)) | ||
m .~ Normal() | ||
|
||
# Submodel likelihood | ||
@submodel _likelihood_dot_observe(m, x) | ||
|
||
return (; m, x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo11(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV} | ||
m = TV(undef, length(x)) | ||
m .~ Normal() | ||
|
||
# Dotted observe for `Matrix`. | ||
return x .~ MvNormal(m, 0.5) | ||
end | ||
|
||
const demo_models = ( | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
demo1(), | ||
demo2(), | ||
demo3(), | ||
demo4(), | ||
demo5(), | ||
demo6(), | ||
demo7(), | ||
demo8(), | ||
demo9(), | ||
demo10(), | ||
demo11(), | ||
) | ||
|
||
# TODO: Is this really the best/most convenient "default" test method? | ||
""" | ||
test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) | ||
|
||
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. | ||
|
||
In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the | ||
`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target` | ||
provided in `kwargs...`. | ||
|
||
# Arguments | ||
- `meanfunction`: A callable which computes the mean of the marginal means from the | ||
chain resulting from the `sample` call. | ||
- `sampler`: The `AbstractMCMC.AbstractSampler` to test. | ||
- `args...`: Arguments forwarded to `sample`. | ||
|
||
# Keyword arguments | ||
- `target`: Value to compare result of `meanfunction(chain)` to. | ||
- `atol=1e-1`: Absolute tolerance used in `@test`. | ||
- `rtol=1e-3`: Relative tolerance used in `@test`. | ||
- `kwargs...`: Keyword arguments forwarded to `sample`. | ||
""" | ||
function test_sampler_demo_models( | ||
meanfunction, | ||
sampler::AbstractMCMC.AbstractSampler, | ||
args...; | ||
target=8.0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haha, not entirely certain what you're saying 😅 It's "random" in the sense that it just happened to be the number that you get from the default values of the models, but it's deliberate in the sense that it's the actual true mean of the posterior of the models with the default values:) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it's really the expected result, but only if you pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So if IIUC, maybe that is better? function test_sampler_demo_models(
meanfunction,
sampler::AbstractMCMC.AbstractSampler,
args...;
target,
atol=1e-1,
rtol=1e-3,
kwargs...,
)
...
end
test_sampler_demo_models(::typeof(mean), args...; kwargs...,) = test_sampler_demo_models(mean, args...; target=8.0, kwargs...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not valid just for |
||
atol=1e-1, | ||
rtol=1e-3, | ||
kwargs..., | ||
) | ||
@testset "$(nameof(typeof(sampler))) on $(m.name)" for model in demo_models | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) | ||
μ = meanfunction(chain) | ||
@test μ ≈ target atol = atol rtol = rtol | ||
end | ||
end | ||
|
||
""" | ||
test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) | ||
|
||
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. | ||
|
||
As of right now, this is just an alias for [`test_sampler_demo_models`](@ref). | ||
""" | ||
function test_sampler_continuous( | ||
meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... | ||
) | ||
return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) | ||
end | ||
|
||
function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) | ||
# Default for `MCMCChains.Chains`. | ||
return test_sampler_continuous(sampler, args...; kwargs...) do chain | ||
mean(Array(chain)) | ||
end | ||
end | ||
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could use a bit more descriptive names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thought crossed my mind, but was worried they'd be too long. Buuuut we're not going to be using the constructors of these models very often, so the cost of making them overly verbose is approx. 0 👍
I'll do that!