Skip to content

Commit

Permalink
A final bug in the last (even old) cell call is left, but I am out of…
Browse files Browse the repository at this point in the history
… ideas for today.
  • Loading branch information
kellertuer committed Apr 8, 2024
1 parent aa41069 commit d56ed0a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 31 deletions.
28 changes: 15 additions & 13 deletions src/plans/record.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ A RecordGroup to record the current iteration and the cost. The cost can then be
A RecordGroup to record the current iteration and the cost, which can then be accessed using `get_record(:Cost)` or `r[:Cost]`.
r = RecordGroup([RecordIteration(), :Cost => RecordCost()])
r = RecordGroup([RecordIteration(), (:Cost, RecordCost())])
A RecordGroup identical to the previous constructor, just a little easier to use.
"""
Expand All @@ -303,16 +303,16 @@ mutable struct RecordGroup <: RecordAction
return new(g, symbols)
end
function RecordGroup(
records::Vector{<:Union{<:RecordAction,Pair{Symbol,<:RecordAction}}}
records::Vector{<:Union{<:RecordAction,Tuple{Symbol,<:RecordAction}}}
)
g = Array{RecordAction,1}()
si = Dict{Symbol,Int}()
for i in 1:length(records)
if records[i] isa RecordAction
push!(g, records[i])
else
push!(g, records[i].second)
push!(si, records[i].first => i)
push!(g, records[i][2])
push!(si, records[i][1] => i)
end
end
return RecordGroup(g, si)
Expand Down Expand Up @@ -826,17 +826,19 @@ If `:WhenActive` is present, the resulting Action is wrappedn in [`RecordWhenAct
"""
function RecordGroupFactory(s::AbstractManoptSolverState, a::Array{<:Any,1})
# filter out every
group = Array{Union{<:RecordAction,Pair{Symbol,<:RecordAction}},1}()
group = Array{Union{<:RecordAction,Tuple{Symbol,<:RecordAction}},1}()
for e in filter(x -> !isa(x, Int) && (x [:WhenActive]), a) # filter Ints, &Active
if e isa Symbol # factory for this symbol, store in a pair (for better access later)
push!(group, e => RecordActionFactory(s, e))
elseif e isa Pair{<:Symbol,<:RecordAction} #already a generated action
push!(group, e)
else # process the others as elements for an action factory
push!(group, RecordActionFactory(s, e))
end
push!(group, (e, RecordActionFactory(s, e)))
elseif e isa Tuple{<:Symbol,<:RecordAction} #already a generated action
push!(group, e)
else # process the others as elements for an action factory
push!(group, RecordActionFactory(s, e))
end
end
record = length(group) > 1 ? RecordGroup(group) : first(group)
(length(group) > 1) && (record = RecordGroup(group))
(length(group) == 1) &&
(record = first(group) isa RecordAction ? first(group) : first(group)[2])
# filter integer numbers
e = filter(x -> isa(x, Int), a)
if length(e) > 0
Expand Down Expand Up @@ -867,7 +869,7 @@ create a [`RecordAction`](@ref) where
* `:IterativeTime` to record the times taken for each iteration.
"""
RecordActionFactory(::AbstractManoptSolverState, a::RecordAction) = a
RecordActionFactory(::AbstractManoptSolverState, sa::Pair{Symbol,<:RecordAction}) = sa
RecordActionFactory(::AbstractManoptSolverState, sa::Tuple{Symbol,<:RecordAction}) = sa
function RecordActionFactory(s::AbstractManoptSolverState, symbol::Symbol)
(symbol == :Change) && return RecordChange()
(symbol == :Cost) && return RecordCost()
Expand Down
40 changes: 22 additions & 18 deletions tutorials/HowToRecord.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Pkg.activate("."); # for reproducibility use the local tutorial environment.
Let's first load the necessary packages.

```{julia}
using Manopt, Manifolds, Random, ManifoldDiff
using Manopt, Manifolds, Random, ManifoldDiff, LinearAlgebra
using ManifoldDiff: grad_distance
Random.seed!(42);
```
Expand Down Expand Up @@ -134,9 +134,9 @@ We now first build a [`RecordGroup`](https://manoptjl.org/stable/plans/record/#
```{julia}
rI = RecordEvery(
RecordGroup([
:Iteration => RecordIteration(),
:Cost => RecordCost(),
:Gradient => RecordEntry(similar(data[1]), :X),
(:Iteration, RecordIteration()),
(:Cost, RecordCost()),
(:Gradient, RecordEntry(similar(data[1]), :X)),
]),
6,
)
Expand Down Expand Up @@ -188,8 +188,8 @@ M = Sphere(d - 1)
v0 = project(M, [ones(2)..., zeros(d - 2)...])
Z = v0 * v0'
#Cost and gradient
f(M, p) = -tr(transpose(p) * Z * p) / 2
grad_f(M, p) = project(M, p, -transpose.(Z) * p / 2 - Z * p / 2)
f2(M, p) = -tr(transpose(p) * Z * p) / 2
grad_f2(M, p) = project(M, p, -transpose.(Z) * p / 2 - Z * p / 2)
# Constraints
g(M, p) = -p # now p ≥ 0
mI = -Matrix{Float64}(I, d, d)
Expand All @@ -210,8 +210,8 @@ This is done with the `:Subsolver` keyword in the main `record=` keyword.
#| output: false
s1 = exact_penalty_method(
M,
f,
grad_f,
f2,
grad_f2,
p0;
g = g,
grad_g = grad_g,
Expand All @@ -233,8 +233,8 @@ When adding a number to not record on every iteration, the `:Subsolver` keyword
#| output: false
s2 = exact_penalty_method(
M,
f,
grad_f,
f2,
grad_f2,
p0;
g = g,
grad_g = grad_g,
Expand All @@ -256,8 +256,8 @@ Finally, instead of recording iterations, we can also specify to record the stop
#| output: false
s3 = exact_penalty_method(
M,
f,
grad_f,
f2,
grad_f2,
p0;
g = g,
grad_g = grad_g,
Expand Down Expand Up @@ -311,7 +311,7 @@ Now we can initialize the new cost and call the gradient descent.
Note that this illustrates also the last use case since you can pass symbol-action pairs into the `record=`array.

```{julia}
f2 = MyCost(data)
f3 = MyCost(data)
```

Now for the plain gradient descent, we have to modify the step (to a constant stepsize) and remove the default check whether the cost increases (setting `debug` to `[]`).
Expand All @@ -320,10 +320,14 @@ We also only look at the first 20 iterations to keep this example small in recor
```{julia}
R3 = gradient_descent(
M,
f2,
f3,
grad_f,
data[1];
record=[:Iteration, :Count => RecordCount(), :Cost],
record=[:Iteration => [
:Iteration,
(:Count, RecordCount()),
:Cost],
],
stepsize = ConstantStepsize(1.0),
stopping_criterion=StopAfterIteration(20),
debug=[],
Expand All @@ -349,18 +353,18 @@ and we see that the cost function is called once per iteration.
If we use this counting cost and run the default gradient descent with Armijo line search, we can infer how many Armijo line search backtracks are preformed:

```{julia}
f3 = MyCost(data)
f4 = MyCost(data)
```

To not get too many entries let's just look at the first 20 iterations again

```{julia}
R4 = gradient_descent(
M,
f3,
f4,
grad_f,
data[1];
record=[:Count => RecordCount()],
record=[RecordCount(),],
return_state=true,
)
```
Expand Down

0 comments on commit d56ed0a

Please sign in to comment.