From 27268c6d84293e3150ce00b9755625a4353a918d Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Wed, 4 Sep 2024 12:08:16 +0200 Subject: [PATCH] =?UTF-8?q?improve=20get=20proximal=20map=20=E2=80=93=20ad?= =?UTF-8?q?d=20test=20coverage.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/Manopt.jl | 2 + src/plans/proximal_plan.jl | 88 ++++++++++++++++++++++------- src/solvers/proximal_point.jl | 16 +++--- test/plans/test_proximal_plan.jl | 4 +- test/runtests.jl | 1 + test/solvers/test_proximal_point.jl | 27 +++++++++ 6 files changed, 107 insertions(+), 31 deletions(-) create mode 100644 test/solvers/test_proximal_point.jl diff --git a/src/Manopt.jl b/src/Manopt.jl index dc77243040..4973e04b76 100644 --- a/src/Manopt.jl +++ b/src/Manopt.jl @@ -508,6 +508,8 @@ export adaptive_regularization_with_cubics, primal_dual_semismooth_Newton, proximal_bundle_method, proximal_bundle_method!, + proximal_point, + proximal_point!, quasi_Newton, quasi_Newton!, stochastic_gradient_descent, diff --git a/src/plans/proximal_plan.jl b/src/plans/proximal_plan.jl index eb2f877d51..13be94d213 100644 --- a/src/plans/proximal_plan.jl +++ b/src/plans/proximal_plan.jl @@ -74,10 +74,6 @@ mutable struct ManifoldProximalMapObjective{E<:AbstractEvaluationType,TC,TP,V} < return new{E,F,PF,typeof(i)}(f, prox_f, i) end end -function check_prox_number(n, i) - (i > n) && throw(ErrorException("the $(i)th entry does not exists, only $n available.")) - return true -end @doc raw""" q = get_proximal_map(M::AbstractManifold, mpo::ManifoldProximalMapObjective, λ, p) get_proximal_map!(M::AbstractManifold, q, mpo::ManifoldProximalMapObjective, λ, p) @@ -101,47 +97,97 @@ function get_proximal_map!(amp::AbstractManoptProblem, q, λ, p) return get_proximal_map!(get_manifold(amp), q, get_objective(amp), λ, p) end +function check_prox_number(pf::Union{Tuple,Vector}, i) + n = length(pf) + (i > n) && throw(ErrorException("the $(i)th entry does not exists, only $n available.")) + return true +end + function get_proximal_map( - M::AbstractManifold, mpo::ManifoldProximalMapObjective{AllocatingEvaluation}, λ, p, i -) - check_prox_number(length(mpo.proximal_maps!!), i) + M::AbstractManifold, + mpo::ManifoldProximalMapObjective{AllocatingEvaluation,F,<:Union{<:Tuple,<:Vector}}, + λ, + p, + i, +) where {F} + check_prox_number(mpo.proximal_maps!!, i) return mpo.proximal_maps!![i](M, λ, p) end function get_proximal_map( - M::AbstractManifold, admo::AbstractDecoratedManifoldObjective, λ, p, i + M::AbstractManifold, admo::AbstractDecoratedManifoldObjective, args... ) - return get_proximal_map(M, get_objective(admo, false), λ, p, i) + return get_proximal_map(M, get_objective(admo, false), args...) end - function get_proximal_map!( - M::AbstractManifold, q, mpo::ManifoldProximalMapObjective{AllocatingEvaluation}, λ, p, i -) - check_prox_number(length(mpo.proximal_maps!!), i) + M::AbstractManifold, + q, + mpo::ManifoldProximalMapObjective{AllocatingEvaluation,F,<:Union{<:Tuple,<:Vector}}, + λ, + p, + i, +) where {F} + check_prox_number(mpo.proximal_maps!!, i) copyto!(M, q, mpo.proximal_maps!![i](M, λ, p)) return q end function get_proximal_map!( - M::AbstractManifold, q, admo::AbstractDecoratedManifoldObjective, λ, p, i + M::AbstractManifold, q, admo::AbstractDecoratedManifoldObjective, args... ) - return get_proximal_map!(M, q, get_objective(admo, false), λ, p, i) + return get_proximal_map!(M, q, get_objective(admo, false), args...) end function get_proximal_map( - M::AbstractManifold, mpo::ManifoldProximalMapObjective{InplaceEvaluation}, λ, p, i -) - check_prox_number(length(mpo.proximal_maps!!), i) + M::AbstractManifold, + mpo::ManifoldProximalMapObjective{InplaceEvaluation,F,<:Union{<:Tuple,<:Vector}}, + λ, + p, + i, +) where {F} + check_prox_number(mpo.proximal_maps!!, i) q = allocate_result(M, get_proximal_map, p) mpo.proximal_maps!![i](M, q, λ, p) return q end function get_proximal_map!( - M::AbstractManifold, q, mpo::ManifoldProximalMapObjective{InplaceEvaluation}, λ, p, i -) - check_prox_number(length(mpo.proximal_maps!!), i) + M::AbstractManifold, + q, + mpo::ManifoldProximalMapObjective{InplaceEvaluation,F,<:Union{<:Tuple,<:Vector}}, + λ, + p, + i, +) where {F} + check_prox_number(mpo.proximal_maps!!, i) mpo.proximal_maps!![i](M, q, λ, p) return q end # # +# Single function accessors +function get_proximal_map( + M::AbstractManifold, mpo::ManifoldProximalMapObjective{AllocatingEvaluation}, λ, p +) + return mpo.proximal_maps!!(M, λ, p) +end +function get_proximal_map!( + M::AbstractManifold, q, mpo::ManifoldProximalMapObjective{AllocatingEvaluation}, λ, p +) + copyto!(M, q, mpo.proximal_maps!!(M, λ, p)) + return q +end +function get_proximal_map( + M::AbstractManifold, mpo::ManifoldProximalMapObjective{InplaceEvaluation}, λ, p +) + q = allocate_result(M, get_proximal_map, p) + mpo.proximal_maps!!(M, q, λ, p) + return q +end +function get_proximal_map!( + M::AbstractManifold, q, mpo::ManifoldProximalMapObjective{InplaceEvaluation}, λ, p +) + mpo.proximal_maps!!(M, q, λ, p) + return q +end +# +# # Proximal based State # # diff --git a/src/solvers/proximal_point.jl b/src/solvers/proximal_point.jl index fa2a89cc9d..98ce60bb17 100644 --- a/src/solvers/proximal_point.jl +++ b/src/solvers/proximal_point.jl @@ -22,7 +22,7 @@ $(_var(:Argument, :M; type=true)) ## Keyword arguments -* `λ=i -> 1.0 / i` a function to compute the ``λ_k, k ∈ $(_tex(:Cal, "N"))``, +* `λ=k -> 1.0` a function to compute the ``λ_k, k ∈ $(_tex(:Cal, "N"))``, $(_var(:Keyword, :p; add=:as_Initial)) $(_var(:Keyword, :stopping_criterion; default="[`StopAfterIteration`](@ref)`(100)`")) @@ -38,7 +38,7 @@ mutable struct ProximalPointState{P,Tλ,TStop<:StoppingCriterion} <: end function ProximalPointState( M::AbstractManifold; - λ::F=(i) -> 1.0 / i, + λ::F=k -> 1.0, p::P=rand(M), stopping_criterion::SC=StopAfterIteration(200), ) where {P,F,SC<:StoppingCriterion} @@ -49,7 +49,7 @@ function show(io::IO, gds::ProximalPointState) Iter = (i > 0) ? "After $i iterations\n" : "" Conv = indicates_convergence(gds.stop) ? "Yes" : "No" s = """ - # Solver state for `Manopt.jl`s Proximal POint Method + # Solver state for `Manopt.jl`s Proximal Point Method $Iter ## Stopping criterion @@ -82,7 +82,7 @@ $(_var(:Argument, :M; type=true)) $(_var(:Keyword, :evaluation)) * `f=nothing`: a cost function ``f: $(_math(:M))→ℝ`` to minimize. For running the algorithm, ``f`` is not required, but for example when recording the cost or using a stopping criterion that requires a cost function. -* `λ=iter -> 1/iter`: a function returning the (square summable but not summable) sequence of ``λ_i`` +* `λ= k -> 1.0`: a function returning the (square summable but not summable) sequence of ``λ_i`` $(_var(:Keyword, :stopping_criterion; default="[`StopAfterIteration`](@ref)`(200)`$(_sc(:Any))[`StopWhenChangeLess`](@ref)`(1e-12)`)")) $(_note(:OtherKeywords)) @@ -102,7 +102,7 @@ function proximal_point( ) p_ = _ensure_mutating_variable(p) f_ = _ensure_mutating_cost(f, p) - prox_f_ = _ensure_mutating_proc(prox_f, p, evaluation) + prox_f_ = _ensure_mutating_prox(prox_f, p, evaluation) mpo = ManifoldProximalMapObjective(f_, prox_f_; evaluation=evaluation) rs = proximal_point(M, mpo, p_; evaluation=evaluation, kwargs...) return _ensure_matching_output(p, rs) @@ -131,9 +131,9 @@ function proximal_point!( M::AbstractManifold, mpo::O, p; - stopping_criterion::StoppingCriterion=StopAfterIteration(200) | + stopping_criterion::StoppingCriterion=StopAfterIteration(1000) | StopWhenChangeLess(M, 1e-12), - λ=i -> 1 / i, + λ=k -> 1, kwargs..., ) where {O<:Union{ManifoldProximalMapObjective,AbstractDecoratedManifoldObjective}} dmpo = decorate_objective!(M, mpo; kwargs...) @@ -147,6 +147,6 @@ function initialize_solver!(::AbstractManoptProblem, pps::ProximalPointState) return pps end function step_solver!(amp::AbstractManoptProblem, pps::ProximalPointState, k) - get_proximal_map!(amp, pps.p, pps.λ(k), pps.p, 1) + get_proximal_map!(amp, pps.p, pps.λ(k), pps.p) return pps end diff --git a/test/plans/test_proximal_plan.jl b/test/plans/test_proximal_plan.jl index 469cc7af3b..2694dded82 100644 --- a/test/plans/test_proximal_plan.jl +++ b/test/plans/test_proximal_plan.jl @@ -40,8 +40,8 @@ include("../utils/dummy_types.jl") @test get_count(cppo, :ProximalMap, 1) == 2 # the single ones have to be tricked a bit cppo2 = ManifoldCountObjective(M, ppo, Dict([:ProximalMap => 0])) - @test q == get_proximal_map(M, cppo2, 0.1, p) - get_proximal_map!(M, q2, cppo2, 0.1, p) + @test q == get_proximal_map(M, cppo2, 0.1, p, 1) + get_proximal_map!(M, q2, cppo2, 0.1, p, 1) @test q2 == q @test get_count(cppo2, :ProximalMap) == 2 end diff --git a/test/runtests.jl b/test/runtests.jl index e254f85b13..3d90036bcb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,6 +56,7 @@ include("utils/example_tasks.jl") include("solvers/test_Levenberg_Marquardt.jl") include("solvers/test_Nelder_Mead.jl") include("solvers/test_proximal_bundle_method.jl") + include("solvers/test_proximal_point.jl") include("solvers/test_quasi_Newton.jl") include("solvers/test_particle_swarm.jl") include("solvers/test_primal_dual_semismooth_Newton.jl") diff --git a/test/solvers/test_proximal_point.jl b/test/solvers/test_proximal_point.jl new file mode 100644 index 0000000000..96fc1d8bd4 --- /dev/null +++ b/test/solvers/test_proximal_point.jl @@ -0,0 +1,27 @@ +using Manopt, Manifolds, ManifoldDiff +using ManifoldDiff: prox_distance, prox_distance! + +@testset "Proximal Point" begin + # Dummy problem + M = Sphere(2) + q = [1.0, 0.0, 0.0] + f(M, p) = 0.5 * distance(M, p, q)^2 + prox_f(M, λ, p) = prox_distance(M, λ, q, p) + prox_f!(M, r, λ, p) = prox_distance!(M, r, λ, q, p) + + p0 = [0.0, 0.0, 1.0] + q1 = proximal_point(M, prox_f, p0) + @test distance(M, q, q1) < 1e-12 + q2 = proximal_point(M, prox_f!, p0; evaluation=InplaceEvaluation()) + @test distance(M, q1, q2) == 0 + os3 = proximal_point(M, prox_f, p0; return_state=true, return_objective=true) + obj = os3[1] + # test with get_prox map that these are fix points + pps = os3[2] + q3a = get_proximal_map(M, obj, 1.0, get_iterate(pps)) + @test isapprox(M, q2, q3a) + q3b = rand(M) + get_proximal_map!(M, q3b, obj, 1.0, get_iterate(pps)) + @test distance(M, q3a, q3b) == 0 + @test startswith(repr(pps), "# Solver state for `Manopt.jl`s Proximal Point Method\n") +end