Skip to content

Commit

Permalink
Merge pull request #32 from JuliaGenAI/unit_testing
Browse files Browse the repository at this point in the history
Unit tests + more design changes + function level optimizations.
  • Loading branch information
codetalker7 authored Sep 8, 2024
2 parents 266613d + 1bda1a8 commit 7731a61
Show file tree
Hide file tree
Showing 23 changed files with 3,671 additions and 1,353 deletions.
1 change: 1 addition & 0 deletions examples/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ indexer = Indexer(config)
# then big example
config = ColBERTConfig(
use_gpu = true,
checkpoint = "/home/codetalker7/models/colbertv2.0/",
collection = "./downloads/lotte/lifestyle/dev/collection.tsv",
index_path = "./lotte_lifestyle_index/"
)
Expand Down
3 changes: 3 additions & 0 deletions src/ColBERT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Clustering
using CSV
using Dates
using Flux
using .Iterators
using JLD2
using JSON
using LinearAlgebra
Expand All @@ -23,8 +24,10 @@ export ColBERTConfig

# models, document/query tokenizers
include("local_loading.jl")
include("modelling/tokenization/tokenizer_utils.jl")
include("modelling/tokenization/doc_tokenization.jl")
include("modelling/tokenization/query_tokenization.jl")
include("modelling/embedding_utils.jl")
include("modelling/checkpoint.jl")
export BaseColBERT, Checkpoint

Expand Down
100 changes: 80 additions & 20 deletions src/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
struct Indexer
config::ColBERTConfig
checkpoint::Checkpoint
bert::HF.HGFBertModel
linear::Layers.Dense
tokenizer::TextEncoders.AbstractTransformerTextEncoder
collection::Vector{String}
skiplist::Vector{Int}
end

"""
Expand All @@ -19,14 +22,32 @@ An [`Indexer`] wrapping a [`ColBERTConfig`](@ref), a [`Checkpoint`](@ref) and
a collection of documents to index.
"""
function Indexer(config::ColBERTConfig)
base_colbert = BaseColBERT(config.checkpoint)
checkpoint = Checkpoint(base_colbert, config)
tokenizer, bert, linear = load_hgf_pretrained_local(config.checkpoint)
bert = bert |> Flux.gpu
linear = linear |> Flux.gpu
collection = readlines(config.collection)

@info "Loaded ColBERT layers from the $(checkpoint) HuggingFace checkpoint."
punctuations_and_padsym = [string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~"));
tokenizer.padsym]
skiplist = config.mask_punctuation ?
[lookup(tokenizer.vocab, sym) for sym in punctuations_and_padsym] :
[lookup(tokenizer.vocab, tokenizer.padsym)]

# configuring the tokenizer; using doc_maxlen
process = tokenizer.process
truncpad_pipe = Pipeline{:token}(
TextEncodeBase.trunc_and_pad(
config.doc_maxlen - 1, "[PAD]", :tail, :tail),
:token)
process = process[1:4] |> truncpad_pipe |> process[6:end]
tokenizer = TextEncoders.BertTextEncoder(
tokenizer.tokenizer, tokenizer.vocab, process;
startsym = tokenizer.startsym, endsym = tokenizer.endsym,
padsym = tokenizer.padsym, trunc = tokenizer.trunc)

@info "Loaded ColBERT layers from the $(config.checkpoint) HuggingFace checkpoint."
@info "Loaded $(length(collection)) documents from $(config.collection)."

Indexer(config, checkpoint, collection)
Indexer(config, bert, linear, tokenizer, collection, skiplist)
end

