Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more NF examples #11

Draft
wants to merge 42 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4f591af
add support for hasconverged
zuhengxu Jul 11, 2023
4e31807
fix test error
zuhengxu Jul 11, 2023
99ac0c1
rm example/Manifest.toml
zuhengxu Jul 11, 2023
e4f5efa
minor bug fix for trainig loop
zuhengxu Jul 11, 2023
b345e78
test new stopping criterion
zuhengxu Jul 11, 2023
f10512f
test convergent condition/ rm unready examples
zuhengxu Jul 11, 2023
b93dbe8
Merge branch 'TuringLang:main' into hasconverge
zuhengxu Jul 11, 2023
1c1c88a
rm julia test from CI
zuhengxu Jul 11, 2023
5d2844f
Revert "rm julia test from CI"
zuhengxu Jul 11, 2023
3fdcb0e
make autodiff pkgs as extension + require for bwd compat
zuhengxu Jul 12, 2023
ddad59e
debugging Ext
zuhengxu Jul 12, 2023
ef60ee1
keep debugging
zuhengxu Jul 13, 2023
5a5deb0
Fix AD package extension loading issues
sunxd3 Jul 13, 2023
25d4211
Applying @devmotion's comment
sunxd3 Jul 13, 2023
df5eddd
patch last commit
sunxd3 Jul 13, 2023
d44143b
patch for julia 1.6
sunxd3 Jul 13, 2023
b03e922
loading dep pkgs from main pkg instead of functions for explicitness
zuhengxu Jul 13, 2023
e9acf70
fixing test err
zuhengxu Jul 13, 2023
c6bf68b
rm unready examples
zuhengxu Jul 13, 2023
fabf20a
update realnvp
zuhengxu Jul 24, 2023
7964f91
minor ed
zuhengxu Jul 24, 2023
01de9d4
removing unnecessary import
zuhengxu Jul 24, 2023
6993d33
refactor affinecoupling and example/
zuhengxu Jul 31, 2023
810881f
debug affinecoupling flow
zuhengxu Jul 31, 2023
0277c9e
adapt to the updated autoforwarddiff to resolve test err
zuhengxu Jul 31, 2023
4bc02bf
fix test err
zuhengxu Jul 31, 2023
144c668
add new implementation of affcoupling using Bijectors.Coupling
zuhengxu Jul 31, 2023
2678198
implement ham flow
zuhengxu Aug 1, 2023
ab7ac64
finish hamflow implementation
zuhengxu Aug 1, 2023
308ab61
minor update
zuhengxu Aug 1, 2023
ceafbde
rename hamflow.jl to hamiltonian_layer.jl
zuhengxu Aug 3, 2023
5589f36
upadting readme
zuhengxu Aug 3, 2023
de01e0e
rm hamflow.jl
zuhengxu Aug 3, 2023
93a8572
Merge branch 'main' of github.com:zuhengxu/NormalizingFlows.jl into m…
zuhengxu Aug 8, 2023
b8a8f5f
sync with main
zuhengxu Aug 9, 2023
c858b10
fix minor bugs in affine coupling layer
zuhengxu Aug 16, 2023
baef6a9
test affine coupling flow on banana
zuhengxu Aug 16, 2023
0323680
rename simple flow run files
zuhengxu Aug 16, 2023
144b06a
update loglikelihood to fit in optimize interface
zuhengxu Aug 16, 2023
f70a4b7
fix minor bugs in nsf_layer
zuhengxu Aug 17, 2023
30cfc32
rm unused data in nsf lfow
zuhengxu Aug 17, 2023
00387d3
rm @view to avoid zygote mutation error
zuhengxu Aug 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,36 @@

