Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add indirection for update step, add projection for LocationScale #65

Merged
merged 4 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,37 @@ module AdvancedVIBijectorsExt
if isdefined(Base, :get_extension)
using AdvancedVI
using Bijectors
using LinearAlgebra
using Optimisers
using Random
else
using ..AdvancedVI
using ..Bijectors
using ..LinearAlgebra
using ..Optimisers
using ..Random
end

function AdvancedVI.update_variational_params!(
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
opt_st,
params,
restructure,
grad
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.dist.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.dist.scale)
@. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

opt_st, params
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
Expand Down
27 changes: 27 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,33 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di
"""
function value_and_gradient! end

# Update for gradient descent step
"""
update_variational_params!(family_type, opt_st, params, restructure, grad)

Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.

This is a wrapper around `Optimisers.update!` to provide some indirection.
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
Instead, the return values should be used.

# Arguments
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
- `params`: Current set of parameters to be updated.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
- `grad`: Gradient to be used by the update rule of `opt_st`.

# Returns
- `opt_st`: Updated optimizer state.
- `params`: Updated parameters.
"""
function update_variational_params! end

update_variational_params!(::Type, opt_st, params, restructure, grad) =
Optimisers.update!(opt_st, params, grad)

# estimators
"""
AbstractVariationalObjective
Expand Down
58 changes: 39 additions & 19 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@ represented as follows:
```
"""
struct MvLocationScale{
S, D <: ContinuousDistribution, L
S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
location::L
scale ::S
dist ::D
location ::L
scale ::S
dist ::D
scale_eps::E
end

function MvLocationScale(
location ::AbstractVector{T},
scale ::AbstractMatrix{T},
dist ::ContinuousDistribution;
scale_eps::T = sqrt(eps(T))
) where {T <: Real}
MvLocationScale(location, scale, dist, scale_eps)
end

Functors.@functor MvLocationScale (location, scale)
Expand Down Expand Up @@ -57,17 +67,17 @@ Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)
function StatsBase.entropy(q::MvLocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
n_dims*convert(eltype(location), entropy(dist)) + first(logdet(scale))
n_dims*convert(eltype(location), entropy(dist)) + logdet(scale)
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
end

function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end

function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved

function Distributions.rand(q::MvLocationScale)
Expand Down Expand Up @@ -128,14 +138,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix.
function FullRankGaussian(
μ::AbstractVector{T},
L::LinearAlgebra.AbstractTriangular{T};
check_args::Bool = true
scale_eps::T = sqrt(eps(T))
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
@assert minimum(diag(L)) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base)
MvLocationScale(μ, L, q_base, scale_eps)
end

"""
Expand All @@ -153,12 +160,25 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix
function MeanFieldGaussian(
μ::AbstractVector{T},
L::Diagonal{T};
check_args::Bool = true
scale_eps::T = sqrt(eps(T)),
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
@assert minimum(diag(L)) ≥ sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base)
MvLocationScale(μ, L, q_base, scale_eps)
end

function update_variational_params!(
::MvLocationScale, opt_st, params, restructure, grad
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
)
opt_st, params = Optimisers.update!(opt_st, params, grad)
q = restructure(params)
ϵ = q.scale_eps

# Project the scale matrix to the set of positive definite triangular matrices
diag_idx = diagind(q.scale)
@. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ)

params, _ = Optimisers.destructure(q)

opt_st, params
end
4 changes: 3 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ function optimize(
stat = merge(stat, stat′)

grad = DiffResults.gradient(grad_buf)
opt_st, params = Optimisers.update!(opt_st, params, grad)
opt_st, params = update_variational_params!(
typeof(q_init), opt_st, params, restructure, grad
)

if !isnothing(callback)
stat′ = callback(
Expand Down
Loading