Skip to content

Commit

Permalink
make dataloaders iterations inferred (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Feb 1, 2025
1 parent 1813426 commit 2e43380
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ function DataLoader(

# Wrapping with ObsView in order to work around
# issue https://github.com/FluxML/Flux.jl/issues/1935
_data = ObsView(data)
_data = ObsView(data, collect(1:numobs(data)))
if batchsize > 0
_data = BatchView(_data; batchsize, partial, collate)
end
Expand Down
2 changes: 1 addition & 1 deletion src/obsview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ struct ObsView{Tdata, I<:Union{Int,AbstractVector}} <: AbstractDataContainer
function ObsView(data::T, indices::I) where {T,I}
1 <= minimum(indices) || throw(BoundsError(data, indices))
maximum(indices) <= numobs(data) || throw(BoundsError(data, indices))
new{T,I}(data, indices)
return new{T,I}(data, indices)
end
end

Expand Down
19 changes: 12 additions & 7 deletions test/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Y2 = [1:5;]

d = DataLoader(X2, batchsize=2)
@test_broken @inferred(first(d)) isa Array
@test @inferred(first(d)) isa Array
batches = collect(d)
@test_broken eltype(d) == typeof(X2)
@test eltype(batches) == typeof(X2)
Expand All @@ -15,15 +15,15 @@
@test batches[3] == X2[:,5:5]

d = DataLoader(X2, batchsize=2, partial=false)
# @inferred first(d)
@inferred first(d)
batches = collect(d)
@test_broken eltype(d) == typeof(X2)
@test length(batches) == 2
@test batches[1] == X2[:,1:2]
@test batches[2] == X2[:,3:4]

d = DataLoader((X2,), batchsize=2, partial=false)
# @inferred first(d)
@inferred first(d)
batches = collect(d)
@test_broken eltype(d) == Tuple{typeof(X2)}
@test eltype(batches) == Tuple{typeof(X2)}
Expand All @@ -32,7 +32,7 @@
@test batches[2] == (X2[:,3:4],)

d = DataLoader((X2, Y2), batchsize=2)
# @inferred first(d)
@inferred first(d)
batches = collect(d)
@test_broken eltype(d) == Tuple{typeof(X2), typeof(Y2)}
@test eltype(batches) == Tuple{typeof(X2), typeof(Y2)}
Expand All @@ -49,7 +49,7 @@

# test with NamedTuple
d = DataLoader((x=X2, y=Y2), batchsize=2)
# @inferred first(d)
@inferred first(d)
batches = collect(d)
@test_broken eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X2), typeof(Y2)}}
@test eltype(batches) == NamedTuple{(:x, :y), Tuple{typeof(X2), typeof(Y2)}}
Expand Down Expand Up @@ -133,7 +133,7 @@

dloader = DataLoader(data, batchsize=2)
c = collect(dloader)
@test eltype(c) == UnitRange{Int64}
@test eltype(c) == Vector{Int64}
@test c[1] == 1:2

dloader = DataLoader(data, batchsize=2, shuffle=true)
Expand All @@ -160,34 +160,39 @@
X_ = rand(10, 20)

d = DataLoader(X_, collate=false, batchsize = 2)
@inferred first(d)
for (i, x) in enumerate(d)
@test x == [getobs(X_, 2i-1), getobs(X_, 2i)]
end

d = DataLoader(X_, collate=nothing, batchsize = 2)
@inferred first(d)
for (i, x) in enumerate(d)
@test x == hcat(getobs(X_, 2i-1), getobs(X_, 2i))
end

d = DataLoader(X_, collate=true, batchsize = 2)
@inferred first(d)
for (i, x) in enumerate(d)
@test x == hcat(getobs(X_, 2i-1), getobs(X_, 2i))
end

d = DataLoader((X_, X_), collate=false, batchsize = 2)
@inferred first(d)
for (i, x) in enumerate(d)
@test x isa Vector
all((isa).(x, Tuple))
end

d = DataLoader((X_, X_), collate=true, batchsize = 2)
@inferred first(d)
for (i, x) in enumerate(d)
@test all(==(hcat(getobs(X_, 2i-1), getobs(X_, 2i))), x)
end

@testset "nothing vs. true" begin
d = CustomRangeIndex(10)
@test first(DataLoader(d, batchsize = 2, collate=nothing)) isa UnitRange
@test first(DataLoader(d, batchsize = 2, collate=nothing)) isa Vector
@test first(DataLoader(d, batchsize = 2, collate=true)) isa Vector
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct CustomRangeIndex
end
Base.length(r::CustomRangeIndex) = r.n
Base.getindex(r::CustomRangeIndex, idx::Int) = idx
Base.getindex(r::CustomRangeIndex, idxs::UnitRange) = idxs
Base.getindex(r::CustomRangeIndex, idxs::AbstractVector{Int}) = idxs

# --------------------------------------------------------------------

Expand Down

0 comments on commit 2e43380

Please sign in to comment.