-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New benchmark #339
New benchmark #339
Changes from 7 commits
5512168
9d7e479
a23838c
cecc2ca
a4bf77e
99cb471
4ef1dd7
7fcf166
9f26d83
ee551c2
d90fed0
3091c08
546be06
5060c84
0925a2b
7d7acea
2a55909
f7e1435
6e334bc
77893eb
8fa2d42
056b125
7514294
3d94b8e
0cc7148
d32dbc3
3645923
2a51798
30a0b6d
377d242
a0f6396
07529fa
1d8bb8e
4db49ce
649bd55
02764d5
93528a4
96fb14f
fc9b060
dbf0bb2
565a9fb
92db97e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "Manopt" | ||
uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" | ||
authors = ["Ronny Bergmann <[email protected]>"] | ||
version = "0.4.45" | ||
version = "0.4.46" | ||
|
||
[deps] | ||
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
using Revise | ||
using Optim, Manopt | ||
using Manifolds | ||
using LineSearches | ||
|
||
using Profile | ||
using ProfileView | ||
using BenchmarkTools | ||
using Plots | ||
|
||
""" | ||
StopWhenGradientInfNormLess <: StoppingCriterion | ||
|
||
A stopping criterion based on the current gradient infinity norm in a basis arbitrarily | ||
chosen for each manifold. | ||
|
||
# Constructor | ||
|
||
StopWhenGradientInfNormLess(ε::Float64) | ||
|
||
Create a stopping criterion with threshold `ε` for the gradient, that is, this criterion | ||
indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`. | ||
""" | ||
mutable struct StopWhenGradientInfNormLess <: StoppingCriterion | ||
threshold::Float64 | ||
reason::String | ||
at_iteration::Int | ||
StopWhenGradientInfNormLess(ε::Float64) = new(ε, "", 0) | ||
end | ||
function (c::StopWhenGradientInfNormLess)( | ||
mp::AbstractManoptProblem, s::AbstractManoptSolverState, i::Int | ||
) | ||
M = get_manifold(mp) | ||
if i == 0 # reset on init | ||
c.reason = "" | ||
c.at_iteration = 0 | ||
end | ||
if (norm(get_gradient(s), Inf) < c.threshold) && (i > 0) | ||
c.reason = "The algorithm reached approximately critical point after $i iterations; the gradient norm ($(norm(M,get_iterate(s),get_gradient(s)))) is less than $(c.threshold).\n" | ||
c.at_iteration = i | ||
return true | ||
end | ||
return false | ||
end | ||
function Manopt.status_summary(c::StopWhenGradientInfNormLess) | ||
has_stopped = length(c.reason) > 0 | ||
s = has_stopped ? "reached" : "not reached" | ||
return "|grad f|ₒₒ < $(c.threshold): $s" | ||
end | ||
Manopt.indicates_convergence(c::StopWhenGradientInfNormLess) = true | ||
function Base.show(io::IO, c::StopWhenGradientInfNormLess) | ||
return print( | ||
io, "StopWhenGradientInfNormLess($(c.threshold))\n $(status_summary(c))" | ||
) | ||
end | ||
|
||
function f_rosenbrock(x) | ||
result = 0.0 | ||
for i in 1:2:length(x) | ||
result += (1.0 - x[i])^2 + 100.0 * (x[i + 1] - x[i]^2)^2 | ||
end | ||
return result | ||
end | ||
function f_rosenbrock_manopt(::AbstractManifold, x) | ||
return f_rosenbrock(x) | ||
end | ||
|
||
optimize(f_rosenbrock, [0.0, 0.0], Optim.NelderMead()) | ||
|
||
function g_rosenbrock!(storage, x) | ||
for i in 1:2:length(x) | ||
storage[i] = -2.0 * (1.0 - x[i]) - 400.0 * (x[i + 1] - x[i]^2) * x[i] | ||
storage[i + 1] = 200.0 * (x[i + 1] - x[i]^2) | ||
end | ||
return storage | ||
end | ||
|
||
optimize(f_rosenbrock, g_rosenbrock!, [0.0, 0.0], LBFGS()) | ||
|
||
function g_rosenbrock_manopt!(M::AbstractManifold, storage, x) | ||
g_rosenbrock!(storage, x) | ||
if isnan(x[1]) | ||
error("nan") | ||
end | ||
riemannian_gradient!(M, storage, x, storage) | ||
return storage | ||
end | ||
|
||
M = Euclidean(2) | ||
Manopt.NelderMead(M, f_rosenbrock_manopt) | ||
|
||
qn_opts = quasi_Newton( | ||
M, | ||
f_rosenbrock_manopt, | ||
g_rosenbrock_manopt!, | ||
[0.0, 0.0]; | ||
evaluation=InplaceEvaluation(), | ||
return_state=true, | ||
) | ||
|
||
function test_f(f_manopt, g_manopt!, x0, N::Int) | ||
M = Euclidean(N) | ||
return quasi_Newton( | ||
M, f_manopt, g_manopt!, x0; evaluation=InplaceEvaluation(), return_state=true | ||
) | ||
end | ||
|
||
function prof() | ||
N = 32 | ||
x0 = zeros(N) | ||
test_f(f_rosenbrock_manopt, g_rosenbrock_manopt!, x0, N) | ||
|
||
Profile.clear() | ||
@profile for i in 1:100000 | ||
test_f(f_rosenbrock_manopt, g_rosenbrock_manopt!, x0, N) | ||
end | ||
return ProfileView.view() | ||
end | ||
|
||
function manifold_maker(name::Symbol, N, lib::Symbol) | ||
if lib === :Manopt | ||
if name === :Euclidean | ||
return Euclidean(N) | ||
elseif name === :Sphere | ||
return Manifolds.Sphere(N - 1) | ||
end | ||
elseif lib === :Optim | ||
if name === :Euclidean | ||
return Optim.Flat() | ||
elseif name === :Sphere | ||
return Optim.Sphere() | ||
end | ||
else | ||
error("Unknown library: $lib") | ||
end | ||
end | ||
|
||
function generate_cmp(f, g!, f_manopt, g_manopt!; mem_len::Int=2) | ||
plt = plot() | ||
xlabel!("dimension") | ||
ylabel!("time [ms]") | ||
title!("Optimization times for $f") | ||
|
||
times_manopt = Float64[] | ||
times_optim = Float64[] | ||
|
||
N_vals = [2^n for n in 1:3:16] | ||
ls_hz = LineSearches.HagerZhang() | ||
|
||
gtol = 1e-6 | ||
for manifold_name in [:Euclidean, :Sphere] | ||
println("Benchmarking $f for gtol=$gtol on $manifold_name") | ||
for N in N_vals | ||
println("Benchmarking for N=$N") | ||
M = manifold_maker(manifold_name, N, :Manopt) | ||
method_optim = LBFGS(; | ||
m=mem_len, | ||
linesearch=ls_hz, | ||
manifold=manifold_maker(manifold_name, N, :Optim), | ||
) | ||
|
||
x0 = zeros(N) | ||
x0[1] = 1 | ||
manopt_sc = StopWhenGradientInfNormLess(gtol) | StopAfterIteration(1000) | ||
bench_manopt = @benchmark quasi_Newton( | ||
$M, | ||
$f_manopt, | ||
$g_manopt!, | ||
$x0; | ||
stepsize=$(Manopt.LineSearchesStepsize(ls_hz)), | ||
evaluation=$(InplaceEvaluation()), | ||
memory_size=$mem_len, | ||
stopping_criterion=$(manopt_sc), | ||
) | ||
|
||
manopt_state = quasi_Newton( | ||
M, | ||
f_manopt, | ||
g_manopt!, | ||
x0; | ||
stepsize=Manopt.LineSearchesStepsize(ls_hz), | ||
evaluation=InplaceEvaluation(), | ||
return_state=true, | ||
memory_size=mem_len, | ||
stopping_criterion=manopt_sc, | ||
) | ||
manopt_iters = get_count(manopt_state, :Iterations) | ||
push!(times_manopt, median(bench_manopt.times) / 1000) | ||
println("Manopt.jl time: $(median(bench_manopt.times) / 1000) ms") | ||
println("Manopt.jl iterations: $(manopt_iters)") | ||
|
||
options_optim = Optim.Options(; g_tol=gtol) | ||
bench_optim = @benchmark optimize($f, $g!, $x0, $method_optim, $options_optim) | ||
|
||
optim_state = optimize( | ||
f_rosenbrock, g_rosenbrock!, x0, method_optim, options_optim | ||
) | ||
println("Optim.jl time: $(median(bench_optim.times) / 1000) ms") | ||
push!(times_optim, median(bench_optim.times) / 1000) | ||
println("Optim.jl iterations: $(optim_state.iterations)") | ||
end | ||
plot!( | ||
N_vals, times_manopt; label="Manopt.jl ($manifold_name)", xaxis=:log, yaxis=:log | ||
) | ||
plot!( | ||
N_vals, times_optim; label="Optim.jl ($manifold_name)", xaxis=:log, yaxis=:log | ||
) | ||
end | ||
xticks!(N_vals, string.(N_vals)) | ||
|
||
return plt | ||
end | ||
|
||
#generate_cmp(f_rosenbrock, g_rosenbrock!, f_rosenbrock_manopt, g_rosenbrock_manopt!) | ||
|
||
function test_case_manopt() | ||
N = 4 | ||
mem_len = 2 | ||
M = Manifolds.Euclidean(N) | ||
ls_hz = LineSearches.HagerZhang() | ||
|
||
x0 = zeros(N) | ||
x0[1] = 1 | ||
manopt_sc = StopWhenGradientInfNormLess(1e-6) | StopAfterIteration(1000) | ||
|
||
return quasi_Newton( | ||
M, | ||
f_rosenbrock_manopt, | ||
g_rosenbrock_manopt!, | ||
x0; | ||
stepsize=Manopt.LineSearchesStepsize(ls_hz), | ||
evaluation=InplaceEvaluation(), | ||
return_state=true, | ||
memory_size=mem_len, | ||
stopping_criterion=manopt_sc, | ||
) | ||
end | ||
|
||
function test_case_optim() | ||
N = 4 | ||
mem_len = 2 | ||
ls_hz = LineSearches.HagerZhang() | ||
method_optim = LBFGS(; m=mem_len, linesearch=ls_hz, manifold=Optim.Flat()) | ||
options_optim = Optim.Options(; g_tol=1e-6) | ||
|
||
x0 = zeros(N) | ||
x0[1] = 1 | ||
optim_state = optimize(f_rosenbrock, g_rosenbrock!, x0, method_optim, options_optim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous all, this could also be one line just the |
||
return optim_state | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,15 +42,15 @@ function (cs::Manopt.LineSearchesStepsize)( | |
retract!(M, p_tmp, p, η, α, cs.retraction_method) | ||
get_gradient!(mp, X_tmp, p_tmp) | ||
vector_transport_to!(M, Y_tmp, p, η, p_tmp, cs.vector_transport_method) | ||
return real(inner(M, p_tmp, Y_tmp, Y_tmp)) | ||
return real(inner(M, p_tmp, X_tmp, Y_tmp)) | ||
end | ||
function ϕdϕ(α) | ||
# TODO: optimize? | ||
retract!(M, p_tmp, p, η, α, cs.retraction_method) | ||
get_gradient!(mp, X_tmp, p_tmp) | ||
vector_transport_to!(M, Y_tmp, p, η, p_tmp, cs.vector_transport_method) | ||
phi = f(M, p_tmp) | ||
dphi = real(inner(M, p_tmp, Y_tmp, Y_tmp)) | ||
dphi = real(inner(M, p_tmp, X_tmp, Y_tmp)) | ||
return (phi, dphi) | ||
end | ||
|
||
|
@@ -59,6 +59,7 @@ function (cs::Manopt.LineSearchesStepsize)( | |
return α | ||
catch ex | ||
if isa(ex, LineSearches.LineSearchException) | ||
println(ex) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I would prefer a re-throw of the error here; and maybe we should introduce such errors in our line searches as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, then the easiest solution is to just not catch the error. We could have that in our line searches too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
# maybe indicate failure? | ||
return zero(dphi_0) | ||
else | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the
Optim.jl
run? One could mention that in a commentThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes; this will be removed or commented when preparing final version.