From f90056be018d2012afc40877a2a6b775a8a05f87 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Fri, 24 Jun 2022 22:51:02 -0700 Subject: [PATCH 1/9] enhance wrapped distributions --- src/distribution_wrappers.jl | 42 +++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index d8968a68e..58de91968 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -2,20 +2,40 @@ using Distributions: Distributions using Bijectors: Bijectors using Distributions: Univariate, Multivariate, Matrixvariate +""" +Base type for distribution wrappers. +""" +abstract type WrappedDistribution{variate,support,Td<:Distribution{variate,support}} <: + Distribution{variate,support} +end + +wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where Td = Td +wrapped_dist_type(d::WrappedDistribution) = wrapped_dist_type(d) + +wrapped_dist(d::WrappedDistribution) = d.dist + +Base.length(d::WrappedDistribution{<:Multivariate}) = length(wrapped_dist(d)) +Base.size(d::WrappedDistribution{<:Multivariate}) = size(wrapped_dist(d)) +Base.eltype(::Type{T}) where T <: WrappedDistribution = eltype(wrapped_dist_type(T)) +Base.eltype(d::WrappedDistribution) = eltype(wrapped_dist_type(d)) + +Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution) = rand(rng, wrapped_dist(d)) +Distributions.minimum(d::WrappedDistribution) = minimum(wrapped_dist(d)) +Distributions.maximum(d::WrappedDistribution) = maximum(wrapped_dist(d)) + +Bijectors.bijector(d::WrappedDistribution) = bijector(wrapped_dist(d)) + """ A named distribution that carries the name of the random variable with it. """ struct NamedDist{variate,support,Td<:Distribution{variate,support},Tv<:VarName} <: - Distribution{variate,support} + WrappedDistribution{variate,support,Td} dist::Td name::Tv end NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}()) -Base.length(dist::NamedDist) = Base.length(dist.dist) -Base.size(dist::NamedDist) = Base.size(dist.dist) - Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x) function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.logpdf(dist.dist, x) @@ -27,10 +47,13 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.loglikelihood(dist.dist, x) end -Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist) +""" +Wrapper around distribution `Td` that suppresses `logpdf()` calculation. +Note that *SampleFromPrior* would still sample from `Td`. +""" struct NoDist{variate,support,Td<:Distribution{variate,support}} <: - Distribution{variate,support} + WrappedDistribution{variate,support,Td} dist::Td end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) @@ -38,9 +61,6 @@ NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) nodist(dist::Distribution) = NoDist(dist) nodist(dists::AbstractArray) = nodist.(dists) -Base.length(dist::NoDist) = Base.length(dist.dist) -Base.size(dist::NoDist) = Base.size(dist.dist) - Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist) Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0 Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 @@ -48,8 +68,6 @@ function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Rea return zeros(Int, size(x, 2)) end Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 -Distributions.minimum(d::NoDist) = minimum(d.dist) -Distributions.maximum(d::NoDist) = maximum(d.dist) Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0 function Bijectors.logpdf_with_trans( @@ -67,5 +85,3 @@ function Bijectors.logpdf_with_trans( ) return 0 end - -Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist) From 7f79c23e152a14566dc13b8cb2f603bc14b9efaf Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 27 Jun 2022 13:59:31 -0700 Subject: [PATCH 2/9] distr_wrappers: add tests for multivariate distrs --- test/distribution_wrappers.jl | 38 ++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/test/distribution_wrappers.jl b/test/distribution_wrappers.jl index 8bb692783..2c117b7c2 100644 --- a/test/distribution_wrappers.jl +++ b/test/distribution_wrappers.jl @@ -1,13 +1,33 @@ @testset "distribution_wrappers.jl" begin - d = Normal() - nd = DynamicPPL.NoDist(d) + @testset "univariate" begin + d = Normal() + nd = DynamicPPL.NoDist(d) - # Smoke test - rand(nd) + # Smoke test + rand(nd) - # Actual tests - @test minimum(nd) == -Inf - @test maximum(nd) == Inf - @test logpdf(nd, 15.0) == 0 - @test Bijectors.logpdf_with_trans(nd, 30, true) == 0 + # Actual tests + @test minimum(nd) == -Inf + @test maximum(nd) == Inf + @test logpdf(nd, 15.0) == 0 + @test Bijectors.logpdf_with_trans(nd, 30, true) == 0 + @test Bijectors.bijector(nd) == Bijectors.bijector(d) + end + + @testset "multivariate" begin + d = Product([Normal(), Uniform()]) + nd = DynamicPPL.NoDist(d) + + # Smoke test + @test length(rand(nd)) == 2 + + # Actual tests + #@test length(nd) == 2 + #@test size(nd) == (2,) + @test minimum(nd) == [-Inf, 0.0] + @test maximum(nd) == [Inf, 1.0] + @test logpdf(nd, [15.0, 0.5]) == 0 + @test Bijectors.logpdf_with_trans(nd, [0, 1]) == 0 + @test Bijectors.bijector(nd) == Bijectors.bijector(d) + end end From bcca942a4cb286121e0e8224fa89c20197a76baa Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 27 Jun 2022 14:01:04 -0700 Subject: [PATCH 3/9] add tests for model with multivariate NoDist --- test/context_implementations.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 8a795320d..d4b2cc177 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -68,5 +68,34 @@ end end end + + @testset "multivariate NoDist" begin + @model function genmodel() + x ~ NoDist(Product(fill(Uniform(-20, 20), 5))) + for i in eachindex(x) + x[i] ~ Normal(0, 1) + end + end + gen_model = genmodel() + vi_gen = VarInfo(gen_model) + @test isfinite(logjoint(gen_model, vi_gen)) + # test for bijector + link!(vi_gen, DynamicPPL.SampleFromPrior()) + invlink!(vi_gen, DynamicPPL.SampleFromPrior()) + + # explicit model specification + expl_model = DynamicPPL.Model(NamedTuple()) do model, varinfo, context + DynamicPPL.tilde_assume!!(context, NoDist(Product(fill(Uniform(-20, 20), 5))), @varname(x), varinfo) + x = varinfo[@varname(x)] + @test x isa Vector{<:Real} + @test length(x) == 5 + return nothing, DynamicPPL.acclogp!!(varinfo, sum(logpdf.(Ref(Normal(0, 1)), x))) + end + vi_expl = VarInfo(expl_model) + @test isfinite(logjoint(expl_model, vi_expl)) + # test for bijector + link!(vi_expl, DynamicPPL.SampleFromPrior()) + invlink!(vi_expl, DynamicPPL.SampleFromPrior()) + end end end From ec07ed53842be43947d777ea69754713f0ac1368 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 27 Jun 2022 14:21:21 -0700 Subject: [PATCH 4/9] fix commented out tests --- test/distribution_wrappers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/distribution_wrappers.jl b/test/distribution_wrappers.jl index 2c117b7c2..fb3a24247 100644 --- a/test/distribution_wrappers.jl +++ b/test/distribution_wrappers.jl @@ -22,8 +22,8 @@ @test length(rand(nd)) == 2 # Actual tests - #@test length(nd) == 2 - #@test size(nd) == (2,) + @test length(nd) == 2 + @test size(nd) == (2,) @test minimum(nd) == [-Inf, 0.0] @test maximum(nd) == [Inf, 1.0] @test logpdf(nd, [15.0, 0.5]) == 0 From e8710b3a34cab68981ea02fd3b64e6024293b268 Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 27 Jun 2022 14:35:21 -0700 Subject: [PATCH 5/9] fix reviewdog formatting issues --- src/distribution_wrappers.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 58de91968..89f7e8a0d 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -6,20 +6,22 @@ using Distributions: Univariate, Multivariate, Matrixvariate Base type for distribution wrappers. """ abstract type WrappedDistribution{variate,support,Td<:Distribution{variate,support}} <: - Distribution{variate,support} + Distribution{variate,support} end -wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where Td = Td +wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where {Td} = Td wrapped_dist_type(d::WrappedDistribution) = wrapped_dist_type(d) wrapped_dist(d::WrappedDistribution) = d.dist Base.length(d::WrappedDistribution{<:Multivariate}) = length(wrapped_dist(d)) Base.size(d::WrappedDistribution{<:Multivariate}) = size(wrapped_dist(d)) -Base.eltype(::Type{T}) where T <: WrappedDistribution = eltype(wrapped_dist_type(T)) +Base.eltype(::Type{T}) where {T<:WrappedDistribution} = eltype(wrapped_dist_type(T)) Base.eltype(d::WrappedDistribution) = eltype(wrapped_dist_type(d)) -Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution) = rand(rng, wrapped_dist(d)) +function Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution) + rand(rng, wrapped_dist(d)) +end Distributions.minimum(d::WrappedDistribution) = minimum(wrapped_dist(d)) Distributions.maximum(d::WrappedDistribution) = maximum(wrapped_dist(d)) From 2976bbf06371feaa790432c9786ccbc30508917b Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 27 Jun 2022 18:48:46 -0700 Subject: [PATCH 6/9] 2nd round of reviewdog fixes --- src/distribution_wrappers.jl | 5 ++--- test/context_implementations.jl | 12 +++++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 89f7e8a0d..844b0473f 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -6,8 +6,7 @@ using Distributions: Univariate, Multivariate, Matrixvariate Base type for distribution wrappers. """ abstract type WrappedDistribution{variate,support,Td<:Distribution{variate,support}} <: - Distribution{variate,support} -end + Distribution{variate,support} end wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where {Td} = Td wrapped_dist_type(d::WrappedDistribution) = wrapped_dist_type(d) @@ -20,7 +19,7 @@ Base.eltype(::Type{T}) where {T<:WrappedDistribution} = eltype(wrapped_dist_type Base.eltype(d::WrappedDistribution) = eltype(wrapped_dist_type(d)) function Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution) - rand(rng, wrapped_dist(d)) + return rand(rng, wrapped_dist(d)) end Distributions.minimum(d::WrappedDistribution) = minimum(wrapped_dist(d)) Distributions.maximum(d::WrappedDistribution) = maximum(wrapped_dist(d)) diff --git a/test/context_implementations.jl b/test/context_implementations.jl index d4b2cc177..62d8469c4 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -82,14 +82,20 @@ # test for bijector link!(vi_gen, DynamicPPL.SampleFromPrior()) invlink!(vi_gen, DynamicPPL.SampleFromPrior()) - + # explicit model specification expl_model = DynamicPPL.Model(NamedTuple()) do model, varinfo, context - DynamicPPL.tilde_assume!!(context, NoDist(Product(fill(Uniform(-20, 20), 5))), @varname(x), varinfo) + DynamicPPL.tilde_assume!!( + context, + NoDist(Product(fill(Uniform(-20, 20), 5))), + @varname(x), + varinfo + ) x = varinfo[@varname(x)] @test x isa Vector{<:Real} @test length(x) == 5 - return nothing, DynamicPPL.acclogp!!(varinfo, sum(logpdf.(Ref(Normal(0, 1)), x))) + return (nothing, + DynamicPPL.acclogp!!(varinfo, sum(logpdf.(Ref(Normal(0, 1)), x)))) end vi_expl = VarInfo(expl_model) @test isfinite(logjoint(expl_model, vi_expl)) From ece33c6f307a2008df33cdc2926e5850d053685d Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Mon, 27 Jun 2022 18:53:31 -0700 Subject: [PATCH 7/9] refer WrappedDist and NoDist from API docs --- docs/src/api.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 2dfda9119..9d8f16b22 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -103,7 +103,9 @@ pointwise_loglikelihoods ``` ```@docs +WrappedDistribution NamedDist +NoDist ``` ## Testing Utilities From 490257ab21d4c5494225ffad243e32cad4735e4e Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Tue, 28 Jun 2022 10:12:38 -0700 Subject: [PATCH 8/9] export WrappedDist to make docs happy --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 594084d66..2192bb4ab 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -106,6 +106,7 @@ export AbstractVarInfo, dot_tilde_assume, dot_tilde_observe, # Pseudo distributions + WrappedDistribution, NamedDist, NoDist, # Prob macros From 71f330490b3b4a9e08dcad10576ec295835237cf Mon Sep 17 00:00:00 2001 From: Alexey Stukalov Date: Tue, 28 Jun 2022 10:20:21 -0700 Subject: [PATCH 9/9] 3rd round of trying to make the format doggy happy --- test/context_implementations.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 62d8469c4..c548a3d8a 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -89,13 +89,15 @@ context, NoDist(Product(fill(Uniform(-20, 20), 5))), @varname(x), - varinfo + varinfo, ) x = varinfo[@varname(x)] @test x isa Vector{<:Real} @test length(x) == 5 - return (nothing, - DynamicPPL.acclogp!!(varinfo, sum(logpdf.(Ref(Normal(0, 1)), x)))) + return ( + nothing, + DynamicPPL.acclogp!!(varinfo, sum(logpdf.(Ref(Normal(0, 1)), x))), + ) end vi_expl = VarInfo(expl_model) @test isfinite(logjoint(expl_model, vi_expl))