Skip to content

Commit

Permalink
Fix a few tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Oct 19, 2024
1 parent 6b58dde commit 5d4304e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/Manopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ using ManifoldsBase:
get_vector,
get_vector!,
get_vectors,
has_components,
injectivity_radius,
inner,
inverse_retract,
Expand Down
28 changes: 16 additions & 12 deletions src/plans/stopping_criterion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ $(_var(:Field, :storage))
* `threshold`: the threshold for the change to check (run under to stop)
* `outer_norm`: if `M` is a manifold with components, this can be used to specify the norm,
that is used to compute the overall distance based on the element-wise distance.
You can deactivate this, but setting this value to `missing`.
# Example
Expand All @@ -238,7 +239,7 @@ If the manifold does not have components, the outer norm is ignored.
threshold::Float64;
storage::StoreStateAction=StoreStateAction([:Iterate]),
inverse_retraction_method::IRT=default_inverse_retraction_method(M)
outer_norm::Float64=2.0
outer_norm=missing
)
initialize the stopping criterion to a threshold `ε` using the
Expand All @@ -261,8 +262,8 @@ function StopWhenChangeLess(
ε::F;
storage::StoreStateAction=StoreStateAction(M; store_points=Tuple{:Iterate}),
inverse_retraction_method::IRT=default_inverse_retraction_method(M),
outer_norm::N=2,
) where {F<:Real,N<:Real,IRT<:AbstractInverseRetractionMethod}
outer_norm::N=missing,
) where {F,N,IRT<:AbstractInverseRetractionMethod}
return StopWhenChangeLess{F,IRT,typeof(storage),N}(
ε, zero(ε), storage, inverse_retraction_method, -1, outer_norm
)
Expand All @@ -277,7 +278,7 @@ function (c::StopWhenChangeLess)(mp::AbstractManoptProblem, s::AbstractManoptSol
if has_storage(c.storage, PointStorageKey(:Iterate))
M = get_manifold(mp)
p_old = get_storage(c.storage, PointStorageKey(:Iterate))
r = has_components(M) ? [c.outer_norm] : []
r = (has_components(M) && !ismissing(c.outer_norm)) ? [c.outer_norm] : []
c.last_change = distance(
M, get_iterate(s), p_old, c.inverse_retraction_method, r...
)
Expand All @@ -303,9 +304,10 @@ function status_summary(c::StopWhenChangeLess)
end
indicates_convergence(c::StopWhenChangeLess) = true
function show(io::IO, c::StopWhenChangeLess)
s = ismissing(c.outer_norm) ? "" : "and outer norm $(c.outer_norm)"
return print(
io,
"StopWhenChangeLess with threshold $(c.threshold) and outer norm $(c.outer_norm).\n $(status_summary(c))",
"StopWhenChangeLess with threshold $(c.threshold)$(s).\n $(status_summary(c))",
)
end

Expand Down Expand Up @@ -480,6 +482,7 @@ $(_var(:Field, :storage))
* `threshold`: the threshold for the change to check (run under to stop)
* `outer_norm`: if `M` is a manifold with components, this can be used to specify the norm,
that is used to compute the overall distance based on the element-wise distance.
You can deactivate this, but setting this value to `missing`.
# Example
Expand All @@ -502,7 +505,7 @@ The `outer_norm` has no effect on manifols, that do not consist of components.
ε::Float64;
storage::StoreStateAction=StoreStateAction([:Iterate]),
vector_transport_method::IRT=default_vector_transport_method(M),
outer_norm::N=2
outer_norm::N=missing
)
Create a stopping criterion with threshold `ε` for the change gradient, that is, this criterion
Expand All @@ -526,9 +529,9 @@ function StopWhenGradientChangeLess(
M; store_points=Tuple{:Iterate}, store_vectors=Tuple{:Gradient}
),
vector_transport_method::VTM=default_vector_transport_method(M),
outer_norm::N=2,
outer_norm::N=missing,
) where {F,N,VTM<:AbstractVectorTransportMethod}
return StopWhenGradientChangeLess{F,VTM,typeof(storage)}(
return StopWhenGradientChangeLess{F,VTM,typeof(storage),N}(
ε, zero(ε), storage, vector_transport_method, -1, outer_norm
)
end
Expand All @@ -551,7 +554,7 @@ function (c::StopWhenGradientChangeLess)(
X_old = get_storage(c.storage, VectorStorageKey(:Gradient))
p = get_iterate(s)
Xt = vector_transport_to(M, p_old, X_old, p, c.vector_transport_method)
r = has_components(M) ? [c.outer_norm] : []
r = (has_components(M) && !ismissing(c.outer_norm)) ? [c.outer_norm] : []
c.last_change = norm(M, p, Xt - get_gradient(s), r...)
if c.last_change < c.threshold && k > 0
c.at_iteration = k
Expand All @@ -574,9 +577,10 @@ function status_summary(c::StopWhenGradientChangeLess)
return "|Δgrad f| < $(c.threshold): $s"
end
function show(io::IO, c::StopWhenGradientChangeLess)
s = ismissing(c.outer_norm) ? "" : "outer_norm=$(c.outer_norm), "
return print(
io,
"StopWhenGradientChangeLess with threshold $(c.threshold); ouer_norm=$(c.outer_norm), vector_transport_method=$(c.vector_transport_method))\n $(status_summary(c))",
"StopWhenGradientChangeLess with threshold $(c.threshold); $(s)vector_transport_method=$(c.vector_transport_method))\n $(status_summary(c))",
)
end

Expand Down Expand Up @@ -628,7 +632,7 @@ by passing `missing` to `outer_norm`.
# Constructor
StopWhenGradientNormLess(ε; norm=ManifoldsBase.norm, outer_norm=2)
StopWhenGradientNormLess(ε; norm=ManifoldsBase.norm, outer_norm=missing)
Create a stopping criterion with threshold `ε` for the gradient, that is, this criterion
indicates to stop when [`get_gradient`](@ref) returns a gradient vector of norm less than `ε`,
Expand All @@ -655,7 +659,7 @@ function (sc::StopWhenGradientNormLess)(
sc.at_iteration = -1
end
if (k > 0)
r = (has_components(M) && !ismissing(outer_norm)) ? [c.outer_norm] : []
r = (has_components(M) && !ismissing(sc.outer_norm)) ? [sc.outer_norm] : []
sc.last_change = sc.norm(M, get_iterate(s), get_gradient(s), r...)
if sc.last_change < sc.threshold
sc.at_iteration = k
Expand Down
2 changes: 1 addition & 1 deletion test/plans/test_stopping_criteria.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct DummyStoppingCriterion <: StoppingCriterion end
@testset "Stopping Criterion &/| operators" begin
a = StopAfterIteration(200)
b = StopWhenChangeLess(Euclidean(), 1e-6)
sb = "StopWhenChangeLess with threshold 1.0e-6\n $(Manopt.status_summary(b))"
sb = "StopWhenChangeLess with threshold 1.0e-6.\n $(Manopt.status_summary(b))"
@test repr(b) == sb
@test get_reason(b) == ""
b2 = StopWhenChangeLess(Euclidean(), 1e-6) # second constructor
Expand Down

0 comments on commit 5d4304e

Please sign in to comment.