Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "low-rank" variational families #76

Merged
merged 46 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
03563ea
rename location scale source file
Red-Portal Aug 3, 2024
5ab7286
revert renaming of location_scale file
Red-Portal Aug 3, 2024
3e0bf3d
add location-low-rank-scale family (except `entropy` and `logpdf`)
Red-Portal Aug 3, 2024
0bd6e5c
add feature complete `MvLocationScaleLowRank` with tests
Red-Portal Aug 5, 2024
34546e1
fix remove misleading comment
Red-Portal Aug 5, 2024
e030f2d
fix add missing test files
Red-Portal Aug 5, 2024
c7f36d6
fix broadcasting error on Julia 1.6
Red-Portal Aug 5, 2024
1bb3e3e
fix bug in sampling from `LocationScaleLowRank`
Red-Portal Aug 7, 2024
ddd2122
fix missing squared bug in `LocationScaleLowRank`
Red-Portal Aug 7, 2024
b24737f
add documentation for low-rank families
Red-Portal Aug 9, 2024
1d56953
add convenience constructors for `LocationScaleLowRank`
Red-Portal Aug 9, 2024
6752c6b
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Aug 10, 2024
52568b5
fix mhauru's suggestions and run formatter
Red-Portal Aug 10, 2024
96eae86
run formatter
Red-Portal Aug 10, 2024
15556da
run formatter
Red-Portal Aug 10, 2024
f796154
fix bugs and improve comments in `MvLocationScale` and lowrank
Red-Portal Aug 11, 2024
6b1699c
promote families.md into a higher category
Red-Portal Aug 11, 2024
5187d76
add test for `MVLocationScale` with non-Gaussian
Red-Portal Aug 14, 2024
8821908
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Aug 27, 2024
6dfc919
tighten compat bound for `Distributions`
Red-Portal Aug 27, 2024
c3ce393
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 4, 2024
5c04d50
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 5, 2024
ba293e5
fix base distribution standardization bug in `LocationScale`
Red-Portal Sep 5, 2024
426d943
fix base distribution standardization bug in `LocationScaleLowRank`
Red-Portal Sep 5, 2024
3cc9e80
format weird indentation in test `for` loops
Red-Portal Sep 5, 2024
0481dda
update docs add example for `LocationScaleLowRank`
Red-Portal Sep 5, 2024
8449402
fix docs warn about divergence when using `MvLocationScaleLowRank`
Red-Portal Sep 6, 2024
ff14c4c
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 9, 2024
e48f231
Merge branch 'master' into lowrank
yebai Sep 10, 2024
aa8feee
Merge branch 'master' into lowrank
yebai Sep 10, 2024
5149869
Merge branch 'master' into lowrank
yebai Sep 10, 2024
e196da6
Update Benchmark.yml
yebai Sep 10, 2024
e4bff67
disable more features for PRs from forks
yebai Sep 10, 2024
894a849
fix `LocationScale` interfaces to only allow univariate base dist
Red-Portal Sep 11, 2024
f1cabba
Merge branch 'lowrank' of github.com:Red-Portal/AdvancedVI.jl into lo…
Red-Portal Sep 11, 2024
ce6793c
fix test comparison operator for families
Red-Portal Sep 11, 2024
71aeb5a
fix test comparison operator for families
Red-Portal Sep 11, 2024
77ace2b
fix test comparison operator for families
Red-Portal Sep 11, 2024
641de39
fix test comparison operator for families
Red-Portal Sep 11, 2024
a58f209
fix test comparison operator for families
Red-Portal Sep 11, 2024
846b259
fix test comparison operator for families
Red-Portal Sep 11, 2024
1116f68
fix test comparison operator for families
Red-Portal Sep 11, 2024
42d730d
fix formatting
Red-Portal Sep 11, 2024
99d08c5
fix formatting
Red-Portal Sep 11, 2024
4a90c5d
fix scale lower bound to `1e-4`
Red-Portal Sep 12, 2024
c41709b
fix docstring for `LowRankGaussian`
Red-Portal Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/families.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ As we can see, `LowRankGaussian` converges faster than `FullRankGaussian`.
While `FullRankGaussian` can converge to the true solution since it is a more expressive variational family, `LowRankGaussian` gets there faster.

!!! info

`MvLocationScaleLowRank` tend to work better with the `Optimisers.Adam` optimizer due to non-smoothness.
Other optimisers may experience divergences.


### API

