Skip to content

Commit

Permalink
Implement and document PPA.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Sep 3, 2024
1 parent 158f99b commit 834b859
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 18 deletions.
2 changes: 2 additions & 0 deletions docs/src/solvers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ If the gradient does not exist everywhere, that is if the splitting yields summa
* [Difference of Convex Proximal Point](@ref solver-difference-of-convex-proximal-point) uses a splitting of the (non-convex) function ``f = g - h`` into a difference of two functions; provided the proximal map of ``g`` and the subgradient of ``h``, the next iterate is computed. Compared to DCA, the corresponding sub problem is here written in a form that yields the proximal map.
* [Douglas—Rachford](DouglasRachford.md) uses a splitting ``f(p) = F(x) + G(x)`` and their proximal maps to compute a minimizer of ``f``, which can be non-smooth.
* [Primal-dual Riemannian semismooth Newton Algorithm](@ref solver-pdrssn) extends Chambolle-Pock and requires the differentials of the proximal maps additionally.
* The [Proximal Point](proximal_point.md) uses the proximal map of ``f`` iteratively.

## Constrained

Expand Down Expand Up @@ -120,6 +121,7 @@ For these you can use
| [Particle Swarm](particle_swarm.md) | [`particle_swarm`](@ref) | [`ParticleSwarmState`](@ref) |
[Primal-dual Riemannian semismooth Newton Algorithm](@ref solver-pdrssn) | [`primal_dual_semismooth_Newton`](@ref) | [`PrimalDualSemismoothNewtonState`](@ref) |
| [Proximal Bundle Method](proximal_bundle_method.md) | [`proximal_bundle_method`](@ref) | [`ProximalBundleMethodState`](@ref) |
| [Proximal Point](proximal_point.md) | [`proximal_point`](@ref) | [`ProximalPointState`](@ref) |
| [Quasi-Newton Method](quasi_Newton.md) | [`quasi_Newton`](@ref) | [`QuasiNewtonState`](@ref) |
| [Steihaug-Toint Truncated Conjugate-Gradient Method](@ref tCG) | [`truncated_conjugate_gradient_descent`](@ref) | [`TruncatedConjugateGradientState`](@ref) |
| [Subgradient Method](subgradient.md) | [`subgradient_method`](@ref) | [`SubGradientMethodState`](@ref) |
Expand Down
21 changes: 21 additions & 0 deletions docs/src/solvers/proximal_point.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Proximal Point Method

```@meta
CurrentModule = Manopt
```

```@docs
proximal_point
proximal_point!
```

## State

```@docs
ProximalPointState
```

