Skip to content

Commit

Permalink
Merge pull request #24 from JuliaReinforcementLearning/jpsl/tweaks
Browse files Browse the repository at this point in the history
Add Feedback and expand test
  • Loading branch information
jeremiahpslewis authored Mar 29, 2024
2 parents c473aef + da410ca commit 3c9db4f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.1.14"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[compat]
Adapt = "2, 3, 4"
Expand Down
42 changes: 24 additions & 18 deletions src/CircularArrayBuffers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Base.isempty(cb::CircularArrayBuffer) = cb.nframes == 0
"""
_buffer_index(cb::CircularArrayBuffer, i::Int)
Return the index of the `i`-th element in the buffer.
Return the index of the `i`-th element in the buffer. Note the `i` is assumed to be the linear indexing of `cb`.
"""
@inline function _buffer_index(cb::CircularArrayBuffer, i::Int)
idx = (cb.first - 1) * cb.step_size + i
Expand All @@ -90,25 +90,24 @@ end
"""
wrap_index(idx, n)
Return the index of the `idx`-th element in the buffer, if index is one past the size, return 1, else error.
Return the index of the `idx`-th element in the buffer, if index is one past the size, return 1, else error.
"""
function wrap_index(idx, n)
function wrap_index(idx::Int, n::Int)
if idx <= n
return idx
elseif idx <= 2n
return idx - n
else
@info "oops! idx $(idx) > 2n $(2n)"
return idx - n
return -1 # NOTE: This should never happen, due to @boundscheck
end
end

"""
_buffer_frame(cb::CircularArrayBuffer, i::Int)
Return the index of the `i`-th frame in the buffer.
Here `i` is assumed to be the last dimension of `cb`. Each `frame` means a slice of the last dimension. Since we use *circular frames* (the `data` buffer) underlying, this function transforms the logical `i`-th frame to the real frame of the internal buffer.
"""
@inline function _buffer_frame(cb::CircularArrayBuffer, i::Int)
@inline function _buffer_frame(cb::CircularArrayBuffer{T,N}, i::Int) where {T,N}
n = capacity(cb)
idx = cb.first + i - 1
return wrap_index(idx, n)
Expand All @@ -123,19 +122,26 @@ function Base.empty!(cb::CircularArrayBuffer)
cb
end

function Base.push!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
if cb.nframes == capacity(cb)
function _update_first_and_nframes!(cb)
if isfull(cb)
cb.first = (cb.first == capacity(cb) ? 1 : cb.first + 1)
else
cb.nframes += 1
end
if N == 1
i = _buffer_frame(cb, cb.nframes)
cb.buffer[i:i] .= Ref(data)
else
cb.buffer[ntuple(_ -> (:), N - 1)..., _buffer_frame(cb, cb.nframes)] .= data
end
cb
return cb
end

function Base.push!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
_update_first_and_nframes!(cb)
cb.buffer[ntuple(_ -> (:), N - 1)..., _buffer_frame(cb, cb.nframes)] .= data
return cb
end

function Base.push!(cb::CircularVectorBuffer{T}, data) where {T}
_update_first_and_nframes!(cb)
i = _buffer_frame(cb, cb.nframes)
cb.buffer[i:i] .= Ref(data)
return cb
end

function Base.append!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
Expand Down Expand Up @@ -180,7 +186,7 @@ function Base.pop!(cb::CircularArrayBuffer{T,N}) where {T,N}
else
res = @views cb.buffer[ntuple(_ -> (:), N - 1)..., _buffer_frame(cb, cb.nframes)]
cb.nframes -= 1
res
return res
end
end

Expand All @@ -194,7 +200,7 @@ function Base.popfirst!(cb::CircularArrayBuffer{T,N}) where {T,N}
if cb.first > capacity(cb)
cb.first = 1
end
res
return res
end
end

Expand Down
12 changes: 10 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,21 @@ CUDA.allowscalar(false)
@test_throws BoundsError @view b[:, 9]
end

@testset "Bounds error for zero-length buffer" begin
@testset "Bounds error for zero-length / underfilled buffer" begin
b = CircularVectorBuffer{Bool}(10)
@test_throws BoundsError b[1]
@test_throws BoundsError b[end]
for i in 1:5

push!(b, true)
@test b[1] == true
@test b[end] == true
@test_throws BoundsError b[2]
for i in 1:15
push!(b, true)
end
@test b[end] == true
@test b[10] == true
@test_throws BoundsError b[15]
end

@testset "1D vector" begin
Expand Down

0 comments on commit 3c9db4f

Please sign in to comment.