From e924768499c973ab83f6749d73b596112ff2dc18 Mon Sep 17 00:00:00 2001 From: Johannes Fischer Date: Thu, 2 Jan 2025 14:24:26 +0100 Subject: [PATCH 1/3] Fix SARTSTraces etc. capacity --- src/common/CircularArraySARTSATraces.jl | 8 ++--- src/common/CircularArraySARTSTraces.jl | 6 ++-- src/common/CircularArraySLARTTraces.jl | 2 +- src/common/CircularPrioritizedTraces.jl | 6 +--- test/common.jl | 39 ++++++++++++++++--------- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/common/CircularArraySARTSATraces.jl b/src/common/CircularArraySARTSATraces.jl index 393e64b..53678b7 100644 --- a/src/common/CircularArraySARTSATraces.jl +++ b/src/common/CircularArraySARTSATraces.jl @@ -24,12 +24,12 @@ function CircularArraySARTSATraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+2)) + + MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) + MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) + Traces( - reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity+1), - terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity+1), + reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), + terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), ) end -CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces))) +CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = minimum(map(capacity,t.traces)) diff --git a/src/common/CircularArraySARTSTraces.jl b/src/common/CircularArraySARTSTraces.jl index acf999c..eb43038 100644 --- a/src/common/CircularArraySARTSTraces.jl +++ b/src/common/CircularArraySARTSTraces.jl @@ -17,8 +17,8 @@ function CircularArraySARTSTraces(; state=Int => (), action=Int => (), reward=Float32 => (), - terminal=Bool => ()) - + terminal=Bool => () +) state_eltype, state_size = state action_eltype, action_size = action reward_eltype, reward_size = reward @@ -32,4 +32,4 @@ function CircularArraySARTSTraces(; ) end -CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces))) +CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = minimum(map(capacity,t.traces)) diff --git a/src/common/CircularArraySLARTTraces.jl b/src/common/CircularArraySLARTTraces.jl index 73f6da9..0677906 100644 --- a/src/common/CircularArraySLARTTraces.jl +++ b/src/common/CircularArraySLARTTraces.jl @@ -34,4 +34,4 @@ function CircularArraySLARTTraces(; ) end -CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces))) \ No newline at end of file +CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = minimum(map(capacity,t.traces)) \ No newline at end of file diff --git a/src/common/CircularPrioritizedTraces.jl b/src/common/CircularPrioritizedTraces.jl index 09b2ffd..db4b409 100644 --- a/src/common/CircularPrioritizedTraces.jl +++ b/src/common/CircularPrioritizedTraces.jl @@ -12,11 +12,7 @@ end function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts} new_names = (:key, :priority, names...) new_Ts = Tuple{Int,Float32,Ts.parameters...} - if traces isa CircularArraySARTSATraces - c = capacity(traces) - 1 - else - c = capacity(traces) - end + c = capacity(traces) CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}( CircularVectorBuffer{Int}(c), SumTree(c), diff --git a/test/common.jl b/test/common.jl index 714520b..fdaf844 100644 --- a/test/common.jl +++ b/test/common.jl @@ -34,15 +34,20 @@ end ) |> gpu @test t isa CircularArraySARTSATraces + @test ReinforcementLearningTrajectories.capacity(t) == 3 + @test CircularArrayBuffers.capacity(t) == 3 - push!(t, (state=ones(Float32, 2, 3),)) + push!(t, (state=ones(Float32, 2, 3),) |> gpu) push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu) @test length(t) == 0 push!(t, (reward=1.0f0, terminal=false) |> gpu) @test length(t) == 0 # next_action is still missing - push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 2) |> gpu) + push!(t, (action=ones(Float32, 2) * 2,) |> gpu) + @test length(t) == 1 + + push!(t, (state=ones(Float32, 2, 3) * 3,) |> gpu) @test length(t) == 1 # this will trigger the scalar indexing of CuArray @@ -71,29 +76,33 @@ end @test length(t) == 3 + push!(t, (action=ones(Float32, 2) * 6,) |> gpu) + @test length(t) == 3 + # this will trigger the scalar indexing of CuArray CUDA.@allowscalar @test t[1] == ( - state=ones(Float32, 2, 3) * 2, - next_state=ones(Float32, 2, 3) * 3, - action=ones(Float32, 2) * 2, - next_action=ones(Float32, 2) * 3, - reward=2.0f0, + state=ones(Float32, 2, 3) * 3, + next_state=ones(Float32, 2, 3) * 4, + action=ones(Float32, 2) * 3, + next_action=ones(Float32, 2) * 4, + reward=3.0f0, terminal=false, ) CUDA.@allowscalar @test t[end] == ( - state=ones(Float32, 2, 3) * 4, - next_state=ones(Float32, 2, 3) * 5, - action=ones(Float32, 2) * 4, - next_action=ones(Float32, 2) * 5, - reward=4.0f0, + state=ones(Float32, 2, 3) * 5, + next_state=ones(Float32, 2, 3) * 6, + action=ones(Float32, 2) * 5, + next_action=ones(Float32, 2) * 6, + reward=5.0f0, terminal=false, ) batch = t[1:3] @test size(batch.state) == (2, 3, 3) @test size(batch.action) == (2, 3) - @test batch.reward == [2.0, 3.0, 4.0] |> gpu + @test batch.reward == [3.0, 4.0, 5.0] |> gpu @test batch.terminal == Bool[0, 0, 0] |> gpu + end @testset "ElasticArraySARTSTraces" begin @@ -127,6 +136,8 @@ end ) @test t isa CircularArraySLARTTraces + @test ReinforcementLearningTrajectories.capacity(t) == 3 + @test CircularArrayBuffers.capacity(t) == 3 end @testset "CircularPrioritizedTraces-SARTS" begin @@ -136,6 +147,7 @@ end ), default_priority=1.0f0 ) + @test ReinforcementLearningTrajectories.capacity(t) == 3 push!(t, (state=0, action=0)) @@ -196,6 +208,7 @@ end ), default_priority=1.0f0 ) + @test ReinforcementLearningTrajectories.capacity(t) == 3 push!(t, (state=0, action=0)) From 1560ff5b4bdeb892fa2af03de0fa90ae08c46b90 Mon Sep 17 00:00:00 2001 From: Johannes Fischer Date: Thu, 2 Jan 2025 14:26:10 +0100 Subject: [PATCH 2/3] Remove whitespace --- src/episodes.jl | 10 ++++------ src/samplers.jl | 46 +++++++++++++++++++++++----------------------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/src/episodes.jl b/src/episodes.jl index 90a8b79..6b49aa5 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -5,9 +5,9 @@ using ElasticArrays: ElasticArray, ElasticVector """ EpisodesBuffer(traces::AbstractTraces) -Wraps an `AbstractTraces` object, usually the container of a `Trajectory`. +Wraps an `AbstractTraces` object, usually the container of a `Trajectory`. `EpisodesBuffer` tracks the indexes of the `traces` object that belong to the same episodes. -To that end, it stores +To that end, it stores 1. an vector `sampleable_inds` of Booleans that determine whether an index in Traces is legally sampleable (i.e., it is not the index of a last state of an episode); 2. a vector `episodes_lengths` that contains the total duration of the episode that each step belong to; @@ -32,7 +32,7 @@ end """ PartialNamedTuple(::NamedTuple) -Wraps a NamedTuple to signal an EpisodesBuffer that it is pushed into that it should +Wraps a NamedTuple to signal an EpisodesBuffer that it is pushed into that it should ignore the fact that this is a partial insertion. Used at the end of an episode to complete multiplex traces before moving to the next episode. """ @@ -118,8 +118,6 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T)) end elseif traces_signature <: Tuple traces_signature = traces_signature.parameters - - for tr in traces_signature if !(tr <: MultiplexTraces) #push a duplicate of last element as a dummy element, should never be sampled. @@ -171,7 +169,7 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) return nothing end -function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTuple to push without incrementing the step number. +function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTuple to push without incrementing the step number. push!(eb.traces, xs.namedtuple) eb.sampleable_inds[end-1] = 1 #completes the episode trajectory. end diff --git a/src/samplers.jl b/src/samplers.jl index e5443a7..21628bd 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -93,10 +93,10 @@ export MetaSampler """ MetaSampler(::NamedTuple) -Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a +Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a batch from each sampler. Used internally for algorithms that sample multiple times per epoch. -Note that a single "sampling" with a MetaSampler only increases the Trajectory controler +Note that a single "sampling" with a MetaSampler only increases the Trajectory controler count by 1, not by the number of internal samplers. This should be taken into account when initializing an agent. @@ -131,15 +131,15 @@ export MultiBatchSampler """ MultiBatchSampler(sampler, n) -Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination +Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination with MetaSampler to allow different sampling rates between samplers. -Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory +Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory controler count by 1, not by `n`. This should be taken into account when initializing an agent. # Example ``` -MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3), +MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3), critic = MultiBatchSampler(BatchSampler(100), 5)) ``` """ @@ -169,13 +169,13 @@ export NStepBatchSampler NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG) Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning. -The "next" element of Multiplexed traces (such as the next_state or the next_action) will be +The "next" element of Multiplexed traces (such as the next_state or the next_action) will be that in up to `n > 1` steps later in the buffer. The reward will be the discounted sum of the `n` rewards, with `γ` as the discount factor. -NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set +NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case -of partial observability, for example when the state is approximated by `stacksize` consecutive +of partial observability, for example when the state is approximated by `stacksize` consecutive frames. """ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG} @@ -187,17 +187,17 @@ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRN end NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...) -function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names} +function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names} @assert n >= 1 "n must be ≥ 1." ss = stacksize == 1 ? nothing : stacksize NStepBatchSampler{names, typeof(ss), typeof(rng)}(n, γ, batchsize, ss, rng) end #return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index. -function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer) +function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer) range = copy(eb.sampleable_inds) ns = Vector{Int}(undef, length(eb.sampleable_inds)) - stacksize = isnothing(s.stacksize) ? 1 : s.stacksize + stacksize = isnothing(s.stacksize) ? 1 : s.stacksize for idx in eachindex(range) step_number = eb.step_numbers[idx] range[idx] = step_number >= stacksize && eb.sampleable_inds[idx] @@ -258,9 +258,9 @@ end """ EpisodesSampler() -A sampler that samples all Episodes present in the Trajectory and divides them into +A sampler that samples all Episodes present in the Trajectory and divides them into Episode containers. Truncated Episodes (e.g. due to the buffer capacity) are sampled as well. -There will be at most one truncated episode and it will always be the first one. +There will be at most one truncated episode and it will always be the first one. """ struct EpisodesSampler{names} end @@ -295,7 +295,7 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names) idx += 1 end end - + return [make_episode(t, r, names) for r in ranges] end @@ -304,29 +304,29 @@ end """ MultiStepSampler{names}(batchsize, n, stacksize, rng) -Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index -`x`. The samples are returned in an array of batchsize elements. For each element, n is -truncated by the end of its episode. This means that the dimensions of each sample are not -the same. +Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index +`x`. The samples are returned in an array of batchsize elements. For each element, n is +truncated by the end of its episode. This means that the dimensions of each sample are not +the same. """ struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG} n::Int batchsize::Int stacksize::S - rng::R + rng::R end MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...) -function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names} +function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names} @assert n >= 1 "n must be ≥ 1." ss = stacksize == 1 ? nothing : stacksize MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng) end -function valid_range(s::MultiStepSampler, eb::EpisodesBuffer) +function valid_range(s::MultiStepSampler, eb::EpisodesBuffer) range = copy(eb.sampleable_inds) ns = Vector{Int}(undef, length(eb.sampleable_inds)) - stacksize = isnothing(s.stacksize) ? 1 : s.stacksize + stacksize = isnothing(s.stacksize) ? 1 : s.stacksize for idx in eachindex(range) step_number = eb.step_numbers[idx] range[idx] = step_number >= stacksize && eb.sampleable_inds[idx] @@ -353,7 +353,7 @@ function fetch(::MultiStepSampler, trace, ::Val, inds, ns) [trace[idx:(idx + ns[i] - 1)] for (i,idx) in enumerate(inds)] end -function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names} +function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names} [trace[[idx + i + n - 1 for i in -s.stacksize+1:0, n in 1:ns[j]]] for (j,idx) in enumerate(inds)] end From bfc15918b5bc6d93bb632e52a862a42ee0f68baa Mon Sep 17 00:00:00 2001 From: Johannes Fischer Date: Thu, 2 Jan 2025 16:01:29 +0100 Subject: [PATCH 3/3] Remove special treatment of SARTSA traces Remove methods specifically defined for SARTSA traces in EpisodesBuffer and CircularPrioritizedTraces --- src/common/CircularPrioritizedTraces.jl | 16 --- src/episodes.jl | 33 +----- test/common.jl | 21 ++-- test/episodes.jl | 141 ++++++++++++++---------- test/samplers.jl | 83 +++++++------- 5 files changed, 147 insertions(+), 147 deletions(-) diff --git a/src/common/CircularPrioritizedTraces.jl b/src/common/CircularPrioritizedTraces.jl index db4b409..76581af 100644 --- a/src/common/CircularPrioritizedTraces.jl +++ b/src/common/CircularPrioritizedTraces.jl @@ -34,22 +34,6 @@ function Base.push!(t::CircularPrioritizedTraces, x) end end -function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x) - initial_length = length(t.traces) - push!(t.traces, x) - if length(t.traces) == 1 - push!(t.keys, 1) - push!(t.priorities, t.default_priority) - elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 ) - # only add a key if the length changes after insertion of the tuple - # or if the trace is already at capacity - push!(t.keys, t.keys[end] + 1) - push!(t.priorities, t.default_priority) - else - # may be partial inserting at the first step, ignore it - end -end - function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys) if k === :priority @assert length(vs) == length(keys) diff --git a/src/episodes.jl b/src/episodes.jl index 6b49aa5..d4314d7 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -43,15 +43,13 @@ end # Capacity of an EpisodesBuffer is the capacity of the underlying traces + 1 for certain cases function is_capacity_plus_one(traces::AbstractTraces) if any(t->t isa MultiplexTraces, traces.traces) - # MultiplexTraces buffer next_state and next_action, so we need to add one to the capacity - return true - elseif traces isa CircularPrioritizedTraces - # CircularPrioritizedTraces buffer next_state and next_action, so we need to add one to the capacity + # MultiplexTraces buffer next_state or next_action, so we need to add one to the capacity return true else false end end +is_capacity_plus_one(traces::CircularPrioritizedTraces) = is_capacity_plus_one(traces.traces) function EpisodesBuffer(traces::AbstractTraces) cap = is_capacity_plus_one(traces) ? capacity(traces) + 1 : capacity(traces) @@ -70,7 +68,7 @@ function EpisodesBuffer(traces::AbstractTraces) end function Base.getindex(es::EpisodesBuffer, idx::Int...) - @boundscheck all(es.sampleable_inds[idx...]) + @boundscheck all(es.sampleable_inds[idx...]) || throw(BoundsError(es.sampleable_inds, idx)) getindex(es.traces, idx...) end @@ -79,6 +77,7 @@ function Base.getindex(es::EpisodesBuffer, idx...) end Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...) +capacity(eb::EpisodesBuffer) = capacity(eb.traces) Base.size(eb::EpisodesBuffer) = size(eb.traces) Base.length(eb::EpisodesBuffer) = length(eb.traces) Base.keys(eb::EpisodesBuffer) = keys(eb.traces) @@ -146,7 +145,7 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.episodes_lengths, 0) push!(eb.sampleable_inds, 0) elseif !partial #typical inserting - if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched + if haskey(eb,:next_action) # if trace has next_action if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode eb.sampleable_inds[end-1] = 1 # steps are indexable one step later end @@ -174,28 +173,6 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl eb.sampleable_inds[end-1] = 1 #completes the episode trajectory. end -function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple) - if max_length(eb) == capacity(eb.traces) - popfirst!(eb) - end - push!(eb.traces, xs.namedtuple) - eb.sampleable_inds[end-1] = 1 #completes the episode trajectory. -end - -function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}}) - if max_length(eb) == capacity(eb.traces) - addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal]) - xs = merge(xs.namedtuple, addition) - push!(eb.traces, xs) - pop!(eb.traces[:state].trace) - pop!(eb.traces[:reward]) - pop!(eb.traces[:terminal]) - else - push!(eb.traces, xs.namedtuple) - eb.sampleable_inds[end-1] = 1 - end -end - for f in (:pop!, :popfirst!) @eval function Base.$f(eb::EpisodesBuffer) $f(eb.episodes_lengths) diff --git a/test/common.jl b/test/common.jl index fdaf844..117d6ad 100644 --- a/test/common.jl +++ b/test/common.jl @@ -180,14 +180,14 @@ end default_priority=1.0f0 ) - eb = EpisodesBuffer(t) + eb = EpisodesBuffer(t) push!(eb, (state = 1, action = 1)) for i = 1:5 - push!(eb, (state = i+1, action =i+1, reward = i, terminal = false)) + push!(eb, (state = i+1, action = i+1, reward = i, terminal = false)) end push!(eb, (state = 7, action = 7)) for (j,i) = enumerate(8:11) - push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + push!(eb, (state = i, action = i, reward = i-1, terminal = false)) end s = BatchSampler(1000) b = sample(s, eb) @@ -222,6 +222,8 @@ end b = sample(s, t) + @test t[:priority] == [1.0f0, 1.0f0, 1.0f0] + t[:priority, [1, 2]] = [0, 0] # shouldn't be changed since [1,2] are old keys @@ -240,18 +242,19 @@ end ), default_priority=1.0f0 ) - - eb = EpisodesBuffer(t) + + eb = EpisodesBuffer(t) push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i, reward = i, terminal = false)) + push!(eb, (state = i+1, action = i, reward = i, terminal = false)) end push!(eb, PartialNamedTuple((action = 6,))) push!(eb, (state = 7,)) - for (j,i) = enumerate(8:11) - push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) + for i = 8:11 + push!(eb, (state = i, action = i-1, reward = i-1, terminal = false)) end - push!(eb, PartialNamedTuple((action=12,))) + push!(eb, PartialNamedTuple((action=11,))) + s = BatchSampler(1000) b = sample(s, eb) cm = counter(b[:state]) diff --git a/test/episodes.jl b/test/episodes.jl index 0932416..2633307 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -3,54 +3,65 @@ using CircularArrayBuffers using Test @testset "EpisodesBuffer" begin - @testset "with circular traces" begin + @testset "with circular SARTS traces" begin eb = EpisodesBuffer( CircularArraySARTSTraces(; capacity=10) ) - #push a first episode l=5 + + # push first episode (five steps) push!(eb, (state = 1,)) @test eb.sampleable_inds[end] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 for i = 1:5 - push!(eb, (state = i+1, action =i, reward = i, terminal = false)) + push!(eb, (state = i+1, action = i, reward = i, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 1 @test eb.step_numbers[end] == i + 1 @test eb.episodes_lengths[end-i:end] == fill(i, i+1) end + @test eb[end] == (state = 5, next_state = 6, action = 5, reward = 5, terminal = false) @test eb.sampleable_inds == [1,1,1,1,1,0] @test length(eb.traces) == 5 - #start new episode of 6 periods. + + # start second episode push!(eb, (state = 7,)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 @test eb.sampleable_inds == [1,1,1,1,1,0,0] - @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is filled as zero + @test eb[:reward][6] == 0 # 6 is not a valid index, filled with dummy value zero + @test_throws BoundsError eb[6] # 6 is not a valid index + @test_throws BoundsError eb[7] # 7 is not a valid index + + # push four steps of second episode ep2_len = 0 - for (j,i) = enumerate(8:11) + for (i,s) = enumerate(8:11) ep2_len += 1 - push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) + push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 1 - @test eb.step_numbers[end] == j + 1 - @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) + @test eb.step_numbers[end] == i + 1 + @test eb.episodes_lengths[end-i:end] == fill(ep2_len, ep2_len + 1) end + @test eb[end] == (state = 10, next_state = 11, action = 10, reward = 10, terminal = false) @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] - @test length(eb.traces) == 10 - #three last steps replace oldest steps in the buffer. + @test length(eb) == 10 + # push two more steps of second episode, which replace the oldest steps in the buffer for (i, s) = enumerate(12:13) ep2_len += 1 - push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 1 @test eb.step_numbers[end] == i + 1 + 4 @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end - #episode 1 + @test eb[end] == (state = 12, next_state = 13, action = 12, reward = 12, terminal = false) + @test eb.sampleable_inds == [1,1,1,0,1,1,1,1,1,1,0] + + # verify episode 2 for (i,s) in enumerate(3:13) if i in (4, 11) @test eb.sampleable_inds[i] == 0 @@ -62,8 +73,8 @@ using Test @test b[:state] == b[:action] == b[:reward] == s @test b[:next_state] == s + 1 end - #episode 2 - #start a third episode + + # push third episode push!(eb, (state = 14, )) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @@ -71,28 +82,28 @@ using Test @test eb.step_numbers[end] == 1 #push until it reaches it own start for (i,s) in enumerate(15:26) - push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) end @test eb.sampleable_inds == [fill(true, 10); [false]] @test eb.episodes_lengths == fill(length(15:26), 11) @test eb.step_numbers == [3:13;] - step = popfirst!(eb) + popfirst!(eb) @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9 @test first(eb.step_numbers) == 4 - step = pop!(eb) + pop!(eb) @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8 @test last(eb.step_numbers) == 12 @test size(eb) == size(eb.traces) == (8,) empty!(eb) @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) - show(eb); end - @testset "with PartialNamedTuple" begin + + @testset "with SARTSA traces and PartialNamedTuple" begin eb = EpisodesBuffer( CircularArraySARTSATraces(; capacity=10) ) - #push a first episode l=5 + # push first episode (five steps) push!(eb, (state = 1,)) @test eb.sampleable_inds[end] == 0 @test eb.episodes_lengths[end] == 0 @@ -107,38 +118,46 @@ using Test @test eb.step_numbers[end] == i + 1 @test eb.episodes_lengths[end-i:end] == fill(i, i+1) end + @test eb.sampleable_inds == [1,1,1,1,0,0] push!(eb, PartialNamedTuple((action = 6,))) @test eb.sampleable_inds == [1,1,1,1,1,0] @test length(eb.traces) == 5 - #start new episode of 6 periods. + + # start second episode push!(eb, (state = 7,)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 @test eb.sampleable_inds == [1,1,1,1,1,0,0] - @test eb[:action][6] == 6 - @test eb[:next_action][5] == 6 - @test eb[:reward][6] == 0 #6 is not a valid index, the reward there is dummy, filled as zero - @test_throws BoundsError eb[6] #6 is not a valid index, the reward there is dummy, filled as zero + @test eb[5][:next_action] == eb[:next_action][5] == 6 + @test eb[:reward][6] == 0 # 6 is not a valid index, the reward there is dummy, filled as zero + @test_throws BoundsError eb[6] # 6 is not a valid index ep2_len = 0 - for (j,i) = enumerate(8:11) + # push four steps of second episode + for (i,s) = enumerate(8:11) ep2_len += 1 - push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) + push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 if eb.step_numbers[end] > 2 @test eb.sampleable_inds[end-2] == 1 end - @test eb.step_numbers[end] == j + 1 - @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) + @test eb.step_numbers[end] == i + 1 + @test eb.episodes_lengths[end-i:end] == fill(ep2_len, ep2_len + 1) end @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0] - @test length(eb.traces) == 9 #an action is missing at this stage - #three last steps replace oldest steps in the buffer. + @test length(eb.traces) == 9 # an action is missing at this stage + @test eb.sampleable_inds[end] == 0 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end + + # push two more steps of second episode, which replace the oldest steps in the buffer for (i, s) = enumerate(12:13) ep2_len += 1 - push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 if eb.step_numbers[end] > 2 @@ -149,7 +168,8 @@ using Test end push!(eb, PartialNamedTuple((action = 13,))) @test length(eb.traces) == 10 - #episode 1 + + # verify episode 2 for (i,s) in enumerate(3:13) if i in (4, 11) @test eb.sampleable_inds[i] == 0 @@ -161,16 +181,16 @@ using Test @test b[:state] == b[:action] == b[:reward] == s @test b[:next_state] == b[:next_action] == s + 1 end - #episode 2 - #start a third episode + + # push third episode push!(eb, (state = 14,)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 - #push until it reaches it own start + # push until it reaches it own start for (i,s) in enumerate(15:26) - push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) end push!(eb, PartialNamedTuple((action = 26,))) @test eb.sampleable_inds == [fill(true, 10); [false]] @@ -185,7 +205,6 @@ using Test @test size(eb) == size(eb.traces) == (8,) empty!(eb) @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) - show(eb); end @testset "with vector traces" begin eb = EpisodesBuffer( @@ -193,7 +212,7 @@ using Test state=Int[], reward=Int[]) ) - push!(eb, (state = 1,)) #partial inserting + push!(eb, (state = 1,)) # partial inserting for i = 1:15 push!(eb, (state = i+1, reward =i)) end @@ -201,7 +220,7 @@ using Test @test eb.sampleable_inds == [fill(true, 15); [false]] @test all(==(15), eb.episodes_lengths) @test eb.step_numbers == [1:16;] - push!(eb, (state = 1,)) #partial inserting + push!(eb, (state = 1,)) # partial inserting for i = 1:15 push!(eb, (state = i+1, reward =i)) end @@ -210,11 +229,12 @@ using Test @test eb.step_numbers == [1:16;1:16] @test length(eb) == 31 end - @testset "with ElasticArraySARTSTraces traces" begin + + @testset "with ElasticArraySARTSTraces" begin eb = EpisodesBuffer( ElasticArraySARTSTraces() ) - #push a first episode l=5 + # push first episode (five steps) push!(eb, (state = 1,)) @test eb.sampleable_inds[end] == 0 @test eb.episodes_lengths[end] == 0 @@ -228,15 +248,18 @@ using Test end @test eb.sampleable_inds == [1,1,1,1,1,0] @test length(eb.traces) == 5 - #start new episode of 6 periods. + + # start second episode push!(eb, (state = 7,)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 @test eb.sampleable_inds == [1,1,1,1,1,0,0] - @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is filled as zero + @test eb[:reward][6] == 0 #6 is not a valid index, the reward there is dummy, filled as zero + @test_throws BoundsError eb[6] #6 is not a valid index ep2_len = 0 + # push four steps of second episode for (j,i) = enumerate(8:11) ep2_len += 1 push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) @@ -248,6 +271,7 @@ using Test @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] @test length(eb.traces) == 10 + # push two more steps of second episode, which replace the oldest steps in the buffer for (i, s) = enumerate(12:13) ep2_len += 1 push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) @@ -256,7 +280,7 @@ using Test @test eb.step_numbers[end] == i + 1 + 4 @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end - #episode 1 + # verify episode 2 for i in 3:13 if i in (6, 13) @test eb.sampleable_inds[i] == 0 @@ -268,14 +292,15 @@ using Test @test b[:state] == b[:action] == b[:reward] == i @test b[:next_state] == i + 1 end - #episode 2 - #start a third episode + + # push third episode push!(eb, (state = 14, )) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 + # push until it reaches it own start for (i,s) in enumerate(15:26) push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) end @@ -292,15 +317,14 @@ using Test @test size(eb) == size(eb.traces) == (8,) empty!(eb) @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) - show(eb); =# end - @testset "ElasticArraySARTSATraces with PartialNamedTuple" begin + @testset "ElasticArraySARTSATraces with PartialNamedTuple" begin eb = EpisodesBuffer( ElasticArraySARTSATraces() ) - #push a first episode l=5 + # push first episode (five steps) push!(eb, (state = 1,)) @test eb.sampleable_inds[end] == 0 @test eb.episodes_lengths[end] == 0 @@ -318,7 +342,8 @@ using Test push!(eb, PartialNamedTuple((action = 6,))) @test eb.sampleable_inds == [1,1,1,1,1,0] @test length(eb.traces) == 5 - #start new episode of 6 periods. + + # start second episode push!(eb, (state = 7,)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @@ -329,6 +354,7 @@ using Test @test eb[:next_action][5] == 6 @test eb[:reward][6] == 0 #6 is not a valid index, the reward there is dummy, filled as zero ep2_len = 0 + # push four steps of second episode for (j,i) = enumerate(8:11) ep2_len += 1 push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) @@ -341,7 +367,8 @@ using Test @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) end @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0] - @test length(eb.traces) == 9 #an action is missing at this stage + @test length(eb.traces) == 9 # an action is missing at this stage + # push two more steps of second episode, which replace the oldest steps in the buffer for (i, s) = enumerate(12:13) ep2_len += 1 push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) @@ -355,7 +382,8 @@ using Test end push!(eb, PartialNamedTuple((action = 13,))) @test length(eb.traces) == 12 - #episode 1 + + # verify episode 2 for i in 1:13 if i in (6, 13) @test eb.sampleable_inds[i] == 0 @@ -367,8 +395,8 @@ using Test @test b[:state] == b[:action] == b[:reward] == i @test b[:next_state] == b[:next_action] == i + 1 end - #episode 2 - #start a third episode + + # push third episode push!(eb, (state = 14,)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @@ -392,7 +420,6 @@ using Test @test size(eb) == size(eb.traces) == (8,) empty!(eb) @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) - show(eb); =# end end diff --git a/test/samplers.jl b/test/samplers.jl index 2565ddd..dc931a2 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -13,17 +13,20 @@ import ReinforcementLearningTrajectories.fetch @test keys(b) == (:state, :action) @test size(b.state) == (3, 4, sz) @test size(b.action) == (sz,) - + #In EpisodesBuffer - eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10)) - push!(eb, (state = 1, action = 1)) + eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10)) + push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i+1, reward = i, terminal = false)) + push!(eb, (state = i+1, action = i, reward = i, terminal = false)) end - push!(eb, (state = 7, action = 7)) + push!(eb, PartialNamedTuple((next_action = 6,))) + push!(eb, (state = 7,)) for (j,i) = enumerate(8:11) - push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + push!(eb, (state = i, action = i-1, reward = i-1, terminal = false)) end + push!(eb, PartialNamedTuple((next_action = 11,))) + s = BatchSampler(1000) b = sample(s, eb) cm = counter(b[:state]) @@ -70,7 +73,7 @@ import ReinforcementLearningTrajectories.fetch @test length(batches) == 11 @test length(batches[1][:policy][:a]) == 3 @test length(batches[1][:critic]) == 2 # we sampled 2 batches for critic - @test length(batches[1][:critic][1][:b]) == 5 #each batch is 5 samples + @test length(batches[1][:critic][1][:b]) == 5 #each batch is 5 samples end #! format: off @@ -79,17 +82,20 @@ import ReinforcementLearningTrajectories.fetch n_stack = 2 n_horizon = 3 batchsize = 1000 - eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10)) + eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10)) s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stacksize=n_stack, batchsize=batchsize) - push!(eb, (state = 1, action = 1)) + push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5)) + push!(eb, (state = i+1, action =i, reward = i, terminal = i == 5)) end - push!(eb, (state = 7, action = 7)) + push!(eb, PartialNamedTuple((next_action = 6,))) + push!(eb, (state = 7,)) for (j,i) = enumerate(8:11) - push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + push!(eb, (state = i, action = i-1, reward = i-1, terminal = false)) end + push!(eb, PartialNamedTuple((next_action = 11,))) + weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) @test weights == [0,1,1,1,1,0,0,1,1,1,0] @test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode @@ -108,7 +114,7 @@ import ReinforcementLearningTrajectories.fetch @test next_states == [4 5 5 5 10 10 10; 5 6 6 6 11 11 11] @test all(in(eachcol(next_states)), unique(eachcol(batch[:next_state]))) - #action: samples normally + # action: samples normally actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds]) @test actions == inds @test all(in(actions), unique(batch[:action])) @@ -128,17 +134,17 @@ import ReinforcementLearningTrajectories.fetch γ = 0.99 n_horizon = 3 batchsize = 4 - eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0)) + eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0)) s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize) - + push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i, reward = i, terminal = i == 5)) + push!(eb, (state = i+1, action = i, reward = i, terminal = i == 5)) end push!(eb, PartialNamedTuple((action=6,))) push!(eb, (state = 7,)) for (j,i) = enumerate(7:10) - push!(eb, (state = i+1, action =i, reward = i, terminal = i==10)) + push!(eb, (state = i+1, action = i, reward = i, terminal = i==10)) end push!(eb, PartialNamedTuple((action = 11,))) weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) @@ -151,14 +157,14 @@ import ReinforcementLearningTrajectories.fetch @testset "EpisodesSampler" begin s = EpisodesSampler() - eb = EpisodesBuffer(CircularArraySARTSTraces(capacity=10)) + eb = EpisodesBuffer(CircularArraySARTSTraces(capacity=10)) push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i, reward = i, terminal = false)) + push!(eb, (state = i+1, action = i, reward = i, terminal = false)) end push!(eb, (state = 7,)) for (j,i) = enumerate(8:12) - push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) + push!(eb, (state = i, action = i-1, reward = i-1, terminal = false)) end b = sample(s, eb) @@ -171,25 +177,24 @@ import ReinforcementLearningTrajectories.fetch @test b[2][:next_state] == [8:12;] @test b[2][:action] == [7:11;] @test b[2][:reward] == [7:11;] - + for (j,i) = enumerate(2:5) push!(eb, (state = i, action =i, reward = i-1, terminal = false)) end #only the last state of the first episode is still buffered. Should not be sampled. b = sample(s, eb) @test length(b) == 1 - - #with specified traces + # with specified traces s = EpisodesSampler{(:state,)}() - eb = EpisodesBuffer(CircularArraySARTSTraces(capacity=10)) - push!(eb, (state = 1, action = 1)) + eb = EpisodesBuffer(CircularArraySARTSTraces(capacity=10)) + push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i+1, reward = i, terminal = false)) + push!(eb, (state = i+1, action = i, reward = i, terminal = false)) end - push!(eb, (state = 7, action = 7)) + push!(eb, (state = 7,)) for (j,i) = enumerate(8:12) - push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + push!(eb, (state = i, action = i-1, reward = i-1, terminal = false)) end b = sample(s, eb) @@ -202,34 +207,38 @@ import ReinforcementLearningTrajectories.fetch n_stack = 2 n_horizon = 3 batchsize = 1000 - eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10)) + eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10)) s1 = MultiStepSampler(eb, n=n_horizon, stacksize=n_stack, batchsize=batchsize) - push!(eb, (state = 1, action = 1)) + push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5)) + push!(eb, (state = i+1, action = i, reward = i, terminal = i == 5)) end - push!(eb, (state = 7, action = 7)) + push!(eb, PartialNamedTuple((action=6,))) + push!(eb, (state = 7,)) for (j,i) = enumerate(8:11) - push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + push!(eb, (state = i, action = i-1, reward = i-1, terminal = false)) end + push!(eb, PartialNamedTuple((action=11,))) + weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) @test weights == [0,1,1,1,1,0,0,1,1,1,0] - @test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode + @test ns == [3,3,3,2,1,-1,3,3,2,1,0] # the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode inds = [i for i in eachindex(weights) if weights[i] == 1] batch = sample(s1, eb) for key in keys(eb) @test haskey(batch, key) end - #state and next_state: samples with stacksize + + # state and next_state: samples with stacksize states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds]) @test states == [[1 2 3; 2 3 4], [2 3 4; 3 4 5], [3 4; 4 5], [4; 5;;], [7 8 9; 8 9 10], [8 9; 9 10], [9; 10;;]] @test all(in(states), batch[:state]) - #next_state: samples with stacksize and nsteps forward + # next_state: samples with stacksize and nsteps forward next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds]) @test next_states == [[2 3 4; 3 4 5], [3 4 5; 4 5 6], [4 5; 5 6], [5; 6;;], [8 9 10; 9 10 11], [9 10; 10 11], [10; 11;;]] @test all(in(next_states), batch[:next_state]) - #all other traces sample normally + # all other traces sample normally actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds]) @test actions == [[2,3,4], [3,4,5], [4,5], [5], [8,9,10], [9,10],[10]] @test all(in(actions), batch[:action])