```@bibliography
Pages = ["proximal_point.md"]
Canonical=false
```
2 changes: 1 addition & 1 deletion src/Manopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ include("solvers/LevenbergMarquardt.jl")
include("solvers/particle_swarm.jl")
include("solvers/primal_dual_semismooth_Newton.jl")
include("solvers/proximal_bundle_method.jl")
include("solvers/prorimal_point.jl")
include("solvers/proximal_point.jl")
include("solvers/quasi_Newton.jl")
include("solvers/truncated_conjugate_gradient_descent.jl")
include("solvers/trust_regions.jl")
Expand Down
8 changes: 6 additions & 2 deletions src/plans/proximal_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,20 @@ stores options for the [`cyclic_proximal_point`](@ref) algorithm. These are the
$(_var(:Field, :p; add=[:as_Iterate]))
$(_var(:Field, :stopping_criterion, "stop"))
* `λ`: a function for the values of ``λ_k`` per iteration(cycle ``ì``
* `λ`: a function for the values of ``λ_k`` per iteration(cycle ``k``
* `oder_type`: whether to use a randomly permuted sequence (`:FixedRandomOrder`),
a per cycle permuted sequence (`:RandomOrder`) or the default linear one.
# Constructor
CyclicProximalPointState(M; kwargs...)
CyclicProximalPointState(M::AbstractManifold; kwargs...)
Generate the options
## Input
$(_var(:Argument, :M; type=true))
# Keyword arguments
* `evaluation_order=:LinearOrder`: soecify the `order_type`
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/LevenbergMarquardt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The second signature performs the optimization in-place of `p`.
# Input
$(_var(:Argument, :M; type=true))
* `f`: a cost function ``f: $(_math(:M)) M→ℝ^d``
* `f`: a cost function ``f: $(_math(:M))→ℝ^d``
* `jacobian_f`: the Jacobian of ``f``. The Jacobian is supposed to accept a keyword argument
`basis_domain` which specifies basis of the tangent space at a given point in which the
Jacobian is to be calculated. By default it should be the `DefaultOrthonormalBasis`.
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/cyclic_proximal_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ perform a cyclic proximal point algorithm. This can be done in-place of `p`.
# Input
$(_var(:Argument, :M; type=true))
* `f`: a cost function ``f: $(_math(:M)) M→ℝ`` to minimize
* `f`: a cost function ``f: $(_math(:M))→ℝ`` to minimize
* `proxes_f`: an Array of proximal maps (`Function`s) `(M,λ,p) -> q` or `(M, q, λ, p) -> q` for the summands of ``f`` (see `evaluation`)
where `f` and the proximal maps `proxes_f` can also be given directly as a [`ManifoldProximalMapObjective`](@ref) `mpo`
Expand Down
111 changes: 98 additions & 13 deletions src/solvers/proximal_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
$(_var(:Field, :p; add=[:as_Iterate]))
$(_var(:Field, :stopping_criterion, "stop"))
* `λ`: a function for the values of ``λ_k`` per iteration(cycle ``k``
# Constructor
Expand All @@ -21,32 +22,27 @@ $(_var(:Argument, :M; type=true))
## Keyword arguments
* `λ=i -> 1.0 / i` a function to compute the ``λ_k, k ∈ $(_tex(:Cal, "N"))``,
$(_var(:Keyword, :p; add=:as_Initial))
$(_var(:Keyword, :stopping_criterion; default="[`StopAfterIteration`](@ref)`(100)`"))
# See also
[`proximal point`](@ref)
[`proximal_point`](@ref)
"""
mutable struct ProximalPointState{
P,
TStop<:StoppingCriterion,
} <: AbstractGradientSolverState
mutable struct ProximalPointState{P,Tλ,TStop<:StoppingCriterion} <:
AbstractGradientSolverState
λ::Tλ
p::P
stop::TStop
end
function ProximalPointState(
M::AbstractManifold;
λ::F=(i) -> 1.0 / i,
p::P=rand(M),
stopping_criterion::SC=StopAfterIteration(200),
) where {
P,
SC<:StoppingCriterion,
}
return ProximalPointState{P,SC}(p, stopping_criterion)
end
function get_message(pps::ProximalPointState)
return get_message(pps.stepsize)
) where {P,F,SC<:StoppingCriterion}
return ProximalPointState{P,F,SC}(λ, p, stopping_criterion)
end
function show(io::IO, gds::ProximalPointState)
i = get_count(gds, :Iterations)
Expand All @@ -65,3 +61,92 @@ end
#
#
# solver interface
_doc_PPA = """
proximal_point(M, prox_f, p=rand(M); kwargs...)
proximal_point(M, mpmo, p=rand(M); kwargs...)
proximal_point!(M, prox_f, p; kwargs...)
proximal_point!(M, mpmo, p; kwargs...)
Perform the proximal point algoritm from [FerreiraOliveira:2002](@cite) which reads
```math
p^{(k+1)} = $(_tex(:prox))_{λ_kf}(p^{(k)})
```
# Input
$(_var(:Argument, :M; type=true))
* `prox_f`: a proximal map `(M,λ,p) -> q` or `(M, q, λ, p) -> q` for the summands of ``f`` (see `evaluation`)
# Keyword arguments
$(_var(:Keyword, :evaluation))
* `f=nothing`: a cost function ``f: $(_math(:M))→ℝ`` to minimize. For running the algorithm, ``f`` is not required, but for example when recording the cost or using a stopping criterion that requires a cost function.
* `λ=iter -> 1/iter`: a function returning the (square summable but not summable) sequence of ``λ_i``
$(_var(:Keyword, :stopping_criterion; default="[`StopAfterIteration`](@ref)`(200)`$(_sc(:Any))[`StopWhenChangeLess`](@ref)`(1e-12)`)"))
$(_note(:OtherKeywords))
$(_note(:OutputSection))
"""

@doc "$(_doc_PPA)"
proximal_point(M::AbstractManifold, args...; kwargs...)
function proximal_point(
M::AbstractManifold,
prox_f,
p=rand(M);
f=nothing,
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
)
p_ = _ensure_mutating_variable(p)
f_ = _ensure_mutating_cost(f, p)
prox_f_ = _ensure_mutating_proc(prox_f, p, evaluation)
mpo = ManifoldProximalMapObjective(f_, prox_f_; evaluation=evaluation)
rs = proximal_point(M, mpo, p_; evaluation=evaluation, kwargs...)
return _ensure_matching_output(p, rs)
end
function proximal_point(
M::AbstractManifold, mpo::O, p; kwargs...
) where {O<:Union{ManifoldProximalMapObjective,AbstractDecoratedManifoldObjective}}
q = copy(M, p)
return proximal_point!(M, mpo, q; kwargs...)
end

@doc "$(_doc_PPA)"
proximal_point!(M::AbstractManifold, args...; kwargs...)
function proximal_point!(
M::AbstractManifold,
prox_f,
p;
f=nothing,
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
)
mpo = ManifoldProximalMapObjective(f, prox_f; evaluation=evaluation)
return proximal_point!(M, mpo, p; evaluation=evaluation, kwargs...)
end
function proximal_point!(
M::AbstractManifold,
mpo::O,
p;
stopping_criterion::StoppingCriterion=StopAfterIteration(200) |
StopWhenChangeLess(M, 1e-12),
λ=i -> 1 / i,
kwargs...,
) where {O<:Union{ManifoldProximalMapObjective,AbstractDecoratedManifoldObjective}}
dmpo = decorate_objective!(M, mpo; kwargs...)
dmp = DefaultManoptProblem(M, dmpo)
pps = ProximalPointState(M; p=p, stopping_criterion=stopping_criterion, λ=λ)
dpps = decorate_state!(pps; kwargs...)
solve!(dmp, dpps)
return get_solver_return(get_objective(dmp), dpps)
end
function initialize_solver!(::AbstractManoptProblem, pps::ProximalPointState)
return pps
end
function step_solver!(amp::AbstractManoptProblem, pps::ProximalPointState, k)
get_proximal_map!(amp, pps.p, pps.λ(k), pps.p, 1)
return pps
end

0 comments on commit 834b859

Please sign in to comment.