Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor interface for projections/proximal operators #147

Merged
merged 23 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
238128e
refactor make scale projection operator its own optimization rule
Red-Portal Nov 17, 2024
03338d6
add docs for `ProjectScale`
Red-Portal Nov 17, 2024
233cffa
refactor change of type parameter order for `LocationScaleLowRank`
Red-Portal Nov 17, 2024
960d77d
apply formatter
Red-Portal Nov 17, 2024
db42115
apply formatter
Red-Portal Nov 17, 2024
6dd0fd6
apply formatter
Red-Portal Nov 17, 2024
074218a
update README
Red-Portal Nov 17, 2024
a11e5ce
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into pro…
Red-Portal Dec 9, 2024
a3ce1d1
fix formatting
Red-Portal Dec 9, 2024
ee36164
fix outdated type parameters in `LocationScale`
Red-Portal Dec 10, 2024
cd35e4e
rename averaging function
Red-Portal Dec 24, 2024
f40df75
fix projection/proximal operator interface
Red-Portal Dec 24, 2024
97f64e1
update documentation
Red-Portal Dec 24, 2024
9f1a549
fix formatting
Red-Portal Dec 24, 2024
ebe0637
fix benchmark
Red-Portal Dec 24, 2024
dcf21db
add missing test file
Red-Portal Dec 24, 2024
7868317
fix documentation
Red-Portal Dec 24, 2024
04db344
fix documentation
Red-Portal Dec 24, 2024
f731bdc
fix ambiguous specialization error for `operate`
Red-Portal Dec 24, 2024
86e1ab3
update documentation
Red-Portal Dec 27, 2024
1b3b734
refactor `average` and `operate` to specializations of `apply`
Red-Portal Dec 29, 2024
9887bb4
Merge branch 'projected_proximal_location_scale' of github.com:Turing…
Red-Portal Dec 29, 2024
635ea4e
Update docs/src/optimization.md
yebai Dec 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ q_avg, _, stats, _ = AdvancedVI.optimize(
q_transformed,
max_iter;
adtype=ADTypes.AutoForwardDiff(),
optimizer=Optimisers.Adam(1e-3),
optimizer=ProjectScale(Optimisers.Adam(1e-3)),
)

# Evaluate final ELBO with 10^3 Monte Carlo samples
Expand Down
2 changes: 1 addition & 1 deletion bench/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ begin
]
max_iter = 10^4
d = LogDensityProblems.dimension(prob)
optimizer = Optimisers.Adam(T(1e-3))
optimizer = ProjectScale(Optimisers.Adam(T(1e-3)))

