-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Merged by Bors] - Linearization/flattening of SimpleVarInfo #417
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
…PPL.jl into tor/link-improvements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks, @torfjelde -- I reviewed this PR carefully. It is very readable now. I think you did a nice job by introducing a set of APIs for AbstractVarInfo
. Please see below for some comments. Among them, the major ones are:
- Consider renaming
MaybeThreadSafeVarInfo
toVarInfoOrThreadSafeVarInfo
- Consider avoiding using
Bijector.logpdf_with_trans
so that we can finish Adopting DensityInterface #342 later.
A bit late, but one can check the dashboard here if it is unclear if/what the problem is: https://app.bors.tech/repositories/24589 (one can navigate to it from the bors website or, IIRC, from the bors checks on Github) One can cancel borg with |
Sounds good!
IMO this should be a separate PR. Right now, the focus is just on making it so that we can actually start using And thank you for the information @devmotion ! Very helpful:) |
bors try |
tryAlready running a review |
bors try |
Deleting the |
bors r+ |
This PR introduces a couple of things, though these are related: 1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`. - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`. - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons: - Seems overkill. - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics. - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work. 2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences: - These take additional arguments which should always be sufficient to determine the transformation. These are: - `model` - `sampler` - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline. - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack. 3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use. - Not entirely happy with this approach 😕 EDIT: This should be merged _after_ #420 Co-authored-by: Hong Ge <[email protected]>
Pull request successfully merged into master. Build succeeded: |
This PR introduces a couple of things, though these are related: 1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`. - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`. - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons: - Seems overkill. - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics. - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work. 2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences: - These take additional arguments which should always be sufficient to determine the transformation. These are: - `model` - `sampler` - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline. - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack. 3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use. - Not entirely happy with this approach 😕 EDIT: This should be merged _after_ TuringLang#420 Co-authored-by: Hong Ge <[email protected]>
This PR introduces a couple of things, though these are related:
unflatten(original[, spl], x)
: converts from a certain inputx
, usually aVector
, into an instance similar tooriginal
.VarInfo(varinfo_old, spl, x)
.unflatten
-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template"AbstractVarInfo
+ construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics.unflatten
since it's really just a matter of replacing callsVarInfo(varinfo_old, spl, x)
withunflatten(varinfo_old, spl, x)
. A ParameterHandling.jl approach will require more work.link!!
andinvlink!!
, BangBang-versions oflink!
andinvlink!
, respectively, with some differences:model
sampler
t::AbstractTransformation
. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when callinglink!!
andinvlink!!
, not when used inside of the tilde-pipeline.logp
, so thatgetlogp(vi) ≠ getlogp(link!!(vi))
holds. This allows us to compute, say,logjoint
by first linkingvi
in a single pass, and then computelogjoint(settrans!(vi, NoTransformation()), θ_constrained)
. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than thelogpdf_with_trans(..., true)
within the tilde-callstack.make_default_varinfo(rng, model, sampler)
which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation ofAbstractVarInfo
to use.EDIT: This should be merged after #420