diff --git a/src/plans/stopping_criterion.jl b/src/plans/stopping_criterion.jl index 9beb911b9e..3a6acf09ed 100644 --- a/src/plans/stopping_criterion.jl +++ b/src/plans/stopping_criterion.jl @@ -248,7 +248,7 @@ default. You can also provide an inverse_retraction_method for the `distance` or to use its default inverse retraction. """ mutable struct StopWhenChangeLess{ - F,IRT<:AbstractInverseRetractionMethod,TSSA<:StoreStateAction,N + F,IRT<:AbstractInverseRetractionMethod,TSSA<:StoreStateAction,N<:Union{Missing,Real} } <: StoppingCriterion threshold::F last_change::F @@ -263,7 +263,7 @@ function StopWhenChangeLess( storage::StoreStateAction=StoreStateAction(M; store_points=Tuple{:Iterate}), inverse_retraction_method::IRT=default_inverse_retraction_method(M), outer_norm::N=missing, -) where {F,N,IRT<:AbstractInverseRetractionMethod} +) where {F,N<:Union{Missing,Real},IRT<:AbstractInverseRetractionMethod} return StopWhenChangeLess{F,IRT,typeof(storage),N}( ε, zero(ε), storage, inverse_retraction_method, -1, outer_norm ) @@ -278,7 +278,7 @@ function (c::StopWhenChangeLess)(mp::AbstractManoptProblem, s::AbstractManoptSol if has_storage(c.storage, PointStorageKey(:Iterate)) M = get_manifold(mp) p_old = get_storage(c.storage, PointStorageKey(:Iterate)) - r = (has_components(M) && !ismissing(c.outer_norm)) ? [c.outer_norm] : [] + r = (has_components(M) && !ismissing(c.outer_norm)) ? (c.outer_norm,) : () c.last_change = distance( M, get_iterate(s), p_old, c.inverse_retraction_method, r... ) @@ -513,7 +513,7 @@ indicates to stop when [`get_gradient`](@ref) is in (norm of) its change less th `vector_transport_method` denotes the vector transport ``$(_tex(:Cal,"T"))`` used. """ mutable struct StopWhenGradientChangeLess{ - F,VTM<:AbstractVectorTransportMethod,TSSA<:StoreStateAction,N + F,VTM<:AbstractVectorTransportMethod,TSSA<:StoreStateAction,N<:Union{Missing,Real} } <: StoppingCriterion threshold::F last_change::F @@ -530,7 +530,7 @@ function StopWhenGradientChangeLess( ), vector_transport_method::VTM=default_vector_transport_method(M), outer_norm::N=missing, -) where {F,N,VTM<:AbstractVectorTransportMethod} +) where {F,N<:Union{Missing,Real},VTM<:AbstractVectorTransportMethod} return StopWhenGradientChangeLess{F,VTM,typeof(storage),N}( ε, zero(ε), storage, vector_transport_method, -1, outer_norm ) @@ -554,7 +554,7 @@ function (c::StopWhenGradientChangeLess)( X_old = get_storage(c.storage, VectorStorageKey(:Gradient)) p = get_iterate(s) Xt = vector_transport_to(M, p_old, X_old, p, c.vector_transport_method) - r = (has_components(M) && !ismissing(c.outer_norm)) ? [c.outer_norm] : [] + r = (has_components(M) && !ismissing(c.outer_norm)) ? (c.outer_norm,) : () c.last_change = norm(M, p, Xt - get_gradient(s), r...) if c.last_change < c.threshold && k > 0 c.at_iteration = k @@ -625,7 +625,7 @@ $(_tex(:norm, "X"; index="p")) = $(_tex(:Bigl))( $(_tex(:sum))_{k=1}^n $(_tex(:n ``` where the sum turns into a maximum for the case ``r=∞``. -The `outer_norm` has no effect on manifols, that do not consist of components. +The `outer_norm` has no effect on manifolds that do not consist of components. If you pass in your individual norm, this can be deactivated on such manifolds by passing `missing` to `outer_norm`. @@ -638,7 +638,7 @@ Create a stopping criterion with threshold `ε` for the gradient, that is, this indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`, where the norm to use can be specified in the `norm=` keyword. """ -mutable struct StopWhenGradientNormLess{F,TF,N} <: StoppingCriterion +mutable struct StopWhenGradientNormLess{F,TF,N<:Union{Missing,Real}} <: StoppingCriterion norm::F threshold::TF last_change::TF @@ -646,7 +646,7 @@ mutable struct StopWhenGradientNormLess{F,TF,N} <: StoppingCriterion outer_norm::N function StopWhenGradientNormLess( ε::TF; norm::F=norm, outer_norm::N=missing - ) where {F,TF,N} + ) where {F,TF,N<:Union{Missing,Real}} return new{F,TF,N}(norm, ε, zero(ε), -1, outer_norm) end end @@ -659,7 +659,7 @@ function (sc::StopWhenGradientNormLess)( sc.at_iteration = -1 end if (k > 0) - r = (has_components(M) && !ismissing(sc.outer_norm)) ? [sc.outer_norm] : [] + r = (has_components(M) && !ismissing(sc.outer_norm)) ? (sc.outer_norm,) : () sc.last_change = sc.norm(M, get_iterate(s), get_gradient(s), r...) if sc.last_change < sc.threshold sc.at_iteration = k