for (objname, obj) in [
("RepGradELBO", RepGradELBO(10)),
Expand Down
6 changes: 3 additions & 3 deletions docs/src/elbo/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ _, _, stats_cfe, _ = AdvancedVI.optimize(
max_iter;
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
optimizer = ProjectScale(Optimisers.Adam(3e-3)),
callback = callback,
);

Expand All @@ -230,7 +230,7 @@ _, _, stats_stl, _ = AdvancedVI.optimize(
max_iter;
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
optimizer = ProjectScale(Optimisers.Adam(3e-3)),
callback = callback,
);

Expand Down Expand Up @@ -317,7 +317,7 @@ _, _, stats_qmc, _ = AdvancedVI.optimize(
max_iter;
show_progress = false,
adtype = AutoForwardDiff(),
optimizer = Optimisers.Adam(3e-3),
optimizer = ProjectScale(Optimisers.Adam(3e-3)),
callback = callback,
);

Expand Down
5 changes: 4 additions & 1 deletion docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@ q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize(
n_max_iter;
show_progress=false,
adtype=AutoForwardDiff(),
optimizer=Optimisers.Adam(1e-3),
optimizer=ProjectScale(Optimisers.Adam(1e-3)),
);
nothing
```

`ProjectScale` is a wrapper around an optimization rule such that the variational approximation stays within a stable region of the variational family.
For more information see [this section](@ref projectscale).

`q_avg_trans` is the final output of the optimization procedure.
If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate.

Expand Down
10 changes: 10 additions & 0 deletions docs/src/families.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ FullRankGaussian
MeanFieldGaussian
```

### [Scale Projection Operator](@id projectscale)

For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
To ensure this, we provide the following wrapper around optimization rule:

```@docs
ProjectScale
```

[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*.
### Gaussian Variational Families

```julia
Expand Down
23 changes: 21 additions & 2 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ else
end

function AdvancedVI.update_variational_params!(
proj::ProjectScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
opt_st,
params,
Expand All @@ -24,9 +25,8 @@ function AdvancedVI.update_variational_params!(
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.dist.scale_eps
ϵ = proj.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.dist.scale)
@. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ)

Expand All @@ -35,6 +35,25 @@ function AdvancedVI.update_variational_params!(
return opt_st, params
end

function AdvancedVI.update_variational_params!(
proj::ProjectScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}},
opt_st,
params,
restructure,
grad,
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = proj.scale_eps

@. q.dist.scale_diag = max(q.dist.scale_diag, ϵ)

params, _ = Optimisers.destructure(q)

return opt_st, params
end

function AdvancedVI.reparam_with_entropy(
rng::Random.AbstractRNG,
q::Bijectors.TransformedDistribution,
Expand Down
13 changes: 7 additions & 6 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,17 @@ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restruct

# Update for gradient descent step
"""
update_variational_params!(family_type, opt_st, params, restructure, grad)
update_variational_params!(rule, family_type, opt_st, params, restructure, grad)

Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.
Update variational distribution according to the update rule in the optimizer state `opt_st`, the optimizer given by `rule`, and the variational family type `family_type`.

This is a wrapper around `Optimisers.update!` to provide some indirection.
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
Instead, the return values should be used.

# Arguments
- `rule`: Optimization rule.
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
- `params`: Current set of parameters to be updated.
Expand All @@ -83,9 +84,9 @@ Instead, the return values should be used.
- `opt_st`: Updated optimizer state.
- `params`: Updated parameters.
"""
function update_variational_params! end

function update_variational_params!(::Type, opt_st, params, restructure, grad)
function update_variational_params!(
::Optimisers.AbstractRule, family_type, opt_st, params, restructure, grad
)
return Optimisers.update!(opt_st, params, grad)
end

Expand Down Expand Up @@ -186,7 +187,7 @@ include("objectives/elbo/repgradelbo.jl")
include("objectives/elbo/scoregradelbo.jl")

# Variational Families
export MvLocationScale, MeanFieldGaussian, FullRankGaussian
export MvLocationScale, MeanFieldGaussian, FullRankGaussian, ProjectScale

include("families/location_scale.jl")

Expand Down
84 changes: 39 additions & 45 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@

struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
ContinuousMultivariateDistribution
location::L
scale::S
dist::D
scale_eps::E
end

"""
MvLocationScale(location, scale, dist; scale_eps)
MvLocationScale(location, scale, dist)

The location scale variational family broadly represents various variational
families using `location` and `scale` variational parameters.
Expand All @@ -20,21 +12,11 @@ represented as follows:
u = rand(dist, d)
z = scale*u + location
```

`scale_eps` sets a constraint on the smallest value of `scale` to be enforced during optimization.
This is necessary to guarantee stable convergence.

# Keyword Arguments
- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `1e-4`).
"""
function MvLocationScale(
location::AbstractVector{T},
scale::AbstractMatrix{T},
dist::ContinuousUnivariateDistribution;
scale_eps::T=T(1e-4),
) where {T<:Real}
@assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior."
return MvLocationScale(location, scale, dist, scale_eps)
struct MvLocationScale{S,D<:ContinuousDistribution,L} <: ContinuousMultivariateDistribution
location::L
scale::S
dist::D
end

Functors.@functor MvLocationScale (location, scale)
Expand All @@ -44,18 +26,18 @@ Functors.@functor MvLocationScale (location, scale)
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
# is very inefficient.
# begin
struct RestructureMeanField{S<:Diagonal,D,L,E}
model::MvLocationScale{S,D,L,E}
struct RestructureMeanField{S<:Diagonal,D,L}
model::MvLocationScale{S,D,L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
return MvLocationScale(location, scale, re.model.dist)
end

function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
@unpack location, scale, dist = q
flat = vcat(location, diag(scale))
return flat, RestructureMeanField(q)
Expand All @@ -66,7 +48,7 @@ Base.length(q::MvLocationScale) = length(q.location)

Base.size(q::MvLocationScale) = size(q.location)

Base.eltype(::Type{<:MvLocationScale{S,D,L,E}}) where {S,D,L,E} = eltype(D)
Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D)

function StatsBase.entropy(q::MvLocationScale)
@unpack location, scale, dist = q
Expand Down Expand Up @@ -131,49 +113,61 @@ function Distributions.cov(q::MvLocationScale)
end

"""
FullRankGaussian(μ, L; scale_eps)
FullRankGaussian(μ, L)

Construct a Gaussian variational approximation with a dense covariance matrix.

# Arguments
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.

# Keyword Arguments
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
"""
function FullRankGaussian(
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=T(1e-4)
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
end

"""
MeanFieldGaussian(μ, L; scale_eps)
MeanFieldGaussian(μ, L)

Construct a Gaussian variational approximation with a diagonal covariance matrix.

# Arguments
- `μ::AbstractVector{T}`: Mean of the Gaussian.
- `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.
"""
function MeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T<:Real}
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
end

# Keyword Arguments
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
"""
function MeanFieldGaussian(
μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=T(1e-4)
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
ProjectScale(rule, scale_eps)

Compose an optimization `rule` with a projection, where the projection ensures that a `LocationScale` or `LocationScaleLowRank` has a scale with eigenvalues larger than `scale_eps`.

# Arguments
- `rule::Optimisers.AbstractRule`: Optimization rule to compose with the projection.
- `scale_eps::Real`: Lower bound on the eigenvalues of the scale matrix of the projection.
"""
struct ProjectScale{Rule<:Optimisers.AbstractRule,F<:Real} <: Optimisers.AbstractRule
rule::Rule
scale_eps::F
end

function ProjectScale(rule, scale_eps::Real=1e-5)
return ProjectScale{typeof(rule),typeof(scale_eps)}(rule, scale_eps)
end

Optimisers.setup(proj::ProjectScale, x) = Optimisers.setup(proj.rule, x)

Optimisers.init(proj::ProjectScale, x) = Optimisers.init(proj.rule, x)

function update_variational_params!(
::Type{<:MvLocationScale}, opt_st, params, restructure, grad
proj::ProjectScale, ::Type{<:MvLocationScale}, opt_st, params, restructure, grad
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.scale_eps
ϵ = convert(eltype(params), proj.scale_eps)

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.scale)
Expand Down
Loading
Loading