diff --git a/examples/data.jl b/examples/data.jl index 904d7dd..992f970 100644 --- a/examples/data.jl +++ b/examples/data.jl @@ -62,3 +62,4 @@ encoder = ColBERT.CollectionEncoder(config, checkPoint) indexer = CollectionIndexer(config, encoder, ColBERT.IndexSaver(config=config)) ColBERT.setup(indexer) ColBERT.train(indexer) +ColBERT.index(indexer, chunksize = 3) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index e90b9be..72942ea 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -148,9 +148,9 @@ function train(indexer::CollectionIndexer) save_codec(indexer.saver) end -function index(indexer::CollectionIndexer) +function index(indexer::CollectionIndexer; chunksize::Union{Int, Missing} = missing) load_codec!(indexer.saver) # load the codec objects - batches = enumerate_batches(indexer.config.resource_settings.collection, nranks = indexer.config.run_settings.nranks) + batches = enumerate_batches(indexer.config.resource_settings.collection, chunksize = chunksize, nranks = indexer.config.run_settings.nranks) for (chunk_idx, offset, passages) in batches # TODO: add functionality to not re-write chunks if they already exist! # TODO: add multiprocessing to this step!