Skip to content

Commit

Permalink
Remove special treatment of SARTSA traces
Browse files Browse the repository at this point in the history
Remove methods specifically defined for SARTSA traces in EpisodesBuffer and CircularPrioritizedTraces
  • Loading branch information
johannes-fischer committed Jan 2, 2025
1 parent 1560ff5 commit bfc1591
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 147 deletions.
16 changes: 0 additions & 16 deletions src/common/CircularPrioritizedTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,6 @@ function Base.push!(t::CircularPrioritizedTraces, x)
end
end

function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x)
initial_length = length(t.traces)
push!(t.traces, x)
if length(t.traces) == 1
push!(t.keys, 1)
push!(t.priorities, t.default_priority)
elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 )
# only add a key if the length changes after insertion of the tuple
# or if the trace is already at capacity
push!(t.keys, t.keys[end] + 1)
push!(t.priorities, t.default_priority)
else
# may be partial inserting at the first step, ignore it
end
end

function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
if k === :priority
@assert length(vs) == length(keys)
Expand Down
33 changes: 5 additions & 28 deletions src/episodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@ end
# Capacity of an EpisodesBuffer is the capacity of the underlying traces + 1 for certain cases
function is_capacity_plus_one(traces::AbstractTraces)
if any(t->t isa MultiplexTraces, traces.traces)
# MultiplexTraces buffer next_state and next_action, so we need to add one to the capacity
return true
elseif traces isa CircularPrioritizedTraces
# CircularPrioritizedTraces buffer next_state and next_action, so we need to add one to the capacity
# MultiplexTraces buffer next_state or next_action, so we need to add one to the capacity
return true
else
false
end
end
is_capacity_plus_one(traces::CircularPrioritizedTraces) = is_capacity_plus_one(traces.traces)

function EpisodesBuffer(traces::AbstractTraces)
cap = is_capacity_plus_one(traces) ? capacity(traces) + 1 : capacity(traces)
Expand All @@ -70,7 +68,7 @@ function EpisodesBuffer(traces::AbstractTraces)
end

function Base.getindex(es::EpisodesBuffer, idx::Int...)
@boundscheck all(es.sampleable_inds[idx...])
@boundscheck all(es.sampleable_inds[idx...]) || throw(BoundsError(es.sampleable_inds, idx))
getindex(es.traces, idx...)
end

Expand All @@ -79,6 +77,7 @@ function Base.getindex(es::EpisodesBuffer, idx...)
end

Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...)
capacity(eb::EpisodesBuffer) = capacity(eb.traces)

Check warning on line 80 in src/episodes.jl

View check run for this annotation

Codecov / codecov/patch

src/episodes.jl#L80

Added line #L80 was not covered by tests
Base.size(eb::EpisodesBuffer) = size(eb.traces)
Base.length(eb::EpisodesBuffer) = length(eb.traces)
Base.keys(eb::EpisodesBuffer) = keys(eb.traces)
Expand Down Expand Up @@ -146,7 +145,7 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
push!(eb.episodes_lengths, 0)
push!(eb.sampleable_inds, 0)
elseif !partial #typical inserting
if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched
if haskey(eb,:next_action) # if trace has next_action
if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode
eb.sampleable_inds[end-1] = 1 # steps are indexable one step later
end
Expand Down Expand Up @@ -174,28 +173,6 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
end

function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple)
if max_length(eb) == capacity(eb.traces)
popfirst!(eb)
end
push!(eb.traces, xs.namedtuple)
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
end

function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}})
if max_length(eb) == capacity(eb.traces)
addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal])
xs = merge(xs.namedtuple, addition)
push!(eb.traces, xs)
pop!(eb.traces[:state].trace)
pop!(eb.traces[:reward])
pop!(eb.traces[:terminal])
else
push!(eb.traces, xs.namedtuple)
eb.sampleable_inds[end-1] = 1
end
end

for f in (:pop!, :popfirst!)
@eval function Base.$f(eb::EpisodesBuffer)
$f(eb.episodes_lengths)
Expand Down
21 changes: 12 additions & 9 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ end
default_priority=1.0f0
)

eb = EpisodesBuffer(t)
eb = EpisodesBuffer(t)
push!(eb, (state = 1, action = 1))
for i = 1:5
push!(eb, (state = i+1, action =i+1, reward = i, terminal = false))
push!(eb, (state = i+1, action = i+1, reward = i, terminal = false))
end
push!(eb, (state = 7, action = 7))
for (j,i) = enumerate(8:11)
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
push!(eb, (state = i, action = i, reward = i-1, terminal = false))
end
s = BatchSampler(1000)
b = sample(s, eb)
Expand Down Expand Up @@ -222,6 +222,8 @@ end

b = sample(s, t)

@test t[:priority] == [1.0f0, 1.0f0, 1.0f0]

t[:priority, [1, 2]] = [0, 0]

# shouldn't be changed since [1,2] are old keys
Expand All @@ -240,18 +242,19 @@ end
),
default_priority=1.0f0
)
eb = EpisodesBuffer(t)

eb = EpisodesBuffer(t)
push!(eb, (state = 1,))
for i = 1:5
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
push!(eb, (state = i+1, action = i, reward = i, terminal = false))
end
push!(eb, PartialNamedTuple((action = 6,)))
push!(eb, (state = 7,))
for (j,i) = enumerate(8:11)
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
for i = 8:11
push!(eb, (state = i, action = i-1, reward = i-1, terminal = false))
end
push!(eb, PartialNamedTuple((action=12,)))
push!(eb, PartialNamedTuple((action=11,)))

s = BatchSampler(1000)
b = sample(s, eb)
cm = counter(b[:state])
Expand Down
Loading

0 comments on commit bfc1591

Please sign in to comment.