Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove special handling of SARTSA traces #75

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/common/CircularArraySARTSATraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ function CircularArraySARTSATraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+2)) +
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity+1),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity+1),
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = minimum(map(capacity,t.traces))
6 changes: 3 additions & 3 deletions src/common/CircularArraySARTSTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ())
terminal=Bool => ()
)
state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
Expand All @@ -32,4 +32,4 @@
)
end

CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = minimum(map(capacity,t.traces))

Check warning on line 35 in src/common/CircularArraySARTSTraces.jl

View check run for this annotation

Codecov / codecov/patch

src/common/CircularArraySARTSTraces.jl#L35

Added line #L35 was not covered by tests
2 changes: 1 addition & 1 deletion src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ function CircularArraySLARTTraces(;
)
end

CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = CircularArrayBuffers.capacity(minimum(map(capacity,t.traces)))
CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = minimum(map(capacity,t.traces))
22 changes: 1 addition & 21 deletions src/common/CircularPrioritizedTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ end
function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
new_names = (:key, :priority, names...)
new_Ts = Tuple{Int,Float32,Ts.parameters...}
if traces isa CircularArraySARTSATraces
c = capacity(traces) - 1
else
c = capacity(traces)
end
c = capacity(traces)
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
CircularVectorBuffer{Int}(c),
SumTree(c),
Expand All @@ -38,22 +34,6 @@ function Base.push!(t::CircularPrioritizedTraces, x)
end
end

function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x)
initial_length = length(t.traces)
push!(t.traces, x)
if length(t.traces) == 1
push!(t.keys, 1)
push!(t.priorities, t.default_priority)
elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 )
# only add a key if the length changes after insertion of the tuple
# or if the trace is already at capacity
push!(t.keys, t.keys[end] + 1)
push!(t.priorities, t.default_priority)
else
# may be partial inserting at the first step, ignore it
end
end

function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
if k === :priority
@assert length(vs) == length(keys)
Expand Down
43 changes: 9 additions & 34 deletions src/episodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
"""
EpisodesBuffer(traces::AbstractTraces)

Wraps an `AbstractTraces` object, usually the container of a `Trajectory`.
Wraps an `AbstractTraces` object, usually the container of a `Trajectory`.
`EpisodesBuffer` tracks the indexes of the `traces` object that belong to the same episodes.
To that end, it stores
To that end, it stores
1. an vector `sampleable_inds` of Booleans that determine whether an index in Traces is legally sampleable
(i.e., it is not the index of a last state of an episode);
2. a vector `episodes_lengths` that contains the total duration of the episode that each step belong to;
Expand All @@ -32,7 +32,7 @@
"""
PartialNamedTuple(::NamedTuple)

Wraps a NamedTuple to signal an EpisodesBuffer that it is pushed into that it should
Wraps a NamedTuple to signal an EpisodesBuffer that it is pushed into that it should
ignore the fact that this is a partial insertion. Used at the end of an episode to
complete multiplex traces before moving to the next episode.
"""
Expand All @@ -43,15 +43,13 @@
# Capacity of an EpisodesBuffer is the capacity of the underlying traces + 1 for certain cases
function is_capacity_plus_one(traces::AbstractTraces)
if any(t->t isa MultiplexTraces, traces.traces)
# MultiplexTraces buffer next_state and next_action, so we need to add one to the capacity
return true
elseif traces isa CircularPrioritizedTraces
# CircularPrioritizedTraces buffer next_state and next_action, so we need to add one to the capacity
# MultiplexTraces buffer next_state or next_action, so we need to add one to the capacity
return true
else
false
end
end
is_capacity_plus_one(traces::CircularPrioritizedTraces) = is_capacity_plus_one(traces.traces)

function EpisodesBuffer(traces::AbstractTraces)
cap = is_capacity_plus_one(traces) ? capacity(traces) + 1 : capacity(traces)
Expand All @@ -70,7 +68,7 @@
end

function Base.getindex(es::EpisodesBuffer, idx::Int...)
@boundscheck all(es.sampleable_inds[idx...])
@boundscheck all(es.sampleable_inds[idx...]) || throw(BoundsError(es.sampleable_inds, idx))
getindex(es.traces, idx...)
end

Expand All @@ -79,6 +77,7 @@
end

Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...)
capacity(eb::EpisodesBuffer) = capacity(eb.traces)

Check warning on line 80 in src/episodes.jl

View check run for this annotation

Codecov / codecov/patch

src/episodes.jl#L80

Added line #L80 was not covered by tests
Base.size(eb::EpisodesBuffer) = size(eb.traces)
Base.length(eb::EpisodesBuffer) = length(eb.traces)
Base.keys(eb::EpisodesBuffer) = keys(eb.traces)
Expand Down Expand Up @@ -118,8 +117,6 @@
end
elseif traces_signature <: Tuple
traces_signature = traces_signature.parameters


for tr in traces_signature
if !(tr <: MultiplexTraces)
#push a duplicate of last element as a dummy element, should never be sampled.
Expand Down Expand Up @@ -148,7 +145,7 @@
push!(eb.episodes_lengths, 0)
push!(eb.sampleable_inds, 0)
elseif !partial #typical inserting
if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched
if haskey(eb,:next_action) # if trace has next_action
if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode
eb.sampleable_inds[end-1] = 1 # steps are indexable one step later
end
Expand All @@ -171,33 +168,11 @@
return nothing
end

function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTuple to push without incrementing the step number.
push!(eb.traces, xs.namedtuple)
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
end

function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple)
if max_length(eb) == capacity(eb.traces)
popfirst!(eb)
end
function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTuple to push without incrementing the step number.
push!(eb.traces, xs.namedtuple)
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
end

function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}})
if max_length(eb) == capacity(eb.traces)
addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal])
xs = merge(xs.namedtuple, addition)
push!(eb.traces, xs)
pop!(eb.traces[:state].trace)
pop!(eb.traces[:reward])
pop!(eb.traces[:terminal])
else
push!(eb.traces, xs.namedtuple)
eb.sampleable_inds[end-1] = 1
end
end

for f in (:pop!, :popfirst!)
@eval function Base.$f(eb::EpisodesBuffer)
$f(eb.episodes_lengths)
Expand Down
46 changes: 23 additions & 23 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ export MetaSampler
"""
MetaSampler(::NamedTuple)

Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a
Wraps a NamedTuple containing multiple samplers. When sampled, returns a named tuple with a
batch from each sampler.
Used internally for algorithms that sample multiple times per epoch.
Note that a single "sampling" with a MetaSampler only increases the Trajectory controler
Note that a single "sampling" with a MetaSampler only increases the Trajectory controler
count by 1, not by the number of internal samplers. This should be taken into account when
initializing an agent.

