Skip to content

Commit

Permalink
Simplyfying the signatures of the setup functions; only using primitive
Browse files Browse the repository at this point in the history
types and the most important types for ColBERT (i.e `Checkpoint` and
`ColBERTConfig`). This makes testing easier.
  • Loading branch information
codetalker7 committed Aug 11, 2024
1 parent 48720bd commit 20f0402
Showing 1 changed file with 112 additions and 67 deletions.
179 changes: 112 additions & 67 deletions src/indexing/collection_indexer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,126 +50,171 @@ function CollectionIndexer(
end

"""
_sample_pids(indexer::CollectionIndexer)
encode_passages(
config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String})
Encode a list of passages using `checkpoint`.
The given `passages` are run through the underlying BERT model and the linear layer to
generate the embeddings, after doing relevant document-specific preprocessing.
See [`docFromText`](@ref) for more details.
# Arguments
- `config`: The [`ColBERTConfig`](@ref) to be used.
- `checkpoint`: The [`Checkpoint`](@ref) used to encode the passages.
- `passages`: A list of strings representing the passages to be encoded.
# Returns
A tuple `embs, doclens` where:
- `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`,
where `D` is the embedding dimension and `N` is the total number of embeddings across all the passages.
- `doclens::AbstractVector{Int}`: A vector of document lengths for each passage,
i.e the total number of attended tokens for each document passage.
"""
function encode_passages(
config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String})
@info "Encoding $(length(passages)) passages."

if length(passages) == 0
error("The list of passages to encode is empty!")
end

embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}()
# batching here to avoid storing intermediate embeddings on GPU
# batching also occurs inside docFromText to do batch packing optimizations
for passages_batch in batch(passages, config.index_bsize * 50)
embs_, doclens_ = docFromText(config, checkpoint, passages_batch,
config.index_bsize)
push!(embs, embs_)
append!(doclens, vec(doclens_))
end
embs = cat(embs..., dims = 2)
embs, doclens
end

"""
_sample_pids(num_documents::Int)
Sample PIDs from the collection to be used to compute clusters using a ``k``-means clustering algorithm.
# Arguments
- `indexer`: The collection indexer object containing the collection of passages to be indexed.
- `num_documents`: The total number of documents in the collection. It is assumed that each document has an ID
(aka PID) in the range of integers between `1` and `num_documents` (both inclusive).
# Returns
A `Set` of `Int`s containing the sampled PIDs.
"""
function _sample_pids(indexer::CollectionIndexer)
num_passages = length(indexer.config.collection.data)
function _sample_pids(num_documents::Int)
typical_doclen = 120
num_sampled_pids = 16 * sqrt(typical_doclen * num_passages)
num_sampled_pids = Int(min(1 + floor(num_sampled_pids), num_passages))

sampled_pids = Set(sample(1:num_passages, num_sampled_pids))
num_sampled_pids = 16 * sqrt(typical_doclen * num_documents)
num_sampled_pids = Int(min(1 + floor(num_sampled_pids), num_documents))
sampled_pids = Set(sample(1:num_documents, num_sampled_pids))
@info "# of sampled PIDs = $(length(sampled_pids))"
sampled_pids
end

"""
_sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int})
Compute embeddings for the PIDs sampled by [`_sample_pids`](@ref), compute the average document length using the embeddings, and save the sampled embeddings to disk.
_sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint,
collection::Vector{String}, sampled_pids::Set{Int})
The embeddings for the sampled documents are saved in a file named `sample.jld2` with it's path specified by the indexing directory. This embedding array has shape `(D, N)`, where `D` is the embedding dimension (`128`, after applying the linear layer of the ColBERT model) and `N` is the total number of embeddings over all documents.
Compute embeddings for the PIDs sampled by [`_sample_pids`](@ref).
Sample the passages with `pid` in `sampled_pids` from the `collection` and compute the average passage length. The function returns a tuple containing the embedded passages and the average passage length.
The embeddings for the sampled documents are saved in a file named `sample.jld2` with it's path
specified by the indexing directory. This embedding array has shape `(D, N)`, where `D` is the
embedding dimension (`128`, after applying the linear layer of the ColBERT model) and `N` is the
total number of embeddings over all documents.
# Arguments
- `indexer`: An instance of `CollectionIndexer`.
- `config`: The [`ColBERTConfig`](@ref) to be used.
- `checkpoint`: The [`Checkpoint`] used to encode the passages.
- `collection`: The underlying collection of passages to get the samples from.
- `sampled_pids`: Set of PIDs sampled by [`_sample_pids`](@ref).
# Returns
The average document length (i.e number of attended tokens) computed from the sampled documents.
"""
function _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int})
function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint,
collection::Vector{String}, sampled_pids::Set{Int})
# collect all passages with pids in sampled_pids
collection = indexer.config.collection
sorted_sampled_pids = sort(collect(sampled_pids))
local_sample = collection.data[sorted_sampled_pids]
local_sample = collection[sorted_sampled_pids]

local_sample_embs, local_sample_doclens = encode_passages(indexer.encoder, local_sample)
# get the local sample embeddings
local_sample_embs, local_sample_doclens = encode_passages(
config, checkpoint, local_sample)
@debug "Local sample embeddings shape: $(size(local_sample_embs)), \t Local sample doclens: $(local_sample_doclens)"
@assert size(local_sample_embs)[2]==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))"
@assert length(local_sample) == length(local_sample_doclens)

