Skip to content

Commit ff46b85

Browse files
authored
refactor!: update package (#27)
* Transition to ImplicitDiff v0.8 * Add projection correctness test * Clean up * Fixes * Allow specifying x0 * Different sizes * Fix * Show for debugging * Show * Make tests pass
1 parent 9fc82fc commit ff46b85

File tree

11 files changed

+255
-111
lines changed

11 files changed

+255
-111
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
/docs/build/
66
/docs/Manifest.toml
77
/docs/src/tutorial.md
8+
playground.jl

Project.toml

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiableFrankWolfe"
22
uuid = "b383313e-5450-4164-a800-befbd27b574d"
33
authors = ["Guillaume Dalle"]
4-
version = "0.4.1"
4+
version = "0.5.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,8 +11,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111

1212
[compat]
1313
ChainRulesCore = "1.15"
14-
FrankWolfe = "0.3, 0.4, 0.5"
15-
ImplicitDifferentiation = "0.7"
14+
FrankWolfe = "0.5"
15+
ImplicitDifferentiation = "0.8"
1616
LinearAlgebra = "1"
1717
julia = "1.10"
1818

@@ -25,10 +25,31 @@ FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
2525
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
2626
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
2727
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
28+
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
2829
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
30+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2931
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3032
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
33+
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
34+
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
3135
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3236

3337
[targets]
34-
test = ["Aqua", "ChainRulesCore", "Documenter", "ForwardDiff", "FrankWolfe", "ImplicitDifferentiation", "JET", "JuliaFormatter", "Random", "Statistics", "Test", "Zygote"]
38+
test = [
39+
"Aqua",
40+
"ChainRulesCore",
41+
"Documenter",
42+
"ForwardDiff",
43+
"FrankWolfe",
44+
"ImplicitDifferentiation",
45+
"JET",
46+
"JuliaFormatter",
47+
"ProximalOperators",
48+
"Random",
49+
"StableRNGs",
50+
"Statistics",
51+
"Test",
52+
"TestItems",
53+
"TestItemRunner",
54+
"Zygote",
55+
]

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
55
FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
66
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
7+
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
78
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
89

910
[compat]

docs/src/index.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
```@meta
2-
CurrentModule = DifferentiableFrankWolfe
3-
```
4-
51
# DifferentiableFrankWolfe
62

73
Documentation for [DifferentiableFrankWolfe.jl](https://github.com/gdalle/DifferentiableFrankWolfe.jl).

examples/tutorial.jl

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,49 @@
44

55
using DifferentiableFrankWolfe: DiffFW, simplex_projection
66
using ForwardDiff: ForwardDiff
7-
using FrankWolfe: UnitSimplexOracle
7+
using FrankWolfe: ProbabilitySimplexOracle
8+
using ProximalOperators: ProximalOperators
89
using Test: @test
910
using Zygote: Zygote
1011

1112
# Constructing the wrapper
1213

1314
f(x, θ) = 0.5 * sum(abs2, x - θ) # minimizing the squared distance...
1415
f_grad1(x, θ) = x - θ
15-
lmo = UnitSimplexOracle(1.0) # ... to the probability simplex
16-
dfw = DiffFW(f, f_grad1, lmo); # ... is equivalent to a simplex projection
16+
lmo = ProbabilitySimplexOracle(1.0) # ... to the probability simplex
17+
dfw = DiffFW(f, f_grad1, lmo); # ... is equivalent to a simplex projection if we're not already in it
1718

1819
# Calling the wrapper
1920

20-
θ = rand(10)
21+
x0 = ones(3) ./ 3
22+
θ = [1.0, 1.5, 0.2]
2123

2224
#-
2325

2426
frank_wolfe_kwargs = (; max_iteration=100, epsilon=1e-4)
25-
y, stats = dfw(θ, frank_wolfe_kwargs)
26-
y
27+
y = dfw(θ, x0; frank_wolfe_kwargs...)
28+
29+
#- Comparing with the ground truth
30+
31+
true_simplex_projection(x) = ProximalOperators.prox(ProximalOperators.IndSimplex(1.0), x)[1]
2732

2833
#-
2934

30-
y_true = simplex_projection(θ)
35+
y_true = true_simplex_projection(θ)
3136
@test Vector(y) Vector(y_true) atol = 1e-3
3237

3338
# Differentiating the wrapper
3439

35-
J1 = Zygote.jacobian(_θ -> dfw(_θ, frank_wolfe_kwargs)[1], θ)[1]
36-
J1_true = Zygote.jacobian(simplex_projection, θ)[1]
37-
@test J1 J1_true atol = 1e-3
40+
#-
41+
42+
J_true = ForwardDiff.jacobian(true_simplex_projection, θ)
43+
44+
#-
45+
46+
J1 = Zygote.jacobian(_θ -> dfw(_θ, x0; frank_wolfe_kwargs...), θ)[1]
47+
@test J1 J_true atol = 1e-3
3848

3949
#-
4050

41-
J2 = ForwardDiff.jacobian(_θ -> dfw(_θ, frank_wolfe_kwargs)[1], θ)
42-
J2_true = ForwardDiff.jacobian(simplex_projection, θ)
43-
@test J2 J2_true atol = 1e-3
51+
J2 = ForwardDiff.jacobian(_θ -> dfw(_θ, x0; frank_wolfe_kwargs...), θ)
52+
@test J2 J_true atol = 1e-3

src/DifferentiableFrankWolfe.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ module DifferentiableFrankWolfe
77

88
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
99
using FrankWolfe: FrankWolfe, LinearMinimizationOracle
10-
using FrankWolfe: away_frank_wolfe, compute_extreme_point
10+
using FrankWolfe:
11+
away_frank_wolfe,
12+
blended_conditional_gradient,
13+
blended_pairwise_conditional_gradient,
14+
compute_extreme_point,
15+
pairwise_frank_wolfe
1116
using ImplicitDifferentiation: ImplicitFunction
1217
using LinearAlgebra: dot
1318

src/difffw.jl

Lines changed: 70 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,35 @@
11
"""
22
ForwardFW
33
4-
Underlying solver for [`DiffFW`](@ref), which relies on a variant of Frank-Wolfe.
4+
Underlying solver for [`DiffFW`](@ref), which relies on a variant of Frank-Wolfe with active set memorization.
55
"""
66
struct ForwardFW{F,G,M,A}
77
f::F
88
f_grad1::G
99
lmo::M
1010
alg::A
11+
12+
function ForwardFW(f, f_grad1, lmo, alg)
13+
@assert alg in (
14+
away_frank_wolfe,
15+
blended_conditional_gradient,
16+
blended_pairwise_conditional_gradient,
17+
pairwise_frank_wolfe,
18+
)
19+
return new{typeof(f),typeof(f_grad1),typeof(lmo),typeof(alg)}(f, f_grad1, lmo, alg)
20+
end
21+
end
22+
23+
function (forward::ForwardFW)(θ::AbstractArray, x0::AbstractArray, frank_wolfe_kwargs)
24+
f, f_grad1, lmo, alg = forward.f, forward.f_grad1, forward.lmo, forward.alg
25+
obj(x) = f(x, θ)
26+
grad!(g, x) = copyto!(g, f_grad1(x, θ))
27+
x_final, v_final, primal_value, dual_gap, traj_data, active_set = alg(
28+
obj, grad!, lmo, x0; frank_wolfe_kwargs...
29+
)
30+
stats = (; x_final, v_final, primal_value, dual_gap, traj_data, active_set)
31+
p = active_set.weights
32+
return p, stats
1133
end
1234

1335
"""
@@ -19,20 +41,42 @@ struct ConditionsFW{G}
1941
f_grad1::G
2042
end
2143

44+
function (conditions::ConditionsFW)(
45+
θ::AbstractArray,
46+
p::AbstractVector,
47+
stats::NamedTuple,
48+
_x0::AbstractArray,
49+
_frank_wolfe_kwargs,
50+
)
51+
V = stats.active_set.atoms
52+
f_grad1 = conditions.f_grad1
53+
V_mat = stack(V)
54+
x = V_mat * p
55+
∇ₓf = f_grad1(x, θ)
56+
∇ₚg = transpose(V_mat) * ∇ₓf
57+
T = simplex_projection(p .- ∇ₚg)
58+
return T .- p
59+
end
60+
2261
"""
2362
DiffFW
2463
25-
Callable parametrized wrapper for the Frank-Wolfe algorithm to solve `θ -> argmin_{x ∈ C} f(x, θ)`, which can be differentiated implicitly wrt `θ`.
64+
Callable parametrized wrapper for the Frank-Wolfe algorithm to solve `θ -> argmin_{x ∈ C} f(x, θ)` from a given starting point `x0`.
65+
The solution routine can be differentiated implicitly with respect `θ`, but not with respect to `x0`.
66+
67+
# Constructor
68+
69+
DiffFW(f, f_grad1, lmo, alg=away_frank_wolfe; implicit_kwargs=(;))
2670
27-
Reference: <https://arxiv.org/abs/2105.15183> (section 2 + end of appendix A).
71+
- `f`: function `f(x, θ)` to minimize with respect to `x`
72+
- `f_grad1`: gradient `∇ₓf(x, θ)` of `f` with respect to `x`
73+
- `lmo`: linear minimization oracle `θ -> argmin_{x ∈ C} θᵀx` from [FrankWolfe.jl](https://github.com/ZIB-IOL/FrankWolfe.jl), implicitly defines the convex set `C`
74+
- `alg`: optimization algorithm from [FrankWolfe.jl](https://github.com/ZIB-IOL/FrankWolfe.jl), must return an `active_set`
75+
- `implicit_kwargs`: keyword arguments passed to the `ImplicitFunction` object from [ImplicitDifferentiation.jl](https://github.com/gdalle/ImplicitDifferentiation.jl)
2876
29-
# Fields
77+
# References
3078
31-
- `f`: function `f(x, θ)` to minimize wrt `x`
32-
- `f_grad1`: gradient `∇ₓf(x, θ)` of `f` wrt `x`
33-
- `lmo`: linear minimization oracle `θ -> argmin_{x ∈ C} θᵀx` from [FrankWolfe.jl], implicitly defines the convex set `C`
34-
- `alg`: optimization algorithm from [FrankWolfe.jl](https://github.com/ZIB-IOL/FrankWolfe.jl)
35-
- `implicit`: implicit function from [ImplicitDifferentiation.jl](https://github.com/gdalle/ImplicitDifferentiation.jl)
79+
> [Efficient and Modular Implicit Differentiation](https://proceedings.neurips.cc/paper_files/paper/2022/hash/228b9279ecf9bbafe582406850c57115-Abstract-Conference.html), Blondel et al. (2022)
3680
"""
3781
struct DiffFW{F,G,M<:LinearMinimizationOracle,A,I<:ImplicitFunction}
3882
f::F
@@ -42,11 +86,6 @@ struct DiffFW{F,G,M<:LinearMinimizationOracle,A,I<:ImplicitFunction}
4286
implicit::I
4387
end
4488

45-
"""
46-
DiffFW(f, f_grad1, lmo, alg=away_frank_wolfe; implicit_kwargs=(;))
47-
48-
Constructor for [`DiffFW`](@ref) which chooses a default algorithm and creates the implicit function automatically.
49-
"""
5089
function DiffFW(
5190
f::F, f_grad1::G, lmo::L, alg::A=away_frank_wolfe; implicit_kwargs=NamedTuple()
5291
) where {F,G,L,A}
@@ -57,40 +96,30 @@ function DiffFW(
5796
end
5897

5998
"""
60-
(dfw::DiffFW)(θ::AbstractArray, frank_wolfe_kwargs::NamedTuple)
99+
detailed_output(dfw::DiffFW, θ::AbstractArray, x0::AbstractArray; kwargs...)
61100
62-
Apply the Frank-Wolfe algorithm to `θ` with settings defined by the named tuple `frank_wolfe_kwargs` (given as a positional argument).
101+
Apply the differentiable Frank-Wolfe algorithm defined by `dfw` to parameter `θ` with starting point `x0`.
102+
Keyword arguments are passed on to the Frank-Wolfe algorithm inside `dfw`.
63103
64104
Return a couple (x, stats) where `x` is the solution and `stats` is a named tuple containing additional information (its contents are not covered by public API, and mostly useful for debugging).
65105
"""
66-
function (dfw::DiffFW)(θ::AbstractArray, frank_wolfe_kwargs=NamedTuple())
67-
p, stats = dfw.implicit(θ, frank_wolfe_kwargs)
106+
function detailed_output(dfw::DiffFW, θ::AbstractArray, x0::AbstractArray; kwargs...)
107+
p, stats = dfw.implicit(θ, x0, kwargs)
68108
V = stats.active_set.atoms
69-
x = mapreduce(*,+,p,V)
109+
V_mat = stack(V)
110+
x = V_mat * p
70111
return x, stats
71112
end
72113

73-
function (forward::ForwardFW)(θ::AbstractArray, frank_wolfe_kwargs::NamedTuple)
74-
f, f_grad1, lmo, alg = forward.f, forward.f_grad1, forward.lmo, forward.alg
75-
obj(x) = f(x, θ)
76-
grad!(g, x) = copyto!(g, f_grad1(x, θ))
77-
x0 = compute_extreme_point(lmo, θ)
78-
x_final, v_final, primal_value, dual_gap, traj_data, active_set = alg(
79-
obj, grad!, lmo, x0; frank_wolfe_kwargs...
80-
)
81-
stats = (; x_final, v_final, primal_value, dual_gap, traj_data, active_set)
82-
p = active_set.weights
83-
return p, stats
84-
end
114+
"""
115+
(dfw::DiffFW)(θ::AbstractArray, x0::AbstractArray; kwargs...)
85116
86-
function (conditions::ConditionsFW)(
87-
θ::AbstractArray, p::AbstractVector, stats::NamedTuple, frank_wolfe_kwargs::NamedTuple
88-
)
89-
V = stats.active_set.atoms
90-
x = mapreduce(*,+,p,V)
91-
f_grad1 = conditions.f_grad1
92-
∇ₓf = f_grad1(x, θ)
93-
∇ₚg = dot.(V, Ref(∇ₓf))
94-
T = simplex_projection(p .- ∇ₚg)
95-
return T .- p
117+
Apply the differentiable Frank-Wolfe algorithm defined by `dfw` to parameter `θ` with starting point `x0`.
118+
Keyword arguments are passed on to the Frank-Wolfe algorithm inside `dfw`.
119+
120+
Return the optimal solution `x`.
121+
"""
122+
function (dfw::DiffFW)(θ::AbstractArray, x0::AbstractArray; kwargs...)
123+
x, _ = detailed_output(dfw, θ, x0; kwargs...)
124+
return x
96125
end

src/simplex_projection.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,26 @@ Compute the Euclidean projection of the vector `z` onto the probability simplex.
55
66
This function is differentiable thanks to a custom chain rule.
77
8-
Reference: <https://arxiv.org/abs/1602.02068>.
8+
# References
9+
10+
> [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification](https://proceedings.mlr.press/v48/martins16.html), Martins and Astudillo (2016)
911
"""
1012
function simplex_projection(z::AbstractVector{<:Real}; kwargs...)
1113
p, _ = simplex_projection_and_support(z)
1214
return p
1315
end
1416

15-
"""
16-
simplex_projection_and_support(z)
17-
18-
Compute the Euclidean projection `p` of `z` on the probability simplex as well as the indicators `s` of its support, which are useful for differentiation.
17+
relu(x) = max(x, zero(typeof(x)))
1918

20-
Reference: <https://arxiv.org/abs/1602.02068>.
21-
"""
22-
function simplex_projection_and_support(z::AbstractVector{<:Real})
19+
function simplex_projection_and_support(z::AbstractVector{T}) where {T<:Real}
2320
d = length(z)
2421
z_sorted = sort(z; rev=true)
2522
z_sorted_cumsum = cumsum(z_sorted)
26-
k = maximum(j for j in 1:d if (1 + j * z_sorted[j]) > z_sorted_cumsum[j])
23+
ind_filter = 1 .+ (1:d) .* z_sorted .> z_sorted_cumsum
24+
k = findlast(ind_filter)
2725
τ = (z_sorted_cumsum[k] - 1) / k
28-
p = z .- τ
29-
p .= max.(p, zero(eltype(p)))
30-
s = [Int(p[i] > eps()) for i in 1:d]
26+
p = relu.(z .- τ)
27+
s = p .> eps(T)
3128
return p, s
3229
end
3330

0 commit comments

Comments
 (0)