Skip to content

Commit

Permalink
Making the setup, train and _sample_embeddings functions test
Browse files Browse the repository at this point in the history
friendly; moving file saving and loading to a higher level `index`
function.
  • Loading branch information
codetalker7 committed Aug 15, 2024
1 parent 23cc67b commit 6c63834
Showing 1 changed file with 26 additions and 38 deletions.
64 changes: 26 additions & 38 deletions src/indexing/collection_indexer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,15 @@ function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint,
@assert size(local_sample_embs)[2]==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))"
@assert length(local_sample) == length(local_sample_doclens)

num_sample_embs = size(local_sample_embs)[2]
avg_doclen_est = length(local_sample_doclens) > 0 ?
sum(local_sample_doclens) / length(local_sample_doclens) : 0

sample_path = joinpath(config.index_path, "sample.jld2")
@info "avg_doclen_est = $(avg_doclen_est) \t length(local_sample) = $(length(local_sample))"
@info "Saving sampled embeddings to $(sample_path)."
JLD2.save_object(sample_path, local_sample_embs)

avg_doclen_est
Dict(
"avg_doclen_est" => avg_doclen_est,
"local_sample_embs" => local_sample_embs
)
end

"""
Expand All @@ -141,41 +140,31 @@ and the indexing plan is saved to `plan.json`, with the path being specified by
- `collection`: The underlying collection of passages to initialize the index for.
"""
function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String})
isdir(config.index_path) || mkdir(config.index_path)

chunksize = 0
chunksize = ismissing(config.chunksize) ?
min(25000, 1 + fld(length(collection), config.nranks)) : config.chunksize
num_chunks = cld(length(collection), chunksize)

# sample passages for training centroids later
sampled_pids = _sample_pids(length(collection))
avg_doclen_est = _sample_embeddings(config, checkpoint, collection, sampled_pids)
local_sample_dict = _sample_embeddings(config, checkpoint, collection, sampled_pids)

# computing the number of partitions, i.e clusters
num_passages = length(collection)
num_embeddings_est = num_passages * avg_doclen_est
num_embeddings_est = num_passages * local_sample_dict["avg_doclen_est"]
num_partitions = Int(floor(2^(floor(log2(16 * sqrt(num_embeddings_est))))))

@info "Creating $(num_partitions) clusters."
@info "Estimated $(num_embeddings_est) embeddings."

@info "Saving the index plan to $(joinpath(config.index_path, "plan.json"))."
open(joinpath(config.index_path, "plan.json"), "w") do io
JSON.print(io,
Dict(
"chunksize" => chunksize,
"num_chunks" => num_chunks,
"num_partitions" => num_partitions,
"num_embeddings_est" => num_embeddings_est,
"avg_doclen_est" => avg_doclen_est
),
4 # indent
)
end

@info "Saving the config to the indexing path."
ColBERT.save(config)
Dict(
"chunksize" => chunksize,
"num_chunks" => num_chunks,
"num_partitions" => num_partitions,
"num_embeddings_est" => num_embeddings_est,
"avg_doclen_est" => local_sample_dict["avg_doclen_est"],
"local_sample_embs" => local_sample_dict["local_sample_embs"]
)
end

"""
Expand Down Expand Up @@ -276,24 +265,23 @@ function, and the codec is saved on disk using [`save_codec`](@ref).
- `config`: The [`ColBERTConfig`](@ref) used to train the indexer.
"""
function train(config::ColBERTConfig)
sample, heldout = _concatenate_and_split_sample(config.index_path)
@assert sample isa AbstractMatrix{Float32} "$(typeof(sample))"
@assert heldout isa AbstractMatrix{Float32} "$(typeof(heldout))"

# loading the indexing plan
plan_metadata = JSON.parsefile(joinpath(config.index_path, "plan.json"))

centroids = kmeans(sample, plan_metadata["num_partitions"],
maxiter = config.kmeans_niters, display = :iter).centers
@assert size(centroids)[2]==plan_metadata["num_partitions"] "size(centroids): $(size(centroids)), num_partitions: $(plan_metadata["num_partitions"])"
function train(sample::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32}, num_partitions::Int, nbits::Int, kmeans_niters::Int)
centroids = kmeans(sample, num_partitions,
maxiter = kmeans_niters, display = :iter).centers
@assert size(centroids)[2]==num_partitions
"size(centroids): $(size(centroids)), num_partitions: $(num_partitions)"
@assert centroids isa AbstractMatrix{Float32} "$(typeof(centroids))"

bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals(
config.nbits, centroids, heldout)
nbits, centroids, heldout)
@info "avg_residual = $(avg_residual)"

save_codec(config.index_path, centroids, bucket_cutoffs, bucket_weights, avg_residual)
Dict(
"centroids" => centroids,
"bucket_cutoffs" => bucket_cutoffs,
"bucket_weights" => bucket_weights,
"avg_residual" => avg_residual,
)
end

"""
Expand Down

0 comments on commit 6c63834

Please sign in to comment.