Skip to content

Commit

Permalink
rework dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 28, 2025
1 parent 69e0401 commit 56f29dc
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 131 deletions.
10 changes: 4 additions & 6 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ export mapobs,
shuffleobs

include("batchview.jl")
export batchsize,
BatchView
export batchsize, BatchView

include("obsview.jl")
export obsview, ObsView

include("dataloader.jl")
export eachobs, DataLoader
Expand All @@ -48,10 +50,6 @@ include("folds.jl")
export kfolds,
leavepout

include("obsview.jl")
export obsview,
ObsView

include("randobs.jl")
export randobs

Expand Down
204 changes: 96 additions & 108 deletions src/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ julia> first(DataLoader(["a", "b", "c", "d"], batchsize=2, collate=collate_fn))
"ab"
```
"""
struct DataLoader{T,B,C,R<:AbstractRNG}
struct DataLoader{T<:Union{ObsView,BatchView},B,C,R<:AbstractRNG}
data::T
batchsize::Int
buffer::B # boolean, or external buffer
Expand All @@ -157,74 +157,72 @@ function DataLoader(
collate = Val(nothing),
rng::AbstractRNG = Random.default_rng())

if !(buffer isa Bool) && parallel
throw(ArgumentError("If `parallel=true`, `buffer` must be a boolean."))
end

if collate isa Bool || collate === nothing
collate = Val(collate)
end
return DataLoader(data, batchsize, buffer, partial, shuffle, parallel, collate, rng)
end

function Base.iterate(d::DataLoader)
# TODO move ObsView and BatchWView wrapping to the constructor, so that
# we can parametrize the DataLoader with ObsView and BatchView and define specialized methods.

# Wrapping with ObsView in order to work around
# issue https://github.com/FluxML/Flux.jl/issues/1935
data = ObsView(d.data)
# issue https://github.com/FluxML/Flux.jl/issues/1935
data = ObsView(data)
if batchsize > 0
data = BatchView(data; batchsize, partial, collate)
end

if buffer == true
buffer = _create_buffer(data)
end
# for buffer == false and external buffer, we keep as is

data = d.shuffle ? shuffleobs(d.rng, data) : data
data = d.batchsize > 0 ? BatchView(data; d.batchsize, d.partial, d.collate) : data
return DataLoader(data, batchsize, buffer, partial, shuffle, parallel, collate, rng)
end

function Base.iterate(d::DataLoader)
data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data
if d.parallel
iter = eachobsparallel(data; d.buffer)
else
if d.buffer == false
iter = (getobs(data, i) for i in 1:numobs(data))
elseif d.buffer == true
buf = create_buffer(data)
iter = (getobs!(buf, data, i) for i in 1:numobs(data))
else # external buffer
buf = d.buffer
iter = (getobs!(buf, data, i) for i in 1:numobs(data))
else
iter = (getobs!(d.buffer, data, i) for i in 1:numobs(data))
end
end
obs, state = iterate(iter)
return obs, (iter, state)
end

create_buffer(x) = getobs(x, 1)
function create_buffer(x::BatchView)
obsindices = _batchrange(x, 1)
return [getobs(A.data, idx) for idx in enumerate(obsindices)]
end
function create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData}
obsindices = _batchrange(x, 1)
return getobs(x.data, obsindices)
end

function Base.iterate(::DataLoader, (iter, state))
ret = iterate(iter, state)
isnothing(ret) && return
obs, state = ret
return obs, (iter, state)
end

# recursively unwraps ObsView and BatchView
_unwrapdata(data::BatchView) = _unwrapdata(data.data)
_unwrapdata(data::ObsView) = _unwrapdata(data.data)
_unwrapdata(data) = data

function Base.length(d::DataLoader)
if d.batchsize > 0
return numobs(BatchView(d.data; d.batchsize, d.partial))
else
return numobs(d.data)
end
end
_shuffledata(rng, data::ObsView) = shuffleobs(rng, data)

_shuffledata(rng, data::BatchView) =
BatchView(shuffleobs(rng, data.data); data.batchsize, data.partial, data.collate)

Base.size(e::DataLoader) = (length(e),)
_create_buffer(x) = getobs(x, 1)

function _create_buffer(x::BatchView)
obsindices = _batchrange(x, 1)
return [getobs(A.data, idx) for idx in enumerate(obsindices)]
end

function _create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData}
obsindices = _batchrange(x, 1)
return getobs(x.data, obsindices)
end

Base.IteratorEltype(::DataLoader) = Base.EltypeUnknown()
Base.length(d::DataLoader) = numobs(d.data)
Base.size(d::DataLoader) = (length(d),)
Base.IteratorEltype(d::DataLoader) = Base.EltypeUnknown()

## This causes error in some cases of `collect(loader)`
# function Base.eltype(e::DataLoader)
Expand Down Expand Up @@ -288,100 +286,90 @@ function mapobs(f, d::DataLoader)
collate = f d.collate
end

DataLoader(d.data,
batchsize=d.batchsize,
buffer=d.buffer,
partial=d.partial,
shuffle=d.shuffle,
parallel=d.parallel,
collate=collate,
rng=d.rng)
return DataLoader(_unwrapdata(d.data);
batchsize=d.batchsize,
buffer=d.buffer,
partial=d.partial,
shuffle=d.shuffle,
parallel=d.parallel,
collate=collate,
rng=d.rng)
end


@inline function _dataloader_foldl1(rf, val, e::DataLoader, data)
if e.shuffle
_dataloader_foldl2(rf, val, e, shuffleobs(e.rng, data))
else
_dataloader_foldl2(rf, val, e, data)
end
# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix)))
function Base.showarg(io::IO, d::DataLoader, toplevel)
print(io, "DataLoader(")
Base.showarg(io, _unwrapdata(d.data), false)
d.buffer == false || print(io, ", buffer=", d.buffer)
d.parallel == false || print(io, ", parallel=", d.parallel)
d.shuffle == false || print(io, ", shuffle=", d.shuffle)
d.batchsize == 1 || print(io, ", batchsize=", d.batchsize)
d.partial == true || print(io, ", partial=", d.partial)
d.collate === Val(nothing) || print(io, ", collate=", d.collate)
d.rng == Random.default_rng() || print(io, ", rng=", d.rng)
print(io, ")")
end

@inline function _dataloader_foldl2(rf, val, e::DataLoader, data)
if e.batchsize > 0
_dataloader_foldl3(rf, val, e, BatchView(data; e.batchsize, e.partial))
Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false)

function Base.show(io::IO, m::MIME"text/plain", d::DataLoader)
print(io, length(d), "-element ")
Base.showarg(io, d, false)
print(io, "\n with first element:")
print(io, "\n ", _expanded_summary(first(d)))
end

_expanded_summary(x) = summary(x)
function _expanded_summary(xs::Tuple)
parts = [_expanded_summary(x) for x in xs]
"(" * join(parts, ", ") * ",)"
end
function _expanded_summary(xs::NamedTuple)
parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)]
"(; " * join(parts, ", ") * ")"
end


### TRANSDUCERS IMPLEMENTATION #############################


@inline function _dataloader_foldl1(rf, val, d::DataLoader, data)
if d.shuffle
return _dataloader_foldl2(rf, val, d, _shuffledata(d.rng, data))
else
_dataloader_foldl3(rf, val, e, data)
return _dataloader_foldl2(rf, val, d, data)
end
end

@inline function _dataloader_foldl3(rf, val, e::DataLoader, data)
if e.buffer > 0
_dataloader_foldl4_buffered(rf, val, data)
@inline function _dataloader_foldl2(rf, val, d::DataLoader, data)
if d.buffer == false
return _dataloader_foldl3(rf, val, data)
else
_dataloader_foldl4(rf, val, data)
return _dataloader_foldl3_buffered(rf, val, data, d.buffer)
end
end

@inline function _dataloader_foldl4(rf, val, data)
@inline function _dataloader_foldl3(rf, val, data)
for i in 1:numobs(data)
@inbounds x = getobs(data, i)
# TODO: in 1.8 we could @inline this at the callsite,
# optimizer seems to be very sensitive to inlining and
# quite brittle in its capacity to keep this type stable
val = Transducers.@next(rf, val, x)
end
Transducers.complete(rf, val)
return Transducers.complete(rf, val)
end

@inline function _dataloader_foldl4_buffered(rf, val, data)
buf = getobs(data, 1)
@inline function _dataloader_foldl3_buffered(rf, val, data, buf)
for i in 1:numobs(data)
@inbounds x = getobs!(buf, data, i)
val = Transducers.@next(rf, val, x)
end
Transducers.complete(rf, val)
return Transducers.complete(rf, val)
end

@inline function Transducers.__foldl__(rf, val, e::DataLoader)
e.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads"))
_dataloader_foldl1(rf, val, e, ObsView(e.data))
end

# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix)))
function Base.showarg(io::IO, e::DataLoader, toplevel)
print(io, "DataLoader(")
Base.showarg(io, e.data, false)
e.buffer == false || print(io, ", buffer=", e.buffer)
e.parallel == false || print(io, ", parallel=", e.parallel)
e.shuffle == false || print(io, ", shuffle=", e.shuffle)
e.batchsize == 1 || print(io, ", batchsize=", e.batchsize)
e.partial == true || print(io, ", partial=", e.partial)
e.collate === Val(nothing) || print(io, ", collate=", e.collate)
e.rng == Random.default_rng() || print(io, ", rng=", e.rng)
print(io, ")")
@inline function Transducers.__foldl__(rf, val, d::DataLoader)
d.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads"))
return _dataloader_foldl1(rf, val, d, d.data)
end

Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false)

function Base.show(io::IO, m::MIME"text/plain", e::DataLoader)
if Base.haslength(e)
print(io, length(e), "-element ")
else
print(io, "Unknown-length ")
end
Base.showarg(io, e, false)
print(io, "\n with first element:")
print(io, "\n ", _expanded_summary(first(e)))
end

_expanded_summary(x) = summary(x)
function _expanded_summary(xs::Tuple)
parts = [_expanded_summary(x) for x in xs]
"(" * join(parts, ", ") * ",)"
end
function _expanded_summary(xs::NamedTuple)
parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)]
"(; " * join(parts, ", ") * ")"
end

29 changes: 17 additions & 12 deletions src/obstransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,28 @@ joinobs(datas...) = JoinedData(datas)
"""
shuffleobs([rng], data)
Return a "subset" of `data` that spans all observations, but
has the order of the observations shuffled.
Return a version of the dataset `data` that contains all the
origin observations in a random reordering.
The values of `data` itself are not copied. Instead only the
indices are shuffled. This function calls [`obsview`](@ref) to
accomplish that, which means that the return value is likely of a
different type than `data`.
Optionally, a random number generator `rng` can be passed as the
first argument.
The optional parameter `rng` allows one to specify the
random number generator used for shuffling. This is useful when
reproducible results are desired.
For this function to work, the type of `data` must implement
[`numobs`](@ref) and [`getobs`](@ref).
See also [`obsview`](@ref).
# Examples
```julia
# For Arrays the subset will be of type SubArray
@assert typeof(shuffleobs(rand(4,10))) <: SubArray
Expand All @@ -216,18 +230,9 @@ for x in eachobs(shuffleobs(X))
...
end
```
The optional parameter `rng` allows one to specify the
random number generator used for shuffling. This is useful when
reproducible results are desired. By default, uses the global RNG.
See `Random` in Julia's standard library for more info.
For this function to work, the type of `data` must implement
[`numobs`](@ref) and [`getobs`](@ref). See [`ObsView`](@ref)
for more information.
"""
shuffleobs(data) = shuffleobs(Random.default_rng(), data)

function shuffleobs(rng::AbstractRNG, data)
obsview(data, randperm(rng, numobs(data)))
return obsview(data, randperm(rng, numobs(data)))
end
10 changes: 5 additions & 5 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@
function eachobsparallel(
data;
executor::Executor = _default_executor(),
buffer::Bool = false,
buffer = false,
channelsize = Threads.nthreads())
if buffer
return _eachobsparallel_buffered(data, executor; channelsize)
else
if buffer == false
return _eachobsparallel_unbuffered(data, executor; channelsize)
else
return _eachobsparallel_buffered(buffer, data, executor; channelsize)
end
end


function _eachobsparallel_buffered(
buffer,
data,
executor;
channelsize=Threads.nthreads())
buffer = getobs(data, 1)
buffers = [buffer]
foreach(_ -> push!(buffers, deepcopy(buffer)), 1:channelsize)

Expand Down

0 comments on commit 56f29dc

Please sign in to comment.