From ccaa17a7b9358605afc49b9ade9f309e5b620561 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 26 Aug 2024 18:01:11 +0530 Subject: [PATCH] Refactoring the `setup` code. --- src/indexing/collection_indexer.jl | 66 ++++++++++++++---------------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 2dd3424..1de5537 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -47,26 +47,43 @@ from the sampled documents, and the embedding matrix for the local samples. The shape `(D, N)`, where `D` is the embedding dimension (`128`) and `N` is the total number of embeddings over all the sampled passages. """ -function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, - collection::Vector{String}, sampled_pids::Set{Int}) +function _sample_embeddings(bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + dim::Int, index_bsize::Int, doc_token::String, + skiplist::Vector{Int}, collection::Vector{String}) # collect all passages with pids in sampled_pids + sampled_pids = _sample_pids(length(collection)) sorted_sampled_pids = sort(collect(sampled_pids)) local_sample = collection[sorted_sampled_pids] # get the local sample embeddings local_sample_embs, local_sample_doclens = encode_passages( - config, checkpoint, local_sample) - @debug "Local sample embeddings shape: $(size(local_sample_embs)), \t Local sample doclens: $(local_sample_doclens)" - @assert size(local_sample_embs)[2]==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))" + bert, linear, tokenizer, local_sample, + dim, index_bsize, doc_token, skiplist) + + @assert size(local_sample_embs, 2)==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))" @assert length(local_sample) == length(local_sample_doclens) avg_doclen_est = length(local_sample_doclens) > 0 ? - sum(local_sample_doclens) / length(local_sample_doclens) : - 0 + Float32(sum(local_sample_doclens) / + length(local_sample_doclens)) : + zero(Float32) @info "avg_doclen_est = $(avg_doclen_est) \t length(local_sample) = $(length(local_sample))" avg_doclen_est, local_sample_embs end +function _heldout_split( + sample::AbstractMatrix{Float32}; heldout_fraction::Float32 = 0.05f0) + num_sample_embs = size(sample, 2) + sample = sample[:, shuffle(1:num_sample_embs)] + heldout_size = Int(max( + 1, floor(min(50000, heldout_fraction * num_sample_embs)))) + sample, sample_heldout = sample[ + :, 1:(num_sample_embs - heldout_size)], + sample[:, (num_sample_embs - heldout_size + 1):num_sample_embs] + sample, sample_heldout +end + """ setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) @@ -91,45 +108,22 @@ proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``. A `Dict` containing the indexing plan. """ -function setup(config::ColBERTConfig, checkpoint::Checkpoint, - collection::Vector{String}) - chunksize = 0 - chunksize = ismissing(config.chunksize) ? - min(25000, 1 + fld(length(collection), config.nranks)) : - config.chunksize +function setup(collection::Vector{String}, avg_doclen_est::Float32, + num_clustering_embs::Int, chunksize::Union{Missing, Int}, nranks::Int) + chunksize = ismissing(chunksize) ? + min(25000, 1 + fld(length(collection), nranks)) : + chunksize num_chunks = cld(length(collection), chunksize) - # sample passages for training centroids later - sampled_pids = _sample_pids(length(collection)) - avg_doclen_est, local_sample_embs = _sample_embeddings( - config, checkpoint, collection, sampled_pids) - - # splitting the local sample into heldout set - num_local_sample_embs = size(local_sample_embs, 2) - local_sample_embs = local_sample_embs[ - :, shuffle(1:num_local_sample_embs)] - - # split the sample to get a heldout set - heldout_fraction = 0.05 - heldout_size = Int(max( - 1, floor(min( - 50000, heldout_fraction * num_local_sample_embs)))) - sample, sample_heldout = local_sample_embs[ - :, 1:(num_local_sample_embs - heldout_size)], - local_sample_embs[ - :, (num_local_sample_embs - heldout_size + 1):num_local_sample_embs] - @debug "Split sample sizes: sample size: $(size(sample)), \t sample_heldout size: $(size(sample_heldout))" - # computing the number of partitions, i.e clusters num_passages = length(collection) num_embeddings_est = num_passages * avg_doclen_est - num_partitions = Int(min(size(sample, 2), + num_partitions = Int(min(num_clustering_embs, floor(2^(floor(log2(16 * sqrt(num_embeddings_est))))))) @info "Creating $(num_partitions) clusters." @info "Estimated $(num_embeddings_est) embeddings." - sample, sample_heldout, Dict( "chunksize" => chunksize, "num_chunks" => num_chunks,