Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Linearization/flattening of SimpleVarInfo (TuringLang#417)
This PR introduces a couple of things, though these are related: 1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`. - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`. - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons: - Seems overkill. - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics. - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work. 2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences: - These take additional arguments which should always be sufficient to determine the transformation. These are: - `model` - `sampler` - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline. - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack. 3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use. - Not entirely happy with this approach 😕 EDIT: This should be merged _after_ TuringLang#420 Co-authored-by: Hong Ge <[email protected]>
- Loading branch information