Skip to content

Commit

Permalink
Fix LineSearchesExt bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Dec 28, 2023
1 parent 9d7e479 commit a23838c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions benchmarks/benchmark_comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end
function generate_cmp(f, g!, f_manopt, g_manopt!)
plt = plot()
xlabel!("dimension")
ylabel!("time per iteration [ms]")
ylabel!("time [ms]")
title!("Optimization times for $f")

times_manopt = Float64[]
Expand Down Expand Up @@ -160,14 +160,14 @@ function generate_cmp(f, g!, f_manopt, g_manopt!)
stopping_criterion=manopt_sc,
)
manopt_iters = get_count(manopt_state, :Iterations)
push!(times_manopt, median(bench_manopt.times) / (1000 * manopt_iters))
push!(times_manopt, median(bench_manopt.times) / 1000)
println("Manopt.jl iterations: $(manopt_iters)")

options_optim = Optim.Options(; g_tol=gtol)
bench_optim = @benchmark optimize($f, $g!, $x0, $method_optim, $options_optim)

optim_state = optimize(f_rosenbrock, g_rosenbrock!, x0, method_optim, options_optim)
push!(times_optim, median(bench_optim.times) / (1000 * optim_state.iterations))
push!(times_optim, median(bench_optim.times) / 1000)
println("Optim.jl iterations: $(optim_state.iterations)")
end
plot!(N_vals, times_manopt; label="Manopt.jl", xaxis=:log, yaxis=:log)
Expand Down
4 changes: 2 additions & 2 deletions ext/ManoptLineSearchesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ function (cs::Manopt.LineSearchesStepsize)(
retract!(M, p_tmp, p, η, α, cs.retraction_method)
get_gradient!(mp, X_tmp, p_tmp)
vector_transport_to!(M, Y_tmp, p, η, p_tmp, cs.vector_transport_method)
return real(inner(M, p_tmp, Y_tmp, Y_tmp))
return real(inner(M, p_tmp, X_tmp, Y_tmp))
end
function ϕdϕ(α)
# TODO: optimize?
retract!(M, p_tmp, p, η, α, cs.retraction_method)
get_gradient!(mp, X_tmp, p_tmp)
vector_transport_to!(M, Y_tmp, p, η, p_tmp, cs.vector_transport_method)
phi = f(M, p_tmp)
dphi = real(inner(M, p_tmp, Y_tmp, Y_tmp))
dphi = real(inner(M, p_tmp, X_tmp, Y_tmp))
return (phi, dphi)
end

Expand Down

0 comments on commit a23838c

Please sign in to comment.