diff --git a/Project.toml b/Project.toml index 489c40e1a..42d6baf44 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.15.1" +version = "0.15.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index ac2734b47..1238add26 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -134,4 +134,6 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("test_utils.jl") + end # module diff --git a/src/test_utils.jl b/src/test_utils.jl new file mode 100644 index 000000000..492890336 --- /dev/null +++ b/src/test_utils.jl @@ -0,0 +1,208 @@ +module TestUtils + +using AbstractMCMC +using DynamicPPL +using Distributions +using Test + +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function demo_dot_assume_dot_observe( + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} +) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + x ~ MvNormal(m, 0.25 * I) + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end + +@model function demo_assume_index_observe( + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} +) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + x ~ MvNormal(m, 0.25 * I) + + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end + +@model function demo_assume_multivariate_observe_index(x=[10.0, 10.0]) + # Multivariate `assume` and `observe` + m ~ MvNormal(zero(x), I) + x ~ MvNormal(m, 0.25 * I) + + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end + +@model function demo_dot_assume_observe_index( + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} +) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end + + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function demo_assume_dot_observe(x=[10.0]) + # `assume` and `dot_observe` + m ~ Normal() + x .~ Normal(m, 0.5) + + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end + +@model function demo_assume_observe_literal() + # `assume` and literal `observe` + m ~ MvNormal(zeros(2), I) + [10.0, 10.0] ~ MvNormal(m, 0.25 * I) + + return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) +end + +@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end + + return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__)) +end + +@model function demo_assume_literal_dot_observe() + # `assume` and literal `dot_observe` + m ~ Normal() + [10.0] .~ Normal(m, 0.5) + + return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) +end + +@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function demo_assume_submodel_observe_index_literal() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end + + return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) +end + +@model function _likelihood_dot_observe(m, x) + return x ~ MvNormal(m, 0.25 * I) +end + +@model function demo_dot_assume_observe_submodel( + x=[10.0, 10.0], ::Type{TV}=Vector{Float64} +) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) + + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end + +@model function demo_dot_assume_dot_observe_matrix( + x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} +) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Dotted observe for `Matrix`. + x .~ MvNormal(m, 0.25 * I) + + return (; m=m, x=x, logp=getlogp(__varinfo__)) +end + +const DEMO_MODELS = ( + demo_dot_assume_dot_observe(), + demo_assume_index_observe(), + demo_assume_multivariate_observe_index(), + demo_dot_assume_observe_index(), + demo_assume_dot_observe(), + demo_assume_observe_literal(), + demo_dot_assume_observe_index_literal(), + demo_assume_literal_dot_observe(), + demo_assume_submodel_observe_index_literal(), + demo_dot_assume_observe_submodel(), + demo_dot_assume_dot_observe_matrix(), +) + +# TODO: Is this really the best/most convenient "default" test method? +""" + test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + +Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. + +In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the +`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target` +provided in `kwargs...`. + +# Arguments +- `meanfunction`: A callable which computes the mean of the marginal means from the + chain resulting from the `sample` call. +- `sampler`: The `AbstractMCMC.AbstractSampler` to test. +- `args...`: Arguments forwarded to `sample`. + +# Keyword arguments +- `target`: Value to compare result of `meanfunction(chain)` to. +- `atol=1e-1`: Absolute tolerance used in `@test`. +- `rtol=1e-3`: Relative tolerance used in `@test`. +- `kwargs...`: Keyword arguments forwarded to `sample`. +""" +function test_sampler_demo_models( + meanfunction, + sampler::AbstractMCMC.AbstractSampler, + args...; + target=8.0, + atol=1e-1, + rtol=1e-3, + kwargs..., +) + @testset "$(nameof(typeof(sampler))) on $(m.name)" for model in DEMO_MODELS + chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) + μ = meanfunction(chain) + @test μ ≈ target atol = atol rtol = rtol + end +end + +""" + test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) + +Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. + +As of right now, this is just an alias for [`test_sampler_demo_models`](@ref). +""" +function test_sampler_continuous( + meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... +) + return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) +end + +function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) + # Default for `MCMCChains.Chains`. + return test_sampler_continuous(sampler, args...; kwargs...) do chain + mean(Array(chain)) + end +end + +end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 8a9906524..6eda23cfc 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,116 +1,5 @@ -# A collection of models for which the mean-of-means for the posterior should -# be same. -@model function gdemo1(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV} - # `dot_assume` and `observe` - m = TV(undef, length(x)) - m .~ Normal() - return x ~ MvNormal(m, 0.25 * I) -end - -@model function gdemo2(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV} - # `assume` with indexing and `observe` - m = TV(undef, length(x)) - for i in eachindex(m) - m[i] ~ Normal() - end - return x ~ MvNormal(m, 0.25 * I) -end - -@model function gdemo3(x=[10.0, 10.0]) - # Multivariate `assume` and `observe` - m ~ MvNormal(zero(x), I) - return x ~ MvNormal(m, 0.25 * I) -end - -@model function gdemo4(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV} - # `dot_assume` and `observe` with indexing - m = TV(undef, length(x)) - m .~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m[i], 0.5) - end -end - -# Using vector of `length` 1 here so the posterior of `m` is the same -# as the others. -@model function gdemo5(x=[10.0]) - # `assume` and `dot_observe` - m ~ Normal() - return x .~ Normal(m, 0.5) -end - -@model function gdemo6(::Type{TV}=Vector{Float64}) where {TV} - # `assume` and literal `observe` - m ~ MvNormal(zeros(2), I) - return [10.0, 10.0] ~ MvNormal(m, 0.25 * I) -end - -@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} - # `dot_assume` and literal `observe` with indexing - m = TV(undef, 2) - m .~ Normal() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) - end -end - -@model function gdemo8(::Type{TV}=Vector{Float64}) where {TV} - # `assume` and literal `dot_observe` - m ~ Normal() - return [10.0] .~ Normal(m, 0.5) -end - -@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} - m = TV(undef, 2) - m .~ Normal() - - return m -end - -@model function gdemo9() - # Submodel prior - m = @submodel _prior_dot_assume() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) - end -end - -@model function _likelihood_dot_observe(m, x) - return x ~ MvNormal(m, 0.25 * I) -end - -@model function gdemo10(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV} - m = TV(undef, length(x)) - m .~ Normal() - - # Submodel likelihood - @submodel _likelihood_dot_observe(m, x) -end - -@model function gdemo11(x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64}) where {TV} - m = TV(undef, length(x)) - m .~ Normal() - - # Dotted observe for `Matrix`. - return x .~ MvNormal(m, 0.25 * I) -end - -const gdemo_models = ( - gdemo1(), - gdemo2(), - gdemo3(), - gdemo4(), - gdemo5(), - gdemo6(), - gdemo7(), - gdemo8(), - gdemo9(), - gdemo10(), - gdemo11(), -) - @testset "loglikelihoods.jl" begin - for m in gdemo_models + for m in DynamicPPL.TestUtils.demo_models vi = VarInfo(m) vns = vi.metadata.m.vns