[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://turinglang.github.io/NormalizingFlows.jl/dev/)
[![Build Status](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)


A normalizing flow library for Julia.

The purpose of this package is to provide a simple and flexible interface for
variational inference (VI) and normalizing flows (NF) for Bayesian computation or generative modeling.
The key focus is to ensure modularity and extensibility, so that users can easily
construct (e.g., define customized flow layers) and combine various components
(e.g., choose different VI objectives or gradient estimates)
for variational approximation of general target distributions,
without being tied to specific probabilistic programming frameworks or applications.

See the [documentation](https://turinglang.org/NormalizingFlows.jl/dev/) for more.

## Installation
To install the package, run the following command in the Julia REPL:
```julia
# install the package
] # enter Pkg mode
(@v1.9) pkg> add [email protected]:TuringLang/NormalizingFlows.jl.git
```
Then simply run the following command to use the package:
```julia
using NormalizingFlows
```


## Quick recap of normalizing flows


## Current status and TODOs


401 changes: 401 additions & 0 deletions example/HamiltonianVI/data/bank_dat.csv

Large diffs are not rendered by default.

112 changes: 112 additions & 0 deletions example/HamiltonianVI/hamiltonian_layer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
using Functors
using Flux
using Bijectors
using Bijectors: partition, combine, PartitionMask
using SimpleUnPack: @unpack

struct LeapFrog{T<:Real,I<:Int} <: Bijectors.Bijector
"dimention of the target space"
dim::I
"leapfrog step size"
ϵ::AbstractVector{T}
"number of leapfrog steps"
L::I
"score of the target distribution"
∇logp
"score of the momentum distribution"
∇logm
end
@functor LeapFrog (ϵ,)

function LeapFrog(dim::Int, ϵ::T, L::Int, ∇logp, ∇logm) where {T<:Real}
return LeapFrog(dim, ϵ .* ones(T, dim), L, ∇logp, ∇logm)
end

function Bijectors.inverse(lf::LeapFrog)
@unpack d, ϵ, L, ∇logp, ∇logm = lf
return LeapFrog(d, -ϵ, L, ∇logp, ∇logm)
end

function Bijectors.transform(lf::LeapFrog, z::AbstractVector)
@unpack dim, ϵ, L, ∇logp, ∇logm = lf
@assert length(z) == 2dim "dimension of input must be even, z = [x, ρ]"
# mask = PartitionMask(n, 1:dim)
# x, ρ, emp = partition(mask, z)
x, ρ = z[1:dim], z[(dim + 1):end]

ρ += ϵ ./ 2 .* ∇logp(x)
for i in 1:(L - 1)
x -= ϵ .* ∇logm(ρ)
ρ += ϵ .* ∇logp(x)
end
x -= ϵ .* ∇logm(ρ)
ρ += ϵ ./ 2 .* ∇logp(x)
# return combine(mask, x, ρ, emp)
return vcat(x, ρ)
end

function Bijectors.with_logabsdet_jacobian(lf::LeapFrog, z::AbstractVector)
return Bijectors.transform(lf, z), zero(eltype(z))
end

abstract type TrainableScore end
struct CoresetScore{T<:AbstractVector} <: TrainableScore
"coreset weights"
w::T
"weighted coreset score function of the target, ∇logpw(x, w)"
∇logpw
end
@functor CoresetScore (w,)
function CoresetScore(T, coresize::Int, datasize::Int, ∇logpw)
return CoresetScore(ones(T, coresize) .* N ./ coresize, ∇logpw)
end
(C::CoresetScore)(x::AbstractVector) = C.∇logpw(x, C.w)

struct SurrogateLeapFrog{T<:Real,I<:Int,H<:Union{TrainableScore,Flux.Chain}} <:
Bijectors.Bijector
"dimention of the target space"
dim::I
"leapfrog step size"
ϵ::AbstractVector{T}
"number of leapfrog steps"
L::I
"trainable surrogate of the score of the target distribution, e.g., coreset score or some neural net"
∇S::H
"score of the momentum distribution"
∇logm
end

@functor SurrogateLeapFrog (ϵ, ∇S)

function SurrogateLeapFrog(dim::Int, ϵ::T, L::Int, ∇S, ∇logm) where {T<:Real}
return SurrogateLeapFrog(dim, ϵ .* ones(T, dims), L, ∇S, ∇logm)
end

function Bijectors.inverse(slf::SurrogateLeapFrog)
@unpack dim, ϵ, L, ∇S, ∇logm = slf
return SurrogateLeapFrog(dim, -ϵ, L, ∇S, ∇logm)
end

function Bijectors.transform(slf::SurrogateLeapFrog, z::AbstractVector)
@unpack dim, ϵ, L, ∇S, ∇logm = slf
n = length(z)
@assert n == 2dim "dimension of input must be even, z = [x, ρ]"
# mask = PartitionMask(n, 1:dim)
x, ρ = z[1:dim], z[(dim + 1):end]
# x, ρ, emp = partition(mask, z)

ρ += ϵ ./ 2 .* ∇S(x)
for i in 1:(L - 1)
x -= ϵ .* ∇logm(ρ)
ρ += ϵ .* ∇S(x)
end
x -= ϵ .* ∇logm(ρ)
ρ += ϵ ./ 2 .* ∇S(x)
# return combine(mask, x, ρ, emp)
return vcat(x, ρ)
end

# leapfrog composes shear transformations, hence has unit jacobian
function Bijectors.with_logabsdet_jacobian(slf::SurrogateLeapFrog, z::AbstractVector)
return Bijectors.transform(slf, z), zero(eltype(z))
end
161 changes: 161 additions & 0 deletions example/HamiltonianVI/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
using CSV, DataFrames
using Random, Distributions, LinearAlgebra, Bijectors
using ADTypes
using Optimisers
using Tullio: @tullio
using FunctionChains
using NormalizingFlows
using Zygote
using Zygote: @adjoint, Buffer
using Bijectors: Shift, Scale
include("../common.jl")
include("hamiltonian_layer.jl")

##############################33
# model for posterior inference (logistic regression)
#################################

#########################################################################################
# Example of Bayesian Logistic Regression (the same setting as Gershman et al. 2012):
# The observed data D = {X, y} consist of N binary class labels,
# y_t \in {-1,+1}, and d covariates for each datapoint, X_t \in R^d.
# The hidden variables \theta = {w, \alpha} consist of d regression coefficients w_k \in R,
# and a precision parameter \alpha \in R_+. We assume the following model:
# p(α) = Gamma(α ; a, b) , τ = log α ∈ R
# p(w_k | τ) = N(w_k; 0, exp(-τ))
# p(y_t = 1| x_t, w) = 1 / (1+exp(-w^T x_t)), y ∈ {1, 0}
#########################################################################################
df = DataFrame(CSV.File("example/HamiltonianVI/data/bank_dat.csv"))
xs = Matrix(df)[:, 2:end]
X_raw = xs[:, 1:(end - 1)]
const X = (X_raw .- mean(X_raw; dims=1)) ./ std(X_raw; dims=1)
const Y = xs[:, end]
const a, b = 1.0, 0.01
(N, p) = size(X)
idx = sample(1:N, 20; replace=false)
const Xc, Yc = X[idx, :], Y[idx]

function log_sigmoid(x)
if x < -300
return x
else
return -log1p(exp(-x))
end
end

function neg_sigmoid(x)
return -1.0 / (1.0 + exp(-x))
end

# z = (τ, w1, ..., wd)
function logp(θ, X, Y, w)
τ = θ[1]
W = @view(θ[2:end])
Z = X * W
logpτ = a * τ - b * exp(τ)
logpW = 0.5 * p * τ - 0.5 * exp(τ) * sum(abs2, W)
@tullio llh := w[n] * ((Y[n] - 1.0) * Z[n] + log_sigmoid(Z[n]))
# llh = sum((Y .- 1.) .* Z .- log1p.(exp.(-Z)))
return logpτ + logpW + llh
end

function logp_subsample(θ; batch_size=10)
(N, p) = size(X)
idx = sample(1:N, batch_size; replace=false)
w = N / batch_size .* ones(batch_size)
return logp(θ, X[idx, :], Y[idx], w)
end

function ∇logp(z, X, Y, w)
τ = z[1]
W = @view(z[2:end])
grad = similar(z)
grad[1] = a - b * exp(τ) + 0.5 * p - 0.5 * exp(τ) * sum(abs2, W)
S = neg_sigmoid.(X * W)
@tullio M[j] := w[n] * X[n, j] * (S[n] + Y[n])
grad[2:end] .= -exp(τ) .* W .+ M
return grad
end

function ∇logp_subsample(z; batch_size=10)
(N, p) = size(X)
idx = sample(1:N, batch_size; replace=false)
w = N / batch_size .* ones(batch_size)
return ∇logp(z, X[idx, :], Y[idx], w)
end

function ∇logp_coreset(z, w)
@assert length(w) == size(Xc, 1)
τ = z[1]
W = z[2:end]
dim = length(z)
grad = Buffer(z)
grad[1] = a - b * exp(τ) + 0.5 * p - 0.5 * exp(τ) * sum(abs2, W)
S = neg_sigmoid.(Xc * W)
@tullio M[j] := w[n] * Xc[n, j] * (S[n] + Yc[n])
grad[2:dim] = -exp(τ) .* W .+ M
return copy(grad)
end

# customize gradient for logp_subsample (in each iteration logp evaluation only uses a small batch of the full dataset)
Zygote.refresh()
@adjoint function logp_subsample(z; batch_size=10)
return logp_subsample(z; batch_size=batch_size),
Δ -> (Δ * ∇logp_subsample(z; batch_size=batch_size),)
end

function logp_joint(z; batch_size=10)
dim = div(length(z), 2)
x, ρ = z[1:dim], z[(dim + 1):end]
return logp_subsample(x; batch_size=10) + logpdf(MvNormal(zeros(dims), I), ρ)
end

function ∇logp_joint(z; batch_size=10)
dim = div(length(z), 2)
gx = ∇logp_subsample(z[1:dim]; batch_size=batch_size)
gρ = -z[(dim + 1):end]
return vcat(gx, gρ)
end

@adjoint function logp_joint(z; batch_size=10)
return logp_joint(z; batch_size=batch_size),
Δ -> (Δ * ∇logp_joint(z; batch_size=batch_size),)
end

#################################################33
# train sparse Hamiltonian flow (https://arxiv.org/pdf/2203.05723.pdf)
# note:
# - SHF operates on a joint space (target space × momentum space)
# - instead of using the full score in the Hamiltonain flow, we use coreset score
# - instead of using the momentum refreshment as used in the paper (only perform normalization on the mometnum),
# we just stack a general shift and scaling layer after the leapfrog step
###################################################3

∇S = CoresetScore(Float64, 20, 400, ∇logp_coreset)
dims = 9
L = 20
∇logm(x) = -x
Ls = [
Scale(ones(2dims)) ∘ Shift(ones(2dims)) ∘ SurrogateLeapFrog(dims, 0.02, L, ∇S, ∇logm)
for i in 1:5
]
q0 = MvNormal(zeros(Float64, 2dims), I)
flow = Bijectors.transformed(q0, ∘(Ls...))
# flow = Bijectors.transformed(q0, trans)
flow_untrained = deepcopy(flow)

sample_per_iter = 5
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,)
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3
flow_trained, stats, _ = train_flow(
elbo,
flow,
logp_joint,
sample_per_iter;
max_iters=200_00,
optimiser=Optimisers.Adam(1e-3),
callback=cb,
ADbackend=AutoZygote(),
hasconverged=checkconv,
)
losses = map(x -> x.loss, stats)
6 changes: 6 additions & 0 deletions example/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
Expand All @@ -13,5 +16,8 @@ NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Loading