From 7c6f6a0c6dc361fecfa949d2a1c01f6b4a01b541 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Thu, 30 May 2024 00:54:31 +0530 Subject: [PATCH] Implementing `enumerate_batches`, a function to enumerate batches for a `Collection`. --- src/data/collection.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/data/collection.jl b/src/data/collection.jl index 3cf7b74..dbfeec0 100644 --- a/src/data/collection.jl +++ b/src/data/collection.jl @@ -14,3 +14,27 @@ end function get_chunksize(collection::Collection, nranks::Int) min(25000, 1 + floor(length(collection.data) / nranks)) end + +function enumerate_batches(collection::Collection, chunksize::Union{Int, Missing} = missing, nranks::Union{Int, Missing} = missing) + if ismissing(chunksize) + if ismissing(nranks) + error("Atleast one of the arguments chunksize or nranks must be specified!") + end + chunksize = get_chunksize(collection, nranks) + end + + num_passages = length(collection.data) + batches = Vector{Tuple{Int, Int, Vector{String}}}() + chunk_idx, offset = 1, 1 + while true + push!(batches, (chunk_idx, offset, collection.data[offset:min(offset + chunksize - 1, num_passages)])) + chunk_idx += 1 + offset += chunksize + + if offset > num_passages + break + end + end + batches +end +