From a5280eacf2345c7626c0842eab1e235bea9c89dd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 13:23:58 +0100 Subject: [PATCH 01/15] initial work on adding test utils --- src/test_utils.jl | 116 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 src/test_utils.jl diff --git a/src/test_utils.jl b/src/test_utils.jl new file mode 100644 index 000000000..39cdd7851 --- /dev/null +++ b/src/test_utils.jl @@ -0,0 +1,116 @@ +module TestUtils + +using AbstractMCMC +using DynamicPPL +using Distributions +using Test + +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo3(x=10 * ones(2)) + # Multivariate `assume` and `observe` + m ~ MvNormal(length(x), 1.0) + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function gdemo5(x=10 * ones(1)) + # `assume` and `dot_observe` + m ~ Normal() + return x .~ Normal(m, 0.5) +end + +@model function gdemo6() + # `assume` and literal `observe` + m ~ MvNormal(2, 1.0) + return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) +end + +@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function gdemo8() + # `assume` and literal `dot_observe` + m ~ Normal() + return [10.0] .~ Normal(m, 0.5) +end + +@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function gdemo9() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function _likelihood_dot_observe(m, x) + return x ~ MvNormal(m, 0.5 * ones(length(m))) +end + +@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) +end + +const gdemo_models = ( + gdemo1(), + gdemo2(), + gdemo3(), + gdemo4(), + gdemo5(), + gdemo6(), + gdemo7(), + gdemo8(), + gdemo9(), + gdemo10(), +) + +function test_models(meanf, spl::AbstractSampler, args...; kwargs...) + chain = sample(spl, args...) + μ = meanf(chain) + @test μ ≈ 8.0 atol = atol rtol = rtol +end + +end From dd7d7740ea763ec13f1c4f9af889eb1ed757a9fd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 13:44:23 +0100 Subject: [PATCH 02/15] include test_utils --- Project.toml | 1 + src/DynamicPPL.jl | 2 ++ 2 files changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index db9f26b04..dcc8ddeb0 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..3e9aa3941 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -130,4 +130,6 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("test_utils.jl") + end # module From caffa78f39495c3b7d4f5ccd9b45b8e5efb73314 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 13:44:37 +0100 Subject: [PATCH 03/15] added testing for continuous samplers --- src/test_utils.jl | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 39cdd7851..4bed1d992 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -107,10 +107,23 @@ const gdemo_models = ( gdemo10(), ) -function test_models(meanf, spl::AbstractSampler, args...; kwargs...) - chain = sample(spl, args...) - μ = meanf(chain) - @test μ ≈ 8.0 atol = atol rtol = rtol +function test_sampler_gdemo(spl::AbstractMCMC.AbstractSampler, args...; kwargs...) + # Default for `MCMCChains.Chains`. + return test_sampler_gdemo(spl, args...; kwargs...) do chain + mean(Array(chain)) + end +end + +function test_sampler_gdemo(meanf, spl::AbstractMCMC.AbstractSampler, args...; kwargs...) + @testset "$(spl) on $(m.name)" for m in gdemo_models + chain = AbstractMCMC.sample(m, spl, args...) + μ = meanf(chain) + @test μ ≈ 8.0 atol = atol rtol = rtol + end +end + +function test_sampler_continuous(spl::AbstractMCMC.AbstractSampler, args...; kwargs...) + return test_sampler_gdemo(spl, args...; kwargs...) end end From 7b58174825a983ebb0b9c7ad549727645f7e6238 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 15:48:35 +0100 Subject: [PATCH 04/15] allow specification of target mean in test_sampler_continuous --- src/test_utils.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 4bed1d992..ccebeeb6e 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -107,23 +107,23 @@ const gdemo_models = ( gdemo10(), ) -function test_sampler_gdemo(spl::AbstractMCMC.AbstractSampler, args...; kwargs...) - # Default for `MCMCChains.Chains`. - return test_sampler_gdemo(spl, args...; kwargs...) do chain - mean(Array(chain)) - end -end - -function test_sampler_gdemo(meanf, spl::AbstractMCMC.AbstractSampler, args...; kwargs...) - @testset "$(spl) on $(m.name)" for m in gdemo_models +function test_sampler_gdemo(meanf, spl::AbstractMCMC.AbstractSampler, args...; target = 8.0, atol=1e-1, rtol=1-1, kwargs...) + @testset "$(nameof(typeof(spl))) on $(m.name)" for m in gdemo_models chain = AbstractMCMC.sample(m, spl, args...) μ = meanf(chain) - @test μ ≈ 8.0 atol = atol rtol = rtol + @test μ ≈ target atol = atol rtol = rtol end end +function test_sampler_continuous(meanf, spl::AbstractMCMC.AbstractSampler, args...; kwargs...) + return test_sampler_gdemo(meanf, spl, args...; kwargs...) +end + function test_sampler_continuous(spl::AbstractMCMC.AbstractSampler, args...; kwargs...) - return test_sampler_gdemo(spl, args...; kwargs...) + # Default for `MCMCChains.Chains`. + return test_sampler_continuous(spl, args...; kwargs...) do chain + mean(Array(chain)) + end end end From a1a975262656ae7f2e0ab20fffec6b971214a8e3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 11 Jun 2021 19:32:01 +0100 Subject: [PATCH 05/15] add return-values to the test models which can be useful --- src/test_utils.jl | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index ccebeeb6e..92d9c3864 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,7 +11,8 @@ using Test # `dot_assume` and `observe` m = TV(undef, length(x)) m .~ Normal() - return x ~ MvNormal(m, 0.5 * ones(length(x))) + x ~ MvNormal(m, 0.5 * ones(length(x))) + return (; m, x, logp = getlogp(__varinfo__)) end @model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} @@ -20,13 +21,17 @@ end for i in eachindex(m) m[i] ~ Normal() end - return x ~ MvNormal(m, 0.5 * ones(length(x))) + x ~ MvNormal(m, 0.5 * ones(length(x))) + + return (; m, x, logp = getlogp(__varinfo__)) end @model function gdemo3(x=10 * ones(2)) # Multivariate `assume` and `observe` m ~ MvNormal(length(x), 1.0) - return x ~ MvNormal(m, 0.5 * ones(length(x))) + x ~ MvNormal(m, 0.5 * ones(length(x))) + + return (; m, x, logp = getlogp(__varinfo__)) end @model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} @@ -36,6 +41,8 @@ end for i in eachindex(x) x[i] ~ Normal(m[i], 0.5) end + + return (; m, x, logp = getlogp(__varinfo__)) end # Using vector of `length` 1 here so the posterior of `m` is the same @@ -43,13 +50,17 @@ end @model function gdemo5(x=10 * ones(1)) # `assume` and `dot_observe` m ~ Normal() - return x .~ Normal(m, 0.5) + x .~ Normal(m, 0.5) + + return (; m, x, logp = getlogp(__varinfo__)) end @model function gdemo6() # `assume` and literal `observe` m ~ MvNormal(2, 1.0) - return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) + [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) + + return (; m, x = [10.0, 10.0], logp = getlogp(__varinfo__)) end @model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} @@ -59,12 +70,16 @@ end for i in eachindex(m) 10.0 ~ Normal(m[i], 0.5) end + + return (; m, x = 10 * ones(length(m)), logp = getlogp(__varinfo__)) end @model function gdemo8() # `assume` and literal `dot_observe` m ~ Normal() - return [10.0] .~ Normal(m, 0.5) + [10.0] .~ Normal(m, 0.5) + + return (; m, x = [10.0], logp = getlogp(__varinfo__)) end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} @@ -80,6 +95,8 @@ end for i in eachindex(m) 10.0 ~ Normal(m[i], 0.5) end + + return (; m, x = [10.0], logp = getlogp(__varinfo__)) end @model function _likelihood_dot_observe(m, x) @@ -92,6 +109,8 @@ end # Submodel likelihood @submodel _likelihood_dot_observe(m, x) + + return (; m, x, logp = getlogp(__varinfo__)) end const gdemo_models = ( From b32eb596049d3633817b7a13352245ad8f4bac1e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 09:45:44 +0100 Subject: [PATCH 06/15] added TestUtils submodule --- Project.toml | 2 +- src/test_utils.jl | 90 +++++++++++++++++++------------- test/loglikelihoods.jl | 113 +---------------------------------------- 3 files changed, 57 insertions(+), 148 deletions(-) diff --git a/Project.toml b/Project.toml index f397b1c70..b376cb8e3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.14.1" +version = "0.14.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/test_utils.jl b/src/test_utils.jl index 92d9c3864..b6cdecd79 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -7,15 +7,15 @@ using Test # A collection of models for which the mean-of-means for the posterior should # be same. -@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` m = TV(undef, length(x)) m .~ Normal() x ~ MvNormal(m, 0.5 * ones(length(x))) - return (; m, x, logp = getlogp(__varinfo__)) + return (; m, x, logp=getlogp(__varinfo__)) end -@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `assume` with indexing and `observe` m = TV(undef, length(x)) for i in eachindex(m) @@ -23,18 +23,18 @@ end end x ~ MvNormal(m, 0.5 * ones(length(x))) - return (; m, x, logp = getlogp(__varinfo__)) + return (; m, x, logp=getlogp(__varinfo__)) end -@model function gdemo3(x=10 * ones(2)) +@model function demo3(x=10 * ones(2)) # Multivariate `assume` and `observe` m ~ MvNormal(length(x), 1.0) x ~ MvNormal(m, 0.5 * ones(length(x))) - return (; m, x, logp = getlogp(__varinfo__)) + return (; m, x, logp=getlogp(__varinfo__)) end -@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` with indexing m = TV(undef, length(x)) m .~ Normal() @@ -42,28 +42,28 @@ end x[i] ~ Normal(m[i], 0.5) end - return (; m, x, logp = getlogp(__varinfo__)) + return (; m, x, logp=getlogp(__varinfo__)) end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function gdemo5(x=10 * ones(1)) +@model function demo5(x=10 * ones(1)) # `assume` and `dot_observe` m ~ Normal() x .~ Normal(m, 0.5) - return (; m, x, logp = getlogp(__varinfo__)) + return (; m, x, logp=getlogp(__varinfo__)) end -@model function gdemo6() +@model function demo6() # `assume` and literal `observe` m ~ MvNormal(2, 1.0) [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) - return (; m, x = [10.0, 10.0], logp = getlogp(__varinfo__)) + return (; m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) end -@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} +@model function demo7(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing m = TV(undef, 2) m .~ Normal() @@ -71,15 +71,15 @@ end 10.0 ~ Normal(m[i], 0.5) end - return (; m, x = 10 * ones(length(m)), logp = getlogp(__varinfo__)) + return (; m, x=10 * ones(length(m)), logp=getlogp(__varinfo__)) end -@model function gdemo8() +@model function demo8() # `assume` and literal `dot_observe` m ~ Normal() [10.0] .~ Normal(m, 0.5) - return (; m, x = [10.0], logp = getlogp(__varinfo__)) + return (; m, x=[10.0], logp=getlogp(__varinfo__)) end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} @@ -89,53 +89,73 @@ end return m end -@model function gdemo9() +@model function demo9() # Submodel prior m = @submodel _prior_dot_assume() for i in eachindex(m) 10.0 ~ Normal(m[i], 0.5) end - return (; m, x = [10.0], logp = getlogp(__varinfo__)) + return (; m, x=[10.0], logp=getlogp(__varinfo__)) end @model function _likelihood_dot_observe(m, x) return x ~ MvNormal(m, 0.5 * ones(length(m))) end -@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, length(x)) m .~ Normal() # Submodel likelihood @submodel _likelihood_dot_observe(m, x) - return (; m, x, logp = getlogp(__varinfo__)) + return (; m, x, logp=getlogp(__varinfo__)) end -const gdemo_models = ( - gdemo1(), - gdemo2(), - gdemo3(), - gdemo4(), - gdemo5(), - gdemo6(), - gdemo7(), - gdemo8(), - gdemo9(), - gdemo10(), +@model function demo11(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Dotted observe for `Matrix`. + return x .~ MvNormal(m, 0.5) +end + +const demo_models = ( + demo1(), + demo2(), + demo3(), + demo4(), + demo5(), + demo6(), + demo7(), + demo8(), + demo9(), + demo10(), + demo11(), ) -function test_sampler_gdemo(meanf, spl::AbstractMCMC.AbstractSampler, args...; target = 8.0, atol=1e-1, rtol=1-1, kwargs...) - @testset "$(nameof(typeof(spl))) on $(m.name)" for m in gdemo_models +# TODO: Is this really the best "default"? +function test_sampler_demo_models( + meanf, + spl::AbstractMCMC.AbstractSampler, + args...; + target=8.0, + atol=1e-1, + rtol=1 - 1, + kwargs..., +) + @testset "$(nameof(typeof(spl))) on $(m.name)" for m in demo_models chain = AbstractMCMC.sample(m, spl, args...) μ = meanf(chain) @test μ ≈ target atol = atol rtol = rtol end end -function test_sampler_continuous(meanf, spl::AbstractMCMC.AbstractSampler, args...; kwargs...) - return test_sampler_gdemo(meanf, spl, args...; kwargs...) +function test_sampler_continuous( + meanf, spl::AbstractMCMC.AbstractSampler, args...; kwargs... +) + return test_sampler_demo_models(meanf, spl, args...; kwargs...) end function test_sampler_continuous(spl::AbstractMCMC.AbstractSampler, args...; kwargs...) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index db01a0b9a..6eda23cfc 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,116 +1,5 @@ -# A collection of models for which the mean-of-means for the posterior should -# be same. -@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} - # `dot_assume` and `observe` - m = TV(undef, length(x)) - m .~ Normal() - return x ~ MvNormal(m, 0.5) -end - -@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} - # `assume` with indexing and `observe` - m = TV(undef, length(x)) - for i in eachindex(m) - m[i] ~ Normal() - end - return x ~ MvNormal(m, 0.5) -end - -@model function gdemo3(x=10 * ones(2)) - # Multivariate `assume` and `observe` - m ~ MvNormal(length(x), 1.0) - return x ~ MvNormal(m, 0.5) -end - -@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} - # `dot_assume` and `observe` with indexing - m = TV(undef, length(x)) - m .~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m[i], 0.5) - end -end - -# Using vector of `length` 1 here so the posterior of `m` is the same -# as the others. -@model function gdemo5(x=10 * ones(1)) - # `assume` and `dot_observe` - m ~ Normal() - return x .~ Normal(m, 0.5) -end - -@model function gdemo6(::Type{TV}=Vector{Float64}) where {TV} - # `assume` and literal `observe` - m ~ MvNormal(2, 1.0) - return [10.0, 10.0] ~ MvNormal(m, 0.5) -end - -@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} - # `dot_assume` and literal `observe` with indexing - m = TV(undef, 2) - m .~ Normal() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) - end -end - -@model function gdemo8(::Type{TV}=Vector{Float64}) where {TV} - # `assume` and literal `dot_observe` - m ~ Normal() - return [10.0] .~ Normal(m, 0.5) -end - -@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} - m = TV(undef, 2) - m .~ Normal() - - return m -end - -@model function gdemo9() - # Submodel prior - m = @submodel _prior_dot_assume() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) - end -end - -@model function _likelihood_dot_observe(m, x) - return x ~ MvNormal(m, 0.5) -end - -@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} - m = TV(undef, length(x)) - m .~ Normal() - - # Submodel likelihood - @submodel _likelihood_dot_observe(m, x) -end - -@model function gdemo11(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV} - m = TV(undef, length(x)) - m .~ Normal() - - # Dotted observe for `Matrix`. - return x .~ MvNormal(m, 0.5) -end - -const gdemo_models = ( - gdemo1(), - gdemo2(), - gdemo3(), - gdemo4(), - gdemo5(), - gdemo6(), - gdemo7(), - gdemo8(), - gdemo9(), - gdemo10(), - gdemo11(), -) - @testset "loglikelihoods.jl" begin - for m in gdemo_models + for m in DynamicPPL.TestUtils.demo_models vi = VarInfo(m) vns = vi.metadata.m.vns From b36c84238057029a65d99ee2f8594be3145eb413 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 10:51:40 +0100 Subject: [PATCH 07/15] added some docstrings to TestUtils --- src/test_utils.jl | 50 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index b6cdecd79..977c73fc9 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -135,32 +135,60 @@ const demo_models = ( demo11(), ) -# TODO: Is this really the best "default"? +# TODO: Is this really the best/most convenient "default" test method? +""" + test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + +Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. + +In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the +`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target` +provided in `kwargs...`. + +# Arguments +- `meanfunction`: A callable which computes the mean of the marginal means from the + chain resulting from the `sample` call. +- `sampler`: The `AbstractMCMC.AbstractSampler` to test. +- `args...`: Arguments forwarded to `sample`. + +# Keyword arguments +- `target`: Value to compare result of `meanfunction(chain)` to. +- `atol=1e-1`: Absolute tolerance used in `@test`. +- `rtol=1e-3`: Relative tolerance used in `@test`. +- `kwargs...`: Keyword arguments forwarded to `sample`. +""" function test_sampler_demo_models( - meanf, - spl::AbstractMCMC.AbstractSampler, + meanfunction, + sampler::AbstractMCMC.AbstractSampler, args...; target=8.0, atol=1e-1, - rtol=1 - 1, + rtol=1e-3, kwargs..., ) - @testset "$(nameof(typeof(spl))) on $(m.name)" for m in demo_models - chain = AbstractMCMC.sample(m, spl, args...) - μ = meanf(chain) + @testset "$(nameof(typeof(sampler))) on $(m.name)" for model in demo_models + chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) + μ = meanfunction(chain) @test μ ≈ target atol = atol rtol = rtol end end +""" + test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) + +Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. + +As of right now, this is just an alias for [`test_sampler_demo_models`](@ref). +""" function test_sampler_continuous( - meanf, spl::AbstractMCMC.AbstractSampler, args...; kwargs... + meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler_demo_models(meanf, spl, args...; kwargs...) + return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) end -function test_sampler_continuous(spl::AbstractMCMC.AbstractSampler, args...; kwargs...) +function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) # Default for `MCMCChains.Chains`. - return test_sampler_continuous(spl, args...; kwargs...) do chain + return test_sampler_continuous(sampler, args...; kwargs...) do chain mean(Array(chain)) end end From e0adc2b28c7dbfca5f3dcb33a9411d8a83b81765 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 11:05:05 +0100 Subject: [PATCH 08/15] fix 1.3 compatibility --- src/test_utils.jl | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 977c73fc9..320a13ce3 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -12,7 +12,7 @@ using Test m = TV(undef, length(x)) m .~ Normal() x ~ MvNormal(m, 0.5 * ones(length(x))) - return (; m, x, logp=getlogp(__varinfo__)) + return (; m=m, x=x, logp=getlogp(__varinfo__)) end @model function demo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} @@ -23,7 +23,7 @@ end end x ~ MvNormal(m, 0.5 * ones(length(x))) - return (; m, x, logp=getlogp(__varinfo__)) + return (; m=m, x=x, logp=getlogp(__varinfo__)) end @model function demo3(x=10 * ones(2)) @@ -31,7 +31,7 @@ end m ~ MvNormal(length(x), 1.0) x ~ MvNormal(m, 0.5 * ones(length(x))) - return (; m, x, logp=getlogp(__varinfo__)) + return (; m=m, x=x, logp=getlogp(__varinfo__)) end @model function demo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} @@ -42,7 +42,7 @@ end x[i] ~ Normal(m[i], 0.5) end - return (; m, x, logp=getlogp(__varinfo__)) + return (; m=m, x=x, logp=getlogp(__varinfo__)) end # Using vector of `length` 1 here so the posterior of `m` is the same @@ -52,7 +52,7 @@ end m ~ Normal() x .~ Normal(m, 0.5) - return (; m, x, logp=getlogp(__varinfo__)) + return (; m=m, x=x, logp=getlogp(__varinfo__)) end @model function demo6() @@ -60,7 +60,7 @@ end m ~ MvNormal(2, 1.0) [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) - return (; m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) + return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) end @model function demo7(::Type{TV}=Vector{Float64}) where {TV} @@ -71,7 +71,7 @@ end 10.0 ~ Normal(m[i], 0.5) end - return (; m, x=10 * ones(length(m)), logp=getlogp(__varinfo__)) + return (; m=m, x=10 * ones(length(m)), logp=getlogp(__varinfo__)) end @model function demo8() @@ -79,7 +79,7 @@ end m ~ Normal() [10.0] .~ Normal(m, 0.5) - return (; m, x=[10.0], logp=getlogp(__varinfo__)) + return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} @@ -96,7 +96,7 @@ end 10.0 ~ Normal(m[i], 0.5) end - return (; m, x=[10.0], logp=getlogp(__varinfo__)) + return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) end @model function _likelihood_dot_observe(m, x) @@ -110,7 +110,7 @@ end # Submodel likelihood @submodel _likelihood_dot_observe(m, x) - return (; m, x, logp=getlogp(__varinfo__)) + return (; m=m, x=x, logp=getlogp(__varinfo__)) end @model function demo11(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV} @@ -118,7 +118,9 @@ end m .~ Normal() # Dotted observe for `Matrix`. - return x .~ MvNormal(m, 0.5) + x .~ MvNormal(m, 0.5) + + return (; m=m, x=x, logp=getlogp(__varinfo__)) end const demo_models = ( From 3b71318b99137ba1288ff5ea768ccfbd124aa085 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 11:09:37 +0100 Subject: [PATCH 09/15] test model names are now more informative --- src/test_utils.jl | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 320a13ce3..3379f0ab7 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -7,7 +7,7 @@ using Test # A collection of models for which the mean-of-means for the posterior should # be same. -@model function demo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_dot_observe(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` m = TV(undef, length(x)) m .~ Normal() @@ -15,7 +15,7 @@ using Test return (; m=m, x=x, logp=getlogp(__varinfo__)) end -@model function demo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_assume_index_observe(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `assume` with indexing and `observe` m = TV(undef, length(x)) for i in eachindex(m) @@ -26,7 +26,7 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end -@model function demo3(x=10 * ones(2)) +@model function demo_assume_multivariate_observe_index(x=10 * ones(2)) # Multivariate `assume` and `observe` m ~ MvNormal(length(x), 1.0) x ~ MvNormal(m, 0.5 * ones(length(x))) @@ -34,7 +34,7 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end -@model function demo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_observe_index(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` with indexing m = TV(undef, length(x)) m .~ Normal() @@ -47,7 +47,7 @@ end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function demo5(x=10 * ones(1)) +@model function demo_assume_dot_observe(x=10 * ones(1)) # `assume` and `dot_observe` m ~ Normal() x .~ Normal(m, 0.5) @@ -55,7 +55,7 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end -@model function demo6() +@model function demo_assume_observe_literal() # `assume` and literal `observe` m ~ MvNormal(2, 1.0) [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) @@ -63,7 +63,7 @@ end return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) end -@model function demo7(::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing m = TV(undef, 2) m .~ Normal() @@ -74,7 +74,7 @@ end return (; m=m, x=10 * ones(length(m)), logp=getlogp(__varinfo__)) end -@model function demo8() +@model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` m ~ Normal() [10.0] .~ Normal(m, 0.5) @@ -89,7 +89,7 @@ end return m end -@model function demo9() +@model function demo_assume_submodel_observe_index_literal() # Submodel prior m = @submodel _prior_dot_assume() for i in eachindex(m) @@ -103,7 +103,7 @@ end return x ~ MvNormal(m, 0.5 * ones(length(m))) end -@model function demo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_observe_submodel(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, length(x)) m .~ Normal() @@ -113,7 +113,7 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end -@model function demo11(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_dot_observe_matrix(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, length(x)) m .~ Normal() @@ -124,17 +124,17 @@ end end const demo_models = ( - demo1(), - demo2(), - demo3(), - demo4(), - demo5(), - demo6(), - demo7(), - demo8(), - demo9(), - demo10(), - demo11(), + demo_dot_assume_dot_observe(), + demo_assume_index_observe(), + demo_assume_multivariate_observe_index(), + demo_dot_assume_observe_index(), + demo_assume_dot_observe(), + demo_assume_observe_literal(), + demo_dot_assume_observe_index_literal(), + demo_assume_literal_dot_observe(), + demo_assume_submodel_observe_index_literal(), + demo_dot_assume_observe_submodel(), + demo_dot_assume_dot_observe_matrix(), ) # TODO: Is this really the best/most convenient "default" test method? From 0d3ff2612ea7c78b7071ea3660db81e698f69563 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 11:10:08 +0100 Subject: [PATCH 10/15] Apply suggestions from code review Co-authored-by: Philipp Gabler --- src/test_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 3379f0ab7..c3f7639e8 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -123,7 +123,7 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end -const demo_models = ( +const DEMO_MODELS = ( demo_dot_assume_dot_observe(), demo_assume_index_observe(), demo_assume_multivariate_observe_index(), @@ -168,7 +168,7 @@ function test_sampler_demo_models( rtol=1e-3, kwargs..., ) - @testset "$(nameof(typeof(sampler))) on $(m.name)" for model in demo_models + @testset "$(nameof(typeof(sampler))) on $(m.name)" for model in DEMO_MODELS chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) μ = meanfunction(chain) @test μ ≈ target atol = atol rtol = rtol From 334e8dafb7fbf86777238d160ccd661ceb0dbf68 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 11:23:11 +0100 Subject: [PATCH 11/15] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/test_utils.jl | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index c3f7639e8..a31fe298e 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -7,7 +7,9 @@ using Test # A collection of models for which the mean-of-means for the posterior should # be same. -@model function demo_dot_assume_dot_observe(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_dot_observe( + x=10 * ones(2), ::Type{TV}=Vector{Float64} +) where {TV} # `dot_assume` and `observe` m = TV(undef, length(x)) m .~ Normal() @@ -15,7 +17,9 @@ using Test return (; m=m, x=x, logp=getlogp(__varinfo__)) end -@model function demo_assume_index_observe(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_assume_index_observe( + x=10 * ones(2), ::Type{TV}=Vector{Float64} +) where {TV} # `assume` with indexing and `observe` m = TV(undef, length(x)) for i in eachindex(m) @@ -34,7 +38,9 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end -@model function demo_dot_assume_observe_index(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_observe_index( + x=10 * ones(2), ::Type{TV}=Vector{Float64} +) where {TV} # `dot_assume` and `observe` with indexing m = TV(undef, length(x)) m .~ Normal() @@ -103,7 +109,9 @@ end return x ~ MvNormal(m, 0.5 * ones(length(m))) end -@model function demo_dot_assume_observe_submodel(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_observe_submodel( + x=10 * ones(2), ::Type{TV}=Vector{Float64} +) where {TV} m = TV(undef, length(x)) m .~ Normal() @@ -113,7 +121,9 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end -@model function demo_dot_assume_dot_observe_matrix(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV} +@model function demo_dot_assume_dot_observe_matrix( + x=10 * ones(2, 1), ::Type{TV}=Vector{Float64} +) where {TV} m = TV(undef, length(x)) m .~ Normal() From 46237980f51edd5f1bfde8c7e2885bedf6a41ac9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 26 Aug 2021 19:36:25 +0100 Subject: [PATCH 12/15] Apply suggestions from @devmotion Co-authored-by: David Widmann --- src/test_utils.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index a31fe298e..188366838 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -8,7 +8,7 @@ using Test # A collection of models for which the mean-of-means for the posterior should # be same. @model function demo_dot_assume_dot_observe( - x=10 * ones(2), ::Type{TV}=Vector{Float64} + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` m = TV(undef, length(x)) @@ -18,7 +18,7 @@ using Test end @model function demo_assume_index_observe( - x=10 * ones(2), ::Type{TV}=Vector{Float64} + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} ) where {TV} # `assume` with indexing and `observe` m = TV(undef, length(x)) @@ -30,7 +30,7 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end -@model function demo_assume_multivariate_observe_index(x=10 * ones(2)) +@model function demo_assume_multivariate_observe_index(x=[10.0, 10.0]) # Multivariate `assume` and `observe` m ~ MvNormal(length(x), 1.0) x ~ MvNormal(m, 0.5 * ones(length(x))) @@ -39,7 +39,7 @@ end end @model function demo_dot_assume_observe_index( - x=10 * ones(2), ::Type{TV}=Vector{Float64} + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` with indexing m = TV(undef, length(x)) @@ -53,7 +53,7 @@ end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function demo_assume_dot_observe(x=10 * ones(1)) +@model function demo_assume_dot_observe(x=[10.0]) # `assume` and `dot_observe` m ~ Normal() x .~ Normal(m, 0.5) @@ -77,7 +77,7 @@ end 10.0 ~ Normal(m[i], 0.5) end - return (; m=m, x=10 * ones(length(m)), logp=getlogp(__varinfo__)) + return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__)) end @model function demo_assume_literal_dot_observe() @@ -110,7 +110,7 @@ end end @model function demo_dot_assume_observe_submodel( - x=10 * ones(2), ::Type{TV}=Vector{Float64} + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} ) where {TV} m = TV(undef, length(x)) m .~ Normal() @@ -122,7 +122,7 @@ end end @model function demo_dot_assume_dot_observe_matrix( - x=10 * ones(2, 1), ::Type{TV}=Vector{Float64} + x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} ) where {TV} m = TV(undef, length(x)) m .~ Normal() From d86ab1d7ecdd32747f27b450d5696dab1509e302 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 27 Aug 2021 01:37:20 +0100 Subject: [PATCH 13/15] Apply suggestions from code review Co-authored-by: David Widmann --- src/test_utils.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 188366838..492890336 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -13,7 +13,7 @@ using Test # `dot_assume` and `observe` m = TV(undef, length(x)) m .~ Normal() - x ~ MvNormal(m, 0.5 * ones(length(x))) + x ~ MvNormal(m, 0.25 * I) return (; m=m, x=x, logp=getlogp(__varinfo__)) end @@ -25,15 +25,15 @@ end for i in eachindex(m) m[i] ~ Normal() end - x ~ MvNormal(m, 0.5 * ones(length(x))) + x ~ MvNormal(m, 0.25 * I) return (; m=m, x=x, logp=getlogp(__varinfo__)) end @model function demo_assume_multivariate_observe_index(x=[10.0, 10.0]) # Multivariate `assume` and `observe` - m ~ MvNormal(length(x), 1.0) - x ~ MvNormal(m, 0.5 * ones(length(x))) + m ~ MvNormal(zero(x), I) + x ~ MvNormal(m, 0.25 * I) return (; m=m, x=x, logp=getlogp(__varinfo__)) end @@ -63,8 +63,8 @@ end @model function demo_assume_observe_literal() # `assume` and literal `observe` - m ~ MvNormal(2, 1.0) - [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) + m ~ MvNormal(zeros(2), I) + [10.0, 10.0] ~ MvNormal(m, 0.25 * I) return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) end @@ -106,7 +106,7 @@ end end @model function _likelihood_dot_observe(m, x) - return x ~ MvNormal(m, 0.5 * ones(length(m))) + return x ~ MvNormal(m, 0.25 * I) end @model function demo_dot_assume_observe_submodel( @@ -128,7 +128,7 @@ end m .~ Normal() # Dotted observe for `Matrix`. - x .~ MvNormal(m, 0.5) + x .~ MvNormal(m, 0.25 * I) return (; m=m, x=x, logp=getlogp(__varinfo__)) end From 857b99b2d00df15c483f9630e44ea3e585cd0c11 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 16:11:30 +0100 Subject: [PATCH 14/15] fixed tests --- test/loglikelihoods.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 6eda23cfc..4d5003f03 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,5 +1,5 @@ @testset "loglikelihoods.jl" begin - for m in DynamicPPL.TestUtils.demo_models + for m in DynamicPPL.TestUtils.DEMO_MODELS vi = VarInfo(m) vns = vi.metadata.m.vns From 9199004fa1ca1bf49db3afbc345b9fcf90120d39 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 16:33:19 +0100 Subject: [PATCH 15/15] fixed tests --- src/test_utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index 492890336..d4b5c7206 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -2,6 +2,7 @@ module TestUtils using AbstractMCMC using DynamicPPL +using LinearAlgebra using Distributions using Test