diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 0a6339b..e90b9be 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -149,11 +149,13 @@ function train(indexer::CollectionIndexer) end function index(indexer::CollectionIndexer) + load_codec!(indexer.saver) # load the codec objects batches = enumerate_batches(indexer.config.resource_settings.collection, nranks = indexer.config.run_settings.nranks) for (chunk_idx, offset, passages) in batches # TODO: add functionality to not re-write chunks if they already exist! # TODO: add multiprocessing to this step! embs, doclens = encode_passages(indexer.encoder, passages) - @info "Saving chunk $(chunk_idx): \t $(length(passages)) passages and $(size(embs)[2]) embeddings. From offset #$(offset) onward." + @info "Saving chunk $(chunk_idx): \t $(length(passages)) passages and $(size(embs)[2]) embeddings. From offset #$(offset) onward." + save_chunk(indexer.saver, chunk_idx, offset, embs, doclens) end end diff --git a/src/indexing/index_saver.jl b/src/indexing/index_saver.jl index c08b994..7d8d944 100644 --- a/src/indexing/index_saver.jl +++ b/src/indexing/index_saver.jl @@ -31,6 +31,33 @@ function save_codec(saver::IndexSaver) ) end -# function save_chunk(saver::IndexSaver, chunk_idx::Int, offset::Int, embs::Matrix{Float64}, doclens::Vector{Int}) -# compressed_embs = compress -# end +function save_chunk(saver::IndexSaver, chunk_idx::Int, offset::Int, embs::Matrix{Float64}, doclens::Vector{Int}) + codes, residuals = compress(saver.codec, embs) + path_prefix = joinpath(saver.config.indexing_settings.index_path, string(chunk_idx)) + + # saving the compressed embeddings + codes_path = "$(path_prefix).codes.jld2" + residuals_path = "$(path_prefix).residuals.jld2" + @info "Saving compressed codes to $(codes_path) and residuals to $(residuals_path)" + save(codes_path, Dict("codes" => codes)) + save(residuals_path, Dict("residuals" => residuals)) + + # saving doclens + doclens_path = joinpath(saver.config.indexing_settings.index_path, "doclens.$(chunk_idx).jld2") + @info "Saving doclens to $(doclens_path)" + save(doclens_path, Dict("doclens" => doclens)) + + # the metadata + metadata_path = joinpath(saver.config.indexing_settings.index_path, "$(chunk_idx).metadata.json") + @info "Saving metadata to $(metadata_path)" + open(metadata_path, "w") do io + JSON.print(io, + Dict( + "passage_offset" => offset, + "num_passages" => length(doclens), + "num_embeddings" => length(codes), + ), + 4 # indent + ) + end +end