-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Merged by Bors] - Add TestUtils submodule #313
Closed
Closed
Changes from 16 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
a5280ea
initial work on adding test utils
torfjelde dd7d774
include test_utils
torfjelde caffa78
added testing for continuous samplers
torfjelde c8d6b1f
Merge branch 'master' into tor/test-utils
torfjelde 7b58174
allow specification of target mean in test_sampler_continuous
torfjelde a1a9752
add return-values to the test models which can be useful
torfjelde eb92fd1
Merge branch 'master' into tor/test-utils
torfjelde fc815d0
Merge branch 'master' into tor/test-utils
torfjelde b32eb59
added TestUtils submodule
torfjelde b36c842
added some docstrings to TestUtils
torfjelde e0adc2b
fix 1.3 compatibility
torfjelde 3b71318
test model names are now more informative
torfjelde 0d3ff26
Apply suggestions from code review
torfjelde 334e8da
Apply suggestions from code review
torfjelde 4623798
Apply suggestions from @devmotion
torfjelde d86ab1d
Apply suggestions from code review
torfjelde 0f6f784
Merge branch 'master' into tor/test-utils
torfjelde 857b99b
fixed tests
torfjelde 9199004
fixed tests
torfjelde File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
module TestUtils | ||
|
||
using AbstractMCMC | ||
using DynamicPPL | ||
using Distributions | ||
using Test | ||
|
||
# A collection of models for which the mean-of-means for the posterior should | ||
# be same. | ||
@model function demo_dot_assume_dot_observe( | ||
x=[10.0, 10.0], ::Type{TV}=Vector{Float64} | ||
) where {TV} | ||
# `dot_assume` and `observe` | ||
m = TV(undef, length(x)) | ||
m .~ Normal() | ||
x ~ MvNormal(m, 0.25 * I) | ||
return (; m=m, x=x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo_assume_index_observe( | ||
x=[10.0, 10.0], ::Type{TV}=Vector{Float64} | ||
) where {TV} | ||
# `assume` with indexing and `observe` | ||
m = TV(undef, length(x)) | ||
for i in eachindex(m) | ||
m[i] ~ Normal() | ||
end | ||
x ~ MvNormal(m, 0.25 * I) | ||
|
||
return (; m=m, x=x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo_assume_multivariate_observe_index(x=[10.0, 10.0]) | ||
# Multivariate `assume` and `observe` | ||
m ~ MvNormal(zero(x), I) | ||
x ~ MvNormal(m, 0.25 * I) | ||
|
||
return (; m=m, x=x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo_dot_assume_observe_index( | ||
x=[10.0, 10.0], ::Type{TV}=Vector{Float64} | ||
) where {TV} | ||
# `dot_assume` and `observe` with indexing | ||
m = TV(undef, length(x)) | ||
m .~ Normal() | ||
for i in eachindex(x) | ||
x[i] ~ Normal(m[i], 0.5) | ||
end | ||
|
||
return (; m=m, x=x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
# Using vector of `length` 1 here so the posterior of `m` is the same | ||
# as the others. | ||
@model function demo_assume_dot_observe(x=[10.0]) | ||
# `assume` and `dot_observe` | ||
m ~ Normal() | ||
x .~ Normal(m, 0.5) | ||
|
||
return (; m=m, x=x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo_assume_observe_literal() | ||
# `assume` and literal `observe` | ||
m ~ MvNormal(zeros(2), I) | ||
[10.0, 10.0] ~ MvNormal(m, 0.25 * I) | ||
|
||
return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} | ||
# `dot_assume` and literal `observe` with indexing | ||
m = TV(undef, 2) | ||
m .~ Normal() | ||
for i in eachindex(m) | ||
10.0 ~ Normal(m[i], 0.5) | ||
end | ||
|
||
return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo_assume_literal_dot_observe() | ||
# `assume` and literal `dot_observe` | ||
m ~ Normal() | ||
[10.0] .~ Normal(m, 0.5) | ||
|
||
return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} | ||
m = TV(undef, 2) | ||
m .~ Normal() | ||
|
||
return m | ||
end | ||
|
||
@model function demo_assume_submodel_observe_index_literal() | ||
# Submodel prior | ||
m = @submodel _prior_dot_assume() | ||
for i in eachindex(m) | ||
10.0 ~ Normal(m[i], 0.5) | ||
end | ||
|
||
return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function _likelihood_dot_observe(m, x) | ||
return x ~ MvNormal(m, 0.25 * I) | ||
end | ||
|
||
@model function demo_dot_assume_observe_submodel( | ||
x=[10.0, 10.0], ::Type{TV}=Vector{Float64} | ||
) where {TV} | ||
m = TV(undef, length(x)) | ||
m .~ Normal() | ||
|
||
# Submodel likelihood | ||
@submodel _likelihood_dot_observe(m, x) | ||
|
||
return (; m=m, x=x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
@model function demo_dot_assume_dot_observe_matrix( | ||
x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} | ||
) where {TV} | ||
m = TV(undef, length(x)) | ||
m .~ Normal() | ||
|
||
# Dotted observe for `Matrix`. | ||
x .~ MvNormal(m, 0.25 * I) | ||
|
||
return (; m=m, x=x, logp=getlogp(__varinfo__)) | ||
end | ||
|
||
const DEMO_MODELS = ( | ||
demo_dot_assume_dot_observe(), | ||
demo_assume_index_observe(), | ||
demo_assume_multivariate_observe_index(), | ||
demo_dot_assume_observe_index(), | ||
demo_assume_dot_observe(), | ||
demo_assume_observe_literal(), | ||
demo_dot_assume_observe_index_literal(), | ||
demo_assume_literal_dot_observe(), | ||
demo_assume_submodel_observe_index_literal(), | ||
demo_dot_assume_observe_submodel(), | ||
demo_dot_assume_dot_observe_matrix(), | ||
) | ||
|
||
# TODO: Is this really the best/most convenient "default" test method? | ||
""" | ||
test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) | ||
|
||
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. | ||
|
||
In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the | ||
`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target` | ||
provided in `kwargs...`. | ||
|
||
# Arguments | ||
- `meanfunction`: A callable which computes the mean of the marginal means from the | ||
chain resulting from the `sample` call. | ||
- `sampler`: The `AbstractMCMC.AbstractSampler` to test. | ||
- `args...`: Arguments forwarded to `sample`. | ||
|
||
# Keyword arguments | ||
- `target`: Value to compare result of `meanfunction(chain)` to. | ||
- `atol=1e-1`: Absolute tolerance used in `@test`. | ||
- `rtol=1e-3`: Relative tolerance used in `@test`. | ||
- `kwargs...`: Keyword arguments forwarded to `sample`. | ||
""" | ||
function test_sampler_demo_models( | ||
meanfunction, | ||
sampler::AbstractMCMC.AbstractSampler, | ||
args...; | ||
target=8.0, | ||
atol=1e-1, | ||
rtol=1e-3, | ||
kwargs..., | ||
) | ||
@testset "$(nameof(typeof(sampler))) on $(m.name)" for model in DEMO_MODELS | ||
chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) | ||
μ = meanfunction(chain) | ||
@test μ ≈ target atol = atol rtol = rtol | ||
end | ||
end | ||
|
||
""" | ||
test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) | ||
|
||
Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. | ||
|
||
As of right now, this is just an alias for [`test_sampler_demo_models`](@ref). | ||
""" | ||
function test_sampler_continuous( | ||
meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... | ||
) | ||
return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) | ||
end | ||
|
||
function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) | ||
# Default for `MCMCChains.Chains`. | ||
return test_sampler_continuous(sampler, args...; kwargs...) do chain | ||
mean(Array(chain)) | ||
end | ||
end | ||
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://xkcd.com/221/?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, not entirely certain what you're saying 😅 It's "random" in the sense that it just happened to be the number that you get from the default values of the models, but it's deliberate in the sense that it's the actual true mean of the posterior of the models with the default values:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it's really the expected result, but only if you pass
meanfunction = mean
? I suspected it's either that or a random you put in for testing.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if IIUC, maybe that is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not valid just for
mean
though. This is because different samplers can have completely different return-values fromsample
, and so we want to allow different mean functions, while still wanting the target to be8.0
.