Skip to content

Commit

Permalink
Linearization/flattening of SimpleVarInfo (TuringLang#417)
Browse files Browse the repository at this point in the history
This PR introduces a couple of things, though these are related:
1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`.
   - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`.
   - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons:
     - Seems overkill.
     - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics.
     - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work.
2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences:
   - These take additional arguments which should always be sufficient to determine the transformation. These are:
     - `model`
     - `sampler`
     - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline.
   - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack.
3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use.
   - Not entirely happy with this approach 😕
   

EDIT: This should be merged _after_ TuringLang#420 

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
2 people authored and Alexey Stukalov committed Mar 21, 2023
1 parent a3f12dc commit 07cc442
Show file tree
Hide file tree
Showing 19 changed files with 1,383 additions and 463 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.20.2"
version = "0.21.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ using DynamicPPL
using DynamicPPL: AbstractPPL

# Doctest setup
DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true)
DocMeta.setdocmeta!(
DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true
)

makedocs(;
sitename="DynamicPPL",
Expand Down
35 changes: 33 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,23 +156,56 @@ AbstractVarInfo

### Common API

#### Accumulation of log-probabilities

```@docs
getlogp
setlogp!!
acclogp!!
resetlogp!!
```

#### Variables and their realizations

```@docs
keys
getindex
DynamicPPL.getindex_raw
push!!
empty!!
isempty
```

```@docs
values_as
```

#### Transformations

```@docs
DynamicPPL.AbstractTransformation
DynamicPPL.NoTransformation
DynamicPPL.DynamicTransformation
DynamicPPL.StaticTransformation
```

```@docs
DynamicPPL.istrans
DynamicPPL.settrans!!
DynamicPPL.transformation
DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
DynamicPPL.maybe_invlink_before_eval!!
```

#### Utils

```@docs
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
```

#### `SimpleVarInfo`

```@docs
Expand All @@ -191,10 +224,8 @@ TypedVarInfo
One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.

```@docs
tonamedtuple
link!
invlink!
istrans
```

```@docs
Expand Down
15 changes: 12 additions & 3 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ export AbstractVarInfo,
setorder!,
istrans,
link!,
link!!,
invlink!,
invlink!!,
tonamedtuple,
values_as,
# VarName (reexport from AbstractPPL)
Expand Down Expand Up @@ -126,27 +128,33 @@ export loglikelihood
# Used here and overloaded in Turing
function getspace end

# Necessary forward declarations
"""
AbstractVarInfo
Abstract supertype for data structures that capture random variables when executing a
probabilistic model and accumulate log densities such as the log likelihood or the
log joint probability of the model.
See also: [`VarInfo`](@ref)
See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
"""
abstract type AbstractVarInfo <: AbstractModelTrace end

const LEGACY_WARNING = """
!!! warning
This method is considered legacy, and is likely to be deprecated in the future.
"""

# Necessary forward declarations
include("utils.jl")
include("selector.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
Expand All @@ -155,5 +163,6 @@ include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")

end # module
Loading

0 comments on commit 07cc442

Please sign in to comment.