diff --git a/test/plans/test_proximal_plan.jl b/test/plans/test_proximal_plan.jl index 2694dded82..4018cb29d9 100644 --- a/test/plans/test_proximal_plan.jl +++ b/test/plans/test_proximal_plan.jl @@ -15,9 +15,11 @@ include("../utils/dummy_types.jl") M = Euclidean(2) p = [1.0, 2.0] Q = [[2.0, 3.0], [3.0, 4.0]] - f(M, p) = sum(distance(M, p, q) for q in Q) + f(M, p) = 0.5 * sum(distance(M, p, q)^2 for q in Q) + f2(M, p) = 0.5 * distance(M, p, Q[1]) proxes_f = Tuple((N, λ, p) -> prox_distance(N, λ, q, p) for q in Q) ppo = ManifoldProximalMapObjective(f, proxes_f) + ppo2 = ManifoldProximalMapObjective(f2, proxes_f[1]) @testset "Objective Decorator passthrough" begin dppo = DummyDecoratedObjective(ppo) for i in 1:2 @@ -44,6 +46,14 @@ include("../utils/dummy_types.jl") get_proximal_map!(M, q2, cppo2, 0.1, p, 1) @test q2 == q @test get_count(cppo2, :ProximalMap) == 2 + # single function + cppo3 = ManifoldCountObjective(M, ppo2, Dict([:ProximalMap => 0])) + q = get_proximal_map(M, cppo3, 0.1, p) + @test q == get_proximal_map(M, ppo2, 0.1, p) + q2 = copy(M, p) + get_proximal_map!(M, q2, cppo3, 0.1, p) + @test q2 == q + @test get_count(cppo3, :ProximalMap) == 2 end @testset "Cache" begin cppo = ManifoldCountObjective(M, ppo, [:ProximalMap])