"""
Expand All @@ -43,40 +64,79 @@ function index(indexer::Indexer)
@info "Index at $(indexer.config.index_path) already exists! Skipping indexing."
return
end

# getting and saving the indexing plan
isdir(indexer.config.index_path) || mkdir(indexer.config.index_path)
sample, sample_heldout, plan_dict = setup(
indexer.config, indexer.checkpoint, indexer.collection)

# sampling passages and getting their embedings
@info "Sampling PIDs for clustering and generating their embeddings."
@time avg_doclen_est, sample = _sample_embeddings(
indexer.bert, indexer.linear, indexer.tokenizer,
indexer.config.dim, indexer.config.index_bsize,
indexer.config.doc_token_id, indexer.skiplist, indexer.collection)

# splitting the sample to a clustering set and a heldout set
@info "Splitting the sampled embeddings to a heldout set."
@time sample, sample_heldout = _heldout_split(sample)
@assert sample isa AbstractMatrix{Float32} "$(typeof(sample))"
@assert sample_heldout isa AbstractMatrix{Float32} "$(typeof(sample_heldout))"

# generating the indexing setup
plan_dict = setup(indexer.collection, avg_doclen_est, size(sample, 2),
indexer.config.chunksize, indexer.config.nranks)
@info "Saving the index plan to $(joinpath(indexer.config.index_path, "plan.json"))."
open(joinpath(indexer.config.index_path, "plan.json"), "w") do io
JSON.print(io,
plan_dict,
4 # indent
4
)
end
@info "Saving the config to the indexing path."
ColBERT.save(indexer.config)

# training/clustering
@assert sample isa AbstractMatrix{Float32} "$(typeof(sample))"
@assert sample_heldout isa AbstractMatrix{Float32} "$(typeof(sample_heldout))"
@info "Training the clusters."
centroids, bucket_cutoffs, bucket_weights, avg_residual = train(
@time centroids, bucket_cutoffs, bucket_weights, avg_residual = train(
sample, sample_heldout, plan_dict["num_partitions"],
indexer.config.nbits, indexer.config.kmeans_niters)
save_codec(
indexer.config.index_path, centroids, bucket_cutoffs,
bucket_weights, avg_residual)
sample, sample_heldout, centroids = nothing, nothing, nothing # these are big arrays
sample, sample_heldout = nothing, nothing # these are big arrays

# indexing
@info "Building the index."
index(indexer.config, indexer.checkpoint, indexer.collection)
@time index(indexer.config.index_path, indexer.bert, indexer.linear,
indexer.tokenizer, indexer.collection, indexer.config.dim,
indexer.config.index_bsize, indexer.config.doc_token_id,
indexer.skiplist, plan_dict["num_chunks"], plan_dict["chunksize"],
centroids, bucket_cutoffs, indexer.config.nbits)

# collect embedding offsets and more metadata for chunks
chunk_emb_counts = load_chunk_metadata_property(
indexer.config.index_path, "num_embeddings")
num_embeddings, embeddings_offsets = _collect_embedding_id_offset(chunk_emb_counts)
@info "Updating chunk metadata and indexing plan"
plan_dict["num_embeddings"] = num_embeddings
plan_dict["embeddings_offsets"] = embeddings_offsets
open(joinpath(indexer.config.index_path, "plan.json"), "w") do io
JSON.print(io,
plan_dict,
4
)
end
save_chunk_metadata_property(
indexer.config.index_path, "embedding_offset", embeddings_offsets)

# build and save the ivf
@info "Building the centroid to embedding IVF."
codes = load_codes(indexer.config.index_path)
ivf, ivf_lengths = _build_ivf(codes, plan_dict["num_partitions"])

@info "Saving the IVF."
ivf_path = joinpath(indexer.config.index_path, "ivf.jld2")
ivf_lengths_path = joinpath(indexer.config.index_path, "ivf_lengths.jld2")
JLD2.save_object(ivf_path, ivf)
JLD2.save_object(ivf_lengths_path, ivf_lengths)

# finalizing
@info "Running some final checks."
# check if all relevant files are saved
_check_all_files_are_saved(indexer.config.index_path)
_collect_embedding_id_offset(indexer.config.index_path)
_build_ivf(indexer.config.index_path)
end
Loading

0 comments on commit 7731a61

Please sign in to comment.