Skip to content

Commit

Permalink
Finally got it to work.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Apr 9, 2024
1 parent 5a2cb02 commit 325f2cf
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
14 changes: 8 additions & 6 deletions src/plans/record.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ mutable struct RecordGroup <: RecordAction
return new(g, symbols)
end
function RecordGroup(
records::Vector{<:Union{<:RecordAction,Pair{<:RecordAction, Symbol}}}
records::Vector{<:Union{<:RecordAction,Pair{<:RecordAction,Symbol}}}
)
g = Array{RecordAction,1}()
si = Dict{Symbol,Int}()
Expand Down Expand Up @@ -776,10 +776,12 @@ function RecordFactory(s::AbstractManoptSolverState, a::Array{<:Any,1})
# filter out :Iteration defaults
# filter numbers & stop & pairs (pairs handles separately, numbers at the end)
iter_entries = filter(
x -> !isa(x, Pair) && (x [:Stop, :WhenActive]) && !isa(x, Int), a
x ->
!isa(x, Pair{Symbol,T} where {T}) && (x [:Stop, :WhenActive]) && !isa(x, Int),
a,
)
# Filter pairs
b = filter(x -> isa(x, Pair), a)
b = filter(x -> isa(x, Pair{Symbol,T} where {T}), a)
# Push this to the :Iteration if that exists or add that pair
i = findlast(x -> (isa(x, Pair)) && (x.first == :Iteration), b)
if !isnothing(i)
Expand Down Expand Up @@ -832,11 +834,11 @@ If `:WhenActive` is present, the resulting Action is wrappedn in [`RecordWhenAct
"""
function RecordGroupFactory(s::AbstractManoptSolverState, a::Array{<:Any,1})
# filter out every
group = Array{Union{<:RecordAction,Tuple{Symbol,<:RecordAction}},1}()
group = Array{Union{<:RecordAction,Pair{<:RecordAction,Symbol}},1}()
for e in filter(x -> !isa(x, Int) && (x [:WhenActive]), a) # filter Ints, &Active
if e isa Symbol # factory for this symbol, store in a pair (for better access later)
push!(group, (e, RecordActionFactory(s, e)))
elseif e isa Pair{<:RecordAction, <:Symbol} #already a generated action => symbol to store at
push!(group, RecordActionFactory(s, e) => e)
elseif e isa Pair{<:RecordAction,Symbol} #already a generated action => symbol to store at
push!(group, e)

Check warning on line 842 in src/plans/record.jl

View check run for this annotation

Codecov / codecov/patch

src/plans/record.jl#L842

Added line #L842 was not covered by tests
else # process the others as elements for an action factory
push!(group, RecordActionFactory(s, e))
Expand Down
2 changes: 1 addition & 1 deletion test/plans/test_record.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Manopt.get_manopt_parameter(d::TestRecordParameterState, ::Val{:value}) = d.valu
RecordIteration,
)
@test isa(RecordFactory(gds, :Iteration)[:Iteration], RecordIteration)
sa = :It3 => RecordIteration()
sa = RecordIteration() => :It3
@test RecordActionFactory(gds, sa) === sa
@test !has_record(gds)
@test_throws ErrorException get_record(gds)
Expand Down
16 changes: 10 additions & 6 deletions tutorials/HowToRecord.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,16 @@ We now first build a [`RecordGroup`](https://manoptjl.org/stable/plans/record/#
```{julia}
rI = RecordEvery(
RecordGroup([
(:Iteration, RecordIteration()),
(:Cost, RecordCost()),
(:Gradient, RecordEntry(similar(data[1]), :X)),
RecordIteration() => :Iteration,
RecordCost() => :Cost,
RecordEntry(similar(data[1]), :X) => :Gradient,
]),
6,
)
```

and for recording the final iteration number
where the notation as a pair with the symbol can be read as “Is accessible by”.
For recording the final iteration number

```{julia}
sI = RecordIteration()
Expand Down Expand Up @@ -174,6 +175,9 @@ and the other values during the iterations are
get_record(res, :Iteration, (:Iteration, :Cost))
```

where the last tuple contains the names from the pairs when we generated the record group.
So similarly we can use `:Gradient` as specified before to access the recorded gradient.

## Recording from a Subsolver

One can also record from a subsolver. For that we need a problem that actually requires a subsolver. We take the constraint example from the
Expand Down Expand Up @@ -325,7 +329,7 @@ R3 = gradient_descent(
data[1];
record=[:Iteration => [
:Iteration,
(:Count, RecordCount()),
RecordCount() => :Count,
:Cost],
],
stepsize = ConstantStepsize(1.0),
Expand All @@ -335,7 +339,7 @@ R3 = gradient_descent(
)
```

For `:Cost` we already learned how to access them, the `:Count =>` introduces the following action to obtain the `:Count`. We can again access the whole sets of records
For `:Cost` we already learned how to access them, the ` => :Count` introduces preceeding action to obtain the `:Count` symbol as its access. We can again access the whole sets of records

```{julia}
get_record(R3)
Expand Down

0 comments on commit 325f2cf

Please sign in to comment.