From 5b68fe6ceaf8c8f72d7f40ab6d529b9910376a50 Mon Sep 17 00:00:00 2001 From: Dharanish Date: Tue, 7 May 2024 12:26:04 +0200 Subject: [PATCH 1/9] Fix bugs with SARTSATraces Bug 1 - When pushing more traces into CircularArraySARTSATraces than its capacity, the state and action traces are not in line anymore Bug 2 - sampleable_inds were not correct for CircularArraySARTSATraces Bug 3 - CircularArraySARTSATraces were not sampleable by a EpisodesSampler --- src/common/CircularArraySARTSATraces.jl | 6 ++--- src/episodes.jl | 20 +++++++++++--- test/common.jl | 16 ++++++----- test/episodes.jl | 36 ++++++++++++++++++------- 4 files changed, 55 insertions(+), 23 deletions(-) diff --git a/src/common/CircularArraySARTSATraces.jl b/src/common/CircularArraySARTSATraces.jl index 94f2aa1..393e64b 100644 --- a/src/common/CircularArraySARTSATraces.jl +++ b/src/common/CircularArraySARTSATraces.jl @@ -24,11 +24,11 @@ function CircularArraySARTSATraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) + + MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+2)) + MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) + Traces( - reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), - terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), + reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity+1), + terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity+1), ) end diff --git a/src/episodes.jl b/src/episodes.jl index e93fb92..effde62 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -138,6 +138,8 @@ fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces) fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces) +max_length(eb::EpisodesBuffer) = max_length(eb.traces) + function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.traces, xs) partial = ispartial_insert(eb, xs) @@ -146,10 +148,12 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.episodes_lengths, 0) push!(eb.sampleable_inds, 0) elseif !partial #typical inserting - if length(eb.traces) < length(eb) && length(eb) > 2 #case when PartialNamedTuple is used. Steps are indexable one step later - eb.sampleable_inds[end-1] = 1 - else #case when we don't, length of traces and eb will match. - eb.sampleable_inds[end] = 1 #previous step is now indexable + if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched + if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode + eb.sampleable_inds[end-1] = 1 # steps are indexable one step later + end + else + eb.sampleable_inds[end] = 1 # otherwise, previous step is now indexable end push!(eb.sampleable_inds, 0) #this one is no longer ep_length = last(eb.step_numbers) @@ -172,6 +176,14 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl eb.sampleable_inds[end-1] = 1 #completes the episode trajectory. end +function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple) + if max_length(eb) == capacity(eb.traces) + popfirst!(eb) + end + push!(eb.traces, xs.namedtuple) + eb.sampleable_inds[end-1] = 1 #completes the episode trajectory. +end + for f in (:pop!, :popfirst!) @eval function Base.$f(eb::EpisodesBuffer) $f(eb.episodes_lengths) diff --git a/test/common.jl b/test/common.jl index afcc89b..63e2309 100644 --- a/test/common.jl +++ b/test/common.jl @@ -24,7 +24,7 @@ @test length(t) == 0 end -@testset "CircularArraySARTSTraces" begin +@testset "CircularArraySARTSATraces" begin t = CircularArraySARTSATraces(; capacity=3, state=Float32 => (2, 3), @@ -35,13 +35,14 @@ end @test t isa CircularArraySARTSATraces - push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2)) |> gpu) + push!(t, (state=ones(Float32, 2, 3),)) + push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu) @test length(t) == 0 push!(t, (reward=1.0f0, terminal=false) |> gpu) - @test length(t) == 0 # next_state and next_action is still missing + @test length(t) == 0 # next_action is still missing - push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2) |> gpu) + push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 2) |> gpu) @test length(t) == 1 # this will trigger the scalar indexing of CuArray @@ -55,17 +56,18 @@ end ) push!(t, (reward=2.0f0, terminal=false)) - push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3) |> gpu) + push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 3) |> gpu) @test length(t) == 2 push!(t, (reward=3.0f0, terminal=false)) - push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4) |> gpu) + push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 4) |> gpu) @test length(t) == 3 push!(t, (reward=4.0f0, terminal=false)) - push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5) |> gpu) + push!(t, (state=ones(Float32, 2, 3) * 6, action=ones(Float32, 2) * 5) |> gpu) + push!(t, (reward=5.0f0, terminal=false)) @test length(t) == 3 diff --git a/test/episodes.jl b/test/episodes.jl index ef7855c..76848ae 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -100,7 +100,10 @@ using Test for i = 1:5 push!(eb, (state = i+1, action =i, reward = i, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if length(eb) >= 1 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == i + 1 @test eb.episodes_lengths[end-i:end] == fill(i, i+1) end @@ -116,24 +119,30 @@ using Test @test eb.sampleable_inds == [1,1,1,1,1,0,0] @test eb[:action][6] == 6 @test eb[:next_action][5] == 6 - @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero + @test eb[6][:reward] == 0 broken = true #6 is not a valid index and cannot be indexed because a PartialNamedTuple is used ep2_len = 0 for (j,i) = enumerate(8:11) ep2_len += 1 push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == j + 1 @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) end - @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] + @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0] @test length(eb.traces) == 9 #an action is missing at this stage #three last steps replace oldest steps in the buffer. for (i, s) = enumerate(12:13) ep2_len += 1 push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == i + 1 + 4 @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end @@ -298,7 +307,10 @@ using Test for i = 1:5 push!(eb, (state = i+1, action =i, reward = i, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == i + 1 @test eb.episodes_lengths[end-i:end] == fill(i, i+1) end @@ -320,17 +332,23 @@ using Test ep2_len += 1 push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == j + 1 @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) end - @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] + @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0] @test length(eb.traces) == 9 #an action is missing at this stage for (i, s) = enumerate(12:13) ep2_len += 1 push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == i + 1 + 4 @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end From e78aa467eca550c090483bff599e50d42cb756f3 Mon Sep 17 00:00:00 2001 From: Dharanish Date: Tue, 7 May 2024 12:28:03 +0200 Subject: [PATCH 2/9] Fix bugs with SARTSATraces Bug 1 - When pushing more traces into CircularArraySARTSATraces than its capacity, the state and action traces are not in line anymore Bug 2 - sampleable_inds were not correct for CircularArraySARTSATraces Bug 3 - CircularArraySARTSATraces were not sampleable by a EpisodesSampler --- src/traces.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/traces.jl b/src/traces.jl index 1cd36bb..35e1f77 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -247,6 +247,7 @@ function Base.:(+)(t1::Traces{k1,T1,N1,E1}, t2::Traces{k2,T2,N2,E2}) where {k1,T end Base.size(t::Traces) = (mapreduce(length, min, t.traces),) +max_length(t::Traces) = mapreduce(length, max, t.traces) function capacity(t::Traces{names,Trs,N,E}) where {names,Trs,N,E} minimum(map(idx->capacity(t[idx]), names)) From 0ec5743ef25ea0d3c3d4e22582076ea350791c53 Mon Sep 17 00:00:00 2001 From: Dharanish Date: Thu, 9 May 2024 11:21:18 +0200 Subject: [PATCH 3/9] Fix CircularPrioritizedTraces with SARTSA --- src/common/CircularPrioritizedTraces.jl | 23 ++++++++++++++++++++++- src/episodes.jl | 14 ++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/common/CircularPrioritizedTraces.jl b/src/common/CircularPrioritizedTraces.jl index 755b9a1..09b2ffd 100644 --- a/src/common/CircularPrioritizedTraces.jl +++ b/src/common/CircularPrioritizedTraces.jl @@ -12,7 +12,11 @@ end function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts} new_names = (:key, :priority, names...) new_Ts = Tuple{Int,Float32,Ts.parameters...} - c = capacity(traces) + if traces isa CircularArraySARTSATraces + c = capacity(traces) - 1 + else + c = capacity(traces) + end CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}( CircularVectorBuffer{Int}(c), SumTree(c), @@ -34,6 +38,22 @@ function Base.push!(t::CircularPrioritizedTraces, x) end end +function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x) + initial_length = length(t.traces) + push!(t.traces, x) + if length(t.traces) == 1 + push!(t.keys, 1) + push!(t.priorities, t.default_priority) + elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 ) + # only add a key if the length changes after insertion of the tuple + # or if the trace is already at capacity + push!(t.keys, t.keys[end] + 1) + push!(t.priorities, t.default_priority) + else + # may be partial inserting at the first step, ignore it + end +end + function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys) if k === :priority @assert length(vs) == length(keys) @@ -48,6 +68,7 @@ function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys) end Base.size(t::CircularPrioritizedTraces) = size(t.traces) +max_length(t::CircularPrioritizedTraces) = max_length(t.traces) function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol) if s === :priority diff --git a/src/episodes.jl b/src/episodes.jl index effde62..ac82866 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -184,6 +184,20 @@ function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, eb.sampleable_inds[end-1] = 1 #completes the episode trajectory. end +function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}}) + if max_length(eb) == capacity(eb.traces) + addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal]) + xs = (xs.namedtuple, addition) + push!(eb.traces, xs) + pop!(eb.traces[:state].trace) + pop!(eb.traces[:reward]) + pop!(eb.traces[:terminal]) + else + push!(eb.traces, xs.namedtuple) + eb.sampleable_inds[end-1] = 1 + end +end + for f in (:pop!, :popfirst!) @eval function Base.$f(eb::EpisodesBuffer) $f(eb.episodes_lengths) From 73f2efb0a4de800362cf5c2caddcc46b276b8776 Mon Sep 17 00:00:00 2001 From: Dharanish Date: Thu, 9 May 2024 11:21:32 +0200 Subject: [PATCH 4/9] Fix sampling of CircularPrioritizedTraces --- src/samplers.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/samplers.jl b/src/samplers.jl index 8701189..e5443a7 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -74,7 +74,7 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir t = e.traces p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) - w .*= e.sampleable_inds[1:end-1] + w .*= e.sampleable_inds[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...)) end @@ -247,7 +247,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) valids, ns = valid_range(s,e) - w .*= valids[1:end-1] + w .*= valids[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) merge( (key=t.keys[inds], priority=p[inds]), @@ -362,7 +362,7 @@ function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, < p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) valids, ns = valid_range(s,e) - w .*= valids[1:end-1] + w .*= valids[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) merge( (key=t.keys[inds], priority=p[inds]), From fbf054af9c43be7bfa09fb76d9642b7272bb3226 Mon Sep 17 00:00:00 2001 From: Dharanish Date: Thu, 9 May 2024 11:21:58 +0200 Subject: [PATCH 5/9] New test for CircularPrioritizedTraces with SARTSA --- test/common.jl | 64 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/test/common.jl b/test/common.jl index 63e2309..e445e47 100644 --- a/test/common.jl +++ b/test/common.jl @@ -129,9 +129,9 @@ end @test t isa CircularArraySLARTTraces end -@testset "CircularPrioritizedTraces" begin +@testset "CircularPrioritizedTraces-SARTS" begin t = CircularPrioritizedTraces( - CircularArraySARTSATraces(; + CircularArraySARTSTraces(; capacity=3 ), default_priority=1.0f0 @@ -161,8 +161,68 @@ end @test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0 #EpisodesBuffer + t = CircularPrioritizedTraces( + CircularArraySARTSTraces(; + capacity=10 + ), + default_priority=1.0f0 + ) + + eb = EpisodesBuffer(t) + push!(eb, (state = 1, action = 1)) + for i = 1:5 + push!(eb, (state = i+1, action =i+1, reward = i, terminal = false)) + end + push!(eb, (state = 7, action = 7)) + for (j,i) = enumerate(8:11) + push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + end + s = BatchSampler(1000) + b = sample(s, eb) + cm = counter(b[:state]) + @test !haskey(cm, 6) + @test !haskey(cm, 11) + @test all(in(keys(cm)), [1:5;7:10]) + + + eb[:priority, [1, 2]] = [0, 0] + @test eb[:priority] == [zeros(2);ones(8)] +end + +@testset "CircularPrioritizedTraces-SARTSA" begin t = CircularPrioritizedTraces( CircularArraySARTSATraces(; + capacity=3 + ), + default_priority=1.0f0 + ) + + push!(t, (state=0, action=0)) + + for i in 1:5 + push!(t, (reward=1.0f0, terminal=false, state=i, action=i)) + end + + @test length(t) == 3 + + s = BatchSampler(5) + + b = sample(s, t) + + t[:priority, [1, 2]] = [0, 0] + + # shouldn't be changed since [1,2] are old keys + @test t[:priority] == [1.0f0, 1.0f0, 1.0f0] + + t[:priority, [3, 4, 5]] = [0, 1, 0] + + b = sample(s, t) + + @test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0 + + #EpisodesBuffer + t = CircularPrioritizedTraces( + CircularArraySARTSTraces(; capacity=10 ), default_priority=1.0f0 From 320e3f88d3a44df57ac6dc7237bfbdc85d1479b5 Mon Sep 17 00:00:00 2001 From: Dharanish Date: Thu, 9 May 2024 11:22:34 +0200 Subject: [PATCH 6/9] Fix test of CircularPrioritizedTraces with SARTSA The usage of SARTSA traces is more restrictive and should be done in this way --- test/samplers.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/samplers.jl b/test/samplers.jl index 1b914bb..2565ddd 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -130,15 +130,17 @@ import ReinforcementLearningTrajectories.fetch batchsize = 4 eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0)) s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize) - - push!(eb, (state = 1, action = 1)) + + push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5)) + push!(eb, (state = i+1, action =i, reward = i, terminal = i == 5)) end - push!(eb, (state = 7, action = 7)) - for (j,i) = enumerate(8:11) - push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + push!(eb, PartialNamedTuple((action=6,))) + push!(eb, (state = 7,)) + for (j,i) = enumerate(7:10) + push!(eb, (state = i+1, action =i, reward = i, terminal = i==10)) end + push!(eb, PartialNamedTuple((action = 11,))) weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) inds = [i for i in eachindex(weights) if weights[i] == 1] batch = sample(s1, eb) From 687d3eda958a0b7a8e686eb990a827796cc6a351 Mon Sep 17 00:00:00 2001 From: Dharanish Date: Thu, 9 May 2024 11:51:13 +0200 Subject: [PATCH 7/9] Fix CircularPrioritizedTraces --- src/episodes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/episodes.jl b/src/episodes.jl index ac82866..90a8b79 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -187,7 +187,7 @@ end function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}}) if max_length(eb) == capacity(eb.traces) addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal]) - xs = (xs.namedtuple, addition) + xs = merge(xs.namedtuple, addition) push!(eb.traces, xs) pop!(eb.traces[:state].trace) pop!(eb.traces[:reward]) From dcec549ed1aab08ca9da8b96e83170809e4dd8fe Mon Sep 17 00:00:00 2001 From: Dharanish Date: Thu, 9 May 2024 11:51:43 +0200 Subject: [PATCH 8/9] Fix CircularPrioritizedTraces test with SARTSA --- test/common.jl | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/test/common.jl b/test/common.jl index e445e47..714520b 100644 --- a/test/common.jl +++ b/test/common.jl @@ -222,29 +222,27 @@ end #EpisodesBuffer t = CircularPrioritizedTraces( - CircularArraySARTSTraces(; + CircularArraySARTSATraces(; capacity=10 ), default_priority=1.0f0 ) - + eb = EpisodesBuffer(t) - push!(eb, (state = 1, action = 1)) + push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i+1, reward = i, terminal = false)) + push!(eb, (state = i+1, action =i, reward = i, terminal = false)) end - push!(eb, (state = 7, action = 7)) + push!(eb, PartialNamedTuple((action = 6,))) + push!(eb, (state = 7,)) for (j,i) = enumerate(8:11) - push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) end + push!(eb, PartialNamedTuple((action=12,))) s = BatchSampler(1000) b = sample(s, eb) cm = counter(b[:state]) @test !haskey(cm, 6) @test !haskey(cm, 11) @test all(in(keys(cm)), [1:5;7:10]) - - - eb[:priority, [1, 2]] = [0, 0] - @test eb[:priority] == [zeros(2);ones(8)] end From 29a6a3eefbb30b1388d65aba64450789f23687ae Mon Sep 17 00:00:00 2001 From: Jeremiah <4462211+jeremiahpslewis@users.noreply.github.com> Date: Thu, 9 May 2024 22:54:51 +0200 Subject: [PATCH 9/9] Version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 193d5a0..040791e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ReinforcementLearningTrajectories" uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c" -version = "0.4" +version = "0.4.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"