Skip to content

Commit 22c7ce8

Browse files
committed
Update base point properly.
1 parent 71d1a07 commit 22c7ce8

File tree

5 files changed

+20
-18
lines changed

5 files changed

+20
-18
lines changed

docs/src/plans/objective.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -249,5 +249,6 @@ AbstractManifoldSubObjective
249249
```@docs
250250
Manopt.get_objective_cost
251251
Manopt.get_objective_gradient
252-
Manopt.get_objective_Hessian
252+
Manopt.get_objective_hessian
253+
Manopt.get_objective_preconditioner
253254
```

src/plans/problem.jl

+4
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ function set_manopt_parameter!(amp::AbstractManoptProblem, ::Val{:Manifold}, arg
9595
set_manopt_parameter!(get_manifold(amp), args...)
9696
return amp
9797
end
98+
function set_manopt_parameter!(TpM::TangentSpace, ::Val{:Basepoint}, p)
99+
copyto!(TpM.manifold, TpM.point, p)
100+
return TpM
101+
end
98102

99103
function set_manopt_parameter!(amp::AbstractManoptProblem, ::Val{:Objective}, args...)
100104
set_manopt_parameter!(get_objective(amp), args...)

src/solvers/trust_regions.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -565,9 +565,8 @@ function step_solver!(mp::AbstractManoptProblem, trs::TrustRegionsState, i)
565565
# Solve TR subproblem – update options
566566
# TODO provide these setters for the sub problem / sub state
567567
# set_paramater!(trs.sub_problem, :Basepoint, trs.p)
568-
set_manopt_parameter!(trs.sub_state, :Basepoint, trs.p)
569-
set_manopt_parameter!(trs.sub_problem, :Basepoint, trs.p)
570-
set_manopt_parameter!(trs.sub_state, :Iterate, trs.Y)
568+
set_manopt_parameter!(trs.sub_problem, :Manifold, :Basepoint, copy(M, trs.p))
569+
set_manopt_parameter!(trs.sub_state, :Iterate, copy(M, trs.p, trs.Y))
571570
set_manopt_parameter!(trs.sub_state, :TrustRegionRadius, trs.trust_region_radius)
572571
solve!(trs.sub_problem, trs.sub_state)
573572
#

test/solvers/test_difference_of_convex.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ import Manifolds: inner
129129
M, grad_h!, p0; g=g, grad_g=grad_g!, evaluation=InplaceEvaluation()
130130
)
131131
p5 = difference_of_convex_proximal_point(M, grad_h, p0; g=g, grad_g=grad_g)
132-
#p5b = difference_of_convex_proximal_point(M, grad_h; g=g, grad_g=grad_g)
132+
p5b = difference_of_convex_proximal_point(M, grad_h; g=g, grad_g=grad_g)
133133
# using gradient descent
134134
p5c = difference_of_convex_proximal_point(
135135
M,
@@ -150,12 +150,12 @@ import Manifolds: inner
150150
p6 = get_solver_result(s2)
151151
@test Manopt.get_message(s2) == ""
152152

153-
@test_broken isapprox(M, p3, p4)
153+
@test isapprox(M, p3, p4)
154154
@test isapprox(M, p4, p5)
155155
@test isapprox(M, p5, p6)
156-
# @test isapprox(f(M, p5b), 0.0; atol=2e-16) # bit might be a different min due to rand
157-
@test isapprox(f(M, p5c), 0.0; atol=1e-9) # might be a bit imprecise
158-
@test_broken isapprox(f(M, p4), 0.0; atol=1e-8) # might be a bit imprecise
156+
@test isapprox(f(M, p5b), 0.0; atol=2e-16) # bit might be a different min due to rand
157+
@test isapprox(f(M, p5c), 0.0; atol=1e-10)
158+
@test isapprox(f(M, p4), 0.0; atol=1e-14)
159159

160160
Random.seed!(23)
161161
p7 = difference_of_convex_algorithm(M, f, g, grad_h; grad_g=grad_g)
@@ -169,7 +169,7 @@ import Manifolds: inner
169169
p9 = difference_of_convex_algorithm(
170170
M, f, g, grad_h, p0; grad_g=grad_g, sub_hess=nothing
171171
)
172-
@test_broken isapprox(M, p9, p2; atol=1e-7)
172+
@test isapprox(M, p9, p2; atol=1e-9)
173173

174174
@test_throws ErrorException difference_of_convex_proximal_point(
175175
M, grad_h, p0; sub_problem=nothing

test/solvers/test_trust_regions.jl

+6-8
Original file line numberDiff line numberDiff line change
@@ -359,24 +359,22 @@ include("../utils/example_tasks.jl")
359359
end
360360
@testset "Euclidean Embedding" begin
361361
Random.seed!(42)
362-
n = 5
362+
n = 2
363363
A = Symmetric(randn(n + 1, n + 1))
364364
# Euclidean variant with conversion
365365
M = Sphere(n)
366-
p0 = rand(M)
366+
p0 = [1.0, zeros(n)...]
367367
f(E, p) = p' * A * p
368368
∇f(E, p) = A * p
369369
∇²f(M, p, X) = A * X
370370
λ = min(eigvals(A)...)
371-
q = trust_regions(M, f, ∇f, p0; objective_type=:Euclidean)
372-
q2 = trust_regions(M, f, ∇f, ∇²f, p0; objective_type=:Euclidean)
373-
@test f(M, q) λ atol = 2 * 1e-1
374-
@test_broken f(M, q) f(M, q2)
371+
q = trust_regions(M, f, ∇f, p0; objective_type=:Euclidean, (project!)=project!)
372+
@test f(M, q) λ atol = 1 * 1e-1 # a bit inprecise?
375373
grad_f(M, p) = A * p - (p' * A * p) * p
376374
Hess_f(M, p, X) = A * X - (p' * A * X) .* p - (p' * A * p) .* X
377375
q3 = trust_regions(M, f, grad_f, p0)
378376
q4 = trust_regions(M, f, grad_f, Hess_f, p0)
379-
@test f(M, q) f(M, q3) atol = 5 * 1e-1 # A bit imprecise?
380-
@test f(M, q) f(M, q4) atol = 5 * 1e-1 # A bit imprecise?
377+
@test f(M, q3) λ atol = 5 * 1e-8
378+
@test f(M, q4) λ atol = 5 * 1e-10
381379
end
382380
end

0 commit comments

Comments
 (0)