diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 0bd1d73..4f54f7a 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -50,126 +50,171 @@ function CollectionIndexer( end """ - _sample_pids(indexer::CollectionIndexer) + encode_passages( + config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) + +Encode a list of passages using `checkpoint`. + +The given `passages` are run through the underlying BERT model and the linear layer to +generate the embeddings, after doing relevant document-specific preprocessing. +See [`docFromText`](@ref) for more details. + +# Arguments + + - `config`: The [`ColBERTConfig`](@ref) to be used. + - `checkpoint`: The [`Checkpoint`](@ref) used to encode the passages. + - `passages`: A list of strings representing the passages to be encoded. + +# Returns + +A tuple `embs, doclens` where: + + - `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`, + where `D` is the embedding dimension and `N` is the total number of embeddings across all the passages. + - `doclens::AbstractVector{Int}`: A vector of document lengths for each passage, + i.e the total number of attended tokens for each document passage. +""" +function encode_passages( + config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) + @info "Encoding $(length(passages)) passages." + + if length(passages) == 0 + error("The list of passages to encode is empty!") + end + + embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}() + # batching here to avoid storing intermediate embeddings on GPU + # batching also occurs inside docFromText to do batch packing optimizations + for passages_batch in batch(passages, config.index_bsize * 50) + embs_, doclens_ = docFromText(config, checkpoint, passages_batch, + config.index_bsize) + push!(embs, embs_) + append!(doclens, vec(doclens_)) + end + embs = cat(embs..., dims = 2) + embs, doclens +end + +""" + _sample_pids(num_documents::Int) Sample PIDs from the collection to be used to compute clusters using a ``k``-means clustering algorithm. # Arguments - - `indexer`: The collection indexer object containing the collection of passages to be indexed. + - `num_documents`: The total number of documents in the collection. It is assumed that each document has an ID + (aka PID) in the range of integers between `1` and `num_documents` (both inclusive). # Returns A `Set` of `Int`s containing the sampled PIDs. """ -function _sample_pids(indexer::CollectionIndexer) - num_passages = length(indexer.config.collection.data) +function _sample_pids(num_documents::Int) typical_doclen = 120 - num_sampled_pids = 16 * sqrt(typical_doclen * num_passages) - num_sampled_pids = Int(min(1 + floor(num_sampled_pids), num_passages)) - - sampled_pids = Set(sample(1:num_passages, num_sampled_pids)) + num_sampled_pids = 16 * sqrt(typical_doclen * num_documents) + num_sampled_pids = Int(min(1 + floor(num_sampled_pids), num_documents)) + sampled_pids = Set(sample(1:num_documents, num_sampled_pids)) @info "# of sampled PIDs = $(length(sampled_pids))" sampled_pids end """ - _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) - -Compute embeddings for the PIDs sampled by [`_sample_pids`](@ref), compute the average document length using the embeddings, and save the sampled embeddings to disk. + _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, + collection::Vector{String}, sampled_pids::Set{Int}) -The embeddings for the sampled documents are saved in a file named `sample.jld2` with it's path specified by the indexing directory. This embedding array has shape `(D, N)`, where `D` is the embedding dimension (`128`, after applying the linear layer of the ColBERT model) and `N` is the total number of embeddings over all documents. +Compute embeddings for the PIDs sampled by [`_sample_pids`](@ref). -Sample the passages with `pid` in `sampled_pids` from the `collection` and compute the average passage length. The function returns a tuple containing the embedded passages and the average passage length. +The embeddings for the sampled documents are saved in a file named `sample.jld2` with it's path +specified by the indexing directory. This embedding array has shape `(D, N)`, where `D` is the +embedding dimension (`128`, after applying the linear layer of the ColBERT model) and `N` is the +total number of embeddings over all documents. # Arguments - - `indexer`: An instance of `CollectionIndexer`. + - `config`: The [`ColBERTConfig`](@ref) to be used. + - `checkpoint`: The [`Checkpoint`] used to encode the passages. + - `collection`: The underlying collection of passages to get the samples from. - `sampled_pids`: Set of PIDs sampled by [`_sample_pids`](@ref). # Returns The average document length (i.e number of attended tokens) computed from the sampled documents. """ -function _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) +function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, + collection::Vector{String}, sampled_pids::Set{Int}) # collect all passages with pids in sampled_pids - collection = indexer.config.collection sorted_sampled_pids = sort(collect(sampled_pids)) - local_sample = collection.data[sorted_sampled_pids] + local_sample = collection[sorted_sampled_pids] - local_sample_embs, local_sample_doclens = encode_passages(indexer.encoder, local_sample) + # get the local sample embeddings + local_sample_embs, local_sample_doclens = encode_passages( + config, checkpoint, local_sample) @debug "Local sample embeddings shape: $(size(local_sample_embs)), \t Local sample doclens: $(local_sample_doclens)" @assert size(local_sample_embs)[2]==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))" + @assert length(local_sample) == length(local_sample_doclens) - indexer.num_sample_embs = size(local_sample_embs)[2] - indexer.avg_doclen_est = length(local_sample_doclens) > 0 ? - sum(local_sample_doclens) / length(local_sample_doclens) : 0 + num_sample_embs = size(local_sample_embs)[2] + avg_doclen_est = length(local_sample_doclens) > 0 ? + sum(local_sample_doclens) / length(local_sample_doclens) : 0 - sample_path = joinpath(indexer.config.index_path, "sample.jld2") - @info "avg_doclen_est = $(indexer.avg_doclen_est) \t length(local_sample) = $(length(local_sample))" + sample_path = joinpath(config.index_path, "sample.jld2") + @info "avg_doclen_est = $(avg_doclen_est) \t length(local_sample) = $(length(local_sample))" @info "Saving sampled embeddings to $(sample_path)." JLD2.save(sample_path, Dict("local_sample_embs" => local_sample_embs)) - indexer.avg_doclen_est + avg_doclen_est end """ - _save_plan(indexer::CollectionIndexer) + setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) -Save the indexing plan to a JSON file. +Initialize the index by computing some indexing-specific estimates and save the indexing plan to disk. -Information about the number of chunks, number of clusters, estimated number of embeddings over all documents and the estimated average document length is saved to a file named `plan.json`, with directory specified by the indexing directory. +The number of chunks into which the document embeddings will be stored is simply computed using the +number of documents and the size of a chunk. A bunch of pids used for initializing the centroids for +the embedding clusters are sampled using the [`_sample_pids`](@ref) and [`_sample_embeddings`](@ref) +functions, and these samples are used to calculate the average document lengths and the estimated number +of embeddings which will be computed across all documents. Finally, the number of clusters to be used +for indexing is computed, and is proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``, +and the indexing plan is saved to `plan.json`, with the path being specified by the indexing directory. # Arguments - - `indexer`: The `CollectionIndexer` object that contains the index plan to be saved. + - `config`: The [`ColBERTConfig`](@ref) being used to set up the indexing. + - `checkpoint`: The [`Checkpoint`](@ref) used to compute embeddings. + - `collection`: The underlying collection of passages to initialize the index for. """ -function _save_plan(indexer::CollectionIndexer) - @info "Saving the index plan to $(indexer.plan_path)." - # TODO: export the config as json as well - open(indexer.plan_path, "w") do io +function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) + chunksize = min(25000, 1 + fld(length(collection), config.nranks)) + num_chunks = cld(length(collection), chunksize) + + # sample passages for training centroids later + sampled_pids = _sample_pids(length(collection)) + avg_doclen_est = _sample_embeddings(config, checkpoint, collection, sampled_pids) + + # computing the number of partitions, i.e clusters + num_passages = length(collection) + num_embeddings_est = num_passages * avg_doclen_est + num_partitions = Int(floor(2^(floor(log2(16 * sqrt(num_embeddings_est)))))) + + @info "Creating $(num_partitions) clusters." + @info "Estimated $(num_embeddings_est) embeddings." + + @info "Saving the index plan to $(joinpath(config.index_path, "plan.json"))." + open(joinpath(config.index_path, "plan.json"), "w") do io JSON.print(io, Dict( - "num_chunks" => indexer.num_chunks, - "num_partitions" => indexer.num_partitions, - "num_embeddings_est" => indexer.num_embeddings_est, - "avg_doclen_est" => indexer.avg_doclen_est + "num_chunks" => num_chunks, + "num_partitions" => num_partitions, + "num_embeddings_est" => num_embeddings_est, + "avg_doclen_est" => avg_doclen_est ), 4 # indent ) end -end - -""" - setup(indexer::CollectionIndexer) - -Initialize `indexer` by computing some indexing-specific estimates and save the indexing plan to disk. - -The number of chunks into which the document embeddings will be stored (`indexer.num_chunks`) is simply computed using the number of documents and the size of a chunk obtained from [`get_chunksize`](@ref). A bunch of pids used for initializing the centroids for the embedding clusters are sampled using the [`_sample_pids`](@ref) and [`_sample_embeddings`](@ref) functions, and these samples are used to calculate the average document lengths and the estimated number of embeddings which will be computed across all documents. Finally, the number of clusters (`indexer.num_partitions`) to be used for indexing is computed, and is proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``, and the indexing plan is saved to `plan.json` (see [`_save_plan`](@ref)) in the indexing directory. - -# Arguments - - - `indexer::CollectionIndexer`: The indexer to be initialized. -""" -function setup(indexer::CollectionIndexer) - collection = indexer.config.collection - indexer.num_chunks = Int(ceil(length(collection.data) / get_chunksize( - collection, indexer.config.nranks))) - - # sample passages for training centroids later - sampled_pids = _sample_pids(indexer) - avg_doclen_est = _sample_embeddings(indexer, sampled_pids) - - # computing the number of partitions, i.e clusters - num_passages = length(indexer.config.collection.data) - indexer.num_embeddings_est = num_passages * avg_doclen_est - indexer.num_partitions = Int(floor(2^(floor(log2(16 * - sqrt(indexer.num_embeddings_est)))))) - - @info "Creating $(indexer.num_partitions) clusters." - @info "Estimated $(indexer.num_embeddings_est) embeddings." - _save_plan(indexer) + @info "Saving the config to the indexing path." + ColBERT.save(config) end """