Skip to content

Commit

Permalink
Merge pull request #22 from JuliaReinforcementLearning/jpsl/patch
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiahpslewis authored Mar 27, 2024
2 parents 020a842 + 2faa73d commit 67b944b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
name = "CircularArrayBuffers"
uuid = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
authors = ["Jun Tian <[email protected]> and contributors"]
version = "0.1.13"
version = "0.1.14"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[compat]
Adapt = "2, 3, 4"
julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CUDA", "Test"]
38 changes: 28 additions & 10 deletions src/CircularArrayBuffers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,42 @@ capacity(cb::CircularArrayBuffer{T,N}) where {T,N} = size(cb.buffer, N)
isfull(cb::CircularArrayBuffer) = cb.nframes == capacity(cb)
Base.isempty(cb::CircularArrayBuffer) = cb.nframes == 0

"""
_buffer_index(cb::CircularArrayBuffer, i::Int)
Return the index of the `i`-th element in the buffer.
"""
@inline function _buffer_index(cb::CircularArrayBuffer, i::Int)
ind = (cb.first - 1) * cb.step_size + i
if ind > length(cb.buffer)
ind - length(cb.buffer)
idx = (cb.first - 1) * cb.step_size + i
return wrap_index(idx, length(cb.buffer))
end
@inline _buffer_index(cb::CircularArrayBuffer, I::AbstractVector{<:Integer}) = map(Base.Fix1(_buffer_index, cb), I)

"""
wrap_index(idx, n)
Return the index of the `idx`-th element in the buffer, if index is one past the size, return 1, else error.
"""
function wrap_index(idx, n)
if idx <= n
return idx
elseif idx <= 2n
return idx - n
else
ind
@info "oops! idx $(idx) > 2n $(2n)"
return idx - n
end
end
@inline _buffer_index(cb::CircularArrayBuffer, I::AbstractVector{<:Integer}) = map(Base.Fix1(_buffer_index, cb), I)

"""
_buffer_frame(cb::CircularArrayBuffer, i::Int)
Return the index of the `i`-th frame in the buffer.
"""
@inline function _buffer_frame(cb::CircularArrayBuffer, i::Int)
n = capacity(cb)
idx = cb.first + i - 1
if idx > n
idx - n
else
idx
end
return wrap_index(idx, n)
end

_buffer_frame(cb::CircularArrayBuffer, I::CartesianIndex) = CartesianIndex(map(i->_buffer_frame(cb, i), Tuple(I)))
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ if CUDA.functional()
@test isempty(b) == true
@test length(b) == 0
@test size(b) == (0,)
# element must has the exact same length with the element of buffer
# element must have the exact same length with the element of buffer
@test_throws Exception push!(b, [1, 2])

for x in 1:3
Expand Down

0 comments on commit 67b944b

Please sign in to comment.