Skip to content

Commit

Permalink
add bounds check
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremiah Lewis committed Mar 22, 2024
1 parent 58bfea5 commit 7d0d053
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions src/episodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,20 @@ function EpisodesBuffer(traces::AbstractTraces)
end
end

Base.getindex(es::EpisodesBuffer, idx...) = getindex(es.traces, idx...)
Base.setindex!(es::EpisodesBuffer, idx...) = setindex!(es.traces, idx...)
Base.size(es::EpisodesBuffer) = size(es.traces)
Base.length(es::EpisodesBuffer) = length(es.traces)
Base.keys(es::EpisodesBuffer) = keys(es.traces)
Base.keys(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(es.traces.traces)
function Base.getindex(es::EpisodesBuffer, idx::Int...)
@boundscheck all(es.sampleable_inds[idx...])
getindex(es.traces, idx...)
end

function Base.getindex(es::EpisodesBuffer, idx...)
getindex(es.traces, idx...)
end

Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...)
Base.size(eb::EpisodesBuffer) = size(eb.traces)
Base.length(eb::EpisodesBuffer) = length(eb.traces)
Base.keys(eb::EpisodesBuffer) = keys(eb.traces)
Base.keys(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(eb.traces.traces)
function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where {names}
s = nameof(typeof(eb))
t = eb.traces
Expand All @@ -83,7 +91,7 @@ function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where
end

ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps.
ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs)
ispartial_insert(eb::EpisodesBuffer, xs) = ispartial_insert(eb.traces, xs)
ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs)

function pad!(trace::Trace)
Expand Down Expand Up @@ -126,9 +134,9 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
return :($ex)
end

fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces)
fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces)

fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces)
fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces)

function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
push!(eb.traces, xs)
Expand Down Expand Up @@ -165,17 +173,17 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
end

for f in (:pop!, :popfirst!)
@eval function Base.$f(es::EpisodesBuffer)
$f(es.episodes_lengths)
$f(es.sampleable_inds)
$f(es.step_numbers)
$f(es.traces)
@eval function Base.$f(eb::EpisodesBuffer)
$f(eb.episodes_lengths)
$f(eb.sampleable_inds)
$f(eb.step_numbers)
$f(eb.traces)
end
end

function Base.empty!(es::EpisodesBuffer)
empty!(es.traces)
empty!(es.episodes_lengths)
empty!(es.sampleable_inds)
empty!(es.step_numbers)
function Base.empty!(eb::EpisodesBuffer)
empty!(eb.traces)
empty!(eb.episodes_lengths)
empty!(eb.sampleable_inds)
empty!(eb.step_numbers)
end

0 comments on commit 7d0d053

Please sign in to comment.