Skip to content

Commit

Permalink
improve get proximal map – add test coverage.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Sep 4, 2024
1 parent 3f96950 commit 27268c6
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 31 deletions.
2 changes: 2 additions & 0 deletions src/Manopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,8 @@ export adaptive_regularization_with_cubics,
primal_dual_semismooth_Newton,
proximal_bundle_method,
proximal_bundle_method!,
proximal_point,
proximal_point!,
quasi_Newton,
quasi_Newton!,
stochastic_gradient_descent,
Expand Down
88 changes: 67 additions & 21 deletions src/plans/proximal_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ mutable struct ManifoldProximalMapObjective{E<:AbstractEvaluationType,TC,TP,V} <
return new{E,F,PF,typeof(i)}(f, prox_f, i)
end
end
function check_prox_number(n, i)
(i > n) && throw(ErrorException("the $(i)th entry does not exists, only $n available."))
return true
end
@doc raw"""
q = get_proximal_map(M::AbstractManifold, mpo::ManifoldProximalMapObjective, λ, p)
get_proximal_map!(M::AbstractManifold, q, mpo::ManifoldProximalMapObjective, λ, p)
Expand All @@ -101,47 +97,97 @@ function get_proximal_map!(amp::AbstractManoptProblem, q, λ, p)
return get_proximal_map!(get_manifold(amp), q, get_objective(amp), λ, p)
end

function check_prox_number(pf::Union{Tuple,Vector}, i)
n = length(pf)
(i > n) && throw(ErrorException("the $(i)th entry does not exists, only $n available."))
return true
end

function get_proximal_map(
M::AbstractManifold, mpo::ManifoldProximalMapObjective{AllocatingEvaluation}, λ, p, i
)
check_prox_number(length(mpo.proximal_maps!!), i)
M::AbstractManifold,
mpo::ManifoldProximalMapObjective{AllocatingEvaluation,F,<:Union{<:Tuple,<:Vector}},
λ,
p,
i,
) where {F}
check_prox_number(mpo.proximal_maps!!, i)
return mpo.proximal_maps!![i](M, λ, p)
end
function get_proximal_map(
M::AbstractManifold, admo::AbstractDecoratedManifoldObjective, λ, p, i
M::AbstractManifold, admo::AbstractDecoratedManifoldObjective, args...
)
return get_proximal_map(M, get_objective(admo, false), λ, p, i)
return get_proximal_map(M, get_objective(admo, false), args...)
end

function get_proximal_map!(
M::AbstractManifold, q, mpo::ManifoldProximalMapObjective{AllocatingEvaluation}, λ, p, i
)
check_prox_number(length(mpo.proximal_maps!!), i)
M::AbstractManifold,
q,
mpo::ManifoldProximalMapObjective{AllocatingEvaluation,F,<:Union{<:Tuple,<:Vector}},
λ,
p,
i,
) where {F}
check_prox_number(mpo.proximal_maps!!, i)
copyto!(M, q, mpo.proximal_maps!![i](M, λ, p))
return q
end
function get_proximal_map!(
M::AbstractManifold, q, admo::AbstractDecoratedManifoldObjective, λ, p, i
M::AbstractManifold, q, admo::AbstractDecoratedManifoldObjective, args...
)
return get_proximal_map!(M, q, get_objective(admo, false), λ, p, i)
return get_proximal_map!(M, q, get_objective(admo, false), args...)
end
function get_proximal_map(
M::AbstractManifold, mpo::ManifoldProximalMapObjective{InplaceEvaluation}, λ, p, i
)
check_prox_number(length(mpo.proximal_maps!!), i)
M::AbstractManifold,
mpo::ManifoldProximalMapObjective{InplaceEvaluation,F,<:Union{<:Tuple,<:Vector}},
λ,
p,
i,
) where {F}
check_prox_number(mpo.proximal_maps!!, i)
q = allocate_result(M, get_proximal_map, p)
mpo.proximal_maps!![i](M, q, λ, p)
return q
end
function get_proximal_map!(
M::AbstractManifold, q, mpo::ManifoldProximalMapObjective{InplaceEvaluation}, λ, p, i
)
check_prox_number(length(mpo.proximal_maps!!), i)
M::AbstractManifold,
q,
mpo::ManifoldProximalMapObjective{InplaceEvaluation,F,<:Union{<:Tuple,<:Vector}},
λ,
p,
i,
) where {F}
check_prox_number(mpo.proximal_maps!!, i)
mpo.proximal_maps!![i](M, q, λ, p)
return q
end
#
#
# Single function accessors
function get_proximal_map(
M::AbstractManifold, mpo::ManifoldProximalMapObjective{AllocatingEvaluation}, λ, p
)
return mpo.proximal_maps!!(M, λ, p)
end
function get_proximal_map!(
M::AbstractManifold, q, mpo::ManifoldProximalMapObjective{AllocatingEvaluation}, λ, p
)
copyto!(M, q, mpo.proximal_maps!!(M, λ, p))
return q
end
function get_proximal_map(

Check warning on line 176 in src/plans/proximal_plan.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/proximal_plan.jl#L176

Added line #L176 was not covered by tests
M::AbstractManifold, mpo::ManifoldProximalMapObjective{InplaceEvaluation}, λ, p
)
q = allocate_result(M, get_proximal_map, p)
mpo.proximal_maps!!(M, q, λ, p)
return q

Check warning on line 181 in src/plans/proximal_plan.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/proximal_plan.jl#L179-L181

Added lines #L179 - L181 were not covered by tests
end
function get_proximal_map!(
M::AbstractManifold, q, mpo::ManifoldProximalMapObjective{InplaceEvaluation}, λ, p
)
mpo.proximal_maps!!(M, q, λ, p)
return q
end
#
#
# Proximal based State
#
#
Expand Down
16 changes: 8 additions & 8 deletions src/solvers/proximal_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ $(_var(:Argument, :M; type=true))
## Keyword arguments
* `λ=i -> 1.0 / i` a function to compute the ``λ_k, k ∈ $(_tex(:Cal, "N"))``,
* `λ=k -> 1.0` a function to compute the ``λ_k, k ∈ $(_tex(:Cal, "N"))``,
$(_var(:Keyword, :p; add=:as_Initial))
$(_var(:Keyword, :stopping_criterion; default="[`StopAfterIteration`](@ref)`(100)`"))
Expand All @@ -38,7 +38,7 @@ mutable struct ProximalPointState{P,Tλ,TStop<:StoppingCriterion} <:
end
function ProximalPointState(
M::AbstractManifold;
λ::F=(i) -> 1.0 / i,
λ::F=k -> 1.0,
p::P=rand(M),
stopping_criterion::SC=StopAfterIteration(200),
) where {P,F,SC<:StoppingCriterion}
Expand All @@ -49,7 +49,7 @@ function show(io::IO, gds::ProximalPointState)
Iter = (i > 0) ? "After $i iterations\n" : ""
Conv = indicates_convergence(gds.stop) ? "Yes" : "No"
s = """
# Solver state for `Manopt.jl`s Proximal POint Method
# Solver state for `Manopt.jl`s Proximal Point Method
$Iter
## Stopping criterion
Expand Down Expand Up @@ -82,7 +82,7 @@ $(_var(:Argument, :M; type=true))
$(_var(:Keyword, :evaluation))
* `f=nothing`: a cost function ``f: $(_math(:M))→ℝ`` to minimize. For running the algorithm, ``f`` is not required, but for example when recording the cost or using a stopping criterion that requires a cost function.
* `λ=iter -> 1/iter`: a function returning the (square summable but not summable) sequence of ``λ_i``
* `λ= k -> 1.0`: a function returning the (square summable but not summable) sequence of ``λ_i``
$(_var(:Keyword, :stopping_criterion; default="[`StopAfterIteration`](@ref)`(200)`$(_sc(:Any))[`StopWhenChangeLess`](@ref)`(1e-12)`)"))
$(_note(:OtherKeywords))
Expand All @@ -102,7 +102,7 @@ function proximal_point(
)
p_ = _ensure_mutating_variable(p)
f_ = _ensure_mutating_cost(f, p)
prox_f_ = _ensure_mutating_proc(prox_f, p, evaluation)
prox_f_ = _ensure_mutating_prox(prox_f, p, evaluation)
mpo = ManifoldProximalMapObjective(f_, prox_f_; evaluation=evaluation)
rs = proximal_point(M, mpo, p_; evaluation=evaluation, kwargs...)
return _ensure_matching_output(p, rs)
Expand Down Expand Up @@ -131,9 +131,9 @@ function proximal_point!(
M::AbstractManifold,
mpo::O,
p;
stopping_criterion::StoppingCriterion=StopAfterIteration(200) |
stopping_criterion::StoppingCriterion=StopAfterIteration(1000) |
StopWhenChangeLess(M, 1e-12),
λ=i -> 1 / i,
λ=k -> 1,
kwargs...,
) where {O<:Union{ManifoldProximalMapObjective,AbstractDecoratedManifoldObjective}}
dmpo = decorate_objective!(M, mpo; kwargs...)
Expand All @@ -147,6 +147,6 @@ function initialize_solver!(::AbstractManoptProblem, pps::ProximalPointState)
return pps
end
function step_solver!(amp::AbstractManoptProblem, pps::ProximalPointState, k)
get_proximal_map!(amp, pps.p, pps.λ(k), pps.p, 1)
get_proximal_map!(amp, pps.p, pps.λ(k), pps.p)
return pps
end
4 changes: 2 additions & 2 deletions test/plans/test_proximal_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ include("../utils/dummy_types.jl")
@test get_count(cppo, :ProximalMap, 1) == 2
# the single ones have to be tricked a bit
cppo2 = ManifoldCountObjective(M, ppo, Dict([:ProximalMap => 0]))
@test q == get_proximal_map(M, cppo2, 0.1, p)
get_proximal_map!(M, q2, cppo2, 0.1, p)
@test q == get_proximal_map(M, cppo2, 0.1, p, 1)
get_proximal_map!(M, q2, cppo2, 0.1, p, 1)
@test q2 == q
@test get_count(cppo2, :ProximalMap) == 2
end
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ include("utils/example_tasks.jl")
include("solvers/test_Levenberg_Marquardt.jl")
include("solvers/test_Nelder_Mead.jl")
include("solvers/test_proximal_bundle_method.jl")
include("solvers/test_proximal_point.jl")
include("solvers/test_quasi_Newton.jl")
include("solvers/test_particle_swarm.jl")
include("solvers/test_primal_dual_semismooth_Newton.jl")
Expand Down
27 changes: 27 additions & 0 deletions test/solvers/test_proximal_point.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Manopt, Manifolds, ManifoldDiff
using ManifoldDiff: prox_distance, prox_distance!

@testset "Proximal Point" begin
# Dummy problem
M = Sphere(2)
q = [1.0, 0.0, 0.0]
f(M, p) = 0.5 * distance(M, p, q)^2
prox_f(M, λ, p) = prox_distance(M, λ, q, p)
prox_f!(M, r, λ, p) = prox_distance!(M, r, λ, q, p)

p0 = [0.0, 0.0, 1.0]
q1 = proximal_point(M, prox_f, p0)
@test distance(M, q, q1) < 1e-12
q2 = proximal_point(M, prox_f!, p0; evaluation=InplaceEvaluation())
@test distance(M, q1, q2) == 0
os3 = proximal_point(M, prox_f, p0; return_state=true, return_objective=true)
obj = os3[1]
# test with get_prox map that these are fix points
pps = os3[2]
q3a = get_proximal_map(M, obj, 1.0, get_iterate(pps))
@test isapprox(M, q2, q3a)
q3b = rand(M)
get_proximal_map!(M, q3b, obj, 1.0, get_iterate(pps))
@test distance(M, q3a, q3b) == 0
@test startswith(repr(pps), "# Solver state for `Manopt.jl`s Proximal Point Method\n")
end

0 comments on commit 27268c6

Please sign in to comment.