Skip to content

Commit 146d904

Browse files
authored
Lenient linear solver (#10)
* Accept inconsistent systems by default * Reexports * Fix tests
1 parent 36adb51 commit 146d904

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2121
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
2222
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2323
FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
24+
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
2425
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
2526
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
2627
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -29,4 +30,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2930
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3031

3132
[targets]
32-
test = ["Aqua", "ChainRulesCore", "Documenter", "ForwardDiff", "FrankWolfe", "JET", "JuliaFormatter", "Random", "Statistics", "Test", "Zygote"]
33+
test = ["Aqua", "ChainRulesCore", "Documenter", "ForwardDiff", "FrankWolfe", "ImplicitDifferentiation", "JET", "JuliaFormatter", "Random", "Statistics", "Test", "Zygote"]

src/DifferentiableFrankWolfe.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ module DifferentiableFrankWolfe
88
using ChainRulesCore: ChainRulesCore, NoTangent
99
using FrankWolfe: FrankWolfe, LinearMinimizationOracle
1010
using FrankWolfe: away_frank_wolfe, compute_extreme_point
11-
using ImplicitDifferentiation: ImplicitFunction
11+
using ImplicitDifferentiation: ImplicitFunction, IterativeLinearSolver
1212
using LinearAlgebra: dot
1313

1414
export DiffFW
15+
export LinearMinimizationOracle, compute_extreme_point # from FrankWolfe
16+
export IterativeLinearSolver # from ImplicitDifferentiation
1517

1618
include("simplex_projection.jl")
1719
include("difffw.jl")

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Aqua
22
using DifferentiableFrankWolfe
33
using Documenter
4+
using ImplicitDifferentiation
45
using JET
56
using JuliaFormatter
67
using Test
@@ -28,4 +29,15 @@ using Zygote
2829
@testset "Tutorial" begin
2930
include(joinpath(@__DIR__, "..", "examples", "tutorial.jl"))
3031
end
32+
33+
@testset "Constructor" begin
34+
dfw1 = DiffFW(f, f_grad1, lmo)
35+
@test !dfw1.implicit.linear_solver.accept_inconsistent
36+
37+
implicit_kwargs = (;
38+
linear_solver=IterativeLinearSolver(; accept_inconsistent=true)
39+
)
40+
dfw2 = DiffFW(f, f_grad1, lmo; implicit_kwargs)
41+
@test dfw2.implicit.linear_solver.accept_inconsistent
42+
end
3143
end

0 commit comments

Comments
 (0)