Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Mateusz Baran <[email protected]>
  • Loading branch information
kellertuer and mateuszbaran authored Oct 19, 2024
1 parent ff0a56a commit 7ff1e88
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/plans/stopping_criterion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ default. You can also provide an inverse_retraction_method for the `distance` or
to use its default inverse retraction.
"""
mutable struct StopWhenChangeLess{
F,IRT<:AbstractInverseRetractionMethod,TSSA<:StoreStateAction,N
F,IRT<:AbstractInverseRetractionMethod,TSSA<:StoreStateAction,N<:Union{Missing,Real}
} <: StoppingCriterion
threshold::F
last_change::F
Expand All @@ -263,7 +263,7 @@ function StopWhenChangeLess(
storage::StoreStateAction=StoreStateAction(M; store_points=Tuple{:Iterate}),
inverse_retraction_method::IRT=default_inverse_retraction_method(M),
outer_norm::N=missing,
) where {F,N,IRT<:AbstractInverseRetractionMethod}
) where {F,N<:Union{Missing,Real},IRT<:AbstractInverseRetractionMethod}
return StopWhenChangeLess{F,IRT,typeof(storage),N}(
ε, zero(ε), storage, inverse_retraction_method, -1, outer_norm
)
Expand All @@ -278,7 +278,7 @@ function (c::StopWhenChangeLess)(mp::AbstractManoptProblem, s::AbstractManoptSol
if has_storage(c.storage, PointStorageKey(:Iterate))
M = get_manifold(mp)
p_old = get_storage(c.storage, PointStorageKey(:Iterate))
r = (has_components(M) && !ismissing(c.outer_norm)) ? [c.outer_norm] : []
r = (has_components(M) && !ismissing(c.outer_norm)) ? (c.outer_norm,) : ()
c.last_change = distance(
M, get_iterate(s), p_old, c.inverse_retraction_method, r...
)
Expand Down Expand Up @@ -513,7 +513,7 @@ indicates to stop when [`get_gradient`](@ref) is in (norm of) its change less th
`vector_transport_method` denotes the vector transport ``$(_tex(:Cal,"T"))`` used.
"""
mutable struct StopWhenGradientChangeLess{
F,VTM<:AbstractVectorTransportMethod,TSSA<:StoreStateAction,N
F,VTM<:AbstractVectorTransportMethod,TSSA<:StoreStateAction,N<:Union{Missing,Real}
} <: StoppingCriterion
threshold::F
last_change::F
Expand All @@ -530,7 +530,7 @@ function StopWhenGradientChangeLess(
),
vector_transport_method::VTM=default_vector_transport_method(M),
outer_norm::N=missing,
) where {F,N,VTM<:AbstractVectorTransportMethod}
) where {F,N<:Union{Missing,Real},VTM<:AbstractVectorTransportMethod}
return StopWhenGradientChangeLess{F,VTM,typeof(storage),N}(
ε, zero(ε), storage, vector_transport_method, -1, outer_norm
)
Expand All @@ -554,7 +554,7 @@ function (c::StopWhenGradientChangeLess)(
X_old = get_storage(c.storage, VectorStorageKey(:Gradient))
p = get_iterate(s)
Xt = vector_transport_to(M, p_old, X_old, p, c.vector_transport_method)
r = (has_components(M) && !ismissing(c.outer_norm)) ? [c.outer_norm] : []
r = (has_components(M) && !ismissing(c.outer_norm)) ? (c.outer_norm,) : ()
c.last_change = norm(M, p, Xt - get_gradient(s), r...)
if c.last_change < c.threshold && k > 0
c.at_iteration = k
Expand Down Expand Up @@ -625,7 +625,7 @@ $(_tex(:norm, "X"; index="p")) = $(_tex(:Bigl))( $(_tex(:sum))_{k=1}^n $(_tex(:n
```
where the sum turns into a maximum for the case ``r=∞``.
The `outer_norm` has no effect on manifols, that do not consist of components.
The `outer_norm` has no effect on manifolds that do not consist of components.
If you pass in your individual norm, this can be deactivated on such manifolds
by passing `missing` to `outer_norm`.
Expand All @@ -638,15 +638,15 @@ Create a stopping criterion with threshold `ε` for the gradient, that is, this
indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`,
where the norm to use can be specified in the `norm=` keyword.
"""
mutable struct StopWhenGradientNormLess{F,TF,N} <: StoppingCriterion
mutable struct StopWhenGradientNormLess{F,TF,N<:Union{Missing,Real}} <: StoppingCriterion
norm::F
threshold::TF
last_change::TF
at_iteration::Int
outer_norm::N
function StopWhenGradientNormLess(
ε::TF; norm::F=norm, outer_norm::N=missing
) where {F,TF,N}
) where {F,TF,N<:Union{Missing,Real}}
return new{F,TF,N}(norm, ε, zero(ε), -1, outer_norm)
end
end
Expand All @@ -659,7 +659,7 @@ function (sc::StopWhenGradientNormLess)(
sc.at_iteration = -1
end
if (k > 0)
r = (has_components(M) && !ismissing(sc.outer_norm)) ? [sc.outer_norm] : []
r = (has_components(M) && !ismissing(sc.outer_norm)) ? (sc.outer_norm,) : ()
sc.last_change = sc.norm(M, get_iterate(s), get_gradient(s), r...)
if sc.last_change < sc.threshold
sc.at_iteration = k
Expand Down

0 comments on commit 7ff1e88

Please sign in to comment.