Expand Down Expand Up @@ -131,15 +131,15 @@ export MultiBatchSampler
"""
MultiBatchSampler(sampler, n)

Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination
Wraps a sampler. When sampled, will sample n batches using sampler. Useful in combination
with MetaSampler to allow different sampling rates between samplers.
Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory
Note that a single "sampling" with a MultiBatchSampler only increases the Trajectory
controler count by 1, not by `n`. This should be taken into account when
initializing an agent.

# Example
```
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
MetaSampler(policy = MultiBatchSampler(BatchSampler(10), 3),
critic = MultiBatchSampler(BatchSampler(100), 5))
```
"""
Expand Down Expand Up @@ -169,13 +169,13 @@ export NStepBatchSampler
NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)

Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
that in up to `n > 1` steps later in the buffer. The reward will be
the discounted sum of the `n` rewards, with `γ` as the discount factor.

NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize` is set
to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
of partial observability, for example when the state is approximated by `stacksize` consecutive
of partial observability, for example when the state is approximated by `stacksize` consecutive
frames.
"""
mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
Expand All @@ -187,17 +187,17 @@ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRN
end

NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...)
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names}
@assert n >= 1 "n must be ≥ 1."
ss = stacksize == 1 ? nothing : stacksize
NStepBatchSampler{names, typeof(ss), typeof(rng)}(n, γ, batchsize, ss, rng)
end

#return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
function valid_range(s::NStepBatchSampler, eb::EpisodesBuffer)
range = copy(eb.sampleable_inds)
ns = Vector{Int}(undef, length(eb.sampleable_inds))
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
for idx in eachindex(range)
step_number = eb.step_numbers[idx]
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
Expand Down Expand Up @@ -258,9 +258,9 @@ end
"""
EpisodesSampler()

A sampler that samples all Episodes present in the Trajectory and divides them into
A sampler that samples all Episodes present in the Trajectory and divides them into
Episode containers. Truncated Episodes (e.g. due to the buffer capacity) are sampled as well.
There will be at most one truncated episode and it will always be the first one.
There will be at most one truncated episode and it will always be the first one.
"""
struct EpisodesSampler{names}
end
Expand Down Expand Up @@ -295,7 +295,7 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
idx += 1
end
end

return [make_episode(t, r, names) for r in ranges]
end

Expand All @@ -304,29 +304,29 @@ end
"""
MultiStepSampler{names}(batchsize, n, stacksize, rng)

Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
`x`. The samples are returned in an array of batchsize elements. For each element, n is
truncated by the end of its episode. This means that the dimensions of each sample are not
the same.
Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
`x`. The samples are returned in an array of batchsize elements. For each element, n is
truncated by the end of its episode. This means that the dimensions of each sample are not
the same.
"""
struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG}
n::Int
batchsize::Int
stacksize::S
rng::R
rng::R
end

MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...)
function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names}
function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names}
@assert n >= 1 "n must be ≥ 1."
ss = stacksize == 1 ? nothing : stacksize
MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng)
end

function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
function valid_range(s::MultiStepSampler, eb::EpisodesBuffer)
range = copy(eb.sampleable_inds)
ns = Vector{Int}(undef, length(eb.sampleable_inds))
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
stacksize = isnothing(s.stacksize) ? 1 : s.stacksize
for idx in eachindex(range)
step_number = eb.step_numbers[idx]
range[idx] = step_number >= stacksize && eb.sampleable_inds[idx]
Expand All @@ -353,7 +353,7 @@ function fetch(::MultiStepSampler, trace, ::Val, inds, ns)
[trace[idx:(idx + ns[i] - 1)] for (i,idx) in enumerate(inds)]
end

function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names}
[trace[[idx + i + n - 1 for i in -s.stacksize+1:0, n in 1:ns[j]]] for (j,idx) in enumerate(inds)]
end

Expand Down
Loading
Loading