Skip to content

Commit

Permalink
Perform invlinking in assume rather than implicitly in getindex (#360)
Browse files Browse the repository at this point in the history
Currently, in `assume`, etc., `invlink` is called implicitly in `getindex` using the distribution extracted from `vi`.

This has a couple of drawbacks:
1. We can only use the distribution for a particular `vn` stored in `vi` obtained during the initial run. This means that we can't even run models where the distributions has dynamic domains, i.e. the domain of a particular random variable is dependent on the realizations of other random variables.
2. We have to store the distribution for each `vn` in `vi`. This was fine when we only had `VarInfo` because we also need it for other functionality, but this is not the case in `SimpleVarInfo` (nor will it be).

So. In this PR we introduce a `getindex_raw` which is `getindex` but without `invlink` if it's already linked, and uses this within `assume`, etc. where we now use the distributions that are passed to `assume` rather than those stored in `vi`.

E.g. the following now works:

``` julia
julia> @model demo() = x ~ InverseGamma(2, 3)
demo (generic function with 2 methods)

julia> vi = SimpleVarInfo((x = 10.0, ), true)
SimpleVarInfo((x = 10.0,), 0.0, true)

julia> _, vi = DynamicPPL.evaluate!!(model, vi, DefaultContext())
(22026.465794806718, SimpleVarInfo{NamedTuple{(:x,), Tuple{Float64}}, Float64}((x = 10.0,), -17.80291162245307, true))
```

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai committed Jul 22, 2022
1 parent 9ecf3dc commit 2ed28ea
Show file tree
Hide file tree
Showing 17 changed files with 1,060 additions and 329 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
# force it to use this PR's version of the package
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
Pkg.update()
Pkg.test() # resolver may fail with test time deps
Pkg.test(julia_args=["--depwarn=no"]) # resolver may fail with test time deps
catch err
err isa Pkg.Resolve.ResolverError || rethrow()
# If we can't resolve that means this is incompatible by SemVer and this is fine
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.19.3"
version = "0.19.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -22,7 +24,9 @@ AbstractPPL = "0.5.1"
BangBang = "0.3"
Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8"
MacroTools = "0.5.6"
Setfield = "0.7.1, 0.8"
ZygoteRules = "0.2"
Expand Down
22 changes: 21 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,14 @@ NamedDist
DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule.

```@docs
DynamicPPL.TestUtils.test_sampler_demo_models
DynamicPPL.TestUtils.test_sampler
DynamicPPL.TestUtils.test_sampler_on_demo_models
DynamicPPL.TestUtils.test_sampler_continuous
DynamicPPL.TestUtils.marginal_mean_of_samples
```

```@docs
DynamicPPL.TestUtils.DEMO_MODELS
```

For every demo model, one can define the true log prior, log likelihood, and log joint probabilities.
Expand All @@ -115,6 +121,20 @@ DynamicPPL.TestUtils.loglikelihood_true
DynamicPPL.TestUtils.logjoint_true
```

And in the case where the model includes constrained variables, it can also be useful to define

```@docs
DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian
DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian
```

Finally, the following methods can also be of use:

```@docs
DynamicPPL.TestUtils.varnames
DynamicPPL.TestUtils.posterior_mean
```

## Advanced

### Variable names
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ using MacroTools: MacroTools
using Setfield: Setfield
using ZygoteRules: ZygoteRules

using DocStringExtensions

using Random: Random

import Base:
Expand Down
94 changes: 56 additions & 38 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
Expand All @@ -64,15 +64,15 @@ function tilde_assume(
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
Expand All @@ -86,7 +86,7 @@ function tilde_assume(
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
Expand Down Expand Up @@ -194,7 +194,7 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
r = vi[vn]
r = vi[vn, dist]
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
end

Expand All @@ -211,16 +211,21 @@ function assume(
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
vi[vn] = vectorize(dist, r)
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r))
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
# Otherwise we just extract it.
r = vi[vn, dist]
end
else
r = init(rng, dist, sampler)
push!!(vi, vn, r, dist, sampler)
settrans!(vi, false, vn)
if istrans(vi)
push!!(vi, vn, link(dist, r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
end
end

return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
Expand Down Expand Up @@ -286,7 +291,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left,
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
Expand All @@ -305,19 +310,20 @@ function dot_tilde_assume(
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
end
end

function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
return dot_assume(NoDist.(right), left, vn, vi)
return dot_assume(nodist(right), left, vn, vi)
end
function dot_tilde_assume(
rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi
)
return dot_assume(rng, sampler, NoDist.(right), vn, left, vi)
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
end

# `PriorContext`
Expand All @@ -326,7 +332,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn,
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
Expand All @@ -345,7 +351,7 @@ function dot_tilde_assume(
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
Expand Down Expand Up @@ -383,14 +389,14 @@ function dot_assume(
vns::AbstractVector{<:VarName},
vi::AbstractVarInfo,
)
@assert length(dist) == size(var, 1)
@assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))"
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = vi[vns]
r = vi[vns, dist]
lp = sum(zip(vns, eachcol(r))) do (vn, ri)
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
end
Expand All @@ -412,19 +418,21 @@ function dot_assume(
end

function dot_assume(
dists::Union{Distribution,AbstractArray{<:Distribution}},
dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi
)
r = getindex.((vi,), vns, (dist,))
lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns)))
return r, lp, vi
end

function dot_assume(
dists::AbstractArray{<:Distribution},
var::AbstractArray,
vns::AbstractArray{<:VarName},
vi,
)
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = reshape(vi[vec(vns)], size(vns))
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
r = getindex.((vi,), vns, dists)
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
return r, lp, vi
end

Expand All @@ -438,7 +446,7 @@ function dot_assume(
)
r = get_and_set_val!(rng, vi, vns, dists, spl)
# Make sure `r` is not a matrix for multivariate distributions
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
return r, lp, vi
end
function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any)
Expand All @@ -462,19 +470,23 @@ function get_and_set_val!(
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
vi[vn] = vectorize(dist, r[:, i])
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i]))
setorder!(vi, vn, get_num_produce(vi))
end
else
r = vi[vns]
r = vi[vns, dist]
end
else
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
push!!(vi, vn, r[:, i], dist, spl)
settrans!(vi, false, vn)
if istrans(vi)
push!!(vi, vn, Bijectors.link(dist, r[:, i]), dist, spl)
# `push!!` sets the trans-flag to `false` by default.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r[:, i], dist, spl)
end
end
end
return r
Expand All @@ -496,12 +508,13 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
vi[vn] = vectorize(dist, r[i])
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i]))
setorder!(vi, vn, get_num_produce(vi))
end
else
r = reshape(vi[vec(vns)], size(vns))
# r = reshape(vi[vec(vns)], size(vns))
r_raw = getindex_raw(vi, vec(vns))
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns)))
end
else
f = (vn, dist) -> init(rng, dist, spl)
Expand All @@ -511,8 +524,13 @@ function get_and_set_val!(
# 1. Figure out the broadcast size and use a `foreach`.
# 2. Define an anonymous function which returns `nothing`, which
# we then broadcast. This will allocate a vector of `nothing` though.
push!!.(Ref(vi), vns, r, dists, Ref(spl))
settrans!.(Ref(vi), false, vns)
if istrans(vi)
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.((vi,), true, vns)
else
push!!.((vi,), vns, r, dists, (spl,))
end
end
return r
end
Expand Down
31 changes: 27 additions & 4 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ end

NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}())

Base.length(dist::NamedDist) = Base.length(dist.dist)
Base.size(dist::NamedDist) = Base.size(dist.dist)

Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.logpdf(dist.dist, x)
Expand All @@ -24,12 +27,20 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.loglikelihood(dist.dist, x)
end

Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist)

struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
Distribution{variate,support}
dist::Td
end
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)

nodist(dist::Distribution) = NoDist(dist)
nodist(dists::AbstractArray) = nodist.(dists)

Base.length(dist::NoDist) = Base.length(dist.dist)
Base.size(dist::NoDist) = Base.size(dist.dist)

Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
Expand All @@ -40,9 +51,21 @@ Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
Distributions.minimum(d::NoDist) = minimum(d.dist)
Distributions.maximum(d::NoDist) = maximum(d.dist)

Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real) = 0
Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
function Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
function Bijectors.logpdf_with_trans(
d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool
)
return 0
end
function Bijectors.logpdf_with_trans(
d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool
)
return zeros(Int, size(x, 2))
end
Bijectors.logpdf_with_trans(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
function Bijectors.logpdf_with_trans(
d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool
)
return 0
end

Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)
Loading

0 comments on commit 2ed28ea

Please sign in to comment.