From a51047faba2ef871cd2fb320ed4ca9f59adf8ca2 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 13:28:26 +0100 Subject: [PATCH 01/11] fix naming --- ...RTTraces.jl => ElasticArraySARTSTraces.jl} | 8 +- src/common/common.jl | 2 +- src/episodes.jl | 2 + src/traces.jl | 2 + test/common.jl | 8 +- test/episodes.jl | 198 ++++++++++++++++++ 6 files changed, 211 insertions(+), 9 deletions(-) rename src/common/{ElasticArraySARTTraces.jl => ElasticArraySARTSTraces.jl} (86%) diff --git a/src/common/ElasticArraySARTTraces.jl b/src/common/ElasticArraySARTSTraces.jl similarity index 86% rename from src/common/ElasticArraySARTTraces.jl rename to src/common/ElasticArraySARTSTraces.jl index 1090ece..bfbbc4d 100644 --- a/src/common/ElasticArraySARTTraces.jl +++ b/src/common/ElasticArraySARTSTraces.jl @@ -1,8 +1,8 @@ -export ElasticArraySARTTraces +export ElasticArraySARTSTraces using ElasticArrays: ElasticArray, resize_lastdim! -const ElasticArraySARTTraces = Traces{ +const ElasticArraySARTSTraces = Traces{ SS′AA′RT, <:Tuple{ <:MultiplexTraces{SS′,<:Trace{<:ElasticArray}}, @@ -12,7 +12,7 @@ const ElasticArraySARTTraces = Traces{ } } -function ElasticArraySARTTraces(; +function ElasticArraySARTSTraces(; state=Int => (), action=Int => (), reward=Float32 => (), @@ -37,4 +37,4 @@ end Base.push!(a::ElasticArray, x) = append!(a, x) Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x]) -Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0) \ No newline at end of file +Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0) diff --git a/src/common/common.jl b/src/common/common.jl index b334f7a..90e4405 100644 --- a/src/common/common.jl +++ b/src/common/common.jl @@ -14,4 +14,4 @@ include("CircularArraySARTSTraces.jl") include("CircularArraySARTSATraces.jl") include("CircularArraySLARTTraces.jl") include("CircularPrioritizedTraces.jl") -include("ElasticArraySARTTraces.jl") +include("ElasticArraySARTSTraces.jl") diff --git a/src/episodes.jl b/src/episodes.jl index 0113a1c..e2cb521 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -1,5 +1,6 @@ export EpisodesBuffer, PartialNamedTuple import DataStructures.CircularBuffer +using ElasticArrays: ElasticArray, resize_lastdim! """ EpisodesBuffer(traces::AbstractTraces) @@ -90,6 +91,7 @@ function pad!(trace::Trace) return nothing end +pad!(vect::ElasticArray{T, Vector{T}}) where {T} = pad!(vect, zero(T)) pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T)) pad!(vect::Vector{T}) where {T} = push!(vect, zero(T)) diff --git a/src/traces.jl b/src/traces.jl index 9676c03..1cd36bb 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -3,6 +3,7 @@ export Trace, Traces, MultiplexTraces import MacroTools: @forward import CircularArrayBuffers.CircularArrayBuffer +using ElasticArrays: ElasticArray import Adapt ##### @@ -55,6 +56,7 @@ Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s capacity(t::AbstractTrace) = ReinforcementLearningTrajectories.capacity(t.parent) capacity(t::CircularArrayBuffer) = CircularArrayBuffers.capacity(t) capacity(::AbstractVector) = Inf +capacity(::ElasticArray) = Inf ##### diff --git a/test/common.jl b/test/common.jl index 06e69a8..afcc89b 100644 --- a/test/common.jl +++ b/test/common.jl @@ -94,15 +94,15 @@ end @test batch.terminal == Bool[0, 0, 0] |> gpu end -@testset "ElasticArraySARTTraces" begin - t = ElasticArraySARTTraces(; +@testset "ElasticArraySARTSTraces" begin + t = ElasticArraySARTSTraces(; state=Float32 => (2, 3), action=Int => (), reward=Float32 => (), terminal=Bool => () ) - @test t isa ElasticArraySARTTraces + @test t isa ElasticArraySARTSTraces push!(t, (state=ones(Float32, 2, 3), action=1)) push!(t, (reward=1.0f0, terminal=false, state=ones(Float32, 2, 3) * 2, action=2)) @@ -185,4 +185,4 @@ end eb[:priority, [1, 2]] = [0, 0] @test eb[:priority] == [zeros(2);ones(8)] -end \ No newline at end of file +end diff --git a/test/episodes.jl b/test/episodes.jl index 8244b19..43a8b08 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -197,4 +197,202 @@ using Test @test eb.step_numbers == [1:16;1:16] @test length(eb) == 31 end + @testset "with elastic traces" begin + t = ElasticArraySARTSTraces(; + state=Int => (), + action=Int => (), + reward=Float32 => (), + terminal=Bool => () + ) + + eb = EpisodesBuffer(t) + push!(eb, (state = 1,)) #partial inserting + for i = 1:15 + push!(eb, (state = i+1, reward =i)) + end + @test length(eb.traces) == 15 + @test eb.sampleable_inds == [fill(true, 15); [false]] + @test all(==(15), eb.episodes_lengths) + @test eb.step_numbers == [1:16;] + push!(eb, (state = 1,)) #partial inserting + for i = 1:15 + push!(eb, (state = i+1, reward =i)) + end + @test eb.sampleable_inds == [fill(true, 15); [false];fill(true, 15); [false]] + @test all(==(15), eb.episodes_lengths) + @test eb.step_numbers == [1:16;1:16] + @test length(eb) == 31 + end + @testset "with circular traces" begin + eb = EpisodesBuffer( + CircularArraySARTSTraces(; + capacity=10) + ) + #push a first episode l=5 + push!(eb, (state = 1,)) + @test eb.sampleable_inds[end] == 0 + @test eb.episodes_lengths[end] == 0 + @test eb.step_numbers[end] == 1 + for i = 1:5 + push!(eb, (state = i+1, action =i, reward = i, terminal = false)) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 1 + @test eb.step_numbers[end] == i + 1 + @test eb.episodes_lengths[end-i:end] == fill(i, i+1) + end + @test eb.sampleable_inds == [1,1,1,1,1,0] + @test length(eb.traces) == 5 + #start new episode of 6 periods. + push!(eb, (state = 7,)) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 0 + @test eb.episodes_lengths[end] == 0 + @test eb.step_numbers[end] == 1 + @test eb.sampleable_inds == [1,1,1,1,1,0,0] + @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is filled as zero + ep2_len = 0 + for (j,i) = enumerate(8:11) + ep2_len += 1 + push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 1 + @test eb.step_numbers[end] == j + 1 + @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) + end + @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] + @test length(eb.traces) == 10 + #three last steps replace oldest steps in the buffer. + for (i, s) = enumerate(12:13) + ep2_len += 1 + push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 1 + @test eb.step_numbers[end] == i + 1 + 4 + @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) + end + #episode 1 + for (i,s) in enumerate(3:13) + if i in (4, 11) + @test eb.sampleable_inds[i] == 0 + continue + else + @test eb.sampleable_inds[i] == 1 + end + b = eb[i] + @test b[:state] == b[:action] == b[:reward] == s + @test b[:next_state] == s + 1 + end + #episode 2 + #start a third episode + push!(eb, (state = 14, )) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 0 + @test eb.episodes_lengths[end] == 0 + @test eb.step_numbers[end] == 1 + #push until it reaches it own start + for (i,s) in enumerate(15:26) + push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + end + @test eb.sampleable_inds == [fill(true, 10); [false]] + @test eb.episodes_lengths == fill(length(15:26), 11) + @test eb.step_numbers == [3:13;] + step = popfirst!(eb) + @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9 + @test first(eb.step_numbers) == 4 + step = pop!(eb) + @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8 + @test last(eb.step_numbers) == 12 + @test size(eb) == size(eb.traces) == (8,) + empty!(eb) + @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) + show(eb); + end + @testset "with PartialNamedTuple" begin + eb = EpisodesBuffer( + CircularArraySARTSATraces(; + capacity=10) + ) + #push a first episode l=5 + push!(eb, (state = 1,)) + @test eb.sampleable_inds[end] == 0 + @test eb.episodes_lengths[end] == 0 + @test eb.step_numbers[end] == 1 + for i = 1:5 + push!(eb, (state = i+1, action =i, reward = i, terminal = false)) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 1 + @test eb.step_numbers[end] == i + 1 + @test eb.episodes_lengths[end-i:end] == fill(i, i+1) + end + push!(eb, PartialNamedTuple((action = 6,))) + @test eb.sampleable_inds == [1,1,1,1,1,0] + @test length(eb.traces) == 5 + #start new episode of 6 periods. + push!(eb, (state = 7,)) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 0 + @test eb.episodes_lengths[end] == 0 + @test eb.step_numbers[end] == 1 + @test eb.sampleable_inds == [1,1,1,1,1,0,0] + @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero + ep2_len = 0 + for (j,i) = enumerate(8:11) + ep2_len += 1 + push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 1 + @test eb.step_numbers[end] == j + 1 + @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) + end + @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] + @test length(eb.traces) == 9 #an action is missing at this stage + #three last steps replace oldest steps in the buffer. + for (i, s) = enumerate(12:13) + ep2_len += 1 + push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 1 + @test eb.step_numbers[end] == i + 1 + 4 + @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) + end + push!(eb, PartialNamedTuple((action = 13,))) + @test length(eb.traces) == 10 + #episode 1 + for (i,s) in enumerate(3:13) + if i in (4, 11) + @test eb.sampleable_inds[i] == 0 + continue + else + @test eb.sampleable_inds[i] == 1 + end + b = eb[i] + @test b[:state] == b[:action] == b[:reward] == s + @test b[:next_state] == b[:next_action] == s + 1 + end + #episode 2 + #start a third episode + push!(eb, (state = 14,)) + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 0 + @test eb.episodes_lengths[end] == 0 + @test eb.step_numbers[end] == 1 + #push until it reaches it own start + for (i,s) in enumerate(15:26) + push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + end + push!(eb, PartialNamedTuple((action = 26,))) + @test eb.sampleable_inds == [fill(true, 10); [false]] + @test eb.episodes_lengths == fill(length(15:26), 11) + @test eb.step_numbers == [3:13;] + step = popfirst!(eb) + @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9 + @test first(eb.step_numbers) == 4 + step = pop!(eb) + @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8 + @test last(eb.step_numbers) == 12 + @test size(eb) == size(eb.traces) == (8,) + empty!(eb) + @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) + show(eb); + end end From 2de532a86b176c214761c5c97b789f178fd52777 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 14:00:57 +0100 Subject: [PATCH 02/11] Update ElasticArraySARTSTraces and add new files --- src/common/ElasticArraySARTSATraces.jl | 31 +++++++++++++ src/common/ElasticArraySARTSTraces.jl | 32 +++++--------- src/common/ElasticArraySLARTTraces.jl | 35 +++++++++++++++ src/common/ElasticPrioritizedTraces.jl | 60 ++++++++++++++++++++++++++ src/common/common.jl | 4 ++ src/common/common_elastic_array.jl | 9 ++++ src/episodes.jl | 5 ++- test/episodes.jl | 31 +------------ 8 files changed, 155 insertions(+), 52 deletions(-) create mode 100644 src/common/ElasticArraySARTSATraces.jl create mode 100644 src/common/ElasticArraySLARTTraces.jl create mode 100644 src/common/ElasticPrioritizedTraces.jl create mode 100644 src/common/common_elastic_array.jl diff --git a/src/common/ElasticArraySARTSATraces.jl b/src/common/ElasticArraySARTSATraces.jl new file mode 100644 index 0000000..d0ecf25 --- /dev/null +++ b/src/common/ElasticArraySARTSATraces.jl @@ -0,0 +1,31 @@ +export ElasticArraySARTSTraces + +const ElasticArraySARTSATraces = Traces{ + SS′AA′RT, + <:Tuple{ + <:MultiplexTraces{SS′,<:Trace{<:ElasticArray}}, + <:MultiplexTraces{AA′,<:Trace{<:ElasticArray}}, + <:Trace{<:ElasticArray}, + <:Trace{<:ElasticArray}, + } +} + +function ElasticArraySARTSATraces(; + state=Int => (), + action=Int => (), + reward=Float32 => (), + terminal=Bool => () +) + state_eltype, state_size = state + action_eltype, action_size = action + reward_eltype, reward_size = reward + terminal_eltype, terminal_size = terminal + + MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + + MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) + + Traces( + reward=ElasticArray{reward_eltype}(undef, reward_size..., 0), + terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0), + ) +end + diff --git a/src/common/ElasticArraySARTSTraces.jl b/src/common/ElasticArraySARTSTraces.jl index bfbbc4d..3c3ba1a 100644 --- a/src/common/ElasticArraySARTSTraces.jl +++ b/src/common/ElasticArraySARTSTraces.jl @@ -1,14 +1,12 @@ export ElasticArraySARTSTraces -using ElasticArrays: ElasticArray, resize_lastdim! - const ElasticArraySARTSTraces = Traces{ - SS′AA′RT, + SS′ART, <:Tuple{ - <:MultiplexTraces{SS′,<:Trace{<:ElasticArray}}, - <:MultiplexTraces{AA′,<:Trace{<:ElasticArray}}, - <:Trace{<:ElasticArray}, - <:Trace{<:ElasticArray}, + <:MultiplexTraces{SS′,<:Trace{<:ElasticArrayBuffer}}, + <:Trace{<:ElasticArrayBuffer}, + <:Trace{<:ElasticArrayBuffer}, + <:Trace{<:ElasticArrayBuffer}, } } @@ -16,25 +14,17 @@ function ElasticArraySARTSTraces(; state=Int => (), action=Int => (), reward=Float32 => (), - terminal=Bool => () -) + terminal=Bool => ()) + state_eltype, state_size = state action_eltype, action_size = action reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + - MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) + + MultiplexTraces{SS′}(ElasticArrayBuffer{state_eltype}(state_size..., capacity+1)) + Traces( - reward=ElasticArray{reward_eltype}(undef, reward_size..., 0), - terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0), + action = ElasticArrayBuffer{action_eltype}(action_size..., capacity), + reward=ElasticArrayBuffer{reward_eltype}(reward_size..., capacity), + terminal=ElasticArrayBuffer{terminal_eltype}(terminal_size..., capacity), ) end - -##### -# extensions for ElasticArrays -##### - -Base.push!(a::ElasticArray, x) = append!(a, x) -Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x]) -Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0) diff --git a/src/common/ElasticArraySLARTTraces.jl b/src/common/ElasticArraySLARTTraces.jl new file mode 100644 index 0000000..7c09e8d --- /dev/null +++ b/src/common/ElasticArraySLARTTraces.jl @@ -0,0 +1,35 @@ +export ElasticArraySLARTTraces + +const ElasticArraySLARTTraces = Traces{ + SS′LL′AA′RT, + <:Tuple{ + <:MultiplexTraces{SS′,<:Trace{<:ElasticArrayBuffer}}, + <:MultiplexTraces{LL′,<:Trace{<:ElasticArrayBuffer}}, + <:MultiplexTraces{AA′,<:Trace{<:ElasticArrayBuffer}}, + <:Trace{<:ElasticArrayBuffer}, + <:Trace{<:ElasticArrayBuffer}, + } +} + +function ElasticArraySLARTTraces(; + capacity::Int, + state=Int => (), + legal_actions_mask=Bool => (), + action=Int => (), + reward=Float32 => (), + terminal=Bool => () +) + state_eltype, state_size = state + action_eltype, action_size = action + legal_actions_mask_eltype, legal_actions_mask_size = legal_actions_mask + reward_eltype, reward_size = reward + terminal_eltype, terminal_size = terminal + + MultiplexTraces{SS′}(ElasticArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{LL′}(ElasticArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + + MultiplexTraces{AA′}(ElasticArrayBuffer{action_eltype}(action_size..., capacity + 1)) + + Traces( + reward=ElasticArrayBuffer{reward_eltype}(reward_size..., capacity), + terminal=ElasticArrayBuffer{terminal_eltype}(terminal_size..., capacity), + ) +end diff --git a/src/common/ElasticPrioritizedTraces.jl b/src/common/ElasticPrioritizedTraces.jl new file mode 100644 index 0000000..2a148cf --- /dev/null +++ b/src/common/ElasticPrioritizedTraces.jl @@ -0,0 +1,60 @@ +export ElasticPrioritizedTraces + +struct ElasticPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts} + keys::ElasticVectorBuffer{Int,Vector{Int}} + priorities::SumTree{Float32} + traces::T + default_priority::Float32 +end + +function ElasticPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts} + new_names = (:key, :priority, names...) + new_Ts = Tuple{Int,Float32,Ts.parameters...} + c = capacity(traces) + ElasticPrioritizedTraces{typeof(traces),new_names,new_Ts}( + ElasticVectorBuffer{Int}(c), + SumTree(c), + traces, + default_priority + ) +end + +function Base.push!(t::ElasticPrioritizedTraces, x) + push!(t.traces, x) + if length(t.traces) == 1 + push!(t.keys, 1) + push!(t.priorities, t.default_priority) + elseif length(t.traces) > 1 + push!(t.keys, t.keys[end] + 1) + push!(t.priorities, t.default_priority) + else + # may be partial inserting at the first step, ignore it + end +end + +function Base.setindex!(t::ElasticPrioritizedTraces, vs, k::Symbol, keys) + if k === :priority + @assert length(vs) == length(keys) + for (i, v) in zip(keys, vs) + if t.keys[1] <= i <= t.keys[end] + t.priorities[i-t.keys[1]+1] = v + end + end + else + @error "unsupported yet" + end +end + +Base.size(t::ElasticPrioritizedTraces) = size(t.traces) + +function Base.getindex(ts::ElasticPrioritizedTraces, s::Symbol) + if s === :priority + Trace(ts.priorities) + elseif s === :key + Trace(ts.keys) + else + ts.traces[s] + end +end + +Base.getindex(t::ElasticPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names)) diff --git a/src/common/common.jl b/src/common/common.jl index 90e4405..b2f734f 100644 --- a/src/common/common.jl +++ b/src/common/common.jl @@ -14,4 +14,8 @@ include("CircularArraySARTSTraces.jl") include("CircularArraySARTSATraces.jl") include("CircularArraySLARTTraces.jl") include("CircularPrioritizedTraces.jl") +include("common_elastic_array.jl") include("ElasticArraySARTSTraces.jl") +include("ElasticArraySARTSATraces.jl") +include("ElasticArraySLARTTraces.jl") +include("ElasticPrioritizedTraces.jl") diff --git a/src/common/common_elastic_array.jl b/src/common/common_elastic_array.jl new file mode 100644 index 0000000..01b4090 --- /dev/null +++ b/src/common/common_elastic_array.jl @@ -0,0 +1,9 @@ +using ElasticArrays: ElasticArray, resize_lastdim! + +##### +# extensions for ElasticArrays +##### + +Base.push!(a::ElasticArray, x) = append!(a, x) +Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x]) +Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0) diff --git a/src/episodes.jl b/src/episodes.jl index e2cb521..d401cf7 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -1,6 +1,6 @@ export EpisodesBuffer, PartialNamedTuple import DataStructures.CircularBuffer -using ElasticArrays: ElasticArray, resize_lastdim! +using ElasticArrays: ElasticArray, ElasticVector, resize_lastdim! """ EpisodesBuffer(traces::AbstractTraces) @@ -91,7 +91,8 @@ function pad!(trace::Trace) return nothing end -pad!(vect::ElasticArray{T, Vector{T}}) where {T} = pad!(vect, zero(T)) +pad!(vect::ElasticArray{T, Vector{T}}) where {T} = push!(vect, zero(T)) +pad!(vect::ElasticVector{T, Vector{T}}) where {T} = push!(vect, zero(T)) pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T)) pad!(vect::Vector{T}) where {T} = push!(vect, zero(T)) diff --git a/test/episodes.jl b/test/episodes.jl index 43a8b08..34de131 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -197,36 +197,9 @@ using Test @test eb.step_numbers == [1:16;1:16] @test length(eb) == 31 end - @testset "with elastic traces" begin - t = ElasticArraySARTSTraces(; - state=Int => (), - action=Int => (), - reward=Float32 => (), - terminal=Bool => () - ) - - eb = EpisodesBuffer(t) - push!(eb, (state = 1,)) #partial inserting - for i = 1:15 - push!(eb, (state = i+1, reward =i)) - end - @test length(eb.traces) == 15 - @test eb.sampleable_inds == [fill(true, 15); [false]] - @test all(==(15), eb.episodes_lengths) - @test eb.step_numbers == [1:16;] - push!(eb, (state = 1,)) #partial inserting - for i = 1:15 - push!(eb, (state = i+1, reward =i)) - end - @test eb.sampleable_inds == [fill(true, 15); [false];fill(true, 15); [false]] - @test all(==(15), eb.episodes_lengths) - @test eb.step_numbers == [1:16;1:16] - @test length(eb) == 31 - end - @testset "with circular traces" begin + @testset "with ElasticArraySARTSTraces traces" begin eb = EpisodesBuffer( - CircularArraySARTSTraces(; - capacity=10) + ElasticArraySARTSTraces() ) #push a first episode l=5 push!(eb, (state = 1,)) From f6cd30914f948f0244b4fea06ab56296aa568079 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 15:23:25 +0100 Subject: [PATCH 03/11] fix --- Project.toml | 2 +- src/common/ElasticArraySARTSTraces.jl | 16 ++++++++-------- src/common/ElasticArraySLARTTraces.jl | 20 ++++++++++---------- src/common/ElasticPrioritizedTraces.jl | 6 ++++-- src/episodes.jl | 2 +- 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index 59e64e8..193d5a0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ReinforcementLearningTrajectories" uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c" -version = "0.3.7" +version = "0.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/common/ElasticArraySARTSTraces.jl b/src/common/ElasticArraySARTSTraces.jl index 3c3ba1a..254e549 100644 --- a/src/common/ElasticArraySARTSTraces.jl +++ b/src/common/ElasticArraySARTSTraces.jl @@ -3,10 +3,10 @@ export ElasticArraySARTSTraces const ElasticArraySARTSTraces = Traces{ SS′ART, <:Tuple{ - <:MultiplexTraces{SS′,<:Trace{<:ElasticArrayBuffer}}, - <:Trace{<:ElasticArrayBuffer}, - <:Trace{<:ElasticArrayBuffer}, - <:Trace{<:ElasticArrayBuffer}, + <:MultiplexTraces{SS′,<:Trace{<:ElasticArray}}, + <:Trace{<:ElasticArray}, + <:Trace{<:ElasticArray}, + <:Trace{<:ElasticArray}, } } @@ -21,10 +21,10 @@ function ElasticArraySARTSTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(ElasticArrayBuffer{state_eltype}(state_size..., capacity+1)) + + MultiplexTraces{SS′}(ElasticArray{state_eltype}(state_size..., Inf)) + Traces( - action = ElasticArrayBuffer{action_eltype}(action_size..., capacity), - reward=ElasticArrayBuffer{reward_eltype}(reward_size..., capacity), - terminal=ElasticArrayBuffer{terminal_eltype}(terminal_size..., capacity), + action = ElasticArray{action_eltype}(action_size..., Inf), + reward=ElasticArray{reward_eltype}(reward_size..., Inf), + terminal=ElasticArray{terminal_eltype}(terminal_size..., Inf), ) end diff --git a/src/common/ElasticArraySLARTTraces.jl b/src/common/ElasticArraySLARTTraces.jl index 7c09e8d..cbbcaae 100644 --- a/src/common/ElasticArraySLARTTraces.jl +++ b/src/common/ElasticArraySLARTTraces.jl @@ -3,11 +3,11 @@ export ElasticArraySLARTTraces const ElasticArraySLARTTraces = Traces{ SS′LL′AA′RT, <:Tuple{ - <:MultiplexTraces{SS′,<:Trace{<:ElasticArrayBuffer}}, - <:MultiplexTraces{LL′,<:Trace{<:ElasticArrayBuffer}}, - <:MultiplexTraces{AA′,<:Trace{<:ElasticArrayBuffer}}, - <:Trace{<:ElasticArrayBuffer}, - <:Trace{<:ElasticArrayBuffer}, + <:MultiplexTraces{SS′,<:Trace{<:ElasticArray}}, + <:MultiplexTraces{LL′,<:Trace{<:ElasticArray}}, + <:MultiplexTraces{AA′,<:Trace{<:ElasticArray}}, + <:Trace{<:ElasticArray}, + <:Trace{<:ElasticArray}, } } @@ -25,11 +25,11 @@ function ElasticArraySLARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(ElasticArrayBuffer{state_eltype}(state_size..., capacity + 1)) + - MultiplexTraces{LL′}(ElasticArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + - MultiplexTraces{AA′}(ElasticArrayBuffer{action_eltype}(action_size..., capacity + 1)) + + MultiplexTraces{SS′}(ElasticArray{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + + MultiplexTraces{AA′}(ElasticArray{action_eltype}(action_size..., capacity + 1)) + Traces( - reward=ElasticArrayBuffer{reward_eltype}(reward_size..., capacity), - terminal=ElasticArrayBuffer{terminal_eltype}(terminal_size..., capacity), + reward=ElasticArray{reward_eltype}(reward_size..., capacity), + terminal=ElasticArray{terminal_eltype}(terminal_size..., capacity), ) end diff --git a/src/common/ElasticPrioritizedTraces.jl b/src/common/ElasticPrioritizedTraces.jl index 2a148cf..84950bb 100644 --- a/src/common/ElasticPrioritizedTraces.jl +++ b/src/common/ElasticPrioritizedTraces.jl @@ -1,7 +1,9 @@ export ElasticPrioritizedTraces +using ElasticArrays: ElasticVector + struct ElasticPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts} - keys::ElasticVectorBuffer{Int,Vector{Int}} + keys::ElasticVector{Int,Vector{Int}} priorities::SumTree{Float32} traces::T default_priority::Float32 @@ -12,7 +14,7 @@ function ElasticPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_prio new_Ts = Tuple{Int,Float32,Ts.parameters...} c = capacity(traces) ElasticPrioritizedTraces{typeof(traces),new_names,new_Ts}( - ElasticVectorBuffer{Int}(c), + ElasticVector{Int}(c), SumTree(c), traces, default_priority diff --git a/src/episodes.jl b/src/episodes.jl index d401cf7..f37e7f6 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -1,6 +1,6 @@ export EpisodesBuffer, PartialNamedTuple import DataStructures.CircularBuffer -using ElasticArrays: ElasticArray, ElasticVector, resize_lastdim! +using ElasticArrays: ElasticArray, ElasticVector """ EpisodesBuffer(traces::AbstractTraces) From 996484c68e222e6e8072a7942187092c1bd88755 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 15:24:26 +0100 Subject: [PATCH 04/11] fix --- src/common/ElasticArraySARTSTraces.jl | 8 ++++---- src/common/ElasticArraySLARTTraces.jl | 10 +++++----- test/episodes.jl | 3 ++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/common/ElasticArraySARTSTraces.jl b/src/common/ElasticArraySARTSTraces.jl index 254e549..c833882 100644 --- a/src/common/ElasticArraySARTSTraces.jl +++ b/src/common/ElasticArraySARTSTraces.jl @@ -21,10 +21,10 @@ function ElasticArraySARTSTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(ElasticArray{state_eltype}(state_size..., Inf)) + + MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + Traces( - action = ElasticArray{action_eltype}(action_size..., Inf), - reward=ElasticArray{reward_eltype}(reward_size..., Inf), - terminal=ElasticArray{terminal_eltype}(terminal_size..., Inf), + action = ElasticArray{action_eltype}(undef, action_size..., 0), + reward=ElasticArray{reward_eltype}(undef, reward_size..., 0), + terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0), ) end diff --git a/src/common/ElasticArraySLARTTraces.jl b/src/common/ElasticArraySLARTTraces.jl index cbbcaae..517eb7f 100644 --- a/src/common/ElasticArraySLARTTraces.jl +++ b/src/common/ElasticArraySLARTTraces.jl @@ -25,11 +25,11 @@ function ElasticArraySLARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(ElasticArray{state_eltype}(state_size..., capacity + 1)) + - MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + - MultiplexTraces{AA′}(ElasticArray{action_eltype}(action_size..., capacity + 1)) + + MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + + MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0)) + + MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) + Traces( - reward=ElasticArray{reward_eltype}(reward_size..., capacity), - terminal=ElasticArray{terminal_eltype}(terminal_size..., capacity), + reward=ElasticArray{reward_eltype}(undef, reward_size..., 0), + terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0), ) end diff --git a/test/episodes.jl b/test/episodes.jl index 34de131..06249ed 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -1,6 +1,7 @@ using ReinforcementLearningTrajectories using CircularArrayBuffers using Test + @testset "EpisodesBuffer" begin @testset "with circular traces" begin eb = EpisodesBuffer( @@ -282,7 +283,7 @@ using Test end @testset "with PartialNamedTuple" begin eb = EpisodesBuffer( - CircularArraySARTSATraces(; + ElasticArraySARTSATraces(; capacity=10) ) #push a first episode l=5 From 14b52c453a33fd5bc1ab0c349a234bbdba1fdb3d Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 16:35:35 +0100 Subject: [PATCH 05/11] tweaks --- src/common/ElasticArraySARTSATraces.jl | 2 +- src/common/ElasticPrioritizedTraces.jl | 62 -------------------------- src/common/common.jl | 1 - src/episodes.jl | 3 ++ test/episodes.jl | 31 ++++++++----- test/traces.jl | 19 ++++++++ 6 files changed, 42 insertions(+), 76 deletions(-) delete mode 100644 src/common/ElasticPrioritizedTraces.jl diff --git a/src/common/ElasticArraySARTSATraces.jl b/src/common/ElasticArraySARTSATraces.jl index d0ecf25..0644396 100644 --- a/src/common/ElasticArraySARTSATraces.jl +++ b/src/common/ElasticArraySARTSATraces.jl @@ -1,4 +1,4 @@ -export ElasticArraySARTSTraces +export ElasticArraySARTSATraces const ElasticArraySARTSATraces = Traces{ SS′AA′RT, diff --git a/src/common/ElasticPrioritizedTraces.jl b/src/common/ElasticPrioritizedTraces.jl deleted file mode 100644 index 84950bb..0000000 --- a/src/common/ElasticPrioritizedTraces.jl +++ /dev/null @@ -1,62 +0,0 @@ -export ElasticPrioritizedTraces - -using ElasticArrays: ElasticVector - -struct ElasticPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts} - keys::ElasticVector{Int,Vector{Int}} - priorities::SumTree{Float32} - traces::T - default_priority::Float32 -end - -function ElasticPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts} - new_names = (:key, :priority, names...) - new_Ts = Tuple{Int,Float32,Ts.parameters...} - c = capacity(traces) - ElasticPrioritizedTraces{typeof(traces),new_names,new_Ts}( - ElasticVector{Int}(c), - SumTree(c), - traces, - default_priority - ) -end - -function Base.push!(t::ElasticPrioritizedTraces, x) - push!(t.traces, x) - if length(t.traces) == 1 - push!(t.keys, 1) - push!(t.priorities, t.default_priority) - elseif length(t.traces) > 1 - push!(t.keys, t.keys[end] + 1) - push!(t.priorities, t.default_priority) - else - # may be partial inserting at the first step, ignore it - end -end - -function Base.setindex!(t::ElasticPrioritizedTraces, vs, k::Symbol, keys) - if k === :priority - @assert length(vs) == length(keys) - for (i, v) in zip(keys, vs) - if t.keys[1] <= i <= t.keys[end] - t.priorities[i-t.keys[1]+1] = v - end - end - else - @error "unsupported yet" - end -end - -Base.size(t::ElasticPrioritizedTraces) = size(t.traces) - -function Base.getindex(ts::ElasticPrioritizedTraces, s::Symbol) - if s === :priority - Trace(ts.priorities) - elseif s === :key - Trace(ts.keys) - else - ts.traces[s] - end -end - -Base.getindex(t::ElasticPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names)) diff --git a/src/common/common.jl b/src/common/common.jl index b2f734f..f8b4023 100644 --- a/src/common/common.jl +++ b/src/common/common.jl @@ -18,4 +18,3 @@ include("common_elastic_array.jl") include("ElasticArraySARTSTraces.jl") include("ElasticArraySARTSATraces.jl") include("ElasticArraySLARTTraces.jl") -include("ElasticPrioritizedTraces.jl") diff --git a/src/episodes.jl b/src/episodes.jl index f37e7f6..b128787 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -85,6 +85,7 @@ end ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps. ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs) ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs) +ispartial_insert(traces::ElasticPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs) function pad!(trace::Trace) pad!(trace.parent) @@ -130,6 +131,8 @@ fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces) fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces) +fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:ElasticPrioritizedTraces}) = fill_multiplex(es.traces.traces) + function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.traces, xs) partial = ispartial_insert(eb, xs) diff --git a/test/episodes.jl b/test/episodes.jl index 06249ed..e95c420 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -114,6 +114,8 @@ using Test @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 @test eb.sampleable_inds == [1,1,1,1,1,0,0] + @test eb[:action][6] == 6 + @test eb[:next_action][6] == 6 @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero ep2_len = 0 for (j,i) = enumerate(8:11) @@ -235,7 +237,7 @@ using Test end @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] @test length(eb.traces) == 10 - #three last steps replace oldest steps in the buffer. + for (i, s) = enumerate(12:13) ep2_len += 1 push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) @@ -245,16 +247,16 @@ using Test @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end #episode 1 - for (i,s) in enumerate(3:13) - if i in (4, 11) + for i in 3:13 + if i in (6, 13) @test eb.sampleable_inds[i] == 0 continue else @test eb.sampleable_inds[i] == 1 end b = eb[i] - @test b[:state] == b[:action] == b[:reward] == s - @test b[:next_state] == s + 1 + @test b[:state] == b[:action] == b[:reward] == i + @test b[:next_state] == i + 1 end #episode 2 #start a third episode @@ -263,13 +265,14 @@ using Test @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 - #push until it reaches it own start + for (i,s) in enumerate(15:26) push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) end - @test eb.sampleable_inds == [fill(true, 10); [false]] - @test eb.episodes_lengths == fill(length(15:26), 11) - @test eb.step_numbers == [3:13;] + @test eb.sampleable_inds[end-5:end] == [fill(true, 5); [false]] + @test eb.episodes_lengths[end-10:end] == fill(length(15:26), 11) + @test eb.step_numbers[end-10:end] == [3:13;] + #= Deactivated until https://github.com/JuliaArrays/ElasticArrays.jl/pull/56/files merged and pop!/popfirst! added to ElasticArrays step = popfirst!(eb) @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9 @test first(eb.step_numbers) == 4 @@ -280,11 +283,11 @@ using Test empty!(eb) @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) show(eb); + =# end - @testset "with PartialNamedTuple" begin + @testset "ElasticArraySARTSATraces with PartialNamedTuple" begin eb = EpisodesBuffer( - ElasticArraySARTSATraces(; - capacity=10) + ElasticArraySARTSATraces() ) #push a first episode l=5 push!(eb, (state = 1,)) @@ -308,6 +311,8 @@ using Test @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 @test eb.sampleable_inds == [1,1,1,1,1,0,0] + @test eb[:action][6] == 6 + @test eb[:next_action][6] == 6 @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero ep2_len = 0 for (j,i) = enumerate(8:11) @@ -358,6 +363,7 @@ using Test @test eb.sampleable_inds == [fill(true, 10); [false]] @test eb.episodes_lengths == fill(length(15:26), 11) @test eb.step_numbers == [3:13;] + #= Deactivated until https://github.com/JuliaArrays/ElasticArrays.jl/pull/56/files merged and pop!/popfirst! added to ElasticArrays step = popfirst!(eb) @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9 @test first(eb.step_numbers) == 4 @@ -368,5 +374,6 @@ using Test empty!(eb) @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) show(eb); + =# end end diff --git a/test/traces.jl b/test/traces.jl index 63003be..762c6f2 100644 --- a/test/traces.jl +++ b/test/traces.jl @@ -137,6 +137,25 @@ using ReinforcementLearningTrajectories: build_trace_index build_trace_index(typeof(t2).parameters[1], typeof(t2).parameters[2]) end +@testset "build_trace_index ElasticArraySARTSATraces" begin + t1 = ElasticArraySARTSATraces(; + capacity=3, + state=Float32 => (2, 3), + action=Float32 => (2,), + reward=Float32 => (), + terminal=Bool => () + ) + @test build_trace_index(typeof(t1).parameters[1], typeof(t1).parameters[2]) == Dict(:reward => 3, + :next_state => 1, + :state => 1, + :action => 2, + :next_action => 2, + :terminal => 4) + + t2 = Traces(; a=[2, 3], b=[false, true]) + build_trace_index(typeof(t2).parameters[1], typeof(t2).parameters[2]) +end + @testset "push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v)" begin t1 = CircularArraySARTSATraces(; capacity=3, From bd179510ae5c8688af2ddc3af2b504f1f9581889 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 16:36:07 +0100 Subject: [PATCH 06/11] drop --- src/episodes.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/episodes.jl b/src/episodes.jl index b128787..f37e7f6 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -85,7 +85,6 @@ end ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps. ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs) ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs) -ispartial_insert(traces::ElasticPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs) function pad!(trace::Trace) pad!(trace.parent) @@ -131,8 +130,6 @@ fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces) fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces) -fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:ElasticPrioritizedTraces}) = fill_multiplex(es.traces.traces) - function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.traces, xs) partial = ispartial_insert(eb, xs) From 1cc7b76dcb31bca9d40c6177a4c611c193e5a8d5 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 16:44:19 +0100 Subject: [PATCH 07/11] tests pass --- test/episodes.jl | 22 +++++++++++----------- test/traces.jl | 24 +++++++++++++++++++++++- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/test/episodes.jl b/test/episodes.jl index e95c420..f93c433 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -285,6 +285,7 @@ using Test show(eb); =# end + @testset "ElasticArraySARTSATraces with PartialNamedTuple" begin eb = EpisodesBuffer( ElasticArraySARTSATraces() @@ -312,8 +313,8 @@ using Test @test eb.step_numbers[end] == 1 @test eb.sampleable_inds == [1,1,1,1,1,0,0] @test eb[:action][6] == 6 - @test eb[:next_action][6] == 6 - @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero + @test eb[:next_action][5] == 6 + @test eb[:reward][6] == 0 #6 is not a valid index, the reward there is dummy, filled as zero ep2_len = 0 for (j,i) = enumerate(8:11) ep2_len += 1 @@ -325,7 +326,6 @@ using Test end @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] @test length(eb.traces) == 9 #an action is missing at this stage - #three last steps replace oldest steps in the buffer. for (i, s) = enumerate(12:13) ep2_len += 1 push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) @@ -335,18 +335,18 @@ using Test @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end push!(eb, PartialNamedTuple((action = 13,))) - @test length(eb.traces) == 10 + @test length(eb.traces) == 12 #episode 1 - for (i,s) in enumerate(3:13) - if i in (4, 11) + for i in 1:13 + if i in (6, 13) @test eb.sampleable_inds[i] == 0 continue else @test eb.sampleable_inds[i] == 1 end b = eb[i] - @test b[:state] == b[:action] == b[:reward] == s - @test b[:next_state] == b[:next_action] == s + 1 + @test b[:state] == b[:action] == b[:reward] == i + @test b[:next_state] == b[:next_action] == i + 1 end #episode 2 #start a third episode @@ -360,9 +360,9 @@ using Test push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) end push!(eb, PartialNamedTuple((action = 26,))) - @test eb.sampleable_inds == [fill(true, 10); [false]] - @test eb.episodes_lengths == fill(length(15:26), 11) - @test eb.step_numbers == [3:13;] + @test eb.sampleable_inds[end-10:end] == [fill(true, 10); [false]] + @test eb.episodes_lengths[end-10:end] == fill(length(15:26), 11) + @test eb.step_numbers[end-10:end] == [3:13;] #= Deactivated until https://github.com/JuliaArrays/ElasticArrays.jl/pull/56/files merged and pop!/popfirst! added to ElasticArrays step = popfirst!(eb) @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9 diff --git a/test/traces.jl b/test/traces.jl index 762c6f2..f30b311 100644 --- a/test/traces.jl +++ b/test/traces.jl @@ -139,7 +139,6 @@ end @testset "build_trace_index ElasticArraySARTSATraces" begin t1 = ElasticArraySARTSATraces(; - capacity=3, state=Float32 => (2, 3), action=Float32 => (2,), reward=Float32 => (), @@ -171,6 +170,29 @@ end @test size(Base.getindex(t1, 1).state) == (2,3) + t2 = Traces(; a=[2, 3], b=[false, true]) + push!(t2, Val(:a), 5) + @test t2[:a][3] == 5 + + @test size(Base.getindex(t2, :a)) == (3,) + @test Base.getindex(t2, 1) == (; a = 2, b= false) +end + + +@testset "push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v)" begin + t1 = ElasticArraySARTSATraces( + state=Float32 => (2, 3), + action=Float32 => (2,), + reward=Float32 => (), + terminal=Bool => () + ) + push!(t1, Val(:reward), 5) + @test t1[:reward][1] == 5 + + @test size(Base.getindex(t1, :reward)) == (1,) + @test size(Base.getindex(t1, :state)) == (0,) + + t2 = Traces(; a=[2, 3], b=[false, true]) push!(t2, Val(:a), 5) @test t2[:a][3] == 5 From 58bfea514934926c79eb62e5dec989a996e01b47 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 16:57:53 +0100 Subject: [PATCH 08/11] fix test --- test/episodes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/episodes.jl b/test/episodes.jl index f93c433..ef7855c 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -115,7 +115,7 @@ using Test @test eb.step_numbers[end] == 1 @test eb.sampleable_inds == [1,1,1,1,1,0,0] @test eb[:action][6] == 6 - @test eb[:next_action][6] == 6 + @test eb[:next_action][5] == 6 @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero ep2_len = 0 for (j,i) = enumerate(8:11) From 7d0d0530a2f399185ca95f9be0cf594815251714 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 17:19:34 +0100 Subject: [PATCH 09/11] add bounds check --- src/episodes.jl | 46 +++++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/src/episodes.jl b/src/episodes.jl index f37e7f6..e93fb92 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -69,12 +69,20 @@ function EpisodesBuffer(traces::AbstractTraces) end end -Base.getindex(es::EpisodesBuffer, idx...) = getindex(es.traces, idx...) -Base.setindex!(es::EpisodesBuffer, idx...) = setindex!(es.traces, idx...) -Base.size(es::EpisodesBuffer) = size(es.traces) -Base.length(es::EpisodesBuffer) = length(es.traces) -Base.keys(es::EpisodesBuffer) = keys(es.traces) -Base.keys(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(es.traces.traces) +function Base.getindex(es::EpisodesBuffer, idx::Int...) + @boundscheck all(es.sampleable_inds[idx...]) + getindex(es.traces, idx...) +end + +function Base.getindex(es::EpisodesBuffer, idx...) + getindex(es.traces, idx...) +end + +Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...) +Base.size(eb::EpisodesBuffer) = size(eb.traces) +Base.length(eb::EpisodesBuffer) = length(eb.traces) +Base.keys(eb::EpisodesBuffer) = keys(eb.traces) +Base.keys(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(eb.traces.traces) function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where {names} s = nameof(typeof(eb)) t = eb.traces @@ -83,7 +91,7 @@ function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where end ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps. -ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs) +ispartial_insert(eb::EpisodesBuffer, xs) = ispartial_insert(eb.traces, xs) ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs) function pad!(trace::Trace) @@ -126,9 +134,9 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T)) return :($ex) end -fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces) +fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces) -fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces) +fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces) function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.traces, xs) @@ -165,17 +173,17 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl end for f in (:pop!, :popfirst!) - @eval function Base.$f(es::EpisodesBuffer) - $f(es.episodes_lengths) - $f(es.sampleable_inds) - $f(es.step_numbers) - $f(es.traces) + @eval function Base.$f(eb::EpisodesBuffer) + $f(eb.episodes_lengths) + $f(eb.sampleable_inds) + $f(eb.step_numbers) + $f(eb.traces) end end -function Base.empty!(es::EpisodesBuffer) - empty!(es.traces) - empty!(es.episodes_lengths) - empty!(es.sampleable_inds) - empty!(es.step_numbers) +function Base.empty!(eb::EpisodesBuffer) + empty!(eb.traces) + empty!(eb.episodes_lengths) + empty!(eb.sampleable_inds) + empty!(eb.step_numbers) end From a31120e1529b19a7a1c0df64a7aaf958967c4638 Mon Sep 17 00:00:00 2001 From: Jeremiah <4462211+jeremiahpslewis@users.noreply.github.com> Date: Fri, 22 Mar 2024 17:27:27 +0100 Subject: [PATCH 10/11] Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 01804ca..9934931 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,8 +18,8 @@ jobs: fail-fast: false matrix: version: - - '1.9' - '1' + - '1.11' - 'nightly' os: - ubuntu-latest From 4895a41da5f3843bcae3bc4e58da7075538af033 Mon Sep 17 00:00:00 2001 From: Jeremiah <4462211+jeremiahpslewis@users.noreply.github.com> Date: Fri, 22 Mar 2024 17:33:08 +0100 Subject: [PATCH 11/11] Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 9934931..35a34d1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,7 +19,7 @@ jobs: matrix: version: - '1' - - '1.11' + - '^1.11.0-alpha' - 'nightly' os: - ubuntu-latest