Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

properly use or create buffer in DataLoader #191

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,3 @@ StatsBase = "0.33, 0.34"
Tables = "1.10"
Transducers = "0.4"
julia = "1.6"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ChainRulesTestUtils", "CUDA", "DataFrames", "SparseArrays", "Test", "Zygote"]
2 changes: 1 addition & 1 deletion src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ include("batchview.jl")
export batchsize,
BatchView

include("eachobs.jl")
include("dataloader.jl")
export eachobs, DataLoader

include("parallel.jl")
Expand Down
28 changes: 21 additions & 7 deletions src/batchview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,31 @@ Base.@propagate_inbounds function Base.getindex(A::BatchView, is::AbstractVector
return _getbatch(A, obsindices)
end

function _getbatch(A::BatchView{TElem, TData, TCollate}, obsindices) where {TElem, TData, TCollate}
return A.collate([getobs(A.data, i) for i in obsindices])
function getobs!(buffer, A::BatchView, i::Int)
obsindices = _batchrange(A, i)
return _getbatch!(buffer, A, obsindices)
end

function _getbatch(A::BatchView{TElem,TData,TCollate}, obsindices) where {TElem,TData,TCollate}
return A.collate([getobs(A.data, idx) for idx in obsindices])
end
function _getbatch(A::BatchView{TElem, TData, Val{false}}, obsindices) where {TElem, TData}
return [getobs(A.data, i) for i in obsindices]
function _getbatch!(buffer, A::BatchView{TElem,TData,TCollate}, obsindices) where {TElem,TData,TCollate}
return A.collate([getobs!(buffer[i], A.data, idx) for (i,idx) in enumerate(obsindices)])
end

function _getbatch(A::BatchView{TElem,TData,Val{false}}, obsindices) where {TElem,TData}
return [getobs(A.data, idx) for idx in obsindices]
end
function _getbatch(A::BatchView{TElem, TData, Val{nothing}}, obsindices) where {TElem, TData}
function _getbatch!(buffer, A::BatchView{TElem,TData,Val{false}}, obsindices) where {TElem,TData}
return [getobs!(buffer[i], A.data, idx) for (i,idx) in enumerate(obsindices)]
end

function _getbatch(A::BatchView{TElem,TData,Val{nothing}}, obsindices) where {TElem,TData}
return getobs(A.data, obsindices)
end
function _getbatch!(buffer, A::BatchView{TElem,TData,Val{nothing}}, obsindices) where {TElem,TData}
return getobs!(buffer, A.data, obsindices)
end

Base.parent(A::BatchView) = A.data
Base.eltype(::BatchView{Tel}) where Tel = Tel
Expand All @@ -196,5 +212,3 @@ function Base.showarg(io::IO, A::BatchView, toplevel)
print(io, ')')
toplevel && print(io, " with eltype ", nameof(eltype(A))) # simplify
end

# --------------------------------------------------------------------
18 changes: 14 additions & 4 deletions src/eachobs.jl → src/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ The original data is preserved in the `data` field of the DataLoader.
- **`buffer`**: If `buffer=true` and supported by the type of `data`,
a buffer will be allocated and reused for memory efficiency.
May want to set `partial=false` to avoid size mismatch.
Finally, can pass an external buffer to be used in `getobs!(buffer, data, idx)`.
Finally, can pass an external buffer to be used in `getobs!`
(depending on the `collate` and `batchsize` options, could be `getobs!(buffer, data, idxs)` or `getobs!(buffer[i], data, idx)`).
Default `false`.
- **`collate`**: Defines the batching behavior. Default `nothing`.
- If `nothing` , a batch is `getobs(data, indices)`.
- If `false`, each batch is `[getobs(data, i) for i in indices]`.
- If `true`, applies MLUtils to the vector of observations in a batch,
- If `true`, applies `MLUtils.batch` to the vector of observations in a batch,
recursively collating arrays in the last dimensions. See [`MLUtils.batch`](@ref) for more information
and examples.
- If a custom function, it will be used in place of `MLUtils.batch`. It should take a vector of observations as input.
Expand Down Expand Up @@ -138,7 +139,7 @@ julia> first(DataLoader(["a", "b", "c", "d"], batchsize=2, collate=collate_fn))
struct DataLoader{T,B,C,R<:AbstractRNG}
data::T
batchsize::Int
buffer::B
buffer::B # boolean, or external buffer
partial::Bool
shuffle::Bool
parallel::Bool
Expand Down Expand Up @@ -183,7 +184,7 @@ function Base.iterate(d::DataLoader)
if d.buffer == false
iter = (getobs(data, i) for i in 1:numobs(data))
elseif d.buffer == true
buf = getobs(data, 1)
buf = create_buffer(data)
iter = (getobs!(buf, data, i) for i in 1:numobs(data))
else # external buffer
buf = d.buffer
Expand All @@ -194,6 +195,15 @@ function Base.iterate(d::DataLoader)
return obs, (iter, state)
end

create_buffer(x) = getobs(x, 1)
function create_buffer(x::BatchView)
obsindices = _batchrange(x, 1)
return [getobs(A.data, idx) for idx in enumerate(obsindices)]
end
function create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData}
obsindices = _batchrange(x, 1)
return getobs(x.data, obsindices)
end

