Skip to content

Commit

Permalink
Try #414:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Dec 19, 2022
2 parents 8c8cfc6 + 078731f commit 17bffbd
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 22 deletions.
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ pointwise_loglikelihoods
```

```@docs
WrappedDistribution
NamedDist
NoDist
```

## Testing Utilities
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ export AbstractVarInfo,
dot_tilde_assume,
dot_tilde_observe,
# Pseudo distributions
WrappedDistribution,
NamedDist,
NoDist,
# Prob macros
Expand Down
43 changes: 30 additions & 13 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,41 @@ using Distributions: Distributions
using Bijectors: Bijectors
using Distributions: Univariate, Multivariate, Matrixvariate

"""
Base type for distribution wrappers.
"""
abstract type WrappedDistribution{variate,support,Td<:Distribution{variate,support}} <:
Distribution{variate,support} end

wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where {Td} = Td
wrapped_dist_type(d::WrappedDistribution) = wrapped_dist_type(d)

wrapped_dist(d::WrappedDistribution) = d.dist

Base.length(d::WrappedDistribution{<:Multivariate}) = length(wrapped_dist(d))
Base.size(d::WrappedDistribution{<:Multivariate}) = size(wrapped_dist(d))
Base.eltype(::Type{T}) where {T<:WrappedDistribution} = eltype(wrapped_dist_type(T))
Base.eltype(d::WrappedDistribution) = eltype(wrapped_dist_type(d))

function Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution)
return rand(rng, wrapped_dist(d))
end
Distributions.minimum(d::WrappedDistribution) = minimum(wrapped_dist(d))
Distributions.maximum(d::WrappedDistribution) = maximum(wrapped_dist(d))

Bijectors.bijector(d::WrappedDistribution) = bijector(wrapped_dist(d))

"""
A named distribution that carries the name of the random variable with it.
"""
struct NamedDist{variate,support,Td<:Distribution{variate,support},Tv<:VarName} <:
Distribution{variate,support}
WrappedDistribution{variate,support,Td}
dist::Td
name::Tv
end

NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}())

Base.length(dist::NamedDist) = Base.length(dist.dist)
Base.size(dist::NamedDist) = Base.size(dist.dist)

Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.logpdf(dist.dist, x)
Expand All @@ -27,29 +48,27 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.loglikelihood(dist.dist, x)
end

Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist)
"""
Wrapper around distribution `Td` that suppresses `logpdf()` calculation.
Note that *SampleFromPrior* would still sample from `Td`.
"""
struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
Distribution{variate,support}
WrappedDistribution{variate,support,Td}
dist::Td
end
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)

nodist(dist::Distribution) = NoDist(dist)
nodist(dists::AbstractArray) = nodist.(dists)

Base.length(dist::NoDist) = Base.length(dist.dist)
Base.size(dist::NoDist) = Base.size(dist.dist)

Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
return zeros(Int, size(x, 2))
end
Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
Distributions.minimum(d::NoDist) = minimum(d.dist)
Distributions.maximum(d::NoDist) = maximum(d.dist)

Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
function Bijectors.logpdf_with_trans(
Expand All @@ -67,5 +86,3 @@ function Bijectors.logpdf_with_trans(
)
return 0
end

Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)
37 changes: 37 additions & 0 deletions test/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,42 @@
end
end
end

@testset "multivariate NoDist" begin
@model function genmodel()
x ~ NoDist(Product(fill(Uniform(-20, 20), 5)))
for i in eachindex(x)
x[i] ~ Normal(0, 1)
end
end
gen_model = genmodel()
vi_gen = VarInfo(gen_model)
@test isfinite(logjoint(gen_model, vi_gen))
# test for bijector
link!(vi_gen, DynamicPPL.SampleFromPrior())
invlink!(vi_gen, DynamicPPL.SampleFromPrior())

# explicit model specification
expl_model = DynamicPPL.Model(NamedTuple()) do model, varinfo, context
DynamicPPL.tilde_assume!!(
context,
NoDist(Product(fill(Uniform(-20, 20), 5))),
@varname(x),
varinfo,
)
x = varinfo[@varname(x)]
@test x isa Vector{<:Real}
@test length(x) == 5
return (
nothing,
DynamicPPL.acclogp!!(varinfo, sum(logpdf.(Ref(Normal(0, 1)), x))),
)
end
vi_expl = VarInfo(expl_model)
@test isfinite(logjoint(expl_model, vi_expl))
# test for bijector
link!(vi_expl, DynamicPPL.SampleFromPrior())
invlink!(vi_expl, DynamicPPL.SampleFromPrior())
end
end
end
38 changes: 29 additions & 9 deletions test/distribution_wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
@testset "distribution_wrappers.jl" begin
d = Normal()
nd = DynamicPPL.NoDist(d)
@testset "univariate" begin
d = Normal()
nd = DynamicPPL.NoDist(d)

# Smoke test
rand(nd)
# Smoke test
rand(nd)

# Actual tests
@test minimum(nd) == -Inf
@test maximum(nd) == Inf
@test logpdf(nd, 15.0) == 0
@test Bijectors.logpdf_with_trans(nd, 30, true) == 0
# Actual tests
@test minimum(nd) == -Inf
@test maximum(nd) == Inf
@test logpdf(nd, 15.0) == 0
@test Bijectors.logpdf_with_trans(nd, 30, true) == 0
@test Bijectors.bijector(nd) == Bijectors.bijector(d)
end

@testset "multivariate" begin
d = Product([Normal(), Uniform()])
nd = DynamicPPL.NoDist(d)

# Smoke test
@test length(rand(nd)) == 2

# Actual tests
@test length(nd) == 2
@test size(nd) == (2,)
@test minimum(nd) == [-Inf, 0.0]
@test maximum(nd) == [Inf, 1.0]
@test logpdf(nd, [15.0, 0.5]) == 0
@test Bijectors.logpdf_with_trans(nd, [0, 1]) == 0
@test Bijectors.bijector(nd) == Bijectors.bijector(d)
end
end

0 comments on commit 17bffbd

Please sign in to comment.