Skip to content

Commit

Permalink
Refactoring the setup code.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed Aug 26, 2024
1 parent 9fa7dc2 commit ccaa17a
Showing 1 changed file with 30 additions and 36 deletions.
66 changes: 30 additions & 36 deletions src/indexing/collection_indexer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,43 @@ from the sampled documents, and the embedding matrix for the local samples. The
shape `(D, N)`, where `D` is the embedding dimension (`128`) and `N` is the total number
of embeddings over all the sampled passages.
"""
function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint,
collection::Vector{String}, sampled_pids::Set{Int})
function _sample_embeddings(bert::HF.HGFBertModel, linear::Layers.Dense,
tokenizer::TextEncoders.AbstractTransformerTextEncoder,
dim::Int, index_bsize::Int, doc_token::String,
skiplist::Vector{Int}, collection::Vector{String})
# collect all passages with pids in sampled_pids
sampled_pids = _sample_pids(length(collection))
sorted_sampled_pids = sort(collect(sampled_pids))
local_sample = collection[sorted_sampled_pids]

# get the local sample embeddings
local_sample_embs, local_sample_doclens = encode_passages(
config, checkpoint, local_sample)
@debug "Local sample embeddings shape: $(size(local_sample_embs)), \t Local sample doclens: $(local_sample_doclens)"
@assert size(local_sample_embs)[2]==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))"
bert, linear, tokenizer, local_sample,
dim, index_bsize, doc_token, skiplist)

@assert size(local_sample_embs, 2)==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))"
@assert length(local_sample) == length(local_sample_doclens)

avg_doclen_est = length(local_sample_doclens) > 0 ?
sum(local_sample_doclens) / length(local_sample_doclens) :
0
Float32(sum(local_sample_doclens) /
length(local_sample_doclens)) :
zero(Float32)
@info "avg_doclen_est = $(avg_doclen_est) \t length(local_sample) = $(length(local_sample))"
avg_doclen_est, local_sample_embs
end

function _heldout_split(
sample::AbstractMatrix{Float32}; heldout_fraction::Float32 = 0.05f0)
num_sample_embs = size(sample, 2)
sample = sample[:, shuffle(1:num_sample_embs)]
heldout_size = Int(max(
1, floor(min(50000, heldout_fraction * num_sample_embs))))
sample, sample_heldout = sample[
:, 1:(num_sample_embs - heldout_size)],
sample[:, (num_sample_embs - heldout_size + 1):num_sample_embs]
sample, sample_heldout
end

"""
setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String})
Expand All @@ -91,45 +108,22 @@ proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``.
A `Dict` containing the indexing plan.
"""
function setup(config::ColBERTConfig, checkpoint::Checkpoint,
collection::Vector{String})
chunksize = 0
chunksize = ismissing(config.chunksize) ?
min(25000, 1 + fld(length(collection), config.nranks)) :
config.chunksize
function setup(collection::Vector{String}, avg_doclen_est::Float32,
num_clustering_embs::Int, chunksize::Union{Missing, Int}, nranks::Int)
chunksize = ismissing(chunksize) ?
min(25000, 1 + fld(length(collection), nranks)) :
chunksize
num_chunks = cld(length(collection), chunksize)

# sample passages for training centroids later
sampled_pids = _sample_pids(length(collection))
avg_doclen_est, local_sample_embs = _sample_embeddings(
config, checkpoint, collection, sampled_pids)

# splitting the local sample into heldout set
num_local_sample_embs = size(local_sample_embs, 2)
local_sample_embs = local_sample_embs[
:, shuffle(1:num_local_sample_embs)]

# split the sample to get a heldout set
heldout_fraction = 0.05
heldout_size = Int(max(
1, floor(min(
50000, heldout_fraction * num_local_sample_embs))))
sample, sample_heldout = local_sample_embs[
:, 1:(num_local_sample_embs - heldout_size)],
local_sample_embs[
:, (num_local_sample_embs - heldout_size + 1):num_local_sample_embs]
@debug "Split sample sizes: sample size: $(size(sample)), \t sample_heldout size: $(size(sample_heldout))"

# computing the number of partitions, i.e clusters
num_passages = length(collection)
num_embeddings_est = num_passages * avg_doclen_est
num_partitions = Int(min(size(sample, 2),
num_partitions = Int(min(num_clustering_embs,
floor(2^(floor(log2(16 * sqrt(num_embeddings_est)))))))

@info "Creating $(num_partitions) clusters."
@info "Estimated $(num_embeddings_est) embeddings."

sample, sample_heldout,
Dict(
"chunksize" => chunksize,
"num_chunks" => num_chunks,
Expand Down

0 comments on commit ccaa17a

Please sign in to comment.