diff --git a/src/Manopt.jl b/src/Manopt.jl index a015e5a3a8..b241f595f6 100644 --- a/src/Manopt.jl +++ b/src/Manopt.jl @@ -201,6 +201,7 @@ include("solvers/LevenbergMarquardt.jl") include("solvers/particle_swarm.jl") include("solvers/primal_dual_semismooth_Newton.jl") include("solvers/proximal_bundle_method.jl") +include("solvers/prorimal_point.jl") include("solvers/quasi_Newton.jl") include("solvers/truncated_conjugate_gradient_descent.jl") include("solvers/trust_regions.jl") diff --git a/src/solvers/proximal_point.jl b/src/solvers/proximal_point.jl new file mode 100644 index 0000000000..0ab05d8479 --- /dev/null +++ b/src/solvers/proximal_point.jl @@ -0,0 +1,67 @@ +# +# +# State +""" + ProximalPointState{P} <: AbstractGradientSolverState + +# Fields + +$(_var(:Field, :p; add=[:as_Iterate])) +$(_var(:Field, :stopping_criterion, "stop")) + +# Constructor + + ProximalPointState(M::AbstractManifold; kwargs...) + +Initialize the proximal point method solver state, where + +## Input + +$(_var(:Argument, :M; type=true)) + +## Keyword arguments + +$(_var(:Keyword, :p; add=:as_Initial)) +$(_var(:Keyword, :stopping_criterion; default="[`StopAfterIteration`](@ref)`(100)`")) + +# See also + +[`proximal point`](@ref) +""" +mutable struct ProximalPointState{ + P, + TStop<:StoppingCriterion, +} <: AbstractGradientSolverState + p::P + stop::TStop +end +function ProximalPointState( + M::AbstractManifold; + p::P=rand(M), + stopping_criterion::SC=StopAfterIteration(200), +) where { + P, + SC<:StoppingCriterion, +} + return ProximalPointState{P,SC}(p, stopping_criterion) +end +function get_message(pps::ProximalPointState) + return get_message(pps.stepsize) +end +function show(io::IO, gds::ProximalPointState) + i = get_count(gds, :Iterations) + Iter = (i > 0) ? "After $i iterations\n" : "" + Conv = indicates_convergence(gds.stop) ? "Yes" : "No" + s = """ + # Solver state for `Manopt.jl`s Proximal POint Method + $Iter + + ## Stopping criterion + + $(status_summary(gds.stop)) + This indicates convergence: $Conv""" + return print(io, s) +end +# +# +# solver interface