Expand Down
8 changes: 4 additions & 4 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ This is necessary to guarantee stable convergence.
function MvLocationScale(
location::AbstractVector{T},
scale::AbstractMatrix{T},
dist::ContinuousDistribution;
scale_eps::T=sqrt(eps(T)),
dist::ContinuousUnivariateDistribution;
scale_eps::T=eps(T)^(1//4),
mhauru marked this conversation as resolved.
Show resolved Hide resolved
) where {T<:Real}
@assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior."
return MvLocationScale(location, scale, dist, scale_eps)
Expand Down Expand Up @@ -143,7 +143,7 @@ Construct a Gaussian variational approximation with a dense covariance matrix.
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`).
"""
function FullRankGaussian(
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=sqrt(eps(T))
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=eps(T)^(1//4)
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
Expand All @@ -162,7 +162,7 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`).
"""
function MeanFieldGaussian(
μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=sqrt(eps(T))
μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=eps(T)^(1//4)
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScale(μ, L, q_base, scale_eps)
Expand Down
8 changes: 4 additions & 4 deletions src/families/location_scale_low_rank.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ function MvLocationScaleLowRank(
location::AbstractVector{T},
scale_diag::AbstractVector{T},
scale_factors::AbstractMatrix{T},
dist::ContinuousDistribution;
scale_eps::T=sqrt(eps(T)),
dist::ContinuousUnivariateDistribution;
scale_eps::T=eps(T)^(1//4),
) where {T<:Real}
@assert minimum(scale_diag) ≥ scale_eps "Initial scale is too small (smallest diagonal scale value is $(minimum(scale_diag)). This might result in unstable optimization behavior."
@assert size(scale_factors, 1) == length(scale_diag)
Expand Down Expand Up @@ -156,7 +156,7 @@ function update_variational_params!(
end

"""
LowRankGaussian(location, scale_diag, scale_factors; check_args = true)
LowRankGaussian(μ, D, U; scale_eps)

Construct a Gaussian variational approximation with a diagonal plus low-rank covariance matrix.

Expand All @@ -169,7 +169,7 @@ Construct a Gaussian variational approximation with a diagonal plus low-rank cov
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`).
"""
function LowRankGaussian(
μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=sqrt(eps(T))
μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=eps(T)^(1//4)
) where {T<:Real}
q_base = Normal{T}(zero(T), one(T))
return MvLocationScaleLowRank(μ, D, U, q_base; scale_eps)
Expand Down
6 changes: 4 additions & 2 deletions test/families/location_scale.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

@testset "interface LocationScale" begin
@testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian, :gaussian_nonstd],
@testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in
[:gaussian, :gaussian_nonstd],
covtype in [:meanfield, :fullrank],
realtype in [Float32, Float64]

Expand Down Expand Up @@ -143,7 +144,8 @@
end

@testset "scale positive definite projection" begin
@testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in [:meanfield, :fullrank],
@testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in
[:meanfield, :fullrank],
realtype in [Float32, Float64],
bijector in [nothing, :identity]

Expand Down
17 changes: 9 additions & 8 deletions test/families/location_scale_low_rank.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

@testset "interface LocationScaleLowRank" begin
@testset "$(basedist) rank=$(rank) $(realtype)" for basedist in [:gaussian, :gaussian_nonstd],
@testset "$(basedist) rank=$(rank) $(realtype)" for basedist in
[:gaussian, :gaussian_nonstd],
n_rank in [1, 2],
realtype in [Float32, Float64]

Expand Down Expand Up @@ -56,7 +57,7 @@
@testset "statistics" begin
@testset "mean" begin
@test eltype(mean(q)) == realtype
@test mean(q) == mean(q_true)
@test mean(q) mean(q_true)
end
@testset "var" begin
@test eltype(var(q)) == realtype
Expand All @@ -81,7 +82,7 @@
@test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2)

z_sample_ref = rand(StableRNG(1), q)
@test z_sample_ref == rand(StableRNG(1), q)
@test z_sample_ref rand(StableRNG(1), q)
end

@testset "rand batch" begin
Expand All @@ -96,7 +97,7 @@
@test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2)

samples_ref = rand(StableRNG(1), q, n_montecarlo)
@test samples_ref == rand(StableRNG(1), q, n_montecarlo)
@test samples_ref rand(StableRNG(1), q, n_montecarlo)
end

@testset "rand! AbstractVector" begin
Expand All @@ -107,7 +108,7 @@
end
z_samples = mapreduce(first, hcat, res)
z_samples_ret = mapreduce(last, hcat, res)
@test z_samples == z_samples_ret
@test z_samples z_samples_ret
@test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype(
1e-2
)
Expand All @@ -121,13 +122,13 @@

z_sample = Array{realtype}(undef, n_dims)
rand!(StableRNG(1), q, z_sample)
@test z_sample_ref == z_sample
@test z_sample_ref z_sample
end

@testset "rand! AbstractMatrix" begin
z_samples = Array{realtype}(undef, n_dims, n_montecarlo)
z_samples_ret = rand!(q, z_samples)
@test z_samples == z_samples_ret
@test z_samples z_samples_ret
@test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype(
1e-2
)
Expand All @@ -141,7 +142,7 @@

z_samples = Array{realtype}(undef, n_dims, n_montecarlo)
rand!(StableRNG(1), q, z_samples)
@test z_samples_ref == z_samples
@test z_samples_ref z_samples
end
end
end
Expand Down
3 changes: 2 additions & 1 deletion test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ if @isdefined(Enzyme)
end

@testset "inference RepGradELBO DistributionsAD" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32],
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in Dict(:Normal => normal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
Expand Down
3 changes: 2 additions & 1 deletion test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ if @isdefined(Enzyme)
end

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32],
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
n_montecarlo in [1, 10],
Expand Down
3 changes: 2 additions & 1 deletion test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ if @isdefined(Enzyme)
end

@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32],
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in
Dict(:NormalLogNormalMeanField => normallognormal_meanfield),
n_montecarlo in [1, 10],
Expand Down
Loading