Skip to content

Commit

Permalink
unify naming and add docstrings to all new (small) functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Jan 9, 2025
1 parent a7e9f8c commit e430d73
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 23 deletions.
68 changes: 57 additions & 11 deletions src/plans/mesh_adaptive_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ A subtype of this The functor has to fulfil
as well as
* provide a `get_poll_success(poll!)` function that indicates whether the last poll was successful in finding a new candidate,
* provide a `is_successful(poll!)` function that indicates whether the last poll was successful in finding a new candidate,
this returns the last successful mesh vector used.
The `kwargs...` could include
Expand Down Expand Up @@ -135,19 +135,47 @@ function LowerTriangularAdaptivePoll(
vector_transport_method,
)
end
function get_poll_success(ltap::LowerTriangularAdaptivePoll)
"""
is_successful(ltap::LowerTriangularAdaptivePoll)
Return whether the last [`LowerTriangularAdaptivePoll`](@ref) step was successful
"""
function is_successful(ltap::LowerTriangularAdaptivePoll)
return ltap.last_poll_improved
end
function get_poll_direction(ltap::LowerTriangularAdaptivePoll)
"""
get_descent_direction(ltap::LowerTriangularAdaptivePoll)
Return the direction of the last [`LowerTriangularAdaptivePoll`](@ref) that yields a descent of the cost.
If the poll was not successful, the zero vector is returned
"""
function get_descent_direction(ltap::LowerTriangularAdaptivePoll)
return ltap.X
end
function get_poll_basepoint(ltap::LowerTriangularAdaptivePoll)
"""
get_basepoint(ltap::LowerTriangularAdaptivePoll)
Return the base point of the tangent space, where the mash for the [`LowerTriangularAdaptivePoll`](@ref) is build in.
"""
function get_basepoint(ltap::LowerTriangularAdaptivePoll)
return ltap.base_point
end
function get_poll_best_candidate(ltap::LowerTriangularAdaptivePoll)
"""
get_candidate(ltap::LowerTriangularAdaptivePoll)
Return the candidate of the last successful [`LowerTriangularAdaptivePoll`](@ref).
If the poll was unsuccessful, the base point is returned.
"""
function get_candidate(ltap::LowerTriangularAdaptivePoll)
return ltap.candidate
end
function update_poll_basepoint!(M, ltap::LowerTriangularAdaptivePoll{P}, p::P) where {P}
"""
update_basepoint!(M, ltap::LowerTriangularAdaptivePoll, p)
Update the base point of the [`LowerTriangularAdaptivePoll`](@ref).
This especially also updates the basis, that is used to build a (new) mesh.
"""
function update_basepoint!(M, ltap::LowerTriangularAdaptivePoll{P}, p::P) where {P}
vector_transport_to!(
M, ltap.X, ltap.base_point, ltap.X, p, ltap.vector_transport_method
)
Expand All @@ -157,7 +185,7 @@ function update_poll_basepoint!(M, ltap::LowerTriangularAdaptivePoll{P}, p::P) w
return ltap
end
function show(io::IO, ltap::LowerTriangularAdaptivePoll)
s = "LowerTriangularAdaptivePoll using `basis=`$(ltap.basis), `retraction_method=`$(ltap.retraction_method), and `vector_transport_method=`$(ltap.vector_transport_method)"
s = "LowerTriangularAdaptivePoll on a basis $(ltap.basis), the retraction_method $(ltap.retraction_method), and the vector_transport_method $(ltap.vector_transport_method)"
return print(io, s)
end
function (ltap::LowerTriangularAdaptivePoll)(
Expand Down Expand Up @@ -270,10 +298,20 @@ function DefaultMeshAdaptiveDirectSearch(
)
return DefaultMeshAdaptiveDirectSearch(p, copy(M, p), X, false, retraction_method)
end
function get_search_success(dmads::DefaultMeshAdaptiveDirectSearch)
"""
is_successful(dmads::DefaultMeshAdaptiveDirectSearch)
Return whether the last [`DefaultMeshAdaptiveDirectSearch`](@ref) was succesful.
"""
function is_successful(dmads::DefaultMeshAdaptiveDirectSearch)
return dmads.last_search_improved
end
function get_search_point(dmads::DefaultMeshAdaptiveDirectSearch)
"""
get_candidate(dmads::DefaultMeshAdaptiveDirectSearch)
Return the last candidate a [`DefaultMeshAdaptiveDirectSearch`](@ref) found
"""
function get_candidate(dmads::DefaultMeshAdaptiveDirectSearch)
return dmads.p
end
function show(io::IO, dmads::DefaultMeshAdaptiveDirectSearch)
Expand All @@ -299,8 +337,16 @@ end
"""
MeshAdaptiveDirectSearchState <: AbstractManoptSolverState
* `p`: current iterate
* `q`: temp (old) iterate
# Fields
$(_var(:Field, :p; add=[:as_Iterate]))
* `mesh_size`: the current (internal) mesh size
* `scale_mesh`: the current scaling of the internal mesh size, yields the actual mesh size used
* `max_stepsize`: an upper bound for the longest step taken in looking for a candidate in either poll or search
* `poll_size`
$(_var(:Field, :stopping_criterion, "stop"))
* `poll::`[`AbstractMeshPollFunction`]: a poll step (functor) to perform
* `search::`[`AbstractMeshSearchFunction`}(@ref) a search step (functor) to perform
"""
mutable struct MeshAdaptiveDirectSearchState{P,F<:Real,PT,ST,SC<:StoppingCriterion} <:
Expand Down
24 changes: 12 additions & 12 deletions src/solvers/mesh_adaptive_direct_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,44 +88,44 @@ function initialize_solver!(
madss.poll(
amp, madss.mesh_size; scale_mesh=madss.scale_mesh, max_stepsize=madss.max_stepsize
)
if get_poll_success(madss.poll)
copyto!(M, madss.p, get_poll_best_candidate(madss.poll))
if is_successful(madss.poll)
copyto!(M, madss.p, get_candidate(madss.poll))
end
return madss
end
function step_solver!(amp::AbstractManoptProblem, madss::MeshAdaptiveDirectSearchState, k)
M = get_manifold(amp)
n = manifold_dimension(M)
# search if the last poll or last search was sucessful
if get_search_success(madss.search) || get_poll_success(madss.poll)
if is_successful(madss.search) || is_successful(madss.poll)
madss.search(
amp,
madss.mesh_size,
get_poll_best_candidate(madss.poll),
get_poll_direction(madss.poll);
get_candidate(madss.poll),
get_descent_direction(madss.poll);
scale_mesh=madss.scale_mesh,
max_stepsize=madss.max_stepsize,
)
end
# For succesful search, copy over iterate - skip poll, but update base
if get_search_success(madss.search)
copyto!(M, madss.p, get_search_point(madss.search))
update_poll_basepoint!(M, madss.poll, madss.p)
if is_successful(madss.search)
copyto!(M, madss.p, get_candidate(madss.search))
update_basepoint!(M, madss.poll, madss.p)
else #search was not sucessful: poll
update_poll_basepoint!(M, madss.poll, madss.p)
update_basepoint!(M, madss.poll, madss.p)
madss.poll(
amp,
madss.mesh_size;
scale_mesh=madss.scale_mesh,
max_stepsize=madss.max_stepsize,
)
# For succesfull poll, copy over iterate
if get_poll_success(madss.poll)
copyto!(M, madss.p, get_poll_best_candidate(madss.poll))
if is_successful(madss.poll)
copyto!(M, madss.p, get_candidate(madss.poll))
end
end
# If neither found a better candidate -> reduce step size, we might be close already!
if !(get_poll_success(madss.poll)) && !(get_search_success(madss.search))
if !(is_successful(madss.poll)) && !(is_successful(madss.search))
madss.mesh_size /= 4
elseif madss.mesh_size < 0.25 # else
madss.mesh_size *= 4 # Coarsen the mesh but not beyond 1
Expand Down

0 comments on commit e430d73

Please sign in to comment.