Skip to content

Commit

Permalink
Use ChangesOfVariables and InverseFunctions (#212)
Browse files Browse the repository at this point in the history
* Add ChangesOfVariables and InverseFunctions to deps

* Replace forward by with_logabsdet_jacobian

* Replace Base.inv with InverseFunctions.inverse

* Improve deprecation scheme for forward

Co-authored-by: David Widmann <[email protected]>

* Improve deprecation scheme for inv

* Test forward and inv deprecations

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Fixes regarding with_logabsdet_jacobian and inverse

* Fix with_logabsdet_jacobian for NamedComposition

* Fix deprecation of inv

* Use inverse instead of inv for Composed

* Use with_logabsdet_jacobian instead of forward

* Workaround for intermittent failures in Dirichlet test

* Use with_logabsdet_jacobian instead of forward

* Use with_logabsdet_jacobian instead of forward

* Add rrules for combine with PartitionMask

Zygote-generated pullback for `combine(m::PartitionMask, x_1, x_2, x_3)`
fails with `no method matching zero(::Type{Nothing})`.

* Use inv instead of inverse for numbers

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Whitespace fix.

Co-authored-by: David Widmann <[email protected]>

* Move combine rrule and add test

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Use @test_deprecated

Co-authored-by: David Widmann <[email protected]>

* Use @test_deprecated

Co-authored-by: David Widmann <[email protected]>

* Use inverse instead of inv

* Use test_inverse and test_with_logabsdet_jacobian

* Use inverse instead of inv

* Increase version number to v0.9.12

* Reexport with_logabsdet_jacobian and inverse

* Increase package version to v0.10.0

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
oschulz and devmotion committed Dec 15, 2021
1 parent 31b1c38 commit b204712
Show file tree
Hide file tree
Showing 33 changed files with 444 additions and 378 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.9.11"
version = "0.10.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand All @@ -22,9 +24,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ArgCheck = "1, 2"
ChainRulesCore = "0.10.11, 1"
ChangesOfVariables = "0.1"
Compat = "3"
Distributions = "0.23.3, 0.24, 0.25"
Functors = "0.1, 0.2"
InverseFunctions = "0.1"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3.3"
MappedArrays = "0.2.2, 0.3, 0.4"
Expand Down
78 changes: 40 additions & 38 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ The following table lists mathematical operations for a bijector and the corresp

| Operation | Method | Automatic |
|:------------------------------------:|:-----------------:|:-----------:|
| `b ↦ b⁻¹` | `inv(b)` ||
| `b ↦ b⁻¹` | `inverse(b)` ||
| `(b₁, b₂) ↦ (b₁ ∘ b₂)` | `b₁ ∘ b₂` ||
| `(b₁, b₂) ↦ [b₁, b₂]` | `stack(b₁, b₂)` ||
| `x ↦ b(x)` | `b(x)` | × |
| `y ↦ b⁻¹(y)` | `inv(b)(y)` | × |
| `y ↦ b⁻¹(y)` | `inverse(b)(y)` | × |
| `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD |
| `x ↦ b(x), log|det J(b, x)|` | `forward(b, x)` ||
| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` ||
| `p ↦ q := b_* p` | `q = transformed(p, b)` ||
| `y ∼ q` | `y = rand(q)` ||
| `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` ||
Expand Down Expand Up @@ -123,7 +123,7 @@ true
What about `invlink`?

```julia
julia> b⁻¹ = inv(b)
julia> b⁻¹ = inverse(b)
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> b⁻¹(y)
Expand All @@ -133,7 +133,7 @@ julia> b⁻¹(y) ≈ invlink(dist, y)
true
```

Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inv(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inv(Exp()) isa Log` is true.
Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inverse(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inverse(Exp()) isa Log` is true.

#### Dimensionality
One more thing. See the `0` in `Inverse{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`:
Expand Down Expand Up @@ -162,7 +162,7 @@ true
And since `Composed isa Bijector`:

```julia
julia> id_x = inv(id_y)
julia> id_x = inverse(id_y)
Composed{Tuple{Inverse{Logit{Float64},0},Logit{Float64}},0}((Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0)))

julia> id_x(x) x
Expand Down Expand Up @@ -199,9 +199,9 @@ julia> logpdf_forward(td, x)
-1.123311289915276
```

#### `logabsdetjac` and `forward`
#### `logabsdetjac` and `with_logabsdet_jacobian`

In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inv(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method
In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inverse(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method

```julia
julia> logabsdetjac(b⁻¹, y)
Expand All @@ -218,21 +218,21 @@ julia> logabsdetjac(b, x) ≈ -logabsdetjac(b⁻¹, y)
true
```

which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `forward` comes to good use:
which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `with_logabsdet_jacobian` comes to good use:

```julia
julia> forward(b, x)
(rv = -0.5369949942509267, logabsdetjac = 1.4575353795716655)
julia> with_logabsdet_jacobian(b, x)
(-0.5369949942509267, 1.4575353795716655)
```

Similarily

```julia
julia> forward(inv(b), y)
(rv = 0.3688868996596376, logabsdetjac = -1.4575353795716655)
julia> with_logabsdet_jacobian(inverse(b), y)
(0.3688868996596376, -1.4575353795716655)
```

In fact, the purpose of `forward` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `forward(b, x)` will take advantage of such opportunities (if implemented).
In fact, the purpose of `with_logabsdet_jacobian` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `with_logabsdet_jacobian(b, x)` will take advantage of such opportunities (if implemented).

#### Sampling from `TransformedDistribution`
At this point we've only shown that we can replicate the existing functionality. But we said `TransformedDistribution isa Distribution`, so we also have `rand`:
Expand All @@ -241,7 +241,7 @@ At this point we've only shown that we can replicate the existing functionality.
julia> y = rand(td) # ∈ ℝ
0.999166054552483

julia> x = inv(td.transform)(y) # transform back to interval [0, 1]
julia> x = inverse(td.transform)(y) # transform back to interval [0, 1]
0.7308945834125756
```

Expand All @@ -261,7 +261,7 @@ Beta{Float64}(α=2.0, β=2.0)
julia> b = bijector(dist) # (0, 1) → ℝ
Logit{Float64}(0.0, 1.0)

julia> b⁻¹ = inv(b) # ℝ → (0, 1)
julia> b⁻¹ = inverse(b) # ℝ → (0, 1)
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1)
Expand All @@ -280,7 +280,7 @@ It's worth noting that `support(Beta)` is the _closed_ interval `[0, 1]`, while
```julia
td = transformed(Beta())

inv(td.transform)(rand(td))
inverse(td.transform)(rand(td))
```

will never result in `0` or `1` though any sample arbitrarily close to either `0` or `1` is possible. _Disclaimer: numerical accuracy is limited, so you might still see `0` and `1` if you're lucky._
Expand Down Expand Up @@ -335,7 +335,7 @@ julia> # Construct the transform
bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists
(Logit{Float64}(0.0, 1.0), Log{0}(), SimplexBijector{true}())

julia> ibs = inv.(bs) # invert, so we get unconstrained-to-constrained
julia> ibs = inverse.(bs) # invert, so we get unconstrained-to-constrained
(Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inverse{SimplexBijector{true},1}(SimplexBijector{true}()))

julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector
Expand Down Expand Up @@ -411,7 +411,7 @@ Similarily to the multivariate ADVI example, we could use `Stacked` to get a _bo
```julia
julia> d = MvNormal(zeros(2), ones(2));

julia> ibs = inv.(bijector.((InverseGamma(2, 3), Beta())));
julia> ibs = inverse.(bijector.((InverseGamma(2, 3), Beta())));

julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)]
Stacked{Tuple{Exp{0},Inverse{Logit{Float64},0}},2}((Exp{0}(), Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2))
Expand Down Expand Up @@ -481,7 +481,7 @@ julia> Flux.params(flow)
Params([[-1.05099; 0.502079] (tracked), [-0.216248; -0.706424] (tracked), [-4.33747] (tracked)])
```

Another useful function is the `forward(d::Distribution)` method. It is similar to `forward(b::Bijector)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.
Another useful function is the `forward(d::Distribution)` method. It is similar to `with_logabsdet_jacobian(b::Bijector, x)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.

```julia
julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns all the useful quantities in one pass
Expand Down Expand Up @@ -542,41 +542,43 @@ Logit{Float64}(0.0, 1.0)
julia> b(0.6)
0.4054651081081642

julia> inv(b)(y)
julia> inverse(b)(y)
Tracked 2-element Array{Float64,1}:
0.3078149833748082
0.72380041667891

julia> logabsdetjac(b, 0.6)
1.4271163556401458

julia> logabsdetjac(inv(b), y) # defaults to `- logabsdetjac(b, inv(b)(x))`
julia> logabsdetjac(inverse(b), y) # defaults to `- logabsdetjac(b, inverse(b)(x))`
Tracked 2-element Array{Float64,1}:
-1.546158373866469
-1.6098711387913573

julia> forward(b, 0.6) # defaults to `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetjac(b, x))`
(0.4054651081081642, 1.4271163556401458)
```

For further efficiency, one could manually implement `forward(b::Logit, x)`:
For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`:

```julia
julia> import Bijectors: forward, Logit
julia> using Bijectors: Logit

julia> function forward(b::Logit{<:Real}, x)
julia> import Bijectors: with_logabsdet_jacobian

julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x)
totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not
y = logit.(totally_worth_saving)
logjac = @. - log((b.b - x) * totally_worth_saving)
return (rv=y, logabsdetjac = logjac)
return (y, logjac)
end
forward (generic function with 16 methods)

julia> forward(b, 0.6)
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
julia> with_logabsdet_jacobian(b, 0.6)
(0.4054651081081642, 1.4271163556401458)

julia> @which forward(b, 0.6)
forward(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
julia> @which with_logabsdet_jacobian(b, 0.6)
with_logabsdet_jacobian(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
```
As you can see it's a very contrived example, but you get the idea.
Expand Down Expand Up @@ -613,10 +615,10 @@ julia> logabsdetjac(b_ad, 0.6)
julia> y = b_ad(0.6)
0.4054651081081642

julia> inv(b_ad)(y)
julia> inverse(b_ad)(y)
0.6

julia> logabsdetjac(inv(b_ad), y)
julia> logabsdetjac(inverse(b_ad), y)
-1.4271163556401458
```
Expand Down Expand Up @@ -665,7 +667,7 @@ help?> Bijectors.Composed

A Bijector representing composition of bijectors. composel and composer results in a Composed for which application occurs from left-to-right and right-to-left, respectively.

Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methdos, e.g. inv.
Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methods, e.g. inverse.

If you want to use an Array as the container instead you can do

Expand Down Expand Up @@ -713,9 +715,9 @@ The distribution interface consists of:
#### Methods
The following methods are implemented by all subtypes of `Bijector`, this also includes bijectors such as `Composed`.
- `(b::Bijector)(x)`: implements the transform of the `Bijector`
- `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `inverse(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))).
- `forward(b::Bijector, x)`: returns named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` in the most efficient manner.
- `with_logabsdet_jacobian(b::Bijector, x)`: returns the tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner.
- `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation.
- `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency.
- `dimension(b::Bijector)`: returns the dimensionality of `b`.
Expand All @@ -725,7 +727,7 @@ For `TransformedDistribution`, together with default implementations for `Distri
- `bijector(d::Distribution)`: returns the default constrained-to-unconstrained bijector for `d`
- `transformed(d::Distribution)`, `transformed(d::Distribution, b::Bijector)`: constructs a `TransformedDistribution` from `d` and `b`.
- `logpdf_forward(d::Distribution, x)`, `logpdf_forward(d::Distribution, x, logjac)`: computes the `logpdf(td, td.transform(x))` using the forward pass, which is potentially faster depending on the transform at hand.
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inv(b), b(x))` depending on which is most efficient.
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inverse(b), b(x))` depending on which is most efficient.
# Bibliography
1. Rezende, D. J., & Mohamed, S. (2015). Variational Inference With Normalizing Flows. [arXiv:1505.05770](https://arxiv.org/abs/1505.05770v6).
Expand Down
18 changes: 15 additions & 3 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ using MappedArrays
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

import ChangesOfVariables: with_logabsdet_jacobian
import InverseFunctions: inverse

import ChainRulesCore
import Functors
import IrrationalConstants
Expand All @@ -51,6 +54,8 @@ export TransformDistribution,
logpdf_with_trans,
isclosedform,
transform,
with_logabsdet_jacobian,
inverse,
forward,
logabsdetjac,
logabsdetjacinv,
Expand Down Expand Up @@ -121,7 +126,7 @@ end
# Distributions

link(d::Distribution, x) = bijector(d)(x)
invlink(d::Distribution, y) = inv(bijector(d))(y)
invlink(d::Distribution, y) = inverse(bijector(d))(y)
function logpdf_with_trans(d::Distribution, x, transform::Bool)
if ispd(d)
return pd_logpdf_with_trans(d, x, transform)
Expand Down Expand Up @@ -188,14 +193,14 @@ function invlink(
y::AbstractVecOrMat{<:Real},
::Val{proj}=Val(true),
) where {proj}
return inv(SimplexBijector{proj}())(y)
return inverse(SimplexBijector{proj}())(y)
end
function invlink_jacobian(
d::Dirichlet,
y::AbstractVector{<:Real},
::Val{proj}=Val(true),
) where {proj}
return jacobian(inv(SimplexBijector{proj}()), y)
return jacobian(inverse(SimplexBijector{proj}()), y)
end

## Matrix
Expand Down Expand Up @@ -249,6 +254,13 @@ include("utils.jl")
include("interface.jl")
include("chainrules.jl")

Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x))

@noinline function Base.inv(b::AbstractBijector)
Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv)
inverse(b)
end

# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
maporbroadcast(f, x::AbstractArray...) = f.(x...)
Expand Down
Loading

2 comments on commit b204712

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/50639

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.0 -m "<description of version>" b20471252c01dd4832e06aa80045046483f3804e
git push origin v0.10.0

Please sign in to comment.