indexer.num_sample_embs = size(local_sample_embs)[2]
indexer.avg_doclen_est = length(local_sample_doclens) > 0 ?
sum(local_sample_doclens) / length(local_sample_doclens) : 0
num_sample_embs = size(local_sample_embs)[2]
avg_doclen_est = length(local_sample_doclens) > 0 ?
sum(local_sample_doclens) / length(local_sample_doclens) : 0

sample_path = joinpath(indexer.config.index_path, "sample.jld2")
@info "avg_doclen_est = $(indexer.avg_doclen_est) \t length(local_sample) = $(length(local_sample))"
sample_path = joinpath(config.index_path, "sample.jld2")
@info "avg_doclen_est = $(avg_doclen_est) \t length(local_sample) = $(length(local_sample))"
@info "Saving sampled embeddings to $(sample_path)."
JLD2.save(sample_path, Dict("local_sample_embs" => local_sample_embs))

indexer.avg_doclen_est
avg_doclen_est
end

"""
_save_plan(indexer::CollectionIndexer)
setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String})
Save the indexing plan to a JSON file.
Initialize the index by computing some indexing-specific estimates and save the indexing plan to disk.
Information about the number of chunks, number of clusters, estimated number of embeddings over all documents and the estimated average document length is saved to a file named `plan.json`, with directory specified by the indexing directory.
The number of chunks into which the document embeddings will be stored is simply computed using the
number of documents and the size of a chunk. A bunch of pids used for initializing the centroids for
the embedding clusters are sampled using the [`_sample_pids`](@ref) and [`_sample_embeddings`](@ref)
functions, and these samples are used to calculate the average document lengths and the estimated number
of embeddings which will be computed across all documents. Finally, the number of clusters to be used
for indexing is computed, and is proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``,
and the indexing plan is saved to `plan.json`, with the path being specified by the indexing directory.
# Arguments
- `indexer`: The `CollectionIndexer` object that contains the index plan to be saved.
- `config`: The [`ColBERTConfig`](@ref) being used to set up the indexing.
- `checkpoint`: The [`Checkpoint`](@ref) used to compute embeddings.
- `collection`: The underlying collection of passages to initialize the index for.
"""
function _save_plan(indexer::CollectionIndexer)
@info "Saving the index plan to $(indexer.plan_path)."
# TODO: export the config as json as well
open(indexer.plan_path, "w") do io
function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String})
chunksize = min(25000, 1 + fld(length(collection), config.nranks))
num_chunks = cld(length(collection), chunksize)

# sample passages for training centroids later
sampled_pids = _sample_pids(length(collection))
avg_doclen_est = _sample_embeddings(config, checkpoint, collection, sampled_pids)

# computing the number of partitions, i.e clusters
num_passages = length(collection)
num_embeddings_est = num_passages * avg_doclen_est
num_partitions = Int(floor(2^(floor(log2(16 * sqrt(num_embeddings_est))))))

@info "Creating $(num_partitions) clusters."
@info "Estimated $(num_embeddings_est) embeddings."

@info "Saving the index plan to $(joinpath(config.index_path, "plan.json"))."
open(joinpath(config.index_path, "plan.json"), "w") do io
JSON.print(io,
Dict(
"num_chunks" => indexer.num_chunks,
"num_partitions" => indexer.num_partitions,
"num_embeddings_est" => indexer.num_embeddings_est,
"avg_doclen_est" => indexer.avg_doclen_est
"num_chunks" => num_chunks,
"num_partitions" => num_partitions,
"num_embeddings_est" => num_embeddings_est,
"avg_doclen_est" => avg_doclen_est
),
4 # indent
)
end
end

"""
setup(indexer::CollectionIndexer)
Initialize `indexer` by computing some indexing-specific estimates and save the indexing plan to disk.
The number of chunks into which the document embeddings will be stored (`indexer.num_chunks`) is simply computed using the number of documents and the size of a chunk obtained from [`get_chunksize`](@ref). A bunch of pids used for initializing the centroids for the embedding clusters are sampled using the [`_sample_pids`](@ref) and [`_sample_embeddings`](@ref) functions, and these samples are used to calculate the average document lengths and the estimated number of embeddings which will be computed across all documents. Finally, the number of clusters (`indexer.num_partitions`) to be used for indexing is computed, and is proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``, and the indexing plan is saved to `plan.json` (see [`_save_plan`](@ref)) in the indexing directory.
# Arguments
- `indexer::CollectionIndexer`: The indexer to be initialized.
"""
function setup(indexer::CollectionIndexer)
collection = indexer.config.collection
indexer.num_chunks = Int(ceil(length(collection.data) / get_chunksize(
collection, indexer.config.nranks)))

# sample passages for training centroids later
sampled_pids = _sample_pids(indexer)
avg_doclen_est = _sample_embeddings(indexer, sampled_pids)

# computing the number of partitions, i.e clusters
num_passages = length(indexer.config.collection.data)
indexer.num_embeddings_est = num_passages * avg_doclen_est
indexer.num_partitions = Int(floor(2^(floor(log2(16 *
sqrt(indexer.num_embeddings_est))))))

@info "Creating $(indexer.num_partitions) clusters."
@info "Estimated $(indexer.num_embeddings_est) embeddings."

_save_plan(indexer)
@info "Saving the config to the indexing path."
ColBERT.save(config)
end

"""
Expand Down

0 comments on commit 20f0402

Please sign in to comment.