diff --git a/docs/src/api.md b/docs/src/api.md index 0e4012e02..2024b58d5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -103,7 +103,9 @@ pointwise_loglikelihoods ``` ```@docs +WrappedDistribution NamedDist +NoDist ``` ## Testing Utilities diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 04b08fb19..18e9f73c0 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -107,6 +107,7 @@ export AbstractVarInfo, dot_tilde_assume, dot_tilde_observe, # Pseudo distributions + WrappedDistribution, NamedDist, NoDist, # Prob macros diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index d8968a68e..844b0473f 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -2,20 +2,41 @@ using Distributions: Distributions using Bijectors: Bijectors using Distributions: Univariate, Multivariate, Matrixvariate +""" +Base type for distribution wrappers. +""" +abstract type WrappedDistribution{variate,support,Td<:Distribution{variate,support}} <: + Distribution{variate,support} end + +wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where {Td} = Td +wrapped_dist_type(d::WrappedDistribution) = wrapped_dist_type(d) + +wrapped_dist(d::WrappedDistribution) = d.dist + +Base.length(d::WrappedDistribution{<:Multivariate}) = length(wrapped_dist(d)) +Base.size(d::WrappedDistribution{<:Multivariate}) = size(wrapped_dist(d)) +Base.eltype(::Type{T}) where {T<:WrappedDistribution} = eltype(wrapped_dist_type(T)) +Base.eltype(d::WrappedDistribution) = eltype(wrapped_dist_type(d)) + +function Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution) + return rand(rng, wrapped_dist(d)) +end +Distributions.minimum(d::WrappedDistribution) = minimum(wrapped_dist(d)) +Distributions.maximum(d::WrappedDistribution) = maximum(wrapped_dist(d)) + +Bijectors.bijector(d::WrappedDistribution) = bijector(wrapped_dist(d)) + """ A named distribution that carries the name of the random variable with it. """ struct NamedDist{variate,support,Td<:Distribution{variate,support},Tv<:VarName} <: - Distribution{variate,support} + WrappedDistribution{variate,support,Td} dist::Td name::Tv end NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}()) -Base.length(dist::NamedDist) = Base.length(dist.dist) -Base.size(dist::NamedDist) = Base.size(dist.dist) - Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x) function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.logpdf(dist.dist, x) @@ -27,10 +48,13 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.loglikelihood(dist.dist, x) end -Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist) +""" +Wrapper around distribution `Td` that suppresses `logpdf()` calculation. +Note that *SampleFromPrior* would still sample from `Td`. +""" struct NoDist{variate,support,Td<:Distribution{variate,support}} <: - Distribution{variate,support} + WrappedDistribution{variate,support,Td} dist::Td end NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) @@ -38,9 +62,6 @@ NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name) nodist(dist::Distribution) = NoDist(dist) nodist(dists::AbstractArray) = nodist.(dists) -Base.length(dist::NoDist) = Base.length(dist.dist) -Base.size(dist::NoDist) = Base.size(dist.dist) - Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist) Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0 Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 @@ -48,8 +69,6 @@ function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Rea return zeros(Int, size(x, 2)) end Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 -Distributions.minimum(d::NoDist) = minimum(d.dist) -Distributions.maximum(d::NoDist) = maximum(d.dist) Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0 function Bijectors.logpdf_with_trans( @@ -67,5 +86,3 @@ function Bijectors.logpdf_with_trans( ) return 0 end - -Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist) diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 8a795320d..c548a3d8a 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -68,5 +68,42 @@ end end end + + @testset "multivariate NoDist" begin + @model function genmodel() + x ~ NoDist(Product(fill(Uniform(-20, 20), 5))) + for i in eachindex(x) + x[i] ~ Normal(0, 1) + end + end + gen_model = genmodel() + vi_gen = VarInfo(gen_model) + @test isfinite(logjoint(gen_model, vi_gen)) + # test for bijector + link!(vi_gen, DynamicPPL.SampleFromPrior()) + invlink!(vi_gen, DynamicPPL.SampleFromPrior()) + + # explicit model specification + expl_model = DynamicPPL.Model(NamedTuple()) do model, varinfo, context + DynamicPPL.tilde_assume!!( + context, + NoDist(Product(fill(Uniform(-20, 20), 5))), + @varname(x), + varinfo, + ) + x = varinfo[@varname(x)] + @test x isa Vector{<:Real} + @test length(x) == 5 + return ( + nothing, + DynamicPPL.acclogp!!(varinfo, sum(logpdf.(Ref(Normal(0, 1)), x))), + ) + end + vi_expl = VarInfo(expl_model) + @test isfinite(logjoint(expl_model, vi_expl)) + # test for bijector + link!(vi_expl, DynamicPPL.SampleFromPrior()) + invlink!(vi_expl, DynamicPPL.SampleFromPrior()) + end end end diff --git a/test/distribution_wrappers.jl b/test/distribution_wrappers.jl index 8bb692783..fb3a24247 100644 --- a/test/distribution_wrappers.jl +++ b/test/distribution_wrappers.jl @@ -1,13 +1,33 @@ @testset "distribution_wrappers.jl" begin - d = Normal() - nd = DynamicPPL.NoDist(d) + @testset "univariate" begin + d = Normal() + nd = DynamicPPL.NoDist(d) - # Smoke test - rand(nd) + # Smoke test + rand(nd) - # Actual tests - @test minimum(nd) == -Inf - @test maximum(nd) == Inf - @test logpdf(nd, 15.0) == 0 - @test Bijectors.logpdf_with_trans(nd, 30, true) == 0 + # Actual tests + @test minimum(nd) == -Inf + @test maximum(nd) == Inf + @test logpdf(nd, 15.0) == 0 + @test Bijectors.logpdf_with_trans(nd, 30, true) == 0 + @test Bijectors.bijector(nd) == Bijectors.bijector(d) + end + + @testset "multivariate" begin + d = Product([Normal(), Uniform()]) + nd = DynamicPPL.NoDist(d) + + # Smoke test + @test length(rand(nd)) == 2 + + # Actual tests + @test length(nd) == 2 + @test size(nd) == (2,) + @test minimum(nd) == [-Inf, 0.0] + @test maximum(nd) == [Inf, 1.0] + @test logpdf(nd, [15.0, 0.5]) == 0 + @test Bijectors.logpdf_with_trans(nd, [0, 1]) == 0 + @test Bijectors.bijector(nd) == Bijectors.bijector(d) + end end