Skip to content

Commit

Permalink
Linearization/flattening of SimpleVarInfo (#417)
Browse files Browse the repository at this point in the history
This PR introduces a couple of things, though these are related:
1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`.
   - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`.
   - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons:
     - Seems overkill.
     - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics.
     - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work.
2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences:
   - These take additional arguments which should always be sufficient to determine the transformation. These are:
     - `model`
     - `sampler`
     - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline.
   - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack.
3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use.
   - Not entirely happy with this approach 😕
   

EDIT: This should be merged _after_ #420 

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai committed Nov 4, 2022
1 parent 0457785 commit 0947bd7
Show file tree
Hide file tree
Showing 19 changed files with 1,383 additions and 463 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.20.2"
version = "0.21.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ using DynamicPPL
using DynamicPPL: AbstractPPL

# Doctest setup
DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true)
DocMeta.setdocmeta!(
DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true
)

makedocs(;
sitename="DynamicPPL",
Expand Down
35 changes: 33 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,23 +154,56 @@ AbstractVarInfo

### Common API

#### Accumulation of log-probabilities

```@docs
getlogp
setlogp!!
acclogp!!
resetlogp!!
```

#### Variables and their realizations

```@docs
keys
getindex
DynamicPPL.getindex_raw
push!!
empty!!
isempty
```

```@docs
values_as
```

#### Transformations

```@docs
DynamicPPL.AbstractTransformation
DynamicPPL.NoTransformation
DynamicPPL.DynamicTransformation
DynamicPPL.StaticTransformation
```

```@docs
DynamicPPL.istrans
DynamicPPL.settrans!!
DynamicPPL.transformation
DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
DynamicPPL.maybe_invlink_before_eval!!
```

#### Utils

```@docs
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
```

#### `SimpleVarInfo`

```@docs
Expand All @@ -189,10 +222,8 @@ TypedVarInfo
One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.

```@docs
tonamedtuple
link!
invlink!
istrans
```

```@docs
Expand Down
15 changes: 12 additions & 3 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ export AbstractVarInfo,
setorder!,
istrans,
link!,
link!!,
invlink!,
invlink!!,
tonamedtuple,
values_as,
# VarName (reexport from AbstractPPL)
Expand Down Expand Up @@ -125,27 +127,33 @@ export loglikelihood
# Used here and overloaded in Turing
function getspace end

# Necessary forward declarations
"""
AbstractVarInfo
Abstract supertype for data structures that capture random variables when executing a
probabilistic model and accumulate log densities such as the log likelihood or the
log joint probability of the model.
See also: [`VarInfo`](@ref)
See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
"""
abstract type AbstractVarInfo <: AbstractModelTrace end

const LEGACY_WARNING = """
!!! warning
This method is considered legacy, and is likely to be deprecated in the future.
"""

# Necessary forward declarations
include("utils.jl")
include("selector.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
Expand All @@ -154,5 +162,6 @@ include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")

end # module
Loading

2 comments on commit 0947bd7

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/71660

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.21.0 -m "<description of version>" 0947bd791785d0935fe11aae5ce23429084b13a8
git push origin v0.21.0

Please sign in to comment.