Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Linearization/flattening of SimpleVarInfo (#417)
This PR introduces a couple of things, though these are related: 1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`. - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`. - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons: - Seems overkill. - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics. - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work. 2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences: - These take additional arguments which should always be sufficient to determine the transformation. These are: - `model` - `sampler` - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline. - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack. 3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use. - Not entirely happy with this approach 😕 EDIT: This should be merged _after_ #420 Co-authored-by: Hong Ge <[email protected]>
- Loading branch information
0947bd7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
0947bd7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/71660
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: