Skip to content

Commit

Permalink
Unify/Simplify last stepsize.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Apr 13, 2024
1 parent 975f62c commit c88da50
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 35 deletions.
5 changes: 5 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

* Tests now also use `Aqua.jl` to spot problems in the code, e.g. ambiguities.

### Fixed

* `get_last_stepsize` was defined in quite different ways that caused ambiguities. That is now internally a bit restructured and should work nicer.
Internally this means that the interims dispatch on `get_last_stepsize(problem, state, step, vars...)` was removed. Now the only two left are `get_last_stepsize(p, s, vars...)` and the one directly checking `get_last_stepsize(::Stepsize)` for stored values.

## [0.4.60] – April 10, 2024

### Added
Expand Down
51 changes: 27 additions & 24 deletions src/plans/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ end
return the stepsize stored within [`AbstractManoptSolverState`](@ref) `ams` when solving the
[`AbstractManoptProblem`](@ref) `amp`.
This method also works for decorated options and the [`Stepsize`](@ref) function within
the options, by default stored in `o.stepsize`.
the options, by default stored in `ams.stepsize`.
"""
function get_stepsize(
amp::AbstractManoptProblem, ams::AbstractManoptSolverState, vars...; kwargs...
Expand Down Expand Up @@ -1262,6 +1262,18 @@ function _get_initial_stepsize(
return get_initial_stepsize(ams.stepsize)
end

@doc raw"""
get_last_stepsize(amp::AbstractManoptProblem, ams::AbstractManoptSolverState, vars...)
return the last computed stepsize stored within [`AbstractManoptSolverState`](@ref) `ams`
when solving the [`AbstractManoptProblem`](@ref) `amp`.
This method takes into account that `ams` might be decorated,
then calls [`get_last_stepsize`](@ref get_last_stepsize(::Stepsize, ::Any...)),
where the stepsize is assumed to be in `ams.stepsize`.
In case this returns `NaN`, a concrete call to the stored stepsize is performed.
For this, usually, the first of the `vars...` should be the current iterate.
"""
function get_last_stepsize(
amp::AbstractManoptProblem, ams::AbstractManoptSolverState, vars...
)
Expand All @@ -1275,33 +1287,24 @@ end
function _get_last_stepsize(
amp::AbstractManoptProblem, ams::AbstractManoptSolverState, ::Val{false}, vars...
)
return get_last_stepsize(amp, ams, ams.stepsize, vars...)
s = get_last_stepsize(ams.stepsize) # if it stores the stepsize itself -> return
!isnan(s) && return s
# if not -> call step.
return ams.stepsize(amp, ams, vars...)
end
#
# dispatch on stepsize
function get_last_stepsize(
amp::AbstractManoptProblem, ams::AbstractManoptSolverState, step::Stepsize, vars...
)
return step(amp, ams, vars...)
end
function get_last_stepsize(
::AbstractManoptProblem, ::AbstractManoptSolverState, step::ArmijoLinesearch, ::Any...
)
@doc raw"""
get_last_stepsize(::Stepsize, vars...)
return the last computed stepsize from within the stepsize.
If no last step size is stored, this returns `NaN`.
"""
get_last_stepsize(::Stepsize, ::Any...) = NaN
function get_last_stepsize(step::ArmijoLinesearch, ::Any...)
return step.last_stepsize
end
function get_last_stepsize(
::AbstractManoptProblem,
::AbstractManoptSolverState,
step::WolfePowellLinesearch,
::Any...,
)
function get_last_stepsize(step::WolfePowellLinesearch, ::Any...)
return step.last_stepsize
end
function get_last_stepsize(
::AbstractManoptProblem,
::AbstractManoptSolverState,
step::WolfePowellBinaryLinesearch,
::Any...,
)
function get_last_stepsize(step::WolfePowellBinaryLinesearch, ::Any...)
return step.last_stepsize
end
4 changes: 1 addition & 3 deletions src/solvers/LevenbergMarquardt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,6 @@ function step_solver!(
return lms
end

function _get_last_stepsize(
::AbstractManoptProblem, lms::LevenbergMarquardtState, ::Val{false}, vars...
)
function get_last_stepsize(::AbstractManoptProblem, lms::LevenbergMarquardtState, ::Any...)
return lms.last_stepsize
end
4 changes: 1 addition & 3 deletions src/solvers/augmented_Lagrangian_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,6 @@ function step_solver!(mp::AbstractManoptProblem, alms::AugmentedLagrangianMethod
end
get_solver_result(alms::AugmentedLagrangianMethodState) = alms.p

function get_last_stepsize(
::AbstractManoptProblem, s::AugmentedLagrangianMethodState, args...
)
function get_last_stepsize(::AbstractManoptProblem, s::AugmentedLagrangianMethodState, i)
return s.last_stepsize
end
6 changes: 4 additions & 2 deletions src/solvers/convex_bundle_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,9 @@ function step_solver!(mp::AbstractManoptProblem, bms::ConvexBundleMethodState, i
return bms
end
get_solver_result(bms::ConvexBundleMethodState) = bms.p_last_serious
get_last_stepsize(::AbstractManoptProblem, bms::ConvexBundleMethodState) = bms.last_stepsize
function get_last_stepsize(::AbstractManoptProblem, bms::ConvexBundleMethodState, i)
return bms.last_stepsize
end

#
#
Expand Down Expand Up @@ -601,6 +603,6 @@ function (d::DebugStepsize)(
dmp::P, bms::ConvexBundleMethodState, i::Int
) where {P<:AbstractManoptProblem}
(i < 1) && return nothing
Printf.format(d.io, Printf.Format(d.format), get_last_stepsize(dmp, bms))
Printf.format(d.io, Printf.Format(d.format), get_last_stepsize(dmp, bms, i))
return nothing
end
2 changes: 1 addition & 1 deletion test/helpers/test_linesearches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ using Test
rosenbrock, rosenbrock_grad!; evaluation=InplaceEvaluation()
)
mp = DefaultManoptProblem(M, mgo)
@test get_last_stepsize(mp, x_opt, x_opt.stepsize, 1) > 0.0
@test get_last_stepsize(mp, x_opt, 1) > 0.0

# this tests catching LineSearchException
@test_throws LineSearchException ls_hz(mp, x_opt, 1, NaN * zero_vector(M, x0))
Expand Down
2 changes: 1 addition & 1 deletion test/plans/test_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct NoIterateState <: AbstractManoptSolverState end
@test repr(Manopt.ReturnSolverState(s)) == "ReturnSolverState($s)"
@test Manopt.status_summary(Manopt.ReturnSolverState(s)) == "DummyState(Float64[])"
a = ArmijoLinesearch(M; initial_stepsize=1.0)
@test get_last_stepsize(pr, s, a) == 1.0
@test get_last_stepsize(a) == 1.0
@test get_initial_stepsize(a) == 1.0
set_manopt_parameter!(s, :Dummy, 1)
end
Expand Down
1 change: 0 additions & 1 deletion test/test_aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ using Aqua, Manopt, Test
Manopt.particle_swarm, # should be fixed
Manopt.stochastic_gradient_descent, # should be fixed
Manopt.truncated_conjugate_gradient_descent!, # will be fixed by removing deprecated methods
Manopt.get_last_stepsize, #Maybe redesign?
],
broken=false,
),
Expand Down

0 comments on commit c88da50

Please sign in to comment.