function Base.iterate(::DataLoader, (iter, state))
ret = iterate(iter, state)
Expand Down
12 changes: 12 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
30 changes: 30 additions & 0 deletions test/batchview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,34 @@ using MLUtils: obsview
@test y isa String
end
end


@testset "getobs!" begin
X = rand(4, 15)
buf1 = rand(4, 3)
bv = BatchView(X, batchsize=3)
@test @inferred(getobs!(buf1, bv, 2)) === buf1
@test buf1 == getobs(bv, 2)

buf12 = [rand(4) for _=1:3]
bv12 = BatchView(X, batchsize=3, collate=false)
res = @inferred(getobs!(buf12, bv12, 2))
@test all(res .=== buf12)
@test buf12 == getobs(bv12, 2)

@testset "custom type" begin # issue #156
struct DummyData{X}
x::X
end
MLUtils.numobs(data::DummyData) = numobs(data.x)
MLUtils.getobs(data::DummyData, idx) = getobs(data.x, idx)
MLUtils.getobs!(buffer, data::DummyData, idx) = getobs!(buffer, data.x, idx)

data = DummyData(X)
buf = rand(4, 3)
bv = BatchView(data, batchsize=3)
@test @inferred(getobs!(buf, bv, 2)) === buf
@test buf == getobs(bv, 2)
end
end
end
69 changes: 69 additions & 0 deletions test/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,72 @@
end
end
end

@testset "eachobs" begin
for (i,x) in enumerate(eachobs(X))
@test x == X[:,i]
end

for (i,x) in enumerate(eachobs(X, buffer=true))
@test x == X[:,i]
end

b = zeros(size(X, 1))
for (i,x) in enumerate(eachobs(X, buffer=b))
@test x == X[:,i]
end
@test b == X[:,end]

@testset "batched" begin
for (i, x) in enumerate(eachobs(X, batchsize=2, partial=true))
if i != 8
@test size(x) == (4,2)
@test x == X[:,2i-1:2i]
else
@test size(x) == (4,1)
@test x == X[:,2i-1:2i-1]
end
end

for (i, x) in enumerate(eachobs(X, batchsize=2, buffer=true, partial=false))
@test size(x) == (4,2)
@test x == X[:,2i-1:2i]
end

b = zeros(4, 2)
for (i, x) in enumerate(eachobs(X, batchsize=2, buffer=b, partial=false))
@test size(x) == (4,2)
@test x == X[:,2i-1:2i]
end
@test b == X[:,end-2:end-1]
end

@testset "shuffled" begin
# does not reshuffle on iteration
shuffled = eachobs(shuffleobs(1:50))
@test collect(shuffled) == collect(shuffled)

# does reshuffle
reshuffled = eachobs(1:50, shuffle = true)
@test collect(reshuffled) != collect(reshuffled)

reshuffled = eachobs(1:50, shuffle = true, buffer = true)
@test collect(reshuffled) != collect(reshuffled)

reshuffled = eachobs(1:50, shuffle = true, parallel = true)
@test collect(reshuffled) != collect(reshuffled)

reshuffled = eachobs(1:50, shuffle = true, buffer = true, parallel = true)
@test collect(reshuffled) != collect(reshuffled)
end
@testset "Argument combinations" begin
for batchsize ∈ (-1, 2), buffer ∈ (true, false), collate ∈ (nothing, true, false),
parallel ∈ (true, false), shuffle ∈ (true, false), partial ∈ (true, false)
if !(buffer isa Bool) && batchsize > 0
buffer = getobs(BatchView(X; batchsize), 1)
end
iter = eachobs(X; batchsize, shuffle, buffer, parallel, partial)
@test_nowarn for _ in iter end
end
end
end
67 changes: 0 additions & 67 deletions test/eachobs.jl

This file was deleted.

3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using Transducers
using ChainRulesTestUtils: test_rrule
using Zygote: ZygoteRuleConfig
using ChainRulesCore: rrule_via_ad
using DataFrames
using DataFrames: DataFrame
using CUDA

showcompact(io, x) = show(IOContext(io, :compact => true), x)
Expand Down Expand Up @@ -90,7 +90,6 @@ include("test_utils.jl")
# @testset "MLUtils.jl" begin

@testset "batchview" begin; include("batchview.jl"); end
@testset "eachobs" begin; include("eachobs.jl"); end
@testset "dataloader" begin; include("dataloader.jl"); end
@testset "folds" begin; include("folds.jl"); end
@testset "observation" begin; include("observation.jl"); end
Expand Down
Loading