From 7d0d0530a2f399185ca95f9be0cf594815251714 Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <--get> Date: Fri, 22 Mar 2024 17:19:34 +0100 Subject: [PATCH] add bounds check --- src/episodes.jl | 46 +++++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/src/episodes.jl b/src/episodes.jl index f37e7f6..e93fb92 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -69,12 +69,20 @@ function EpisodesBuffer(traces::AbstractTraces) end end -Base.getindex(es::EpisodesBuffer, idx...) = getindex(es.traces, idx...) -Base.setindex!(es::EpisodesBuffer, idx...) = setindex!(es.traces, idx...) -Base.size(es::EpisodesBuffer) = size(es.traces) -Base.length(es::EpisodesBuffer) = length(es.traces) -Base.keys(es::EpisodesBuffer) = keys(es.traces) -Base.keys(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(es.traces.traces) +function Base.getindex(es::EpisodesBuffer, idx::Int...) + @boundscheck all(es.sampleable_inds[idx...]) + getindex(es.traces, idx...) +end + +function Base.getindex(es::EpisodesBuffer, idx...) + getindex(es.traces, idx...) +end + +Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...) +Base.size(eb::EpisodesBuffer) = size(eb.traces) +Base.length(eb::EpisodesBuffer) = length(eb.traces) +Base.keys(eb::EpisodesBuffer) = keys(eb.traces) +Base.keys(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(eb.traces.traces) function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where {names} s = nameof(typeof(eb)) t = eb.traces @@ -83,7 +91,7 @@ function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where end ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps. -ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs) +ispartial_insert(eb::EpisodesBuffer, xs) = ispartial_insert(eb.traces, xs) ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs) function pad!(trace::Trace) @@ -126,9 +134,9 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T)) return :($ex) end -fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces) +fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces) -fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces) +fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces) function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.traces, xs) @@ -165,17 +173,17 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl end for f in (:pop!, :popfirst!) - @eval function Base.$f(es::EpisodesBuffer) - $f(es.episodes_lengths) - $f(es.sampleable_inds) - $f(es.step_numbers) - $f(es.traces) + @eval function Base.$f(eb::EpisodesBuffer) + $f(eb.episodes_lengths) + $f(eb.sampleable_inds) + $f(eb.step_numbers) + $f(eb.traces) end end -function Base.empty!(es::EpisodesBuffer) - empty!(es.traces) - empty!(es.episodes_lengths) - empty!(es.sampleable_inds) - empty!(es.step_numbers) +function Base.empty!(eb::EpisodesBuffer) + empty!(eb.traces) + empty!(eb.episodes_lengths) + empty!(eb.sampleable_inds) + empty!(eb.step_numbers) end