Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Linearization/flattening of SimpleVarInfo #417

Closed
wants to merge 245 commits into from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Jul 23, 2022

This PR introduces a couple of things, though these are related:

  1. unflatten(original[, spl], x): converts from a certain input x, usually a Vector, into an instance similar to original.
    • Effectively the same as the current constructor VarInfo(varinfo_old, spl, x).
    • I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons:
      • Seems overkill.
      • unflatten-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" AbstractVarInfo + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics.
      • The current Turing.jl-codebase is easily adapted to this unflatten since it's really just a matter of replacing calls VarInfo(varinfo_old, spl, x) with unflatten(varinfo_old, spl, x). A ParameterHandling.jl approach will require more work.
  2. link!! and invlink!!, BangBang-versions of link! and invlink!, respectively, with some differences:
    • These take additional arguments which should always be sufficient to determine the transformation. These are:
      • model
      • sampler
      • t::AbstractTransformation. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling link!! and invlink!!, not when used inside of the tilde-pipeline.
    • Also adds the logabsdet-jacobian term to the logp, so that getlogp(vi) ≠ getlogp(link!!(vi)) holds. This allows us to compute, say, logjoint by first linking vi in a single pass, and then compute logjoint(settrans!(vi, NoTransformation()), θ_constrained). Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the logpdf_with_trans(..., true) within the tilde-callstack.
  3. make_default_varinfo(rng, model, sampler) which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of AbstractVarInfo to use.
    • Not entirely happy with this approach 😕

EDIT: This should be merged after #420

torfjelde and others added 30 commits January 7, 2022 20:54
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks, @torfjelde -- I reviewed this PR carefully. It is very readable now. I think you did a nice job by introducing a set of APIs for AbstractVarInfo. Please see below for some comments. Among them, the major ones are:

  • Consider renaming MaybeThreadSafeVarInfo to VarInfoOrThreadSafeVarInfo
  • Consider avoiding using Bijector.logpdf_with_trans so that we can finish Adopting DensityInterface #342 later.

src/sampler.jl Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Show resolved Hide resolved
src/transforming.jl Show resolved Hide resolved
src/transforming.jl Show resolved Hide resolved
test/varinfo.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

Regarding bors: but is there a way to reset it?

A bit late, but one can check the dashboard here if it is unclear if/what the problem is: https://app.bors.tech/repositories/24589 (one can navigate to it from the bors website or, IIRC, from the bors checks on Github)

One can cancel borg with bors r- or bors try- (depending on whether it was triggered with bors r+ or bors try). One can also run bors retry (and other things: https://bors.tech/documentation/). It seems the error message shows up if one runs bors try but there is already a bors try command running.

@torfjelde
Copy link
Member Author

Consider renaming MaybeThreadSafeVarInfo to VarInfoOrThreadSafeVarInfo

Sounds good!

Consider avoiding using Bijector.logpdf_with_trans so that we can finish #342 later.

IMO this should be a separate PR. Right now, the focus is just on making it so that we can actually start using SimpleVarInfo in Turing. If I start dropping usage of logpdf_with_trans in SimpleVarInfo, IMO we should do the same for VarInfo, in which case it might as well just be a separate PR.

And thank you for the information @devmotion ! Very helpful:)

src/varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
@yebai
Copy link
Member

yebai commented Nov 3, 2022

bors try

@bors
Copy link
Contributor

bors bot commented Nov 3, 2022

try

Already running a review

@yebai
Copy link
Member

yebai commented Nov 3, 2022

bors try

bors bot added a commit that referenced this pull request Nov 3, 2022
@yebai
Copy link
Member

yebai commented Nov 3, 2022

Deleting the try branch successfully reset bors.

@torfjelde
Copy link
Member Author

bors r+

bors bot pushed a commit that referenced this pull request Nov 4, 2022
This PR introduces a couple of things, though these are related:
1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`.
   - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`.
   - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons:
     - Seems overkill.
     - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics.
     - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work.
2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences:
   - These take additional arguments which should always be sufficient to determine the transformation. These are:
     - `model`
     - `sampler`
     - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline.
   - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack.
3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use.
   - Not entirely happy with this approach 😕
   

EDIT: This should be merged _after_ #420 

Co-authored-by: Hong Ge <[email protected]>
@bors bors bot changed the title Linearization/flattening of SimpleVarInfo [Merged by Bors] - Linearization/flattening of SimpleVarInfo Nov 4, 2022
@bors bors bot closed this Nov 4, 2022
@bors bors bot deleted the tor/simple-varinfo-linearization branch November 4, 2022 17:07
alyst pushed a commit to alyst/DynamicPPL.jl that referenced this pull request Mar 21, 2023
This PR introduces a couple of things, though these are related:
1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`.
   - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`.
   - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons:
     - Seems overkill.
     - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics.
     - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work.
2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences:
   - These take additional arguments which should always be sufficient to determine the transformation. These are:
     - `model`
     - `sampler`
     - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline.
   - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack.
3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use.
   - Not entirely happy with this approach 😕
   

EDIT: This should be merged _after_ TuringLang#420 

Co-authored-by: Hong Ge <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants