Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Dec 29, 2023
1 parent 7fcf166 commit 9f26d83
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 33 deletions.
54 changes: 35 additions & 19 deletions benchmarks/benchmark_comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,23 +135,24 @@ function manifold_maker(name::Symbol, N, lib::Symbol)
end
end

function generate_cmp(f, g!, f_manopt, g_manopt!; mem_len::Int=2)
function generate_cmp(problem_for_N; mem_len::Int=2, manifold_names=[:Euclidean, :Sphere])
plt = plot()
xlabel!("dimension")
ylabel!("time [ms]")
title!("Optimization times for $f")

times_manopt = Float64[]
times_optim = Float64[]
xlabel!(plt, "dimension")
ylabel!(plt, "time [ms]")
title!(plt, "Optimization times")

N_vals = [2^n for n in 1:3:16]
ls_hz = LineSearches.HagerZhang()

gtol = 1e-6
for manifold_name in [:Euclidean, :Sphere]
println("Benchmarking $f for gtol=$gtol on $manifold_name")
for manifold_name in manifold_names
times_manopt = Float64[]
times_optim = Float64[]

println("Benchmarking for gtol=$gtol on $manifold_name")
for N in N_vals
println("Benchmarking for N=$N")
f, g!, f_manopt, g_manopt! = problem_for_N(N)
println("Benchmarking for N=$N, f=$(typeof(f))")
M = manifold_maker(manifold_name, N, :Manopt)
method_optim = LBFGS(;
m=mem_len,
Expand Down Expand Up @@ -188,35 +189,50 @@ function generate_cmp(f, g!, f_manopt, g_manopt!; mem_len::Int=2)
push!(times_manopt, median(bench_manopt.times) / 1000)
println("Manopt.jl time: $(median(bench_manopt.times) / 1000) ms")
println("Manopt.jl iterations: $(manopt_iters)")
println("Manopt.jl objective: $(f(manopt_state.p))")

options_optim = Optim.Options(; g_tol=gtol)
bench_optim = @benchmark optimize($f, $g!, $x0, $method_optim, $options_optim)

optim_state = optimize(
f_rosenbrock, g_rosenbrock!, x0, method_optim, options_optim
)
optim_state = optimize(f, g!, x0, method_optim, options_optim)
println("Optim.jl time: $(median(bench_optim.times) / 1000) ms")
push!(times_optim, median(bench_optim.times) / 1000)
println("Optim.jl iterations: $(optim_state.iterations)")
println("Optim.jl objective: $(optim_state.minimum)")
end
plot!(
N_vals, times_manopt; label="Manopt.jl ($manifold_name)", xaxis=:log, yaxis=:log
plt,
N_vals,
times_manopt;
label="Manopt.jl ($manifold_name)",
xaxis=:log,
yaxis=:log,
)
plot!(
N_vals, times_optim; label="Optim.jl ($manifold_name)", xaxis=:log, yaxis=:log
plt,
N_vals,
times_optim;
label="Optim.jl ($manifold_name)",
xaxis=:log,
yaxis=:log,
)
end
xticks!(N_vals, string.(N_vals))
xticks!(plt, N_vals, string.(N_vals))

return plt
end

#generate_cmp(f_rosenbrock, g_rosenbrock!, f_rosenbrock_manopt, g_rosenbrock_manopt!)
# generate_cmp(N -> (f_rosenbrock, g_rosenbrock!, f_rosenbrock_manopt, g_rosenbrock_manopt!), mem_len=4)
# function gsn_problem_for_cmp(N)
# (f, g!) = make_gsn_problem(N, div(N, 10))
# return (f, g!, f, g!)
# end
# generate_cmp(gsn_problem_for_cmp, manifold_names=[:Sphere], mem_len=4)

function test_case_manopt()
N = 4
N = 2^16
mem_len = 2
M = Manifolds.Euclidean(N)
M = Manifolds.Sphere(N - 1)
ls_hz = LineSearches.HagerZhang()

x0 = zeros(N)
Expand Down
14 changes: 2 additions & 12 deletions ext/ManoptLineSearchesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,8 @@ function (cs::Manopt.LineSearchesStepsize)(
return (phi, dphi)
end

try
α, fp = cs.linesearch(ϕ, dϕ, ϕdϕ, α0, fp, dphi_0)
return α
catch ex
if isa(ex, LineSearches.LineSearchException)
println(ex)
# maybe indicate failure?
return zero(dphi_0)
else
rethrow(ex)
end
end
α, fp = cs.linesearch(ϕ, dϕ, ϕdϕ, α0, fp, dphi_0)
return α
end

end
9 changes: 7 additions & 2 deletions src/solvers/quasi_Newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,13 @@ function step_solver!(mp::AbstractManoptProblem, qns::QuasiNewtonState, iter)
copyto!(M, qns.p_old, get_iterate(qns))
retract!(M, qns.p, qns.p, qns.η, α, qns.retraction_method)
qns.η .*= α
β = locking_condition_scale(
M, qns.direction_update, qns.p_old, qns.η, qns.p, qns.vector_transport_method
# qns.yk update fails if α is equal to 0 because then β is NaN
β = ifelse(
iszero(α),
one(α),
locking_condition_scale(
M, qns.direction_update, qns.p_old, qns.η, qns.p, qns.vector_transport_method
),
)
vector_transport_to!(
M,
Expand Down

0 comments on commit 9f26d83

Please sign in to comment.