diff --git a/examples/indexing.jl b/examples/indexing.jl index d9285a1..48ea65f 100644 --- a/examples/indexing.jl +++ b/examples/indexing.jl @@ -21,6 +21,7 @@ indexer = Indexer(config) # then big example config = ColBERTConfig( use_gpu = true, + checkpoint = "/home/codetalker7/models/colbertv2.0/", collection = "./downloads/lotte/lifestyle/dev/collection.tsv", index_path = "./lotte_lifestyle_index/" ) diff --git a/src/ColBERT.jl b/src/ColBERT.jl index d0c06ba..f29650a 100644 --- a/src/ColBERT.jl +++ b/src/ColBERT.jl @@ -3,6 +3,7 @@ using Clustering using CSV using Dates using Flux +using .Iterators using JLD2 using JSON using LinearAlgebra @@ -23,8 +24,10 @@ export ColBERTConfig # models, document/query tokenizers include("local_loading.jl") +include("modelling/tokenization/tokenizer_utils.jl") include("modelling/tokenization/doc_tokenization.jl") include("modelling/tokenization/query_tokenization.jl") +include("modelling/embedding_utils.jl") include("modelling/checkpoint.jl") export BaseColBERT, Checkpoint diff --git a/src/indexing.jl b/src/indexing.jl index 0d27f33..b80091e 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -1,7 +1,10 @@ struct Indexer config::ColBERTConfig - checkpoint::Checkpoint + bert::HF.HGFBertModel + linear::Layers.Dense + tokenizer::TextEncoders.AbstractTransformerTextEncoder collection::Vector{String} + skiplist::Vector{Int} end """ @@ -19,14 +22,32 @@ An [`Indexer`] wrapping a [`ColBERTConfig`](@ref), a [`Checkpoint`](@ref) and a collection of documents to index. """ function Indexer(config::ColBERTConfig) - base_colbert = BaseColBERT(config.checkpoint) - checkpoint = Checkpoint(base_colbert, config) + tokenizer, bert, linear = load_hgf_pretrained_local(config.checkpoint) + bert = bert |> Flux.gpu + linear = linear |> Flux.gpu collection = readlines(config.collection) - - @info "Loaded ColBERT layers from the $(checkpoint) HuggingFace checkpoint." + punctuations_and_padsym = [string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")); + tokenizer.padsym] + skiplist = config.mask_punctuation ? + [lookup(tokenizer.vocab, sym) for sym in punctuations_and_padsym] : + [lookup(tokenizer.vocab, tokenizer.padsym)] + + # configuring the tokenizer; using doc_maxlen + process = tokenizer.process + truncpad_pipe = Pipeline{:token}( + TextEncodeBase.trunc_and_pad( + config.doc_maxlen - 1, "[PAD]", :tail, :tail), + :token) + process = process[1:4] |> truncpad_pipe |> process[6:end] + tokenizer = TextEncoders.BertTextEncoder( + tokenizer.tokenizer, tokenizer.vocab, process; + startsym = tokenizer.startsym, endsym = tokenizer.endsym, + padsym = tokenizer.padsym, trunc = tokenizer.trunc) + + @info "Loaded ColBERT layers from the $(config.checkpoint) HuggingFace checkpoint." @info "Loaded $(length(collection)) documents from $(config.collection)." - Indexer(config, checkpoint, collection) + Indexer(config, bert, linear, tokenizer, collection, skiplist) end """ @@ -43,40 +64,79 @@ function index(indexer::Indexer) @info "Index at $(indexer.config.index_path) already exists! Skipping indexing." return end - - # getting and saving the indexing plan isdir(indexer.config.index_path) || mkdir(indexer.config.index_path) - sample, sample_heldout, plan_dict = setup( - indexer.config, indexer.checkpoint, indexer.collection) + + # sampling passages and getting their embedings + @info "Sampling PIDs for clustering and generating their embeddings." + @time avg_doclen_est, sample = _sample_embeddings( + indexer.bert, indexer.linear, indexer.tokenizer, + indexer.config.dim, indexer.config.index_bsize, + indexer.config.doc_token_id, indexer.skiplist, indexer.collection) + + # splitting the sample to a clustering set and a heldout set + @info "Splitting the sampled embeddings to a heldout set." + @time sample, sample_heldout = _heldout_split(sample) + @assert sample isa AbstractMatrix{Float32} "$(typeof(sample))" + @assert sample_heldout isa AbstractMatrix{Float32} "$(typeof(sample_heldout))" + + # generating the indexing setup + plan_dict = setup(indexer.collection, avg_doclen_est, size(sample, 2), + indexer.config.chunksize, indexer.config.nranks) @info "Saving the index plan to $(joinpath(indexer.config.index_path, "plan.json"))." open(joinpath(indexer.config.index_path, "plan.json"), "w") do io JSON.print(io, plan_dict, - 4 # indent + 4 ) end @info "Saving the config to the indexing path." ColBERT.save(indexer.config) # training/clustering - @assert sample isa AbstractMatrix{Float32} "$(typeof(sample))" - @assert sample_heldout isa AbstractMatrix{Float32} "$(typeof(sample_heldout))" @info "Training the clusters." - centroids, bucket_cutoffs, bucket_weights, avg_residual = train( + @time centroids, bucket_cutoffs, bucket_weights, avg_residual = train( sample, sample_heldout, plan_dict["num_partitions"], indexer.config.nbits, indexer.config.kmeans_niters) save_codec( indexer.config.index_path, centroids, bucket_cutoffs, bucket_weights, avg_residual) - sample, sample_heldout, centroids = nothing, nothing, nothing # these are big arrays + sample, sample_heldout = nothing, nothing # these are big arrays # indexing @info "Building the index." - index(indexer.config, indexer.checkpoint, indexer.collection) + @time index(indexer.config.index_path, indexer.bert, indexer.linear, + indexer.tokenizer, indexer.collection, indexer.config.dim, + indexer.config.index_bsize, indexer.config.doc_token_id, + indexer.skiplist, plan_dict["num_chunks"], plan_dict["chunksize"], + centroids, bucket_cutoffs, indexer.config.nbits) + + # collect embedding offsets and more metadata for chunks + chunk_emb_counts = load_chunk_metadata_property( + indexer.config.index_path, "num_embeddings") + num_embeddings, embeddings_offsets = _collect_embedding_id_offset(chunk_emb_counts) + @info "Updating chunk metadata and indexing plan" + plan_dict["num_embeddings"] = num_embeddings + plan_dict["embeddings_offsets"] = embeddings_offsets + open(joinpath(indexer.config.index_path, "plan.json"), "w") do io + JSON.print(io, + plan_dict, + 4 + ) + end + save_chunk_metadata_property( + indexer.config.index_path, "embedding_offset", embeddings_offsets) + + # build and save the ivf + @info "Building the centroid to embedding IVF." + codes = load_codes(indexer.config.index_path) + ivf, ivf_lengths = _build_ivf(codes, plan_dict["num_partitions"]) + + @info "Saving the IVF." + ivf_path = joinpath(indexer.config.index_path, "ivf.jld2") + ivf_lengths_path = joinpath(indexer.config.index_path, "ivf_lengths.jld2") + JLD2.save_object(ivf_path, ivf) + JLD2.save_object(ivf_lengths_path, ivf_lengths) - # finalizing - @info "Running some final checks." + # check if all relevant files are saved _check_all_files_are_saved(indexer.config.index_path) - _collect_embedding_id_offset(indexer.config.index_path) - _build_ivf(indexer.config.index_path) end diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 61979d7..b6a3d07 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -17,67 +17,427 @@ A `Vector{UInt32}` of codes, where each code corresponds to the nearest centroid # Examples ```julia-repl -julia> using ColBERT, Flux, CUDA; - -julia> centroids = rand(Float32, 128, 500); - -julia> embs = rand(Float32, 128, 10000); - -julia> ColBERT.compress_into_codes(centroids, embs) -10000-element Vector{UInt32}: - 0x000000e0 - 0x000000fe - 0x0000015b - 0x00000183 - 0x0000017b - 0x0000002b - 0x000001ab - 0x00000160 - 0x000001ab - 0x000000c4 - 0x000000fe - 0x000000e0 - 0x000000e0 - 0x00000174 - 0x00000186 - 0x000000e0 - 0x000000e0 +julia> using ColBERT: compress_into_codes; + +julia> using Flux, CUDA, Random; + +julia> Random.seed!(0); + +julia> centroids = rand(Float32, 128, 500) |> Flux.gpu; + +julia> embs = rand(Float32, 128, 10000) |> Flux.gpu; + +julia> codes = zeros(UInt32, size(embs, 2)) |> Flux.gpu; + +julia> @time compress_into_codes!(codes, centroids, embs); + 0.003489 seconds (4.51 k allocations: 117.117 KiB) + +julia> codes +10000-element CuArray{UInt32, 1, CUDA.DeviceMemory}: + 0x00000194 + 0x00000194 + 0x0000000b + 0x000001d9 + 0x0000011f + 0x00000098 + 0x0000014e + 0x00000012 + 0x000000a0 + 0x00000098 + 0x000001a7 + 0x00000098 + 0x000001a7 + 0x00000194 ⋮ - 0x000000fe - 0x000001ec - 0x00000105 - 0x00000174 - 0x000000e0 - 0x0000015b - 0x00000008 - 0x00000174 - 0x00000147 - 0x000000e0 - 0x0000002b - 0x000000e0 - 0x000000b4 - 0x00000011 - 0x00000186 - 0x00000008 - 0x000000fe + 0x00000199 + 0x000001a7 + 0x0000014e + 0x000001a7 + 0x000001a7 + 0x000001a7 + 0x000000ec + 0x00000098 + 0x000001d9 + 0x00000098 + 0x000001d9 + 0x000001d9 + 0x00000012 ``` """ -function compress_into_codes( - centroids::AbstractMatrix{Float32}, embs::AbstractMatrix{Float32}) +function compress_into_codes!( + codes::AbstractVector{UInt32}, centroids::AbstractMatrix{Float32}, + embs::AbstractMatrix{Float32}; + bsize::Int = 1000) _, n = size(embs) - codes = Vector{UInt32}() - bsize = div(1 << 29, size(centroids)[2]) - for offset in 1:bsize:n # batch on the second dimension + length(codes) == n || + throw(DimensionMismatch("length(codes) must be equal" * + "to the number of embeddings!")) + for offset in 1:bsize:size(embs, 2) offset_end = min(n, offset + bsize - 1) - dot_products = (Flux.gpu(embs[ - :, offset:offset_end]))' * Flux.gpu(centroids) - indices = getindex.(argmax(dot_products, dims = 2), 2) |> Flux.cpu - append!(codes, indices) + dot_products = (embs[:, offset:offset_end])' * centroids # (num_embs, num_centroids) + indices = getindex.(argmax(dot_products, dims = 2), 2) + codes[offset:offset_end] .= indices end - @assert length(codes) == n - "length(codes): $(length(codes)), size(embs): $(size(embs))" - @assert codes isa AbstractVector{UInt32} "$(typeof(codes))" - codes +end + +""" + +# Examples + +```julia-repl +julia> using ColBERT: _binarize; + +julia> using Flux, CUDA, Random; + +julia> Random.seed!(0); + +julia> nbits = 5; + +julia> data = rand(0:2^nbits - 1, 100, 200000) |> Flux.gpu +100×200000 CuArray{Int64, 2, CUDA.DeviceMemory}: + 12 23 11 6 5 2 27 1 0 4 15 8 24 … 4 25 22 18 4 0 15 16 3 25 4 13 + 2 11 29 8 31 3 15 1 8 1 22 22 10 25 25 1 12 21 13 27 20 23 24 9 14 + 27 4 4 15 4 9 19 4 3 10 27 14 3 10 8 18 19 12 9 29 23 8 15 30 21 + 2 7 4 5 25 16 27 23 5 24 26 19 9 22 1 21 12 31 20 4 31 26 21 25 6 + 21 18 25 9 9 17 6 20 16 13 14 2 2 28 13 11 9 22 4 2 22 27 24 9 31 + 3 26 22 8 24 23 29 19 13 3 2 20 14 … 22 18 18 5 16 5 9 3 21 19 17 23 + 3 13 5 9 8 12 24 26 8 10 14 1 21 14 25 18 5 1 4 13 0 14 11 16 8 + 22 20 22 6 25 1 29 23 9 21 13 27 6 11 21 4 31 14 14 5 27 17 6 27 19 + 9 2 7 2 16 1 23 15 2 17 30 18 4 26 5 20 31 18 8 20 13 23 26 29 25 + 0 6 20 8 0 18 9 28 8 30 6 2 21 0 7 25 23 19 2 6 27 13 3 6 22 + 17 2 0 13 26 6 7 8 14 20 11 9 17 … 29 4 28 22 1 10 29 20 11 20 30 8 + 28 5 0 30 1 26 23 9 29 9 29 2 15 27 8 13 11 27 6 11 7 19 4 7 28 + 8 9 16 29 22 8 9 19 30 20 4 0 1 1 25 14 16 17 26 28 31 25 4 22 23 + 10 9 31 22 20 15 1 9 26 2 0 1 27 23 21 15 22 29 29 1 24 30 22 17 22 + 13 8 23 9 1 6 2 28 18 1 15 5 12 28 27 3 6 22 3 20 24 3 2 2 29 + 28 22 19 7 20 28 25 13 3 13 17 31 28 … 18 17 19 6 20 11 31 9 28 9 19 1 + 23 1 7 14 6 14 0 9 1 9 12 30 24 23 2 13 9 0 20 17 4 16 22 27 11 + 4 19 8 31 14 30 2 13 27 16 29 10 30 29 25 28 31 13 11 8 12 30 13 10 7 + 18 26 30 6 31 6 15 11 10 31 21 24 11 19 19 29 17 13 5 3 28 29 31 22 13 + 14 29 18 14 25 10 28 28 15 8 5 14 5 10 17 13 23 0 26 25 13 15 26 3 5 + 0 4 24 23 20 16 25 9 17 27 15 0 10 … 5 18 2 2 30 17 8 11 27 11 15 27 + 15 2 22 8 6 8 16 2 8 24 26 15 30 27 12 28 31 26 18 4 10 5 16 23 16 + 20 20 29 24 1 9 18 31 16 3 9 17 31 8 4 4 15 13 16 0 10 31 28 8 29 + 2 3 2 23 15 21 6 8 21 7 17 15 17 7 15 19 25 3 2 11 26 16 12 11 27 + 13 21 22 20 15 0 22 2 30 14 14 20 26 13 23 14 18 0 24 21 17 8 11 26 22 + ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ + 9 7 1 1 28 28 10 16 23 18 26 9 7 … 14 5 12 3 6 25 20 5 13 3 20 10 + 28 25 21 8 31 4 25 7 27 26 19 4 9 15 26 2 23 14 16 29 17 11 29 12 18 + 4 15 20 2 3 10 6 9 13 22 5 28 21 12 11 12 14 14 9 13 31 12 6 9 21 + 9 24 2 4 27 14 4 15 19 2 14 30 3 17 5 6 2 23 15 11 1 0 10 0 28 + 20 0 26 8 21 7 1 7 22 10 10 5 31 23 5 20 11 29 12 25 14 13 5 25 15 + 2 9 27 28 25 7 27 30 20 5 10 2 28 … 21 19 22 30 24 0 10 19 10 30 22 9 + 10 2 31 10 12 13 16 10 5 28 16 4 16 3 1 31 20 19 16 19 30 31 14 5 20 + 14 2 20 19 16 25 4 1 15 31 22 17 8 12 19 9 29 30 20 13 19 14 18 7 22 + 20 3 27 23 9 21 20 10 14 3 5 26 22 19 19 11 3 22 19 24 12 27 12 28 17 + 1 27 27 10 8 29 17 14 19 6 6 12 6 10 6 24 29 26 11 2 25 7 6 1 28 + 11 19 5 1 7 19 8 17 27 4 4 7 0 … 13 29 0 15 15 2 2 6 24 0 5 18 + 17 31 31 23 24 18 0 31 6 22 20 31 23 16 5 8 17 6 20 23 21 26 15 27 30 + 1 6 30 31 8 3 28 31 10 23 23 24 26 12 30 10 3 25 24 12 20 8 7 14 11 + 26 22 23 21 24 7 2 19 10 27 21 14 7 7 27 1 29 7 23 30 24 12 9 12 14 + 28 26 8 28 10 18 23 28 10 19 31 26 17 18 20 23 8 31 15 18 10 24 28 7 23 + 1 7 15 22 23 0 21 19 28 10 15 13 7 … 21 15 16 1 16 9 25 23 1 24 20 5 + 21 7 30 30 5 0 27 26 6 7 30 2 16 2 16 6 9 6 4 12 4 12 18 28 17 + 11 16 0 20 20 13 18 19 21 7 24 4 26 1 26 7 16 0 2 3 2 22 27 25 15 + 9 20 31 24 14 29 28 26 29 31 7 28 12 28 0 12 3 17 7 0 30 25 22 23 20 + 19 21 30 16 15 20 31 2 2 8 27 20 29 27 13 2 27 8 17 19 15 9 22 3 27 + 13 17 6 4 9 1 18 2 21 27 13 14 12 … 28 21 4 2 11 13 31 13 25 25 29 21 + 2 17 19 15 17 1 12 0 11 9 16 1 13 25 21 28 22 7 13 3 20 7 6 26 21 + 13 6 5 11 12 2 2 3 4 16 30 14 19 16 5 5 19 17 3 31 26 19 2 11 15 + 20 30 21 30 13 26 7 9 11 18 3 0 15 3 14 15 1 9 16 1 16 0 2 2 11 + 3 24 6 16 10 3 7 17 0 30 9 14 1 29 4 8 4 17 14 27 0 17 12 4 19 + +julia> _binarize(data, nbits) +5×100×200000 CuArray{Bool, 3, CUDA.DeviceMemory}: +[:, :, 1] = + 0 0 1 0 1 1 1 0 1 0 1 0 0 0 1 0 1 0 0 … 0 0 0 1 1 1 1 0 0 1 1 1 1 1 1 0 1 0 1 + 0 1 1 1 0 1 1 1 0 0 0 0 0 1 0 0 1 0 1 1 1 0 0 1 0 0 1 0 0 0 1 0 1 0 1 0 0 1 + 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 1 1 1 0 0 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 1 1 0 + 1 0 1 0 0 0 0 0 1 0 0 1 1 1 1 1 0 0 0 1 1 0 0 1 0 0 1 1 0 0 1 1 0 1 0 1 0 0 + 0 0 1 0 1 0 0 1 0 0 1 1 0 0 0 1 1 0 1 0 0 1 0 0 1 0 1 1 0 1 0 0 1 0 0 0 1 0 + +[:, :, 2] = + 1 1 0 1 0 0 1 0 0 0 0 1 1 1 0 0 1 1 0 … 0 0 1 1 1 1 0 0 0 1 1 0 0 1 1 1 0 0 0 + 1 1 0 1 1 1 0 0 1 1 1 0 0 0 0 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 1 1 0 + 1 0 1 1 0 0 1 1 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0 1 1 0 0 1 1 0 + 0 1 0 0 0 1 1 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0 1 0 1 0 0 1 0 0 0 0 0 0 0 0 1 1 + 1 0 0 0 1 1 0 1 0 0 0 0 0 0 0 1 0 1 1 0 0 0 1 1 1 0 1 1 0 0 1 1 1 1 1 0 1 1 + +[:, :, 3] = + 1 1 0 0 1 0 1 0 1 0 0 0 0 1 1 1 1 0 0 … 1 0 1 1 1 1 0 1 0 1 0 0 1 0 0 1 1 1 0 + 1 0 0 0 0 1 0 1 1 0 0 0 0 1 1 1 1 0 1 1 0 1 1 0 1 1 1 0 1 1 0 1 1 1 1 0 0 1 + 0 1 1 1 0 1 1 1 1 1 0 0 0 1 1 0 1 0 1 1 1 0 0 1 1 1 1 0 1 1 0 1 1 1 0 1 1 1 + 1 1 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 1 1 1 0 1 1 0 1 1 0 1 1 1 0 1 1 0 0 0 0 0 + 0 1 0 0 1 1 0 1 0 1 0 0 1 1 1 1 0 0 1 1 1 1 1 0 1 1 1 0 0 1 0 1 1 0 1 0 1 0 + +;;; … + +[:, :, 199998] = + 1 0 1 1 0 1 1 0 0 1 0 0 0 0 0 1 0 1 1 … 0 0 0 0 0 1 1 1 0 0 0 1 0 0 1 0 0 0 0 + 0 0 1 0 0 1 1 1 1 1 0 0 0 1 1 0 1 0 1 1 1 0 1 0 1 1 0 0 0 1 1 1 1 0 1 1 1 0 + 0 0 1 1 0 0 0 1 0 0 1 1 1 1 0 0 1 1 1 1 0 1 1 0 1 1 0 1 0 0 0 1 1 0 1 0 0 1 + 1 1 1 0 1 0 1 0 1 0 0 0 0 0 0 1 0 1 1 1 0 1 0 0 1 0 1 1 1 0 1 0 0 1 0 0 0 1 + 1 1 0 1 1 1 0 0 1 0 1 0 0 1 0 0 1 0 1 0 1 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 + +[:, :, 199999] = + 0 1 0 1 1 1 0 1 1 0 0 1 0 1 0 1 1 0 0 … 1 1 0 1 1 1 0 0 1 0 0 1 1 1 1 0 1 0 0 + 0 0 1 0 0 0 0 1 0 1 1 1 1 0 1 1 1 1 1 0 1 0 0 0 1 1 0 1 0 0 0 1 1 0 1 1 1 0 + 1 0 1 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 1 1 1 1 0 1 0 1 1 1 1 1 0 1 0 1 0 0 0 1 + 0 1 1 1 1 0 0 1 1 0 1 0 0 0 0 0 1 1 0 0 0 1 0 0 1 1 1 0 0 1 1 0 0 1 1 1 0 0 + 0 0 1 1 0 1 1 1 1 0 1 0 1 1 0 1 1 0 1 0 0 1 0 0 1 0 0 0 1 1 1 1 0 1 1 0 0 0 + +[:, :, 200000] = + 1 0 1 0 1 1 0 1 1 0 0 0 1 0 1 1 1 1 1 … 0 0 1 0 0 0 1 0 1 1 1 1 0 1 1 1 1 1 1 + 0 1 0 1 1 1 0 1 0 1 0 0 1 1 0 0 1 1 0 0 1 0 0 1 1 1 1 1 0 0 1 0 1 0 0 1 1 1 + 1 1 1 1 1 1 0 0 0 1 0 1 1 1 1 0 0 1 1 1 1 0 1 0 1 0 1 1 1 0 1 1 0 1 1 1 0 0 + 1 1 0 0 1 0 1 0 1 0 1 1 0 0 1 0 1 0 1 0 0 0 1 0 1 1 1 0 0 0 1 0 1 0 0 1 1 0 + 0 0 1 0 1 1 0 1 1 1 0 1 1 1 1 0 0 0 0 1 1 1 1 1 1 0 0 1 0 1 0 1 1 1 1 0 0 1 +``` +""" +function _binarize(data::AbstractMatrix{T}, nbits::Int) where {T <: Integer} + all(in(0:(1 << nbits - 1)), data) || + throw(DomainError("All values in the matrix should be in " * + "range [0, 2^nbits - 1]!")) + data = stack(fill(data, nbits), dims = 1) # (nbits, dim, batch_size) + positionbits = similar(data, nbits) # respects device + copyto!(positionbits, map(Base.Fix1(<<, 1), 0:(nbits - 1))) # (nbits, 1) + positionbits = reshape(positionbits, nbits, 1, 1) # (nbits, 1, 1) + data .= fld.(data, positionbits) # divide by 2^bit for each bit position + data .= data .& 1 # apply mod 1 to binarize + map(Bool, data) +end + +""" + +# Examples + +```julia-repl +julia> using ColBERT: _binarize, _unbinarize; + +julia> using Flux, CUDA, Random; + +julia> Random.seed!(0); + +julia> nbits = 5; + +julia> data = rand(0:2^nbits - 1, 100, 200000) |> Flux.gpu + +julia> binarized_data = _binarize(data, nbits); + +julia> unbinarized_data = _unbinarize(binarized_data); + +julia> isequal(unbinarized_data, data) +true +``` +""" +function _unbinarize(data::AbstractArray{Bool, 3}) + nbits = size(data, 1) + positionbits = similar(data, Int, nbits) # respects device + copyto!(positionbits, map(Base.Fix1(<<, 1), 0:(nbits - 1))) # (nbits, 1) + positionbits = reshape(positionbits, nbits, 1, 1) # (nbits, 1, 1) + integer_data = sum(data .* positionbits, dims = 1) + reshape(integer_data, size(integer_data)[2:end]) +end + +""" + +# Examples + +```julia-repl +julia> using ColBERT: _bucket_indices; + +julia> using Random; Random.seed!(0); + +julia> data = rand(50, 50) |> Flux.gpu; +50×50 CuArray{Float32, 2, CUDA.DeviceMemory}: + 0.455238 0.828104 0.735106 0.042069 … 0.916387 0.10078 0.00907127 + 0.547642 0.100748 0.993553 0.0275458 0.0954245 0.351846 0.548682 + 0.773354 0.908416 0.703694 0.839846 0.613082 0.605597 0.660227 + 0.940585 0.932748 0.150822 0.920883 0.754362 0.843869 0.0453409 + 0.0296477 0.123079 0.409406 0.672372 0.19912 0.106127 0.945276 + 0.746943 0.149248 0.864755 0.116243 … 0.541295 0.224275 0.660706 + 0.746801 0.743713 0.64608 0.446445 0.951642 0.583662 0.338174 + 0.97667 0.722362 0.692789 0.646206 0.089323 0.305554 0.454803 + 0.329335 0.785124 0.254097 0.271299 0.320879 0.000438984 0.161356 + 0.672001 0.532197 0.869579 0.182068 0.289906 0.068645 0.142121 + 0.0997382 0.523732 0.315933 0.935547 … 0.819027 0.770597 0.654065 + 0.230139 0.997278 0.455917 0.566976 0.0180972 0.275211 0.0619634 + 0.631256 0.709048 0.810256 0.754144 0.452911 0.358555 0.116042 + 0.096652 0.454081 0.715283 0.923417 0.498907 0.781054 0.841858 + 0.69801 0.0439444 0.27613 0.617714 0.589872 0.708365 0.0266968 + 0.470257 0.654557 0.351769 0.812597 … 0.323819 0.621386 0.63478 + 0.114864 0.897316 0.0243141 0.910847 0.232374 0.861399 0.844008 + 0.984812 0.491806 0.356395 0.501248 0.651833 0.173494 0.38356 + 0.730758 0.970359 0.456407 0.8044 0.0385577 0.306404 0.705577 + 0.117333 0.233628 0.332989 0.0857914 0.224095 0.747571 0.387572 + ⋮ ⋱ + 0.908402 0.609104 0.108874 0.430905 … 0.00564743 0.964602 0.541285 + 0.570179 0.10114 0.210174 0.945569 0.149051 0.785343 0.241959 + 0.408136 0.221389 0.425872 0.204654 0.238413 0.583185 0.271998 + 0.526989 0.0401535 0.686314 0.534208 0.29416 0.488244 0.747676 + 0.129952 0.716592 0.352166 0.584363 0.0850619 0.161153 0.243575 + 0.0256413 0.0831649 0.179467 0.799997 … 0.229072 0.711857 0.326977 + 0.939913 0.21433 0.223666 0.914527 0.425202 0.129862 0.766065 + 0.600877 0.516631 0.753827 0.674017 0.665329 0.622929 0.645962 + 0.223773 0.257933 0.854171 0.259882 0.298119 0.231662 0.824881 + 0.268817 0.468576 0.218589 0.835418 0.802857 0.0159643 0.0330232 + 0.408092 0.361884 0.849442 0.527004 … 0.0500168 0.427498 0.70482 + 0.740789 0.952265 0.722908 0.0856596 0.507305 0.32629 0.117663 + 0.873501 0.587707 0.894573 0.355338 0.345011 0.0693833 0.457268 + 0.758824 0.162728 0.608327 0.902837 0.492069 0.716635 0.459272 + 0.922832 0.950539 0.51935 0.52672 0.725665 0.36443 0.936056 + 0.239929 0.3754 0.247219 0.92438 … 0.0763809 0.737196 0.712317 + 0.76676 0.182714 0.866055 0.749239 0.132254 0.755823 0.0869469 + 0.378313 0.0392607 0.93354 0.908511 0.733769 0.552135 0.351491 + 0.811121 0.891591 0.610976 0.0427439 0.0258436 0.482621 0.193291 + 0.109315 0.474986 0.140528 0.776382 0.609791 0.49946 0.116989 + +julia> bucket_cutoffs = sort(rand(5)) |> Flux.gpu; +5-element CuArray{Float32, 1, CUDA.DeviceMemory}: + 0.42291805 + 0.7075339 + 0.8812783 + 0.89976573 + 0.9318977 + +julia> _bucket_indices(data, bucket_cutoffs) +50×50 CuArray{Int64, 2, CUDA.DeviceMemory}: + 1 2 2 0 1 0 2 0 0 2 0 1 1 0 … 0 0 0 1 1 0 2 2 4 0 4 0 0 + 1 0 5 0 1 4 1 2 0 0 5 1 0 0 0 0 1 2 4 2 0 0 0 2 0 0 1 + 2 4 1 2 1 0 5 0 1 1 0 0 0 1 2 5 1 1 1 1 1 1 0 5 1 1 1 + 5 5 0 4 0 0 1 2 4 0 4 1 0 0 5 5 4 2 1 0 2 0 1 0 2 2 0 + 0 0 0 1 0 0 1 1 0 2 0 1 2 0 1 0 2 0 2 0 2 1 1 5 0 0 5 + 2 0 2 0 1 0 1 0 2 4 2 2 0 2 … 0 1 0 4 0 5 0 0 0 2 1 0 1 + 2 2 1 1 1 0 3 0 2 0 1 1 5 0 2 0 0 0 0 1 0 5 5 1 5 1 0 + 5 2 1 1 2 5 0 0 1 3 0 1 0 1 0 0 0 0 0 1 4 0 1 0 0 0 1 + 0 2 0 0 1 1 0 5 2 0 2 2 2 2 0 0 5 5 0 0 2 2 0 2 0 0 0 + 1 1 2 0 2 4 5 5 1 0 2 2 2 0 0 0 1 1 1 0 0 1 1 2 0 0 0 + 0 1 0 5 0 0 2 0 2 0 0 3 0 0 … 1 2 0 5 0 1 2 0 0 0 2 2 1 + 0 5 1 1 2 1 0 1 1 0 0 1 1 0 5 0 0 2 2 0 3 1 1 4 0 0 0 + 1 2 2 2 2 1 1 5 0 0 0 1 0 5 0 1 1 0 0 0 2 0 2 0 1 0 0 + 0 1 2 4 1 2 1 2 0 2 2 0 0 0 0 1 0 1 0 1 3 1 1 1 1 2 2 + 1 0 0 1 4 0 2 2 5 4 0 3 0 1 3 0 0 0 0 5 0 1 2 0 1 2 0 + 1 1 0 2 0 1 5 3 1 2 5 2 1 2 … 1 1 2 0 0 0 2 1 2 3 0 1 1 + 0 3 0 4 0 0 0 0 0 0 0 0 0 1 1 1 1 2 0 1 0 2 3 0 0 2 2 + 5 1 0 1 2 0 2 0 0 2 0 0 1 0 1 4 0 2 0 0 0 0 1 0 1 0 0 + 2 5 1 2 0 1 0 2 5 1 1 1 5 0 1 1 0 0 2 0 1 0 4 0 0 0 1 + 0 0 0 0 0 2 3 1 0 1 1 0 1 2 0 1 1 1 1 0 0 0 5 1 0 2 0 + ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ + 4 1 0 1 4 1 2 0 1 0 0 1 0 2 … 0 0 0 0 0 2 0 2 0 1 0 5 1 + 1 0 0 5 2 2 5 0 0 3 5 0 1 5 1 2 0 1 2 0 0 0 1 0 0 2 0 + 0 0 1 0 0 1 4 0 0 1 0 5 1 5 1 1 2 0 2 0 1 1 2 4 0 1 0 + 1 0 1 1 0 0 0 0 1 0 0 0 0 4 0 0 1 0 3 5 0 1 1 1 0 1 2 + 0 2 0 1 0 0 2 0 2 1 1 2 1 1 0 0 0 1 1 1 0 0 1 2 0 0 0 + 0 0 0 2 5 2 2 0 0 5 5 4 1 0 … 0 0 2 1 5 0 1 0 1 0 0 2 0 + 5 0 0 4 0 1 0 0 0 1 2 2 0 0 1 0 0 0 1 1 4 0 5 1 1 0 2 + 1 1 2 1 1 1 0 0 0 0 0 2 1 0 0 5 0 1 0 0 1 2 0 0 1 1 1 + 0 0 2 0 0 1 1 4 0 2 2 0 5 1 1 1 1 1 5 0 3 2 2 1 0 0 2 + 0 1 0 2 2 1 1 0 1 0 1 0 0 2 5 0 1 0 5 0 0 2 2 0 2 0 0 + 0 0 2 1 0 1 1 1 1 2 4 0 1 2 … 1 1 1 1 0 0 5 1 0 0 0 1 1 + 2 5 2 0 0 0 2 0 2 0 0 0 0 0 4 0 5 5 0 2 0 0 0 0 1 0 0 + 2 1 3 0 1 1 0 0 4 0 0 1 1 0 1 1 0 4 1 1 0 2 0 3 0 0 1 + 2 0 1 4 1 0 0 1 0 2 1 0 0 0 5 1 0 0 1 1 0 0 2 0 1 2 1 + 4 5 1 1 1 1 0 0 0 1 1 0 5 2 5 0 2 2 1 1 1 5 2 1 2 0 5 + 0 0 0 4 2 1 0 3 0 3 2 0 1 2 … 0 1 0 2 0 0 2 5 2 0 0 2 2 + 2 0 2 2 1 0 0 3 1 1 0 5 2 0 2 0 2 0 5 1 0 0 1 0 0 2 0 + 0 0 5 4 1 0 2 2 2 0 1 1 2 5 0 0 0 0 1 0 0 1 0 1 2 1 0 + 2 3 1 0 0 2 0 0 5 0 5 0 1 1 0 0 5 2 0 1 0 5 2 1 0 1 0 + 0 1 0 2 1 0 2 2 1 0 1 4 1 1 5 1 0 1 4 1 1 1 1 1 1 1 0 +``` +""" +function _bucket_indices(data::AbstractMatrix{T}, + bucket_cutoffs::AbstractVector{S}) where {T <: Number, S <: Number} + map(Base.Fix1(searchsortedfirst, bucket_cutoffs), data) .- 1 +end + +""" + +# Examples + +```julia-repl +julia> using ColBERT: _packbits; + +julia> using Random; Random.seed!(0); + +julia> bitsarray = rand(Bool, 2, 128, 200000); + +julia> _packbits(bitsarray) +32×200000 Matrix{UInt8}: + 0x2e 0x93 0x5a 0xbd 0xd1 0x89 0x2c 0x39 0x6a … 0xed 0xdb 0x45 0x95 0xf8 0x64 0x57 0x5b 0x06 + 0x3f 0x45 0x0c 0x2a 0x14 0xdb 0x16 0x2b 0x00 0x70 0xba 0x3c 0x40 0x56 0xa6 0xbe 0x33 0x3d + 0xbd 0x61 0xa3 0xa7 0xb4 0xe7 0x1e 0xf8 0xa7 0xf0 0x70 0xaf 0xc0 0xeb 0xa3 0x34 0x6d 0x81 + 0x15 0x9d 0x02 0xa5 0x7b 0x84 0xde 0x2f 0x28 0xa7 0xf2 0x51 0xb3 0xe7 0x01 0xbf 0x6f 0x5a + 0xaf 0x76 0x8f 0x55 0x81 0x2f 0xa5 0xcc 0x03 0xe7 0xea 0x17 0xf2 0x07 0x45 0x40 0x40 0xd8 + 0xd2 0xd4 0x25 0xcc 0x41 0xc6 0x87 0x7e 0xfd … 0x5a 0xe6 0xed 0x28 0x26 0x8b 0x39 0x3b 0x4b + 0xb3 0xbe 0x08 0xdb 0x73 0x3d 0x58 0x04 0xda 0x7b 0xf7 0xab 0x1f 0x2d 0x7b 0x71 0x12 0xdf + 0x6f 0x86 0x20 0x90 0xa5 0x0f 0xc7 0xeb 0x79 0x19 0x92 0x74 0x59 0x4b 0xfe 0xe2 0xb9 0xef + 0x4b 0x93 0x7c 0x02 0x4f 0x40 0xad 0xe3 0x4f 0x9c 0x9c 0x69 0xd1 0xf8 0xd9 0x9e 0x00 0x70 + 0x77 0x5d 0x05 0xa6 0x2c 0xaa 0x9d 0xf6 0x8d 0xa9 0x4e 0x46 0x70 0xd9 0x47 0x80 0x06 0x7e + 0x6e 0x7e 0x0f 0x3c 0xe7 0xaf 0x12 0xbf 0x0a … 0x3f 0xaf 0xe8 0x57 0x26 0x4b 0x2c 0x3f 0x01 + 0x72 0xb1 0xea 0xde 0x97 0x1d 0xf4 0x4c 0x89 0x47 0x98 0xc5 0xb6 0x47 0xaf 0x95 0xb1 0x74 + 0xc6 0x2b 0x51 0x95 0x30 0xab 0xdc 0x29 0x79 0x5c 0x7b 0xc3 0xf4 0x6a 0xa6 0x09 0x39 0x96 + 0xeb 0xef 0x6f 0x70 0x8d 0x1f 0xb9 0x95 0x4e 0xd0 0xf5 0x68 0x0a 0x04 0x63 0x5b 0x45 0xf5 + 0xef 0xca 0xb7 0xd4 0x31 0x14 0x34 0x96 0x0c 0x1e 0x6a 0xce 0xf2 0xa3 0xa0 0xbe 0x92 0x9c + 0xda 0x91 0x53 0xd1 0x43 0xfa 0x59 0x7a 0x0c … 0x0f 0x7a 0xa0 0x4a 0x19 0xc6 0xd3 0xbb 0x7a + 0x9a 0x81 0xdb 0xee 0xce 0x7e 0x4a 0xb5 0x2a 0x3c 0x3e 0xaa 0xdc 0xa6 0xd5 0xae 0x23 0xb2 + 0x82 0x2b 0xab 0x06 0xfd 0x8a 0x4a 0xba 0x80 0xb6 0x1a 0x62 0xa0 0x29 0x97 0x61 0x6e 0xf7 + 0xb8 0xe6 0x0d 0x21 0x38 0x3a 0x97 0x55 0x58 0x46 0x01 0xe1 0x82 0x34 0xa3 0xfa 0x54 0xb3 + 0x09 0xc7 0x2f 0x7b 0x82 0x0c 0x26 0x4d 0xa4 0x1e 0x64 0xc2 0x55 0x41 0x6b 0x14 0x5c 0x0b + 0xf1 0x2c 0x3c 0x0a 0xf1 0x76 0xd4 0x57 0x42 … 0x44 0xb1 0xac 0xb4 0xa2 0x40 0x1e 0xbb 0x44 + 0xf8 0x0d 0x6d 0x09 0xb0 0x80 0xe3 0x5e 0x18 0xb3 0x43 0x22 0x82 0x0e 0x50 0xfb 0xf6 0x7b + 0xf0 0x32 0x02 0x28 0x36 0x00 0x4f 0x84 0x2b 0xe8 0xcc 0x89 0x07 0x2f 0xf4 0xcb 0x41 0x53 + 0x53 0x9b 0x01 0xf3 0xb2 0x13 0x6a 0x43 0x88 0x22 0xd8 0x33 0xa2 0xab 0xaf 0xe1 0x02 0xf7 + 0x59 0x60 0x4a 0x1a 0x9c 0x29 0xb1 0x1b 0xea 0xe9 0xd6 0x07 0x78 0xc6 0xdf 0x16 0xff 0x87 + 0xba 0x98 0xff 0x98 0xc3 0xa3 0x7d 0x7c 0x75 … 0xfe 0x75 0x4d 0x43 0x8e 0x5e 0x32 0xb0 0x97 + 0x7b 0xc9 0xcf 0x4c 0x99 0xad 0xf1 0x0e 0x0d 0x9f 0xf2 0x92 0x75 0x86 0xd6 0x08 0x74 0x8d + 0x7c 0xd4 0xe7 0x53 0xd3 0x23 0x25 0xce 0x3a 0x19 0xdb 0x14 0xa2 0xf1 0x01 0xd4 0x27 0x20 + 0x2a 0x63 0x51 0xcd 0xab 0xc3 0xb5 0xc1 0x74 0xa5 0xa4 0xe1 0xfa 0x13 0xab 0x1f 0x8f 0x9a + 0x93 0xbe 0xf4 0x54 0x2b 0xb9 0x41 0x9d 0xa8 0xbf 0xb7 0x2b 0x1c 0x09 0x36 0xa5 0x7b 0xdc + 0xdc 0x93 0x23 0xf8 0x90 0xaf 0xfb 0xd1 0xcc … 0x54 0x09 0x8c 0x14 0xfe 0xa7 0x5d 0xd7 0x6d + 0xaf 0x93 0xa2 0x29 0xf9 0x5b 0x24 0xd5 0x2a 0xf1 0x7f 0x3a 0xf5 0x8f 0xd4 0x6e 0x67 0x5b +``` +""" +function _packbits(bitsarray::AbstractArray{Bool, 3}) + nbits, dim, batch_size = size(bitsarray) + dim % 8 == 0 || + throw(DomainError("dim should be a multiple of 8!")) + bitsarray_packed = reinterpret(UInt8, BitArray(vec(bitsarray)).chunks) + reshape(bitsarray_packed[1:(prod(size(bitsarray)) >> 3)], + ((dim >> 3) * nbits, batch_size)) +end + +""" +# Examples + +```julia-repl +julia> using ColBERT: _unpackbits; + +julia> using Random; Random.seed!(0); + +julia> dim, nbits = 128, 2; + +julia> bitsarray = rand(Bool, nbits, dim, 200000); + +julia> packedbits = _packbits(bitsarray); + +julia> unpackedarray = _unpackbits(packedbits, nbits); + +julia> isequal(bitsarray, unpackedarray) +``` +""" +function _unpackbits(packedbits::AbstractMatrix{UInt8}, nbits::Int) + prod(size(packedbits, 1)) % nbits == 0 || + throw(DomainError("The first dimension of packbits should be " * + "a multiple of nbits!")) # resultant matrix will have an nbits-wide dimension + _, batch_size = size(packedbits) + dim = div(size(packedbits, 1), nbits) << 3 + pad_amt = 64 - length(vec(packedbits)) % 64 + chunks = reinterpret( + UInt64, [vec(packedbits); repeat([zero(UInt8)], pad_amt)]) + bitsvector = BitVector(undef, length(chunks) << 6) + bitsvector.chunks = chunks + reshape( + bitsvector[1:(length(vec(packedbits)) << 3)], nbits, dim, batch_size) end """ @@ -97,35 +457,81 @@ using `nbits` bits. # Returns A `AbstractMatrix{UInt8}` of compressed integer residual vectors. + +# Examples + +```julia-repl +julia> using ColBERT: binarize; + +julia> using Statistics, Random; + +julia> Random.seed!(0); + +julia> dim, nbits = 128, 2; # encode residuals in 2 bits + +julia> residuals = rand(Float32, dim, 200000); + +julia> quantiles = collect(0:(2^nbits - 1)) / 2^nbits; + +julia> bucket_cutoffs = Float32.(quantile(residuals, quantiles[2:end])) +3-element Vector{Float32}: + 0.2502231 + 0.5001043 + 0.75005275 + +julia> binarize(dim, nbits, bucket_cutoffs, residuals) +32×200000 Matrix{UInt8}: + 0xb4 0xa2 0x0f 0xd5 0xe2 0xd3 0x03 0xbe 0xe3 … 0x44 0xf5 0x8c 0x62 0x59 0xdc 0xc9 0x9e 0x57 + 0xce 0x7e 0x23 0xd8 0xea 0x96 0x23 0x3e 0xe1 0xfb 0x29 0xa5 0xab 0x28 0xc3 0xed 0x60 0x90 + 0xb1 0x3e 0x96 0xc9 0x84 0x73 0x2c 0x28 0x22 0x27 0x6e 0xca 0x19 0xcd 0x9f 0x1a 0xf4 0xe4 + 0xd8 0x85 0x26 0xe2 0xf8 0xfc 0x59 0xef 0x9a 0x51 0xcf 0x06 0x09 0xec 0x0f 0x96 0x94 0x9d + 0xa7 0xfe 0xe2 0x9a 0xa1 0x5e 0xb0 0xd3 0x98 0x41 0x64 0x7b 0x0c 0xa6 0x69 0x26 0x35 0x05 + 0x12 0x66 0x0c 0x17 0x05 0xff 0xf2 0x35 0xc0 … 0xa6 0xb7 0xda 0x20 0xb4 0xfe 0x33 0xfc 0xa1 + 0x1b 0xa5 0xbc 0xa0 0xc7 0x1c 0xdc 0x43 0x12 0x38 0x81 0x12 0xb1 0x53 0x52 0x50 0x92 0x41 + 0x5b 0xea 0xbe 0x84 0x81 0xed 0xf5 0x83 0x7d 0x4a 0xc8 0x7f 0x95 0xab 0x34 0xcb 0x35 0x15 + 0xd3 0x0a 0x18 0xc8 0xea 0x34 0x31 0xcc 0x79 0x39 0x3c 0xec 0xe2 0x6a 0xb2 0x59 0x62 0x74 + 0x1b 0x01 0xee 0xe7 0xda 0xa9 0xe4 0xe6 0xc5 0x75 0x10 0xa1 0xe1 0xe5 0x50 0x23 0xfe 0xa3 + 0xe8 0x38 0x28 0x7c 0x9f 0xd5 0xf7 0x69 0x73 … 0x4e 0xbc 0x52 0xa0 0xca 0x8b 0xe9 0xaf 0xae + 0x2a 0xa2 0x12 0x1c 0x03 0x21 0x6a 0x6e 0xdb 0xa3 0xe3 0x62 0xb9 0x69 0xc0 0x39 0x48 0x9a + 0x76 0x44 0xce 0xd7 0xf7 0x02 0xbd 0xa1 0x7f 0xee 0x5d 0xea 0x9e 0xbe 0x78 0x51 0xbc 0xa3 + 0xb2 0xe6 0x09 0x33 0x5b 0xd1 0xad 0x1e 0x9e 0x2c 0x36 0x09 0xd3 0x60 0x81 0x0f 0xe0 0x9e + 0xb8 0x18 0x94 0x0a 0x83 0xd0 0x01 0xe1 0x0f 0x76 0x35 0x6d 0x87 0xfe 0x9e 0x9c 0x69 0xe8 + 0x8c 0x6c 0x24 0xf5 0xa9 0xe2 0xbd 0x21 0x83 … 0x1d 0x77 0x11 0xea 0xc1 0xc8 0x09 0xd7 0x4b + 0x97 0x23 0x9f 0x7a 0x8a 0xd1 0x34 0xc6 0xe7 0xe2 0xd0 0x46 0xab 0xbe 0xb3 0x92 0xeb 0xd8 + 0x10 0x6f 0xce 0x60 0x17 0x2a 0x4f 0x4a 0xb3 0xde 0x79 0xea 0x28 0xa7 0x08 0x68 0x81 0x9c + 0xae 0xc9 0xc8 0xbf 0x48 0x33 0xa3 0xca 0x8d 0x78 0x4e 0x0e 0xe2 0xe2 0x23 0x08 0x47 0xe6 + 0x41 0x29 0x8e 0xff 0x66 0xcc 0xd8 0x58 0x59 0x92 0xd8 0xef 0x9c 0x3c 0x51 0xd4 0x65 0x64 + 0xb5 0xc4 0x2d 0x30 0x14 0x54 0xd4 0x79 0x62 … 0xff 0xc1 0xed 0xe4 0x62 0xa4 0x12 0xb7 0x47 + 0xcf 0x9a 0x9a 0xd7 0x6f 0xdf 0xad 0x3a 0xf8 0xe5 0x63 0x85 0x0f 0xaf 0x62 0xab 0x67 0x86 + 0x3e 0xc7 0x92 0x54 0x8d 0xef 0x0b 0xd5 0xbb 0x64 0x5a 0x4d 0x10 0x2e 0x8f 0xd4 0xb0 0x68 + 0x7e 0x56 0x3c 0xb5 0xbd 0x63 0x4b 0xf4 0x8a 0x66 0xc7 0x1a 0x39 0x20 0xa4 0x50 0xac 0xed + 0x3c 0xbc 0x81 0x67 0xb8 0xaf 0x84 0x38 0x8e 0x6e 0x8f 0x3b 0xaf 0xae 0x03 0x0a 0x53 0x55 + 0x3d 0x45 0x76 0x98 0x7f 0x34 0x7d 0x23 0x29 … 0x24 0x3a 0x6b 0x8a 0xb4 0x3c 0x2d 0xe2 0x3a + 0xed 0x41 0xe6 0x86 0xf3 0x61 0x12 0xc5 0xde 0xd1 0x26 0x11 0x36 0x57 0x6c 0x35 0x38 0xe2 + 0x11 0x57 0x82 0x9b 0x19 0x1f 0x56 0xd7 0x06 0x1e 0x2b 0xd9 0x76 0xa1 0x68 0x27 0xb1 0xde + 0x89 0xb3 0xeb 0x86 0xbb 0x57 0xda 0xd3 0x5b 0x0e 0x79 0x4c 0x8c 0x57 0x3d 0xf0 0x98 0xb7 + 0xbf 0xc2 0xac 0xf0 0xed 0x69 0x0e 0x19 0x12 0xfe 0xab 0xcd 0xfc 0x72 0x76 0x5c 0x58 0x8b + 0xe9 0x7b 0xf6 0x22 0xa0 0x60 0x23 0xc9 0x33 … 0x77 0xc7 0xdf 0x8a 0xb9 0xef 0xe3 0x03 0x8a + 0x6b 0x26 0x08 0x53 0xc3 0x17 0xc4 0x33 0x2e 0xc6 0xb8 0x1e 0x54 0xcd 0xeb 0xb9 0x5f 0x38 +``` """ function binarize(dim::Int, nbits::Int, bucket_cutoffs::Vector{Float32}, residuals::AbstractMatrix{Float32}) - num_embeddings = size(residuals)[2] + # bucket indices will be encoded in nbits bits + # so they will be in the range [0, length(bucket_cutoffs) - 1] + # so length(bucket_cutoffs) should be 2^nbits - 1 + dim % 8 == 0 || throw(DomainError("dims should be a multiple of 8!")) + length(bucket_cutoffs) == (1 << nbits) - 1 || + throw(DomainError("length(bucket_cutoffs) should be 2^nbits - 1!")) - if dim % (nbits * 8) != 0 - error("The embeddings dimension must be a multiple of nbits * 8!") - end + # get the bucket indices + bucket_indices = _bucket_indices(residuals, bucket_cutoffs) # (dim, batch_size) - # need to subtract one here, to preserve the number of options (2 ^ nbits) - bucket_indices = (x -> searchsortedfirst(bucket_cutoffs, x)).(residuals) .- - 1 # torch.bucketize - bucket_indices = stack([bucket_indices for i in 1:nbits], dims = 1) # add an nbits-wide extra dimension - positionbits = fill(1, (nbits, 1, 1)) - for i in 1:nbits - positionbits[i, :, :] .= 1 << (i - 1) - end - - bucket_indices = Int.(floor.(bucket_indices ./ positionbits)) # divide by 2^bit for each bit position - bucket_indices = bucket_indices .& 1 # apply mod 1 to binarize - residuals_packed = reinterpret(UInt8, BitArray(vec(bucket_indices)).chunks) # flatten out the bits, and pack them into UInt8 - residuals_packed = reshape( - residuals_packed, (Int(dim / 8) * nbits, num_embeddings)) # reshape back to get compressions for each embedding - @assert ndims(residuals_packed) == 2 - "ndims(residuals_packed): $(ndims(residuals_packed))" - @assert size(residuals_packed)[2] == size(residuals)[2] - "size(residuals_packed): $(size(residuals_packed)), size(residuals): $(size(residuals))" - @assert residuals_packed isa AbstractMatrix{UInt8} "$(typeof(residuals_packed))" + # representing each index in nbits bits + bucket_indices = _binarize(bucket_indices, nbits) # (nbits, dim, batch_size) + # pack bits into UInt8's for each embedding + residuals_packed = _packbits(bucket_indices) residuals_packed end @@ -152,100 +558,227 @@ is is nearest centroid, the residual vector is defined to be # Returns A tuple containing a vector of codes and the compressed residuals matrix. + +# Examples + +```julia-repl +julia> using ColBERT: compress; + +julia> using Random; Random.seed!(0); + +julia> nbits, dim = 2, 128; + +julia> embs = rand(Float32, dim, 100000); + +julia> centroids = embs[:, randperm(size(embs, 2))[1:10000]]; + +julia> bucket_cutoffs = Float32.(sort(rand(2^nbits - 1))); +3-element Vector{Float32}: + 0.08594067 + 0.0968812 + 0.44113323 + +julia> @time codes, compressed_residuals = compress( + centroids, bucket_cutoffs, dim, nbits, embs); + 4.277926 seconds (1.57 k allocations: 4.238 GiB, 6.46% gc time) +``` """ function compress(centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, - dim::Int, nbits::Int, embs::AbstractMatrix{Float32}) - codes, residuals = Vector{UInt32}(), Vector{Matrix{UInt8}}() - bsize = 1 << 18 - for offset in 1:bsize:size(embs)[2] # batch on second dimension - batch = embs[:, offset:min(size(embs)[2], offset + bsize - 1)] - codes_ = compress_into_codes(centroids, batch) # get centroid codes - centroids_ = centroids[:, codes_] # get corresponding centroids - residuals_ = batch - centroids_ - append!(codes, codes_) - push!(residuals, binarize(dim, nbits, bucket_cutoffs, residuals_)) + dim::Int, nbits::Int, embs::AbstractMatrix{Float32}; bsize::Int = 10000) + codes = zeros(UInt32, size(embs, 2)) + compressed_residuals = Matrix{UInt8}( + undef, div(dim, 8) * nbits, size(embs, 2)) + for offset in 1:bsize:size(embs, 2) + offset_end = min(size(embs, 2), offset + bsize - 1) + @views batch_embs = embs[:, offset:offset_end] + @views batch_codes = codes[offset:offset_end] + @views batch_compressed_residuals = compressed_residuals[ + :, offset:offset_end] + compress_into_codes!(batch_codes, centroids, batch_embs) + @views batch_centroids = centroids[:, batch_codes] + batch_residuals = batch_embs - batch_centroids + batch_compressed_residuals .= binarize( + dim, nbits, bucket_cutoffs, batch_residuals) end - residuals = cat(residuals..., dims = 2) + codes, compressed_residuals +end - @assert ndims(codes)==1 "ndims(codes): $(ndims(codes))" - @assert ndims(residuals)==2 "ndims(residuals): $(ndims(residuals))" - @assert length(codes)==size(embs)[2] "length(codes): $(length(codes)), size(embs): $(size(embs))" - @assert size(residuals)[2]==size(embs)[2] "size(residuals): $(size(residuals)), size(embs): $(size(embs))" - @assert codes isa AbstractVector{UInt32} "$(typeof(codes))" - @assert residuals isa AbstractMatrix{UInt8} "$(typeof(residuals))" +""" - codes, residuals -end +# Examples +```julia-repl +julia> using ColBERT: binarize, decompress_residuals; + +julia> using Statistics, Flux, CUDA, Random; + +julia> Random.seed!(0); + +julia> dim, nbits = 128, 2; # encode residuals in 5 bits + +julia> residuals = rand(Float32, dim, 200000); + +julia> quantiles = collect(0:(2^nbits - 1)) / 2^nbits; + +julia> bucket_cutoffs = Float32.(quantile(residuals, quantiles[2:end])) +3-element Vector{Float32}: + 0.2502231 + 0.5001043 + 0.75005275 + +julia> bucket_weights = Float32.(quantile(residuals, quantiles .+ 0.5 / 2^nbits)) +4-element Vector{Float32}: + 0.1250611 + 0.37511465 + 0.62501323 + 0.87501866 + +julia> binary_residuals = binarize(dim, nbits, bucket_cutoffs, residuals); + +julia> decompressed_residuals = decompress_residuals( + dim, nbits, bucket_weights, binary_residuals) +128×200000 Matrix{Float32}: + 0.125061 0.625013 0.875019 0.375115 0.625013 0.875019 … 0.375115 0.125061 0.375115 0.625013 0.875019 + 0.375115 0.125061 0.875019 0.375115 0.125061 0.125061 0.625013 0.875019 0.625013 0.875019 0.375115 + 0.875019 0.625013 0.125061 0.375115 0.625013 0.375115 0.375115 0.375115 0.125061 0.375115 0.375115 + 0.625013 0.625013 0.125061 0.875019 0.875019 0.875019 0.375115 0.875019 0.875019 0.625013 0.375115 + 0.625013 0.625013 0.875019 0.125061 0.625013 0.625013 0.125061 0.875019 0.375115 0.125061 0.125061 + 0.875019 0.875019 0.125061 0.625013 0.625013 0.375115 … 0.625013 0.125061 0.875019 0.125061 0.125061 + 0.125061 0.875019 0.625013 0.375115 0.625013 0.375115 0.625013 0.125061 0.625013 0.625013 0.375115 + 0.875019 0.375115 0.125061 0.875019 0.875019 0.625013 0.125061 0.875019 0.875019 0.375115 0.625013 + 0.375115 0.625013 0.625013 0.375115 0.125061 0.875019 0.375115 0.875019 0.625013 0.125061 0.125061 + 0.125061 0.875019 0.375115 0.625013 0.375115 0.125061 0.875019 0.875019 0.625013 0.375115 0.375115 + 0.875019 0.875019 0.375115 0.125061 0.125061 0.875019 … 0.125061 0.375115 0.375115 0.875019 0.625013 + 0.625013 0.125061 0.625013 0.875019 0.625013 0.375115 0.875019 0.625013 0.125061 0.875019 0.875019 + 0.125061 0.375115 0.625013 0.625013 0.125061 0.125061 0.125061 0.875019 0.625013 0.125061 0.375115 + 0.625013 0.375115 0.375115 0.125061 0.625013 0.875019 0.875019 0.875019 0.375115 0.375115 0.875019 + 0.375115 0.125061 0.625013 0.625013 0.875019 0.875019 0.625013 0.125061 0.375115 0.375115 0.375115 + 0.875019 0.625013 0.125061 0.875019 0.875019 0.875019 … 0.875019 0.125061 0.625013 0.625013 0.625013 + 0.875019 0.625013 0.625013 0.625013 0.375115 0.625013 0.625013 0.375115 0.625013 0.375115 0.375115 + 0.375115 0.875019 0.125061 0.625013 0.125061 0.875019 0.375115 0.625013 0.375115 0.375115 0.375115 + 0.625013 0.875019 0.625013 0.375115 0.625013 0.375115 0.625013 0.625013 0.625013 0.875019 0.125061 + 0.625013 0.875019 0.875019 0.625013 0.625013 0.375115 0.625013 0.375115 0.125061 0.125061 0.125061 + 0.625013 0.625013 0.125061 0.875019 0.375115 0.875019 … 0.125061 0.625013 0.875019 0.125061 0.375115 + 0.125061 0.375115 0.875019 0.375115 0.375115 0.875019 0.375115 0.875019 0.125061 0.875019 0.125061 + 0.375115 0.625013 0.125061 0.375115 0.125061 0.875019 0.875019 0.875019 0.875019 0.875019 0.625013 + 0.125061 0.375115 0.125061 0.125061 0.125061 0.875019 0.625013 0.875019 0.125061 0.875019 0.625013 + 0.875019 0.375115 0.125061 0.125061 0.875019 0.125061 0.875019 0.625013 0.125061 0.625013 0.375115 + 0.625013 0.375115 0.875019 0.125061 0.375115 0.875019 … 0.125061 0.125061 0.125061 0.125061 0.125061 + 0.375115 0.625013 0.875019 0.625013 0.125061 0.375115 0.375115 0.375115 0.375115 0.375115 0.125061 + ⋮ ⋮ ⋱ ⋮ + 0.875019 0.375115 0.375115 0.625013 0.875019 0.375115 0.375115 0.875019 0.875019 0.125061 0.625013 + 0.875019 0.125061 0.875019 0.375115 0.875019 0.875019 0.875019 0.875019 0.625013 0.625013 0.875019 + 0.125061 0.375115 0.375115 0.625013 0.375115 0.125061 0.625013 0.125061 0.125061 0.875019 0.125061 + 0.375115 0.375115 0.625013 0.625013 0.875019 0.375115 0.875019 0.125061 0.375115 0.125061 0.625013 + 0.875019 0.125061 0.375115 0.375115 0.125061 0.125061 … 0.375115 0.875019 0.375115 0.625013 0.125061 + 0.625013 0.125061 0.625013 0.125061 0.875019 0.625013 0.375115 0.625013 0.875019 0.875019 0.625013 + 0.875019 0.375115 0.875019 0.625013 0.875019 0.375115 0.375115 0.375115 0.125061 0.125061 0.875019 + 0.375115 0.875019 0.625013 0.875019 0.375115 0.875019 0.375115 0.125061 0.875019 0.375115 0.625013 + 0.125061 0.375115 0.125061 0.625013 0.625013 0.875019 0.125061 0.625013 0.375115 0.125061 0.875019 + 0.375115 0.375115 0.125061 0.375115 0.375115 0.375115 … 0.625013 0.625013 0.625013 0.875019 0.375115 + 0.125061 0.375115 0.625013 0.625013 0.125061 0.125061 0.625013 0.375115 0.125061 0.625013 0.875019 + 0.375115 0.875019 0.875019 0.625013 0.875019 0.875019 0.875019 0.375115 0.125061 0.125061 0.875019 + 0.625013 0.125061 0.625013 0.375115 0.625013 0.375115 0.375115 0.875019 0.125061 0.625013 0.375115 + 0.125061 0.875019 0.625013 0.125061 0.875019 0.375115 0.375115 0.875019 0.875019 0.375115 0.875019 + 0.625013 0.625013 0.875019 0.625013 0.625013 0.375115 … 0.375115 0.125061 0.875019 0.625013 0.625013 + 0.875019 0.625013 0.125061 0.125061 0.375115 0.375115 0.625013 0.625013 0.125061 0.125061 0.875019 + 0.875019 0.125061 0.875019 0.125061 0.875019 0.625013 0.125061 0.375115 0.875019 0.625013 0.625013 + 0.875019 0.125061 0.625013 0.875019 0.625013 0.625013 0.875019 0.875019 0.375115 0.375115 0.125061 + 0.625013 0.875019 0.625013 0.875019 0.875019 0.375115 0.375115 0.375115 0.375115 0.375115 0.625013 + 0.375115 0.875019 0.625013 0.625013 0.125061 0.125061 … 0.375115 0.875019 0.875019 0.875019 0.625013 + 0.625013 0.625013 0.375115 0.125061 0.125061 0.125061 0.625013 0.875019 0.125061 0.125061 0.625013 + 0.625013 0.875019 0.875019 0.625013 0.625013 0.625013 0.875019 0.625013 0.625013 0.125061 0.125061 + 0.875019 0.375115 0.875019 0.125061 0.625013 0.375115 0.625013 0.875019 0.875019 0.125061 0.625013 + 0.875019 0.625013 0.125061 0.875019 0.875019 0.875019 0.375115 0.875019 0.375115 0.875019 0.125061 + 0.625013 0.375115 0.625013 0.125061 0.125061 0.375115 … 0.875019 0.625013 0.625013 0.875019 0.625013 + 0.625013 0.625013 0.125061 0.375115 0.125061 0.375115 0.125061 0.625013 0.875019 0.375115 0.875019 + 0.375115 0.125061 0.125061 0.375115 0.875019 0.125061 0.875019 0.875019 0.625013 0.375115 0.125061 +``` +""" function decompress_residuals( dim::Int, nbits::Int, bucket_weights::Vector{Float32}, binary_residuals::AbstractMatrix{UInt8}) - @assert ndims(binary_residuals)==2 "ndims(binary_residuals): $(ndims(binary_residuals))" - @assert size(binary_residuals)[1]==(dim / 8) * nbits "size(binary_residuals): $(size(binary_residuals)), (dim / 8) * nbits: $((dim / 8) * nbits)" + dim % 8 == 0 || throw(DomainError("dim should be a multiple of 8!")) + size(binary_residuals, 1) == div(dim, 8) * nbits || + throw(DomainError("The dimension each residual in binary_residuals " * + "should be (dim / 8) * nbits!")) + length(bucket_weights) == (1 << nbits) || + throw(DomainError("bucket_weights should have length 2^nbits!")) + + # unpacking bits + unpacked_bits = _unpackbits(binary_residuals, nbits) # (nbits, dim, batch_size) + + # unbinarze the packed bits, and add 1 to get bin indices + unpacked_bits = _unbinarize(unpacked_bits) + unpacked_bits = unpacked_bits .+ 1 # (dim, batch_size) + + # get the residuals from the bucket weights + all(in(1:length(bucket_weights)), unpacked_bits) || + throw(BoundsError("All the unpacked indices in binary_residuals should " * + "be in range 1:length(bucket_weights)!")) + decompressed_residuals = bucket_weights[unpacked_bits] + decompressed_residuals +end - # unpacking UInt8 into bits - unpacked_bits = BitVector() - for byte in vec(binary_residuals) - append!(unpacked_bits, [byte & (0x1 << n) != 0 for n in 0:7]) - end +""" +# Examples - # reshaping into dims (nbits, dim, num_embeddings); inverse of what binarize does - unpacked_bits = reshape( - unpacked_bits, nbits, dim, size(binary_residuals)[2]) +```julia-repl +julia> using ColBERT: compress, decompress; - # get decimal value for coordinate of the nbits-wide dimension; again, inverse of binarize - positionbits = fill(1, (nbits, 1, 1)) - for i in 1:nbits - positionbits[i, :, :] .= 1 << (i - 1) - end +julia> using Random; Random.seed!(0); - # multiply by 2^(i - 1) for the ith bit, and take sum to get the original bin back - unpacked_bits = unpacked_bits .* positionbits - unpacked_bits = sum(unpacked_bits, dims = 1) - unpacked_bits = unpacked_bits .+ 1 # adding 1 to get correct bin indices +julia> nbits, dim = 2, 128; - # reshaping to get rid of the nbits wide dimension - unpacked_bits = reshape(unpacked_bits, size(unpacked_bits)[2:end]...) - embeddings = bucket_weights[unpacked_bits] +julia> embs = rand(Float32, dim, 100000); - @assert ndims(embeddings)==2 "ndims(embeddings): $(ndims(embeddings))" - @assert size(embeddings)[2]==size(binary_residuals)[2] "size(embeddings): $(size(embeddings)), size(binary_residuals): $(size(binary_residuals)) " - @assert embeddings isa AbstractMatrix{Float32} "$(typeof(embeddings))" +julia> centroids = embs[:, randperm(size(embs, 2))[1:10000]]; - embeddings -end +julia> bucket_cutoffs = Float32.(sort(rand(2^nbits - 1))) +3-element Vector{Float32}: + 0.08594067 + 0.0968812 + 0.44113323 + +julia> bucket_weights = Float32.(sort(rand(2^nbits))); +4-element Vector{Float32}: + 0.10379179 + 0.25756857 + 0.27798286 + 0.47973529 +julia> @time codes, compressed_residuals = compress( + centroids, bucket_cutoffs, dim, nbits, embs); + 4.277926 seconds (1.57 k allocations: 4.238 GiB, 6.46% gc time) + +julia> @time decompressed_embeddings = decompress( + dim, nbits, centroids, bucket_weights, codes, compressed_residuals); +0.237170 seconds (276.40 k allocations: 563.049 MiB, 50.93% compilation time) +``` +""" function decompress( dim::Int, nbits::Int, centroids::Matrix{Float32}, bucket_weights::Vector{Float32}, - codes::Vector{UInt32}, residuals::AbstractMatrix{UInt8}) - @assert ndims(codes)==1 "ndims(codes): $(ndims(codes))" - @assert ndims(residuals)==2 "ndims(residuals): $(ndims(residuals))" - @assert length(codes)==size(residuals)[2] "length(codes): $(length(codes)), size(residuals): $(size(residuals))" - - # decompress in batches - D = Vector{AbstractMatrix{Float32}}() - bsize = 1 << 15 + codes::Vector{UInt32}, residuals::AbstractMatrix{UInt8}; bsize::Int = 10000) + length(codes) == size(residuals, 2) || + throw(DomainError("The number of codes should be equal to the number of " * + "residual embeddings!")) + all(in(1:size(centroids, 2)), codes) || + throw(DomainError("All the codes must be in the valid range of centroid " * + "IDs!")) + embeddings = Matrix{Float32}(undef, dim, length(codes)) for batch_offset in 1:bsize:length(codes) - batch_codes = codes[batch_offset:min( - batch_offset + bsize - 1, length(codes))] - batch_residuals = residuals[ - :, batch_offset:min(batch_offset + bsize - 1, length(codes))] - - centroids_ = centroids[:, batch_codes] + batch_offset_end = min(length(codes), batch_offset + bsize - 1) + @views batch_embeddings = embeddings[ + :, batch_offset:batch_offset_end] + @views batch_codes = codes[batch_offset:batch_offset_end] + @views batch_residuals = residuals[:, batch_offset:batch_offset_end] + @views centroids_ = centroids[:, batch_codes] residuals_ = decompress_residuals( dim, nbits, bucket_weights, batch_residuals) - batch_embeddings = centroids_ + residuals_ - batch_embeddings = mapslices( - v -> iszero(v) ? v : normalize(v), batch_embeddings, dims = 1) - push!(D, batch_embeddings) + batch_embeddings .= centroids_ + residuals_ + _normalize_array!(batch_embeddings, dims = 1) end - embeddings = cat(D..., dims = 2) - - @assert ndims(embeddings)==2 "ndims(embeddings): $(ndims(embeddings))" - @assert size(embeddings)[2]==length(codes) "size(embeddings): $(size(embeddings)), length(codes): $(length(codes))" - @assert embeddings isa AbstractMatrix{Float32} "$(typeof(embeddings))" - embeddings end diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 752f12e..f885019 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -1,56 +1,3 @@ -""" - encode_passages( - config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) - -Encode a list of passages using `checkpoint`. - -The given `passages` are run through the underlying BERT model and the linear layer to -generate the embeddings, after doing relevant document-specific preprocessing. -See [`docFromText`](@ref) for more details. - -# Arguments - - - `config`: The [`ColBERTConfig`](@ref) to be used. - - `checkpoint`: The [`Checkpoint`](@ref) used to encode the passages. - - `passages`: A list of strings representing the passages to be encoded. - -# Returns - -A tuple `embs, doclens` where: - - - `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`, - where `D` is the embedding dimension and `N` is the total number of embeddings - across all the passages. - - `doclens::AbstractVector{Int}`: A vector of document lengths for each passage, - i.e the total number of attended tokens for each document passage. -""" -function encode_passages( - config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) - @info "Encoding $(length(passages)) passages." - - if length(passages) == 0 - error("The list of passages to encode is empty!") - end - - embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}() - # batching here to avoid storing intermediate embeddings on GPU - # batching also occurs inside docFromText to do batch packing optimizations - for passage_offset in 1:(config.passages_batch_size):length(passages) - passage_end_offset = min( - length(passages), passage_offset + config.passages_batch_size - 1) - embs_, doclens_ = docFromText( - config, checkpoint, passages[passage_offset:passage_end_offset], - config.index_bsize) - @assert embs_ isa Matrix{Float32} - @assert doclens_ isa Vector{Int} - push!(embs, embs_) - append!(doclens, vec(doclens_)) - embs_, doclens_ = nothing, nothing - end - embs = cat(embs..., dims = 2) - embs, doclens -end - """ _sample_pids(num_documents::Int) @@ -100,26 +47,43 @@ from the sampled documents, and the embedding matrix for the local samples. The shape `(D, N)`, where `D` is the embedding dimension (`128`) and `N` is the total number of embeddings over all the sampled passages. """ -function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, - collection::Vector{String}, sampled_pids::Set{Int}) +function _sample_embeddings(bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + dim::Int, index_bsize::Int, doc_token::String, + skiplist::Vector{Int}, collection::Vector{String}) # collect all passages with pids in sampled_pids + sampled_pids = _sample_pids(length(collection)) sorted_sampled_pids = sort(collect(sampled_pids)) local_sample = collection[sorted_sampled_pids] # get the local sample embeddings local_sample_embs, local_sample_doclens = encode_passages( - config, checkpoint, local_sample) - @debug "Local sample embeddings shape: $(size(local_sample_embs)), \t Local sample doclens: $(local_sample_doclens)" - @assert size(local_sample_embs)[2]==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))" + bert, linear, tokenizer, local_sample, + dim, index_bsize, doc_token, skiplist) + + @assert size(local_sample_embs, 2)==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))" @assert length(local_sample) == length(local_sample_doclens) avg_doclen_est = length(local_sample_doclens) > 0 ? - sum(local_sample_doclens) / length(local_sample_doclens) : - 0 + Float32(sum(local_sample_doclens) / + length(local_sample_doclens)) : + zero(Float32) @info "avg_doclen_est = $(avg_doclen_est) \t length(local_sample) = $(length(local_sample))" avg_doclen_est, local_sample_embs end +function _heldout_split( + sample::AbstractMatrix{Float32}; heldout_fraction::Float32 = 0.05f0) + num_sample_embs = size(sample, 2) + sample = sample[:, shuffle(1:num_sample_embs)] + heldout_size = Int(max( + 1, floor(min(50000, heldout_fraction * num_sample_embs)))) + sample, sample_heldout = sample[ + :, 1:(num_sample_embs - heldout_size)], + sample[:, (num_sample_embs - heldout_size + 1):num_sample_embs] + sample, sample_heldout +end + """ setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) @@ -144,46 +108,23 @@ proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``. A `Dict` containing the indexing plan. """ -function setup(config::ColBERTConfig, checkpoint::Checkpoint, - collection::Vector{String}) - chunksize = 0 - chunksize = ismissing(config.chunksize) ? - min(25000, 1 + fld(length(collection), config.nranks)) : - config.chunksize +function setup(collection::Vector{String}, avg_doclen_est::Float32, + num_clustering_embs::Int, chunksize::Union{Missing, Int}, nranks::Int) + chunksize = ismissing(chunksize) ? + min(25000, 1 + fld(length(collection), nranks)) : + chunksize num_chunks = cld(length(collection), chunksize) - # sample passages for training centroids later - sampled_pids = _sample_pids(length(collection)) - avg_doclen_est, local_sample_embs = _sample_embeddings( - config, checkpoint, collection, sampled_pids) - - # splitting the local sample into heldout set - num_local_sample_embs = size(local_sample_embs, 2) - local_sample_embs = local_sample_embs[ - :, shuffle(1:num_local_sample_embs)] - - # split the sample to get a heldout set - heldout_fraction = 0.05 - heldout_size = Int(max( - 1, floor(min( - 50000, heldout_fraction * num_local_sample_embs)))) - sample, sample_heldout = local_sample_embs[ - :, 1:(num_local_sample_embs - heldout_size)], - local_sample_embs[ - :, (num_local_sample_embs - heldout_size + 1):num_local_sample_embs] - @debug "Split sample sizes: sample size: $(size(sample)), \t sample_heldout size: $(size(sample_heldout))" - # computing the number of partitions, i.e clusters num_passages = length(collection) num_embeddings_est = num_passages * avg_doclen_est - num_partitions = Int(min(size(sample, 2), + num_partitions = Int(min(num_clustering_embs, floor(2^(floor(log2(16 * sqrt(num_embeddings_est))))))) @info "Creating $(num_partitions) clusters." @info "Estimated $(num_embeddings_est) embeddings." - sample, sample_heldout, - Dict( + Dict{String, Any}( "chunksize" => chunksize, "num_chunks" => num_chunks, "num_partitions" => num_partitions, @@ -193,6 +134,19 @@ function setup(config::ColBERTConfig, checkpoint::Checkpoint, ) end +function _bucket_cutoffs_and_weights( + nbits::Int, heldout_avg_residual::AbstractMatrix{Float32}) + num_options = 1 << nbits + quantiles = collect(0:(num_options - 1)) / num_options + bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[2:end], + quantiles .+ (0.5 / num_options) + bucket_cutoffs = Float32.(quantile( + heldout_avg_residual, bucket_cutoffs_quantiles)) + bucket_weights = Float32.(quantile( + heldout_avg_residual, bucket_weights_quantiles)) + bucket_cutoffs, bucket_weights +end + """ _compute_avg_residuals( nbits::Int, centroids::AbstractMatrix{Float32}, @@ -215,32 +169,23 @@ Compute the average residuals and other statistics of the held-out sample embedd A tuple `bucket_cutoffs, bucket_weights, avg_residual`, which will be used in compression/decompression of residuals. """ -function _compute_avg_residuals( +function _compute_avg_residuals!( nbits::Int, centroids::AbstractMatrix{Float32}, - heldout::AbstractMatrix{Float32}) - codes = compress_into_codes(centroids, heldout) # get centroid codes - @assert codes isa AbstractVector{UInt32} "$(typeof(codes))" - heldout_reconstruct = Flux.gpu(centroids[:, codes]) # get corresponding centroids - heldout_avg_residual = Flux.gpu(heldout) - heldout_reconstruct # compute the residual - + heldout::AbstractMatrix{Float32}, codes::AbstractVector{UInt32}) + length(codes) == size(heldout, 2) || + throw(DimensionMismatch("length(codes) must be equal to the number " * + "of embeddings in heldout!")) + + compress_into_codes!(codes, centroids, heldout) # get centroid codes + heldout_reconstruct = centroids[:, codes] # get corresponding centroids + heldout_avg_residual = heldout - heldout_reconstruct # compute the residual avg_residual = mean(abs.(heldout_avg_residual), dims = 2) # for each dimension, take mean of absolute values of residuals # computing bucket weights and cutoffs - num_options = 2^nbits - quantiles = Vector(0:(num_options - 1)) / num_options - bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[2:end], - quantiles .+ (0.5 / num_options) + bucket_cutoffs, bucket_weights = _bucket_cutoffs_and_weights( + nbits, heldout_avg_residual) - bucket_cutoffs = Float32.(quantile( - heldout_avg_residual, bucket_cutoffs_quantiles)) - bucket_weights = Float32.(quantile( - heldout_avg_residual, bucket_weights_quantiles)) - @assert bucket_cutoffs isa AbstractVector{Float32} "$(typeof(bucket_cutoffs))" - @assert bucket_weights isa AbstractVector{Float32} "$(typeof(bucket_weights))" - - @info "Got bucket_cutoffs_quantiles = $(bucket_cutoffs_quantiles) and bucket_weights_quantiles = $(bucket_weights_quantiles)" @info "Got bucket_cutoffs = $(bucket_cutoffs) and bucket_weights = $(bucket_weights)" - bucket_cutoffs, bucket_weights, mean(avg_residual) end @@ -269,22 +214,21 @@ A `Dict` containing the residual codec, i.e information used to compress/decompr function train( sample::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32}, num_partitions::Int, nbits::Int, kmeans_niters::Int) - _, n = size(sample) + # computing clusters sample = sample |> Flux.gpu - centroids = sample[:, randperm(n)[1:num_partitions]] + centroids = sample[:, randperm(size(sample, 2))[1:num_partitions]] # TODO: put point_bsize in the config! kmeans_gpu_onehot!( sample, centroids, num_partitions; max_iters = kmeans_niters) - @assert size(centroids, 2) == num_partitions - "size(centroids): $(size(centroids)), num_partitions: $(num_partitions)" - @assert centroids isa AbstractMatrix{Float32} "$(typeof(centroids))" - centroids = centroids |> Flux.cpu - bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals( - nbits, centroids, heldout) + # computing average residuals + heldout = heldout |> Flux.gpu + codes = zeros(UInt32, size(heldout, 2)) |> Flux.gpu + bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals!( + nbits, centroids, heldout, codes) @info "avg_residual = $(avg_residual)" - centroids, bucket_cutoffs, bucket_weights, avg_residual + Flux.cpu(centroids), bucket_cutoffs, bucket_weights, avg_residual end """ @@ -303,128 +247,86 @@ along with relevant metadata (see [`save_chunk`](@ref)). - `checkpoint`: The [`Checkpoint`](@ref) to compute embeddings. - `collection`: The collection to index. """ -function index(config::ColBERTConfig, checkpoint::Checkpoint, - collection::Vector{String}) - codec = load_codec(config.index_path) - plan_metadata = JSON.parsefile(joinpath(config.index_path, "plan.json")) - for (chunk_idx, passage_offset) in zip(1:plan_metadata["num_chunks"], - 1:plan_metadata["chunksize"]:length(collection)) +function index(index_path::String, bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + collection::Vector{String}, dim::Int, index_bsize::Int, + doc_token::String, skiplist::Vector{Int}, num_chunks::Int, + chunksize::Int, centroids::AbstractMatrix{Float32}, + bucket_cutoffs::AbstractVector{Float32}, nbits::Int) + for (chunk_idx, passage_offset) in zip( + 1:num_chunks, 1:chunksize:length(collection)) passage_end_offset = min( - length(collection), passage_offset + plan_metadata["chunksize"] - 1) - embs, doclens = encode_passages( - config, checkpoint, collection[passage_offset:passage_end_offset]) + length(collection), passage_offset + chunksize - 1) + + # get embeddings for batch + embs, doclens = encode_passages(bert, linear, tokenizer, + collection[passage_offset:passage_end_offset], + dim, index_bsize, doc_token, skiplist) @assert embs isa AbstractMatrix{Float32} "$(typeof(embs))" @assert doclens isa AbstractVector{Int} "$(typeof(doclens))" + # compress embeddings + codes, residuals = compress(centroids, bucket_cutoffs, dim, nbits, embs) + + # save the chunk @info "Saving chunk $(chunk_idx): \t $(passage_end_offset - passage_offset + 1) passages and $(size(embs)[2]) embeddings. From passage #$(passage_offset) onward." - save_chunk(config, codec, chunk_idx, passage_offset, embs, doclens) - embs, doclens = nothing, nothing + save_chunk( + index_path, codes, residuals, chunk_idx, passage_offset, doclens) end end -""" - check_chunk_exists(saver::IndexSaver, chunk_idx::Int) - -Check if the index chunk exists for the given `chunk_idx`. - -# Arguments - - - `saver`: The `IndexSaver` object that contains the indexing settings. - - `chunk_idx`: The index of the chunk to check. - -# Returns - -A boolean indicating whether all relevant files for the chunk exist. -""" -function check_chunk_exists(index_path::String, chunk_idx::Int) - path_prefix = joinpath(index_path, string(chunk_idx)) - codes_path = "$(path_prefix).codes.jld2" - residuals_path = "$(path_prefix).residuals.jld2" - doclens_path = joinpath(index_path, "doclens.$(chunk_idx).jld2") - metadata_path = joinpath(index_path, "$(chunk_idx).metadata.json") - - for file in [codes_path, residuals_path, doclens_path, metadata_path] - if !isfile(file) - return false - end - end - - true -end - function _check_all_files_are_saved(index_path::String) - plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + @info "Checking if all index files are saved." - @info "Checking if all files are saved." - for chunk_idx in 1:(plan_metadata["num_chunks"]) - if !(check_chunk_exists(index_path, chunk_idx)) - @error "Some files for chunk $(chunk_idx) are missing!" - end + # first get the plan + isfile(joinpath(index_path, "plan.json")) || begin + @info "plan.json is missing from the index!" + return false end - @info "Found all files!" -end - -function _collect_embedding_id_offset(index_path::String) - @assert isfile(joinpath(index_path, "plan.json")) "Fatal: plan.json doesn't exist!" plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) - @info "Collecting embedding ID offsets." - embedding_offset = 1 - embeddings_offsets = Vector{Int}() - for chunk_idx in 1:(plan_metadata["num_chunks"]) - metadata_path = joinpath( - index_path, "$(chunk_idx).metadata.json") - - chunk_metadata = open(metadata_path, "r") do io - chunk_metadata = JSON.parse(io) - end - - chunk_metadata["embedding_offset"] = embedding_offset - push!(embeddings_offsets, embedding_offset) - - embedding_offset += chunk_metadata["num_embeddings"] - - open(metadata_path, "w") do io - JSON.print(io, chunk_metadata, 4) - end - end - num_embeddings = embedding_offset - 1 - @assert length(embeddings_offsets) == plan_metadata["num_chunks"] - - @info "Saving the indexing metadata." - plan_metadata["num_embeddings"] = num_embeddings - plan_metadata["embeddings_offsets"] = embeddings_offsets - open(joinpath(index_path, "plan.json"), "w") do io - JSON.print(io, - plan_metadata, - 4 - ) + # get the non-chunk files + files = [ + joinpath(index_path, "config.json"), + joinpath(index_path, "centroids.jld2"), + joinpath(index_path, "bucket_cutoffs.jld2"), + joinpath(index_path, "bucket_weights.jld2"), + joinpath(index_path, "avg_residual.jld2"), + joinpath(index_path, "ivf.jld2"), + joinpath(index_path, "ivf_lengths.jld2") + ] + + # get the chunk files + for chunk_idx in 1:plan_metadata["num_chunks"] + append!(files, + [ + joinpath(index_path, "$(chunk_idx).codes.jld2"), + joinpath(index_path, "$(chunk_idx).residuals.jld2"), + joinpath(index_path, "doclens.$(chunk_idx).jld2"), + joinpath(index_path, "$(chunk_idx).metadata.json") + ]) end -end - -function _build_ivf(index_path::String) - plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) - @info "Building the centroid to embedding IVF." - codes = Vector{UInt32}() - - @info "Loading codes for each embedding." - for chunk_idx in 1:(plan_metadata["num_chunks"]) - chunk_codes = JLD2.load_object(joinpath( - index_path, "$(chunk_idx).codes.jld2")) - append!(codes, chunk_codes) + # check for any missing files + missing_files = findall(!isfile, files) + isempty(missing_files) || begin + @info "$(files[missing_files]) are missing!" + return false end - @assert codes isa AbstractVector{UInt32} "$(typeof(codes))" - @info "Sorting the codes." - ivf, values = sortperm(codes), sort(codes) + @info "Found all files!" + true +end - @info "Getting unique codes and their counts." - ivf_lengths = counts(values, 1:(plan_metadata["num_partitions"])) +function _collect_embedding_id_offset(chunk_emb_counts::Vector{Int}) + length(chunk_emb_counts) > 0 || return 0, zeros(Int, 1) + chunk_embedding_offsets = [1; _head(chunk_emb_counts)] + chunk_embedding_offsets = cumsum(chunk_embedding_offsets) + sum(chunk_emb_counts), chunk_embedding_offsets +end - @info "Saving the IVF." - ivf_path = joinpath(index_path, "ivf.jld2") - ivf_lengths_path = joinpath(index_path, "ivf_lengths.jld2") - JLD2.save_object(ivf_path, ivf) - JLD2.save_object(ivf_lengths_path, ivf_lengths) +function _build_ivf(codes::Vector{UInt32}, num_partitions::Int) + ivf, values = sortperm(codes), sort(codes) + ivf_lengths = counts(values, num_partitions) + ivf, ivf_lengths end diff --git a/src/loaders.jl b/src/loaders.jl index ac25ada..7edacc6 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -110,3 +110,31 @@ function load_compressed_embs(index_path::String) codes, residuals end + +function load_chunk_metadata_property(index_path::String, property::String) + plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + plan_metadata["num_chunks"] > 0 || return [] + vector = nothing + for chunk_idx in 1:plan_metadata["num_chunks"] + chunk_metadata = JSON.parsefile(joinpath( + index_path, "$(chunk_idx).metadata.json")) + if isnothing(vector) + vector = [chunk_metadata[property]] + else + append!(vector, chunk_metadata[property]) + end + end + vector +end + +function load_codes(index_path::String) + @info "Loading codes for each embedding." + plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + codes = Vector{UInt32}() + for chunk_idx in 1:(plan_metadata["num_chunks"]) + chunk_codes = JLD2.load_object(joinpath( + index_path, "$(chunk_idx).codes.jld2")) + append!(codes, chunk_codes) + end + codes +end diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index ff9c57d..da61a10 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -1,243 +1,3 @@ -""" - BaseColBERT(; - bert::HuggingFace.HGFBertModel, linear::Layers.Dense, - tokenizer::TextEncoders.AbstractTransformerTextEncoder) - -A struct representing the BERT model, linear layer, and the tokenizer used to compute -embeddings for documents and queries. - -# Arguments - - - `bert`: The pre-trained BERT model used to generate the embeddings. - - `linear`: The linear layer used to project the embeddings to a specific dimension. - - `tokenizer`: The tokenizer to used by the BERT model. - -# Returns - -A [`BaseColBERT`](@ref) object. - -# Examples - -```julia-repl -julia> using ColBERT, CUDA; - -julia> base_colbert = BaseColBERT("/home/codetalker7/models/colbertv2.0/"); - -julia> base_colbert.bert -HGFBertModel( - Chain( - CompositeEmbedding( - token = Embed(768, 30522), # 23_440_896 parameters - position = ApplyEmbed(.+, FixedLenPositionEmbed(768, 512)), # 393_216 parameters - segment = ApplyEmbed(.+, Embed(768, 2), Transformers.HuggingFace.bert_ones_like), # 1_536 parameters - ), - DropoutLayer( - LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters - ), - ), - Transformer<12>( - PostNormTransformerBlock( - DropoutLayer( - SelfAttention( - MultiheadQKVAttenOp(head = 12, p = nothing), - Fork<3>(Dense(W = (768, 768), b = true)), # 1_771_776 parameters - Dense(W = (768, 768), b = true), # 590_592 parameters - ), - ), - LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters - DropoutLayer( - Chain( - Dense(σ = NNlib.gelu, W = (768, 3072), b = true), # 2_362_368 parameters - Dense(W = (3072, 768), b = true), # 2_360_064 parameters - ), - ), - LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters - ), - ), # Total: 192 arrays, 85_054_464 parameters, 40.422 KiB. - Branch{(:pooled,) = (:hidden_state,)}( - BertPooler(Dense(σ = NNlib.tanh_fast, W = (768, 768), b = true)), # 590_592 parameters - ), -) # Total: 199 arrays, 109_482_240 parameters, 43.578 KiB. - -julia> base_colbert.linear -Dense(W = (768, 128), b = true) # 98_432 parameters - -julia> base_colbert.tokenizer -TrfTextEncoder( -├─ TextTokenizer(MatchTokenization(WordPieceTokenization(bert_uncased_tokenizer, WordPiece(vocab_size = 30522, unk = [UNK], max_char = 100)), 5 patterns)), -├─ vocab = Vocab{String, SizedArray}(size = 30522, unk = [UNK], unki = 101), -├─ config = @NamedTuple{startsym::String, endsym::String, padsym::String, trunc::Union{Nothing, Int64}}(("[CLS]", "[SEP]", "[PAD]", 512)), -├─ annotate = annotate_strings, -├─ onehot = lookup_first, -├─ decode = nestedcall(remove_conti_prefix), -├─ textprocess = Pipelines(target[token] := join_text(source); target[token] := nestedcall(cleanup ∘ remove_prefix_space, target.token); target := (target.token)), -└─ process = Pipelines: - ╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source) - ╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token) - ╰─ target[(token, segment)] := SequenceTemplate{String}([CLS]: Input[1]: [SEP]: (Input[2]: [SEP]:)...)(target.token) - ╰─ target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(512))(target.token) - ╰─ target[token] := TextEncodeBase.trunc_and_pad(512, [PAD], tail, tail)(target.token) - ╰─ target[token] := TextEncodeBase.nested2batch(target.token) - ╰─ target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment) - ╰─ target[segment] := TextEncodeBase.nested2batch(target.segment) - ╰─ target[sequence_mask] := identity(target.attention_mask) - ╰─ target := (target.token, target.segment, target.attention_mask, target.sequence_mask) -``` -""" -struct BaseColBERT - bert::HF.HGFBertModel - linear::Layers.Dense - tokenizer::TextEncoders.AbstractTransformerTextEncoder -end - -function BaseColBERT(modelpath::AbstractString) - tokenizer, bert_model, linear = load_hgf_pretrained_local(modelpath) - bert_model = bert_model |> Flux.gpu - linear = linear |> Flux.gpu - BaseColBERT(bert_model, linear, tokenizer) -end - -""" - Checkpoint(model::BaseColBERT, config::ColBERTConfig) - -A wrapper for [`BaseColBERT`](@ref), containing information for generating embeddings -for docs and queries. - -If the `config` is set to mask punctuations, then the `skiplist` property of the created -[`Checkpoint`](@ref) will be set to a list of token IDs of punctuations. Otherwise, it will be empty. - -# Arguments - - - `model`: The [`BaseColBERT`](@ref) to be wrapped. - - `config`: The underlying [`ColBERTConfig`](@ref). - -# Returns - -The created [`Checkpoint`](@ref). - -# Examples - -Continuing from the example for [`BaseColBERT`](@ref): - -```julia-repl -julia> checkpoint = Checkpoint(base_colbert, config) - -julia> checkpoint.skiplist # by default, all punctuations -32-element Vector{Int64}: - 1000 - 1001 - 1002 - 1003 - 1004 - 1005 - 1006 - 1007 - 1008 - 1009 - 1010 - 1011 - 1012 - 1013 - ⋮ - 1028 - 1029 - 1030 - 1031 - 1032 - 1033 - 1034 - 1035 - 1036 - 1037 - 1064 - 1065 - 1066 - 1067 -``` -""" -struct Checkpoint - model::BaseColBERT - skiplist::Vector{Int64} -end - -function Checkpoint(model::BaseColBERT, config::ColBERTConfig) - if config.mask_punctuation - punctuation_list = string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")) - skiplist = [TextEncodeBase.lookup(model.tokenizer.vocab, punct) - for punct in punctuation_list] - else - skiplist = Vector{Int64}() - end - Checkpoint(model, skiplist) -end - -""" - mask_skiplist(tokenizer::TextEncoders.AbstractTransformerTextEncoder, - integer_ids::AbstractMatrix{Int32}, skiplist::Union{Missing, Vector{Int64}}) - -Create a mask for the given `integer_ids`, based on the provided `skiplist`. -If the `skiplist` is not missing, then any token IDs in the list will be filtered out along with the padding token. -Otherwise, all tokens are included in the mask. - -# Arguments - - - `tokenizer`: The underlying tokenizer. - - `integer_ids`: An `Array` of token IDs for the documents. - - `skiplist`: A list of token IDs to skip in the mask. - -# Returns - -An array of booleans indicating whether the corresponding token ID -is included in the mask or not. The array has the same shape as -`integer_ids`, i.e `(L, N)`, where `L` is the maximum length of -any document in `integer_ids` and `N` is the number of documents. - -# Examples - -Continuing with the example for [`tensorize_docs`](@ref) and the -`skiplist` from the example in [`Checkpoint`](@ref). - -```julia-repl -julia> integer_ids = batches[1][1]; - -julia> ColBERT.mask_skiplist( - checkpoint.model.tokenizer, integer_ids, checkpoint.skiplist) -21×3 BitMatrix: - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 0 1 0 - 0 0 1 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 -``` -""" -function mask_skiplist( - tokenizer::TextEncoders.AbstractTransformerTextEncoder, - integer_ids::AbstractMatrix{Int32}, skiplist::Union{ - Missing, Vector{Int64}}) - filter = integer_ids .!= - TextEncodeBase.lookup(tokenizer.vocab, tokenizer.padsym) - for token_id in skiplist - filter = filter .& (integer_ids .!= token_id) - end - filter -end - """ doc( config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, @@ -288,421 +48,269 @@ julia> mask 1 1 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ``` """ -function doc( - config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, - integer_mask::AbstractMatrix{Bool}) - integer_ids = integer_ids |> Flux.gpu - integer_mask = integer_mask |> Flux.gpu - - D = checkpoint.model.bert((token = integer_ids, - attention_mask = NeuralAttentionlib.GenericSequenceMask(integer_mask))).hidden_state - D = checkpoint.model.linear(D) - - mask = mask_skiplist( - checkpoint.model.tokenizer, integer_ids, checkpoint.skiplist) - mask = reshape(mask, (1, size(mask)...)) # equivalent of unsqueeze - @assert isequal(size(mask)[2:end], size(D)[2:end]) - "size(mask): $(size(mask)), size(D): $(size(D))" - @assert mask isa AbstractArray{Bool} "$(typeof(mask))" - - D = D .* mask # clear out embeddings of masked tokens - - if !config.use_gpu - # doing this because normalize gives exact results - D = mapslices(v -> iszero(v) ? v : normalize(v), D, dims = 1) # normalize each embedding - else - # TODO: try to do some tests to see the gap between this and LinearAlgebra.normalize - # mapreduce doesn't give exact normalization - norms = map(sqrt, mapreduce(abs2, +, D, dims = 1)) - norms[norms .== 0] .= 1 # avoid division by 0 - @assert isequal(size(norms)[2:end], size(D)[2:end]) - @assert size(norms)[1] == 1 - - D = D ./ norms - end +function doc(bert::HF.HGFBertModel, linear::Layers.Dense, + integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}) + linear(bert((token = integer_ids, + attention_mask = NeuralAttentionlib.GenericSequenceMask(bitmask))).hidden_state) +end + +function _doc_embeddings_and_doclens( + bert::HF.HGFBertModel, linear::Layers.Dense, skiplist::Vector{Int}, + integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}) + D = doc(bert, linear, integer_ids, bitmask) # (dim, doc_maxlen, current_batch_size) + mask = _clear_masked_embeddings!(D, integer_ids, skiplist) # (1, doc_maxlen, current_batch_size) + + # normalize each embedding in D; along dims = 1 + _normalize_array!(D, dims = 1) + + # get the doclens by unsqueezing the mask + mask = reshape(mask, size(mask)[2:end]) # (doc_maxlen, current_batch_size) + doclens = vec(sum(mask, dims = 1)) + + # flatten out embeddings, i.e get embeddings for each token in each passage + D = _flatten_embeddings(D) # (dim, total_num_embeddings) + + # remove embeddings for masked tokens + D = _remove_masked_tokens(D, mask) # (dim, total_num_masked_embeddings) - D, mask + @assert ndims(D)==2 "ndims(D): $(ndims(D))" + @assert size(D, 2)==sum(doclens) "size(D): $(size(D)), sum(doclens): $(sum(doclens))" + @assert D isa AbstractMatrix{Float32} "$(typeof(D))" + @assert doclens isa AbstractVector{Int64} "$(typeof(doclens))" + + D, doclens +end + +function _query_embeddings( + bert::HF.HGFBertModel, linear::Layers.Dense, skiplist::Vector{Int}, + integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}) + Q = doc(bert, linear, integer_ids, bitmask) # (dim, query_maxlen, current_batch_size) + + # skiplist only contains the pad symbol by default + _ = _clear_masked_embeddings!(Q, integer_ids, skiplist) + + # normalize each embedding in Q; along dims = 1 + _normalize_array!(Q, dims = 1) + + @assert ndims(Q)==3 "ndims(Q): $(ndims(Q))" + @assert(isequal(size(Q)[2:end], size(integer_ids)), + "size(Q): $(size(Q)), size(integer_ids): $(size(integer_ids))") + @assert Q isa AbstractArray{Float32} "$(typeof(Q))" + + Q end """ - docFromText(config::ColBERTConfig, checkpoint::Checkpoint, - docs::Vector{String}, bsize::Union{Missing, Int}) + encode_passages( + config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) -Get ColBERT embeddings for `docs` using `checkpoint`. +Encode a list of passages using `checkpoint`. -This function also applies ColBERT-style document pre-processing for each document in `docs`. +The given `passages` are run through the underlying BERT model and the linear layer to +generate the embeddings, after doing relevant document-specific preprocessing. +See [`docFromText`](@ref) for more details. # Arguments -- `config`: The [`ColBERTConfig`](@ref) being used. -- `checkpoint`: A [`Checkpoint`](@ref) to be used to compute embeddings. -- `docs`: A list of documents to get the embeddings for. -- `bsize`: A batch size for processing documents in batches. + - `config`: The [`ColBERTConfig`](@ref) to be used. + - `checkpoint`: The [`Checkpoint`](@ref) used to encode the passages. + - `passages`: A list of strings representing the passages to be encoded. # Returns -A tuple `embs, doclens`, where `embs` is an array of embeddings and `doclens` is a `Vector` -of document lengths. The array `embs` has shape `(D, N)`, where `D` is the embedding -dimension (`128` for ColBERT's linear layer) and `N` is the total number of embeddings -across all documents in `docs`. +A tuple `embs, doclens` where: -# Examples + - `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`, + where `D` is the embedding dimension and `N` is the total number of embeddings + across all the passages. + - `doclens::AbstractVector{Int}`: A vector of document lengths for each passage, + i.e the total number of attended tokens for each document passage. -Continuing from the example in [`Checkpoint`](@ref): +# Examples ```julia-repl -julia> docs = [ - "hello world", - "thank you!", - "a", - "this is some longer text, so length should be longer", -]; +julia> using ColBERT: load_hgf_pretrained_local, ColBERTConfig, encode_passages; + +julia> using CUDA, Flux, Transformers, TextEncodeBase; + +julia> config = ColBERTConfig(); + +julia> dim = config.dim +128 + +julia> index_bsize = 128; # this is the batch size to be fed in the transformer + +julia> doc_maxlen = config.doc_maxlen +300 + +julia> doc_token = config.doc_token_id +"[unused1]" + +julia> tokenizer, bert, linear = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/"); + +julia> process = tokenizer.process; + +julia> truncpad_pipe = Pipeline{:token}( + TextEncodeBase.trunc_and_pad(doc_maxlen - 1, "[PAD]", :tail, :tail), + :token); + +julia> process = process[1:4] |> truncpad_pipe |> process[6:end]; + +julia> tokenizer = TextEncoders.BertTextEncoder( + tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, + endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); + +julia> bert = bert |> Flux.gpu; + +julia> linear = linear |> Flux.gpu; + +julia> passages = readlines("./downloads/lotte/lifestyle/dev/collection.tsv")[1:1000]; + +julia> punctuations_and_padsym = [string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")); + tokenizer.padsym]; + +julia> skiplist = [lookup(tokenizer.vocab, sym) + for sym in punctuations_and_padsym]; + +julia> @time embs, doclens = encode_passages( + bert, linear, tokenizer, passages, dim, index_bsize, doc_token, skiplist) # second run stats +[ Info: Encoding 1000 passages. + 25.247094 seconds (29.65 M allocations: 1.189 GiB, 37.26% gc time, 0.00% compilation time) +(Float32[-0.08001435 -0.10785186 … -0.08651956 -0.12118215; 0.07319974 0.06629379 … 0.0929825 0.13665271; … ; -0.037957724 -0.039623592 … 0.031274226 0.063107446; 0.15484622 0.16779025 … 0.11533891 0.11508792], [279, 117, 251, 105, 133, 170, 181, 115, 190, 132 … 76, 204, 199, 244, 256, 125, 251, 261, 262, 263]) -julia> embs, doclens = ColBERT.docFromText(config, checkpoint, docs, config.index_bsize) -(Float32[0.07590997 0.00056472444 … -0.09958261 -0.03259005; 0.08413661 -0.016337946 … -0.061889287 -0.017708546; … ; -0.11584533 0.016651645 … 0.0073241345 0.09233974; 0.043868616 0.084660925 … -0.0294838 -0.08536169], [5 5 4 13]) - -julia> embs -128×27 Matrix{Float32}: - 0.0759101 0.00056477 -0.0256841 0.0847256 … 0.0321216 -0.0811892 -0.0995827 -0.03259 - 0.0841366 -0.0163379 -0.0573766 0.0125381 0.0838632 -0.0118507 -0.0618893 -0.0177087 - -0.0301104 -0.0128124 0.0137095 0.00290062 0.0347227 0.0138398 -0.0573847 0.177861 - 0.0375674 0.216562 0.220287 -0.011 -0.0213431 -0.110819 0.00425487 -0.00131534 - 0.0252677 0.151702 0.189658 -0.104252 -0.0654913 -0.0272064 0.0350983 -0.0381015 - 0.00608619 -0.0415363 -0.0479571 0.00884466 … 0.00207629 0.122848 0.0747105 0.0836628 - -0.185256 -0.106582 -0.0394912 -0.119268 0.163837 0.0352982 -0.0405874 -0.064156 - -0.0816655 -0.142809 -0.15595 -0.109608 0.0882721 0.0565001 -0.134649 0.00380792 - 0.00471225 0.00444501 0.0144707 0.0682628 0.0386771 0.0112827 0.0253297 0.0665075 - -0.121564 -0.189994 -0.173724 -0.0678208 -0.0832335 0.0151939 -0.119054 -0.0980481 - 0.157599 0.0919844 0.0748075 -0.122389 … 0.0599421 0.0330669 0.0205288 0.0184296 - 0.0132481 -0.0430333 -0.0679477 0.0918445 0.14166 0.0404866 0.0575921 0.101701 - 0.0695786 0.0281928 0.000234582 0.0570102 -0.137199 -0.0378472 -0.0531831 -0.123457 - -0.0933987 -0.0390347 -0.0274184 -0.0452961 0.14876 0.0279156 0.0309748 0.00298152 - 0.0458562 0.0729707 0.0336343 0.189599 0.0570071 0.103661 0.00905471 0.127777 - 0.00452595 0.05959 0.0768679 -0.036913 … 0.0768966 0.148845 0.0569493 0.293592 - -0.0385804 -0.00754613 0.0375564 0.00207589 -0.0161775 0.133667 0.266788 0.0394272 - ⋮ ⋱ ⋮ - 0.0510928 -0.138272 -0.111771 -0.192081 -0.0312752 -0.00646487 -0.0171807 -0.0618908 - 0.128495 0.181198 0.131882 -0.064132 -0.00662879 -0.00408871 0.027459 0.0343185 - -0.0961544 -0.0223997 0.025595 -0.12089 0.0042998 0.0117906 -0.0813832 0.0382321 - 0.0285496 0.0556695 0.0805605 -0.0728611 … 0.138845 -0.0139292 -0.14533 -0.017602 - 0.0112119 -0.164717 -0.188169 0.0315999 0.112653 0.071643 -0.0662124 0.164667 - -0.0017815 0.0600865 0.0858722 0.00955078 -0.0506793 0.120243 0.0490749 0.0562548 - -0.0261784 0.0343851 0.0447504 -0.105545 -0.0713677 0.0469064 0.040038 -0.0536368 - -0.0696538 -0.020624 -0.0465219 -0.121079 -0.0636235 0.0441996 0.0842775 0.0567261 - -0.0940355 -0.106123 -0.0424943 0.0650131 … 0.00190927 0.00334517 0.00795241 -0.0439884 - 0.0567849 -0.0312434 -0.0715616 0.136271 -0.0648593 -0.113022 0.0616157 -0.0738149 - -0.0143086 0.105833 0.0762297 0.0102708 -0.162572 -0.142671 -0.0430241 -0.0831737 - 0.0447039 0.0783602 0.0957613 0.0603179 0.0415507 -0.0413788 0.0315282 -0.171445 - 0.129225 0.112544 0.0815397 -0.00357054 0.097503 0.120684 0.107231 0.119762 - 0.00020747 -0.124472 -0.120445 -0.0102294 … -0.24173 -0.0930788 -0.0519734 0.0837617 - -0.115845 0.0166517 0.0199255 -0.044735 -0.0353863 0.0577463 0.00732411 0.0923398 - 0.0438687 0.0846609 0.0960215 0.112225 -0.178799 -0.096704 -0.0294837 -0.0853618 - -julia> doclens -4-element Vector{Int64}: - 5 - 5 - 4 - 13 ``` """ -function docFromText(config::ColBERTConfig, checkpoint::Checkpoint, - docs::Vector{String}, bsize::Union{Missing, Int}) - if ismissing(bsize) - # integer_ids, integer_mask = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize) - # doc(checkpoint, integer_ids, integer_mask) - error("Currently bsize cannot be missing!") - else - integer_ids, integer_mask = tensorize_docs( - config, checkpoint.model.tokenizer, docs) - - # we sort passages by length to do batch packing for more efficient use of the GPU - integer_ids, integer_mask, reverse_indices = _sort_by_length( - integer_ids, integer_mask, bsize) - - @assert length(reverse_indices) == length(docs) - "length(reverse_indices): $(length(reverse_indices)), length(batch_text): $(length(docs))" - @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" - @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" - @assert reverse_indices isa Vector{Int64} "$(typeof(reverse_indices))" - - # aggregate all embeddings - D, mask = Vector{AbstractArray{Float32}}(), - Vector{AbstractArray{Bool}}() - for passage_offset in 1:bsize:length(docs) - passage_end_offset = min(length(docs), passage_offset + bsize - 1) - D_, mask_ = doc( - config, checkpoint, integer_ids[ - :, passage_offset:passage_end_offset], - integer_mask[:, passage_offset:passage_end_offset]) - push!(D, D_) - push!(mask, mask_) - D_, mask_ = nothing, nothing - end - - # concat embeddings and masks, and put them in the original order - D, mask = cat(D..., dims = 3)[:, :, reverse_indices], - cat(mask..., dims = 3)[:, :, reverse_indices] - mask = reshape(mask, size(mask)[2:end]) - - # get doclens, i.e number of attended tokens for each passage - doclens = vec(sum(mask, dims = 1)) - - # flatten out embeddings, i.e get embeddings for each token in each passage - D = reshape(D, size(D)[1], prod(size(D)[2:end])) - - # remove embeddings for masked tokens - D = D[:, reshape(mask, prod(size(mask)))] - - @assert ndims(D)==2 "ndims(D): $(ndims(D))" - @assert size(D)[2]==sum(doclens) "size(D): $(size(D)), sum(doclens): $(sum(doclens))" - @assert D isa AbstractMatrix{Float32} "$(typeof(D))" - @assert doclens isa AbstractVector{Int64} "$(typeof(doclens))" - - Flux.cpu(D), Flux.cpu(doclens) +function encode_passages(bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + passages::Vector{String}, dim::Int, index_bsize::Int, + doc_token::String, skiplist::Vector{Int}) + @info "Encoding $(length(passages)) passages." + length(passages) == 0 && return rand(Float32, dim, 0), rand(Int, 0) + + # batching here to avoid storing intermediate embeddings on GPU + embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}() + for passage_offset in 1:index_bsize:length(passages) + passage_end_offset = min( + length(passages), passage_offset + index_bsize - 1) + + # get the token IDs and attention mask + integer_ids, bitmask = tensorize_docs( + doc_token, tokenizer, passages[passage_offset:passage_end_offset]) + + integer_ids = integer_ids |> Flux.gpu + bitmask = bitmask |> Flux.gpu + + # run the tokens and attention mask through the transformer + # and mask the skiplist tokens + D, doclens_ = _doc_embeddings_and_doclens( + bert, linear, skiplist, integer_ids, bitmask) + + push!(embs, Flux.cpu(D)) + append!(doclens, Flux.cpu(doclens_)) end + embs = cat(embs..., dims = 2) + embs, doclens end """ - query( - config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, - integer_mask::AbstractMatrix{Bool}) + encode_query(searcher::Searcher, query::String) -Compute the hidden state of the BERT and linear layers of ColBERT for queries. +Encode a search query to a matrix of embeddings using the provided `searcher`. The encoded query can then be used to search the collection. # Arguments - - `config`: The [`ColBERTConfig`](@ref) to be used. - - `checkpoint`: The [`Checkpoint`](@ref) containing the layers to compute the embeddings. - - `integer_ids`: An array of token IDs to be fed into the BERT model. - - `integer_mask`: An array of corresponding attention masks. Should have the same shape as `integer_ids`. + - `searcher`: A Searcher object that contains information about the collection and the index. + - `query`: The search query to encode. # Returns -`Q`, where `Q` is an array containing the normalized embeddings for each token in the query matrix. -It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer of ColBERT), -and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of any query and `N` is -the total number of queries. +An array containing the embeddings for each token in the query. Also see [queryFromText](@ref) to see the size of the array. # Examples -Continuing from the queries example for [`tensorize_queries`](@ref) and [`Checkpoint`](@ref): +Here's an example using the `config` and `checkpoint` from the example for [`Checkpoint`](@ref). ```julia-repl -julia> ColBERT.query(config, checkpoint, integer_ids, integer_mask)[:, :, 1] -128×32×1 CuArray{Float32, 3, CUDA.DeviceMemory}: -[:, :, 1] = - 0.0158568 0.169676 0.092745 0.0798617 0.153079 … 0.117006 0.115806 0.115938 0.112977 0.107919 - 0.220185 0.0304873 0.165348 0.150315 -0.0116249 0.0173332 0.0165187 0.0168762 0.0178042 0.0200356 - -0.00790017 -0.0192251 -0.0852365 -0.0799609 -0.0465292 -0.0693319 -0.0737462 -0.0777439 -0.0776733 -0.0830504 - -0.109909 -0.170906 -0.0138701 -0.0409766 -0.177391 -0.113141 -0.118738 -0.126037 -0.126829 -0.13149 - -0.0231787 0.0532214 0.0607473 0.0279048 0.0634681 0.112296 0.111831 0.117017 0.114073 0.108536 - 0.0620549 0.0465075 0.0821693 0.0606439 0.0592031 … 0.0167847 0.0148605 0.0150612 0.0133353 0.0126583 - -0.0290508 0.143255 0.0306142 0.0426579 0.129972 -0.17502 -0.169493 -0.164401 -0.161857 -0.160327 - 0.0921475 0.058833 0.250449 0.234636 0.0412965 0.0590262 0.0642577 0.0664076 0.0659837 0.0711358 - 0.0279402 -0.0278357 0.144855 0.147958 -0.0268559 0.161106 0.157629 0.154552 0.155525 0.163634 - -0.0768143 -0.00587302 0.00543038 0.00443376 -0.0134111 -0.126912 -0.123969 -0.11757 -0.112495 -0.11112 - -0.0184337 0.00668561 -0.191863 -0.161345 0.0222466 … -0.103246 -0.10374 -0.107664 -0.107267 -0.114564 - 0.0112104 0.0214651 -0.0923963 -0.0823052 0.0600248 0.103589 0.103387 0.106261 0.105065 0.10409 - 0.110971 0.272576 0.148319 0.143233 0.239578 0.11224 0.107913 0.109914 0.112652 0.108365 - -0.131066 0.0376254 -0.0164237 -0.000193318 0.00344707 -0.0893371 -0.0919217 -0.0969305 -0.0935498 -0.096145 - -0.0402605 0.0350559 0.0162864 0.0269105 0.00968855 -0.0623393 -0.0670097 -0.070679 -0.0655848 -0.0564059 - 0.0799973 0.0482302 0.0712078 0.0792903 0.0108783 … 0.00820444 0.00854873 0.00889943 0.00932721 0.00751066 - -0.137565 -0.0369116 -0.065728 -0.0664102 -0.0238012 0.029041 0.0292468 0.0297059 0.0278639 0.0257616 - 0.0479746 -0.102338 -0.0557072 -0.0833976 -0.0979401 -0.057629 -0.053911 -0.0566325 -0.0568765 -0.0581378 - 0.0656851 0.0195639 0.0288789 0.0559219 0.0315515 0.0472323 0.054771 0.0596156 0.0541802 0.0525933 - 0.0668634 -0.00400549 0.0297102 0.0505045 -0.00082792 0.0414113 0.0400276 0.0361149 0.0325914 0.0260693 - -0.0691096 0.0348577 -0.000312685 0.0232462 -0.00250495 … -0.141874 -0.142026 -0.132163 -0.129679 -0.131122 - -0.0273036 0.0653352 0.0332689 0.017918 0.0875479 0.0500921 0.0471914 0.0469949 0.0434268 0.0442646 - -0.0981665 -0.0296463 -0.0114686 -0.0348033 -0.0468719 -0.0772672 -0.0805913 -0.0809244 -0.0823798 -0.081472 - ⋮ ⋱ ⋮ - 0.0506199 0.00290888 0.047947 0.063503 -0.0072114 0.0360347 0.0326486 0.033966 0.0327732 0.0261081 - -0.0288586 -0.150171 -0.0699125 -0.108002 -0.142865 -0.0775934 -0.072192 -0.0697569 -0.0715358 -0.0683193 - -0.0646991 0.0724608 -0.00767811 -0.0184348 0.0524162 0.0457267 0.0532778 0.0649795 0.0697126 0.0808413 - 0.0445508 0.0296366 0.0325647 0.0521935 0.0436496 0.129031 0.126605 0.12324 0.120497 0.117703 - -0.127301 -0.0224252 -0.00579415 -0.00877803 -0.0140665 … -0.080026 -0.080839 -0.0823464 -0.0803394 -0.0856279 - 0.0304881 0.0396951 0.0798097 0.0736797 0.0800866 0.0426674 0.0411406 0.0460205 0.0460111 0.0532082 - 0.0488798 0.252244 0.0866849 0.098552 0.251561 -0.0236942 -0.035116 -0.0395483 -0.0463498 -0.0494207 - -0.0296798 -0.0494761 0.00688248 0.0264166 -0.0352487 -0.0476357 -0.0435388 -0.0404835 -0.0410673 -0.0367272 - 0.023548 -0.00147361 0.0629259 0.106951 0.0406627 0.00627022 0.00403014 -0.000107777 -0.000898423 0.00296315 - -0.0574151 -0.0875744 -0.103787 -0.114166 -0.103979 … -0.0708782 -0.0700138 -0.0687795 -0.070967 -0.0636385 - 0.0280373 0.149767 -0.0899733 -0.0732524 0.162316 0.022177 0.0183834 0.0201251 0.0197228 0.0219051 - -0.0617143 -0.0573989 -0.0973785 -0.0805046 -0.0525925 0.0997715 0.102691 0.107432 0.108591 0.109502 - -0.0859687 0.0623054 0.0974813 0.126841 0.0595557 0.0187937 0.0191363 0.0182794 0.0230548 0.031103 - 0.0392044 0.0162653 0.0926306 0.104054 0.0509464 0.0559883 0.0553617 0.0491496 0.0484319 0.0438133 - -0.0340362 -0.0278067 -0.0181035 -0.0282369 -0.0490531 … -0.0564175 -0.0562518 -0.0617946 -0.0631367 -0.0675882 - 0.0131229 0.0565131 -0.0349061 -0.0464192 0.0456515 0.0676478 0.0698765 0.0724731 0.0780165 0.0746229 - -0.117425 0.162483 0.11039 0.136364 0.135339 -0.00432259 -0.00508357 -0.00538224 -0.00685447 -0.00194357 - -0.0401157 -0.00450943 0.0539568 0.0689953 -0.00295334 -0.00671544 -0.00322498 -0.00518066 -0.00600254 -0.0077147 - 0.0893984 0.0695061 -0.049941 -0.035411 0.0767663 0.0913505 0.0964841 0.0960931 0.0961892 0.103431 - -0.116265 -0.106331 -0.179832 -0.149728 -0.0913282 … -0.0287848 -0.0275017 -0.0197172 -0.0220611 -0.018135 - -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.180245 -0.0780865 -0.073571 -0.0699094 -0.0684748 -0.0662903 - 0.100019 -0.0618588 0.106134 0.0989047 -0.0885639 -0.0547317 -0.0553563 -0.055676 -0.0556784 -0.0595709 -``` -""" -function query( - config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, - integer_mask::AbstractMatrix{Bool}) - integer_ids = integer_ids |> Flux.gpu - integer_mask = integer_mask |> Flux.gpu - - Q = checkpoint.model.bert((token = integer_ids, - attention_mask = NeuralAttentionlib.GenericSequenceMask(integer_mask))).hidden_state - Q = checkpoint.model.linear(Q) - - # only skip the pad symbol, i.e an empty skiplist - mask = mask_skiplist( - checkpoint.model.tokenizer, integer_ids, Vector{Int64}()) - mask = reshape(mask, (1, size(mask)...)) # equivalent of unsqueeze - @assert isequal(size(mask)[2:end], size(Q)[2:end]) - "size(mask): $(size(mask)), size(Q): $(size(Q))" - @assert mask isa AbstractArray{Bool} "$(typeof(mask))" - - Q = Q .* mask - - if !config.use_gpu - # doing this because normalize gives exact results - Q = mapslices(v -> iszero(v) ? v : normalize(v), Q, dims = 1) # normalize each embedding - else - # TODO: try to do some tests to see the gap between this and LinearAlgebra.normalize - # mapreduce doesn't give exact normalization - norms = map(sqrt, mapreduce(abs2, +, Q, dims = 1)) - norms[norms .== 0] .= 1 # avoid division by 0 - @assert isequal(size(norms)[2:end], size(Q)[2:end]) - @assert size(norms)[1] == 1 - - Q = Q ./ norms - end +julia> using ColBERT: load_hgf_pretrained_local, ColBERTConfig, encode_queries; - @assert ndims(Q)==3 "ndims(Q): $(ndims(Q))" - @assert isequal(size(Q)[2:end], size(integer_ids)) - "size(Q): $(size(Q)), size(integer_ids): $(size(integer_ids))" - @assert Q isa AbstractArray{Float32} "$(typeof(Q))" +julia> using CUDA, Flux, Transformers, TextEncodeBase; - Q -end +julia> config = ColBERTConfig(); -""" - queryFromText(config::ColBERTConfig, - checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{Missing, Int}) +julia> dim = config.dim +128 -Get ColBERT embeddings for `queries` using `checkpoint`. +julia> index_bsize = 128; # this is the batch size to be fed in the transformer -This function also applies ColBERT-style query pre-processing for each query in `queries`. +julia> query_maxlen = config.query_maxlen +300 -# Arguments +julia> query_token = config.query_token_id +"[unused1]" - - `config`: The [`ColBERTConfig`](@ref) to be used. - - `checkpoint`: A [`Checkpoint`](@ref) to be used to compute embeddings. - - `queries`: A list of queries to get the embeddings for. - - `bsize`: A batch size for processing queries in batches. +julia> tokenizer, bert, linear = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/"); -# Returns +julia> process = tokenizer.process; -`embs`, where `embs` is an array of embeddings. The array `embs` has shape `(D, L, N)`, -where `D` is the embedding dimension (`128` for ColBERT's linear layer), `L` is the -maximum length of any query in the batch, and `N` is the total number of queries in `queries`. +julia> truncpad_pipe = Pipeline{:token}( + TextEncodeBase.trunc_or_pad(query_maxlen - 1, "[PAD]", :tail, :tail), + :token); -# Examples +julia> process = process[1:4] |> truncpad_pipe |> process[6:end]; -Continuing from the example in [`Checkpoint`](@ref): +julia> tokenizer = TextEncoders.BertTextEncoder( + tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, + endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); -```julia-repl -julia> queries = ["what are white spots on raspberries?"]; +julia> bert = bert |> Flux.gpu; -julia> ColBERT.queryFromText(config, checkpoint, queries, 128) -128×32×1 Array{Float32, 3}: -[:, :, 1] = - 0.0158568 0.169676 0.092745 0.0798617 0.153079 … 0.117734 0.117006 0.115806 0.115938 0.112977 0.107919 - 0.220185 0.0304873 0.165348 0.150315 -0.0116249 0.0181126 0.0173332 0.0165187 0.0168762 0.0178042 0.0200356 - -0.00790017 -0.0192251 -0.0852365 -0.0799609 -0.0465292 -0.0672796 -0.0693319 -0.0737462 -0.0777439 -0.0776733 -0.0830504 - -0.109909 -0.170906 -0.0138701 -0.0409766 -0.177391 -0.10489 -0.113141 -0.118738 -0.126037 -0.126829 -0.13149 - -0.0231787 0.0532214 0.0607473 0.0279048 0.0634681 0.113961 0.112296 0.111831 0.117017 0.114073 0.108536 - 0.0620549 0.0465075 0.0821693 0.0606439 0.0592031 … 0.0174852 0.0167847 0.0148605 0.0150612 0.0133353 0.0126583 - -0.0290508 0.143255 0.0306142 0.0426579 0.129972 -0.175238 -0.17502 -0.169493 -0.164401 -0.161857 -0.160327 - 0.0921475 0.058833 0.250449 0.234636 0.0412965 0.0555153 0.0590262 0.0642577 0.0664076 0.0659837 0.0711358 - 0.0279402 -0.0278357 0.144855 0.147958 -0.0268559 0.162062 0.161106 0.157629 0.154552 0.155525 0.163634 - -0.0768143 -0.00587302 0.00543038 0.00443376 -0.0134111 -0.128129 -0.126912 -0.123969 -0.11757 -0.112495 -0.11112 - -0.0184337 0.00668561 -0.191863 -0.161345 0.0222466 … -0.102283 -0.103246 -0.10374 -0.107664 -0.107267 -0.114564 - 0.0112104 0.0214651 -0.0923963 -0.0823052 0.0600248 0.103233 0.103589 0.103387 0.106261 0.105065 0.10409 - 0.110971 0.272576 0.148319 0.143233 0.239578 0.109759 0.11224 0.107913 0.109914 0.112652 0.108365 - -0.131066 0.0376254 -0.0164237 -0.000193318 0.00344707 -0.0862689 -0.0893371 -0.0919217 -0.0969305 -0.0935498 -0.096145 - -0.0402605 0.0350559 0.0162864 0.0269105 0.00968855 -0.0587467 -0.0623393 -0.0670097 -0.070679 -0.0655848 -0.0564059 - 0.0799973 0.0482302 0.0712078 0.0792903 0.0108783 … 0.00501423 0.00820444 0.00854873 0.00889943 0.00932721 0.00751066 - -0.137565 -0.0369116 -0.065728 -0.0664102 -0.0238012 0.0250844 0.029041 0.0292468 0.0297059 0.0278639 0.0257616 - 0.0479746 -0.102338 -0.0557072 -0.0833976 -0.0979401 -0.0583169 -0.057629 -0.053911 -0.0566325 -0.0568765 -0.0581378 - 0.0656851 0.0195639 0.0288789 0.0559219 0.0315515 0.03907 0.0472323 0.054771 0.0596156 0.0541802 0.0525933 - 0.0668634 -0.00400549 0.0297102 0.0505045 -0.00082792 0.0399623 0.0414113 0.0400276 0.0361149 0.0325914 0.0260693 - -0.0691096 0.0348577 -0.000312685 0.0232462 -0.00250495 … -0.146082 -0.141874 -0.142026 -0.132163 -0.129679 -0.131122 - -0.0273036 0.0653352 0.0332689 0.017918 0.0875479 0.0535029 0.0500921 0.0471914 0.0469949 0.0434268 0.0442646 - -0.0981665 -0.0296463 -0.0114686 -0.0348033 -0.0468719 -0.0741133 -0.0772672 -0.0805913 -0.0809244 -0.0823798 -0.081472 - -0.0262739 0.109895 0.0117273 0.0222689 0.100869 0.0119844 0.0132486 0.012956 0.0175875 0.013171 0.0195091 - 0.0861164 0.0799029 0.00381147 0.0170927 0.103322 0.0238912 0.0209658 0.0226638 0.0209905 0.0230679 0.0221191 - 0.125112 0.0880232 0.0351989 0.022897 0.0862715 … -0.0219898 -0.0238914 -0.0207844 -0.0229276 -0.0238033 -0.0236367 - ⋮ ⋱ ⋮ - -0.158838 0.0415251 -0.0584126 -0.0373528 0.0819274 -0.212757 -0.214835 -0.213414 -0.212899 -0.215478 -0.210674 - -0.039636 -0.0837763 -0.0837142 -0.0597521 -0.0868467 0.0309127 0.0339911 0.03399 0.0313526 0.0316408 0.0309661 - 0.0755214 0.0960326 0.0858578 0.0614626 0.111979 … 0.102411 0.101302 0.108277 0.109034 0.107593 0.111863 - 0.0506199 0.00290888 0.047947 0.063503 -0.0072114 0.0388324 0.0360347 0.0326486 0.033966 0.0327732 0.0261081 - -0.0288586 -0.150171 -0.0699125 -0.108002 -0.142865 -0.0811611 -0.0775934 -0.072192 -0.0697569 -0.0715358 -0.0683193 - -0.0646991 0.0724608 -0.00767811 -0.0184348 0.0524162 0.046386 0.0457267 0.0532778 0.0649795 0.0697126 0.0808413 - 0.0445508 0.0296366 0.0325647 0.0521935 0.0436496 0.125633 0.129031 0.126605 0.12324 0.120497 0.117703 - -0.127301 -0.0224252 -0.00579415 -0.00877803 -0.0140665 … -0.0826691 -0.080026 -0.080839 -0.0823464 -0.0803394 -0.0856279 - 0.0304881 0.0396951 0.0798097 0.0736797 0.0800866 0.0448139 0.0426674 0.0411406 0.0460205 0.0460111 0.0532082 - 0.0488798 0.252244 0.0866849 0.098552 0.251561 -0.0212669 -0.0236942 -0.035116 -0.0395483 -0.0463498 -0.0494207 - -0.0296798 -0.0494761 0.00688248 0.0264166 -0.0352487 -0.0486577 -0.0476357 -0.0435388 -0.0404835 -0.0410673 -0.0367272 - 0.023548 -0.00147361 0.0629259 0.106951 0.0406627 0.00599323 0.00627022 0.00403014 -0.000107777 -0.000898423 0.00296315 - -0.0574151 -0.0875744 -0.103787 -0.114166 -0.103979 … -0.0697383 -0.0708782 -0.0700138 -0.0687795 -0.070967 -0.0636385 - 0.0280373 0.149767 -0.0899733 -0.0732524 0.162316 0.0233808 0.022177 0.0183834 0.0201251 0.0197228 0.0219051 - -0.0617143 -0.0573989 -0.0973785 -0.0805046 -0.0525925 0.0936075 0.0997715 0.102691 0.107432 0.108591 0.109502 - -0.0859687 0.0623054 0.0974813 0.126841 0.0595557 0.0244905 0.0187937 0.0191363 0.0182794 0.0230548 0.031103 - 0.0392044 0.0162653 0.0926306 0.104054 0.0509464 0.0516558 0.0559883 0.0553617 0.0491496 0.0484319 0.0438133 - -0.0340362 -0.0278067 -0.0181035 -0.0282369 -0.0490531 … -0.0528032 -0.0564175 -0.0562518 -0.0617946 -0.0631367 -0.0675882 - 0.0131229 0.0565131 -0.0349061 -0.0464192 0.0456515 0.0670016 0.0676478 0.0698765 0.0724731 0.0780165 0.0746229 - -0.117425 0.162483 0.11039 0.136364 0.135339 -0.00589512 -0.00432259 -0.00508357 -0.00538224 -0.00685447 -0.00194357 - -0.0401157 -0.00450943 0.0539568 0.0689953 -0.00295334 -0.0122461 -0.00671544 -0.00322498 -0.00518066 -0.00600254 -0.0077147 - 0.0893984 0.0695061 -0.049941 -0.035411 0.0767663 0.0880484 0.0913505 0.0964841 0.0960931 0.0961892 0.103431 - -0.116265 -0.106331 -0.179832 -0.149728 -0.0913282 … -0.0318565 -0.0287848 -0.0275017 -0.0197172 -0.0220611 -0.018135 - -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.180245 -0.0800835 -0.0780865 -0.073571 -0.0699094 -0.0684748 -0.0662903 - 0.100019 -0.0618588 0.106134 0.0989047 -0.0885639 -0.0577217 -0.0547317 -0.0553563 -0.055676 -0.0556784 -0.0595709 -``` -""" -function queryFromText(config::ColBERTConfig, - checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{ - Missing, Int}) - if ismissing(bsize) - error("Currently bsize cannot be missing!") - end +julia> linear = linear |> Flux.gpu; - # configure the tokenizer to truncate or pad to query_maxlen - tokenizer = checkpoint.model.tokenizer - process = tokenizer.process - truncpad_pipe = Pipeline{:token}( - TextEncodeBase.trunc_or_pad( - config.query_maxlen, "[PAD]", :tail, :tail), - :token) - process = process[1:4] |> truncpad_pipe |> process[6:end] - tokenizer = Transformers.TextEncoders.BertTextEncoder( - tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, - endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc) - - # get ids and masks, embeddings and returning the concatenated tensors - integer_ids, integer_mask = tensorize_queries(config, tokenizer, queries) - - # aggregate all embeddings - Q = Vector{AbstractArray{Float32}}() - for query_offset in 1:bsize:length(queries) - query_end_offset = min(length(queries), query_offset + bsize - 1) - Q_ = query( - config, checkpoint, integer_ids[:, query_offset:query_end_offset], - integer_mask[:, query_offset:query_end_offset]) - push!(Q, Q_) - Q_ = nothing - end - Q = cat(Q..., dims = 3) +julia> skiplist = [lookup(tokenizer.vocab, tokenizer.padsym)] +1-element Vector{Int64}: + 1 - @assert ndims(Q)==3 "ndims(Q): $(ndims(Q))" - @assert Q isa AbstractArray{Float32} "$(typeof(Q))" +julia> attend_to_mask_tokens = config.attend_to_mask_tokens + +julia> queries = [ + "what are white spots on raspberries?", + "here is another query!", +]; - Flux.cpu(Q) +julia> @time encode_queries(bert, linear, tokenizer, queries, dim, index_bsize, + query_token, attend_to_mask_tokens, skiplist); +[ Info: Encoding 2 queries. + 0.029858 seconds (27.58 k allocations: 781.727 KiB, 0.00% compilation time) +``` +""" +function encode_queries(bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + queries::Vector{String}, dim::Int, + index_bsize::Int, query_token::String, attend_to_mask_tokens::Bool, + skiplist::Vector{Int}) + # we assume that tokenizer is configured to truncate or pad to query_maxlen - 1 + @info "Encoding $(length(queries)) queries." + length(queries) == 0 && return rand(Float32, dim, 0) + + # batching here to avoid storing intermediate embeddings on GPU + embs = Vector{AbstractArray{Float32, 3}}() + for query_offset in 1:index_bsize:length(queries) + query_end_offset = min( + length(queries), query_offset + index_bsize - 1) + + # get the token IDs and attention mask + integer_ids, bitmask = tensorize_queries( + query_token, attend_to_mask_tokens, tokenizer, + queries[query_offset:query_end_offset]) # (query_maxlen, current_batch_size) + + integer_ids = integer_ids |> Flux.gpu + bitmask = bitmask |> Flux.gpu + + # run the tokens and attention mask through the transformer + Q = _query_embeddings( + bert, linear, skiplist, integer_ids, bitmask) # (dim, query_maxlen, current_batch_size) + + push!(embs, Flux.cpu(Q)) + end + embs = cat(embs..., dims = 3) end diff --git a/src/modelling/embedding_utils.jl b/src/modelling/embedding_utils.jl new file mode 100644 index 0000000..a3de520 --- /dev/null +++ b/src/modelling/embedding_utils.jl @@ -0,0 +1,205 @@ +""" + mask_skiplist(tokenizer::TextEncoders.AbstractTransformerTextEncoder, + integer_ids::AbstractMatrix{Int32}, skiplist::Union{Missing, Vector{Int64}}) + +Create a mask for the given `integer_ids`, based on the provided `skiplist`. +If the `skiplist` is not missing, then any token IDs in the list will be filtered out along with the padding token. +Otherwise, all tokens are included in the mask. + +# Arguments + + - `tokenizer`: The underlying tokenizer. + - `integer_ids`: An `Array` of token IDs for the documents. + - `skiplist`: A list of token IDs to skip in the mask. + +# Returns + +An array of booleans indicating whether the corresponding token ID +is included in the mask or not. The array has the same shape as +`integer_ids`, i.e `(L, N)`, where `L` is the maximum length of +any document in `integer_ids` and `N` is the number of documents. + +# Examples + +In this example, we'll mask out all punctuations as well as the pad symbol +of a tokenizer. + +```julia-repl +julia> using ColBERT: mask_skiplist; + +julia> using TextEncodeBase + +julia> tokenizer = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/:tokenizer"); + +julia> punctuations_and_padsym = [string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")); + tokenizer.padsym]; + +julia> skiplist = [lookup(tokenizer.vocab, sym) + for sym in punctuations_and_padsym] +33-element Vector{Int64}: + 1000 + 1001 + 1002 + 1003 + 1004 + 1005 + 1006 + 1007 + 1008 + 1009 + 1010 + 1011 + 1012 + 1013 + 1014 + 1025 + 1026 + 1027 + 1028 + 1029 + 1030 + 1031 + 1032 + 1033 + 1034 + 1035 + 1036 + 1037 + 1064 + 1065 + 1066 + 1067 + 1 + +julia> batch_text = [ + "no punctuation text", + "this, batch,! of text contains puncts! but is larger so that? the other text contains pad symbol;" +]; + +julia> integer_ids, _ = tensorize_docs("[unused1]", tokenizer, batch_text) + +julia> integer_ids +27×2 Matrix{Int32}: + 102 102 + 3 3 + 2054 2024 + 26137 1011 + 6594 14109 + 14506 1011 + 3794 1000 + 103 1998 + 1 3794 + 1 3398 + 1 26137 + 1 16650 + 1 1000 + 1 2022 + 1 2004 + 1 3470 + 1 2062 + 1 2009 + 1 1030 + 1 1997 + 1 2061 + 1 3794 + 1 3398 + 1 11688 + 1 6455 + 1 1026 + 1 103 + +julia> decode(tokenizer, integer_ids) +27×2 Matrix{String}: + " [CLS]" " [CLS]" + " [unused1]" " [unused1]" + " no" " this" + " pun" " ," + "ct" " batch" + "uation" " ," + " text" " !" + " [SEP]" " of" + " [PAD]" " text" + " [PAD]" " contains" + " [PAD]" " pun" + " [PAD]" "cts" + " [PAD]" " !" + " [PAD]" " but" + " [PAD]" " is" + " [PAD]" " larger" + " [PAD]" " so" + " [PAD]" " that" + " [PAD]" " ?" + " [PAD]" " the" + " [PAD]" " other" + " [PAD]" " text" + " [PAD]" " contains" + " [PAD]" " pad" + " [PAD]" " symbol" + " [PAD]" " ;" + " [PAD]" " [SEP]" + +julia> mask_skiplist(integer_ids, skiplist) +27×2 BitMatrix: + 1 1 + 1 1 + 1 1 + 1 0 + 1 1 + 1 0 + 1 0 + 1 1 + 0 1 + 0 1 + 0 1 + 0 1 + 0 0 + 0 1 + 0 1 + 0 1 + 0 1 + 0 1 + 0 0 + 0 1 + 0 1 + 0 1 + 0 1 + 0 1 + 0 1 + 0 0 + 0 1 +``` +""" +function mask_skiplist!(mask::AbstractMatrix{Bool}, + integer_ids::AbstractMatrix{Int32}, skiplist::Vector{Int64}) + for token_id in skiplist + mask .= mask .& (integer_ids .!= token_id) + end +end + +function _clear_masked_embeddings!(D::AbstractArray{Float32, 3}, + integer_ids::AbstractMatrix{Int32}, skiplist::Vector{Int}) + isequal(size(D)[2:end], size(integer_ids)) || + throw(DomainError("The number of embeddings in D and tokens " * + "in integer_ids must be equal!")) + # set everything to true + mask = similar(integer_ids, Bool) # respects the device as well + mask .= true + mask_skiplist!(mask, integer_ids, skiplist) # (doc_maxlen, current_batch_size) + mask = reshape(mask, (1, size(mask)...)) # (1, doc_maxlen, current_batch_size) + + # clear embeddings + D .= D .* mask # clear embeddings of masked tokens + mask +end + +function _flatten_embeddings(D::AbstractArray{Float32, 3}) + reshape(D, size(D, 1), prod(size(D)[2:end])) +end + +function _remove_masked_tokens( + D::AbstractMatrix{Float32}, mask::AbstractMatrix{Bool}) + size(D, 2) == prod(size(mask)) || + throw(DimensionMismatch("The total number of embeddings " * " + in D must be equal to the total number of tokens represented by mask!")) + D[:, vec(mask)] +end diff --git a/src/modelling/tokenization/doc_tokenization.jl b/src/modelling/tokenization/doc_tokenization.jl index 736be6c..0ace0eb 100644 --- a/src/modelling/tokenization/doc_tokenization.jl +++ b/src/modelling/tokenization/doc_tokenization.jl @@ -1,5 +1,5 @@ """ - tensorize_docs(config::ColBERTConfig, + tensorize_docs(doc_token_id::String, tokenizer::TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}) @@ -26,11 +26,37 @@ A tuple containing the following is returned: # Examples ```julia-repl -julia> using ColBERT, Transformers; +julia> using ColBERT: tensorize_docs, load_hgf_pretrained_local; -julia> config = ColBERTConfig(); +julia> using Transformers, Transformers.TextEncoders, TextEncodeBase; -julia> tokenizer = Transformers.load_tokenizer(config.checkpoint); +julia> tokenizer = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/:tokenizer") + +# configure the tokenizers maxlen and padding/truncation +julia> doc_maxlen = 20; + +julia> process = tokenizer.process +Pipelines: + target[token] := TextEncodeBase.nestedcall(string_getvalue, source) + target[token] := Transformers.TextEncoders.grouping_sentence(target.token) + target[(token, segment)] := SequenceTemplate{String}([CLS]: Input[1]: [SEP]: (Input[2]: [SEP]:)...)(target.token) + target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(512))(target.token) + target[token] := TextEncodeBase.trunc_and_pad(512, [PAD], tail, tail)(target.token) + target[token] := TextEncodeBase.nested2batch(target.token) + target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment) + target[segment] := TextEncodeBase.nested2batch(target.segment) + target[sequence_mask] := identity(target.attention_mask) + target := (target.token, target.segment, target.attention_mask, target.sequence_mask) + +julia> truncpad_pipe = Pipeline{:token}( + TextEncodeBase.trunc_and_pad(doc_maxlen - 1, "[PAD]", :tail, :tail), + :token); + +julia> process = process[1:4] |> truncpad_pipe |> process[6:end]; + +julia> tokenizer = TextEncoders.BertTextEncoder( + tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, + endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); julia> batch_text = [ "hello world", @@ -40,11 +66,12 @@ julia> batch_text = [ "this is an even longer document. this is some longer text, so length should be longer", ]; -julia> integer_ids, integer_mask = ColBERT.tensorize_docs(config, tokenizer, batch_text) -(Int32[102 102 … 102 102; 3 3 … 3 3; … ; 1 1 … 1 2937; 1 1 … 1 103], Bool[1 1 … 1 1; 1 1 … 1 1; … ; 0 0 … 0 1; 0 0 … 0 1]) +julia> integer_ids, bitmask = tensorize_docs( + "[unused1]", tokenizer, batch_text) +(Int32[102 102 … 102 102; 3 3 … 3 3; … ; 1 1 … 1 2023; 1 1 … 1 2937], Bool[1 1 … 1 1; 1 1 … 1 1; … ; 0 0 … 0 1; 0 0 … 0 1]) julia> integer_ids -21×5 reinterpret(Int32, ::Matrix{PrimitiveOneHot.OneHot{0x0000773a}}): +20×5 Matrix{Int32}: 102 102 102 102 102 3 3 3 3 3 7593 4068 1038 2024 2024 @@ -65,10 +92,9 @@ julia> integer_ids 1 1 1 1 2324 1 1 1 1 2023 1 1 1 1 2937 - 1 1 1 1 103 -julia> integer_mask -21×5 Matrix{Bool}: +julia> bitmask +20×5 Matrix{Bool}: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 @@ -89,31 +115,42 @@ julia> integer_mask 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 - 0 0 0 0 1 +julia> TextEncoders.decode(tokenizer, integer_ids) +20×5 Matrix{String}: + "[CLS]" "[CLS]" "[CLS]" "[CLS]" "[CLS]" + "[unused1]" "[unused1]" "[unused1]" "[unused1]" "[unused1]" + "hello" "thank" "a" "this" "this" + "world" "you" "[SEP]" "is" "is" + "[SEP]" "!" "[PAD]" "some" "an" + "[PAD]" "[SEP]" "[PAD]" "longer" "even" + "[PAD]" "[PAD]" "[PAD]" "text" "longer" + "[PAD]" "[PAD]" "[PAD]" "," "document" + "[PAD]" "[PAD]" "[PAD]" "so" "." + "[PAD]" "[PAD]" "[PAD]" "length" "this" + "[PAD]" "[PAD]" "[PAD]" "should" "is" + "[PAD]" "[PAD]" "[PAD]" "be" "some" + "[PAD]" "[PAD]" "[PAD]" "longer" "longer" + "[PAD]" "[PAD]" "[PAD]" "[SEP]" "text" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "," + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "so" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "length" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "should" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "be" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "longer" ``` """ -function tensorize_docs(config::ColBERTConfig, +function tensorize_docs(doc_token::String, tokenizer::TextEncoders.AbstractTransformerTextEncoder, - batch_text::Vector{String}) - # placeholder for [D] marker token - batch_text = [". " * doc for doc in batch_text] - - # getting the integer ids and masks - encoded_text = Transformers.TextEncoders.encode(tokenizer, batch_text) - ids, mask = encoded_text.token, encoded_text.attention_mask - integer_ids = reinterpret(Int32, ids) - integer_mask = NeuralAttentionlib.getmask(mask, ids)[1, :, :] - - # adding the [D] marker token ID - D_marker_token_id = TextEncodeBase.lookup( - tokenizer.vocab, config.doc_token_id) - integer_ids[2, :] .= D_marker_token_id - - @assert isequal(size(integer_ids), size(integer_mask)) "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(integer_mask)" - @assert isequal(size(integer_ids)[2], length(batch_text)) - @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" - @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" - - integer_ids, integer_mask + batch_text::AbstractArray{String}) + # we assume that tokenizer is configured to have maxlen: doc_maxlen - 1 + integer_ids, bitmask = _integer_ids_and_mask(tokenizer, batch_text) + + # adding the [D] marker token ID as the second token + # first one is always the "[CLS]" token + D_marker_token_id = lookup(tokenizer.vocab, doc_token) |> Int32 + integer_ids = _add_marker_row(integer_ids, D_marker_token_id) + bitmask = _add_marker_row(bitmask, true) + + integer_ids, bitmask end diff --git a/src/modelling/tokenization/query_tokenization.jl b/src/modelling/tokenization/query_tokenization.jl index efbfe56..43d5cb9 100644 --- a/src/modelling/tokenization/query_tokenization.jl +++ b/src/modelling/tokenization/query_tokenization.jl @@ -1,5 +1,6 @@ """ - tensorize_queries(config::ColBERTConfig, +using TextEncodeBase: tokenize + tensorize_queries(query_token::String, attend_to_mask_tokens::Bool, tokenizer::TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}) @@ -30,130 +31,167 @@ config. Note that, at the time of writing this package, configuring tokenizers i clean interface; so, we have to manually configure the tokenizer. ```julia-repl -julia> using ColBERT, Transformers, TextEncodeBase; +julia> using ColBERT: tensorize_queries, load_hgf_pretrained_local; -julia> config = ColBERTConfig(); +julia> using Transformers, Transformers.TextEncoders, TextEncodeBase; -julia> tokenizer = Transformers.load_tokenizer(config.checkpoint); +julia> tokenizer = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/:tokenizer"); + +# configure the tokenizers maxlen and padding/truncation +julia> query_maxlen = 32; julia> process = tokenizer.process; julia> truncpad_pipe = Pipeline{:token}( - TextEncodeBase.trunc_or_pad(config.query_maxlen, "[PAD]", :tail, :tail), - :token); + TextEncodeBase.trunc_or_pad(query_maxlen - 1, "[PAD]", :tail, :tail), + :token); julia> process = process[1:4] |> truncpad_pipe |> process[6:end]; julia> tokenizer = TextEncoders.BertTextEncoder( - tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, - endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); + tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, + endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); -julia> queries = [ +julia> batch_text = [ "what are white spots on raspberries?", - "what do rabbits eat?" + "what do rabbits eat?", + "this is a really long query. I'm deliberately making this long"* + "so that you can actually see that this is really truncated at 32 tokens"* + "and that the other two queries are padded to get 32 tokens."* + "this makes this a nice query as an example." ]; -julia> integer_ids, integer_mask = ColBERT.tensorize_queries(config, tokenizer, queries); - -julia> 32×2 reinterpret(Int32, ::Matrix{OneHot{0x0000773a}}): - 102 102 - 2 2 - 2055 2055 - 2025 2080 - 2318 20404 - 7517 4522 - 2007 1030 - 20711 103 - 2362 104 - 20969 104 - 1030 104 - 103 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - -julia> integer_mask -32×2 Matrix{Bool}: - 1 1 - 1 1 - 1 1 - 1 1 - 1 1 - 1 1 - 1 1 - 1 1 - 1 0 - 1 0 - 1 0 - 1 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - +julia> integer_ids, bitmask = tensorize_queries( + "[unused0]", false, tokenizer, batch_text); +(Int32[102 102 102; 2 2 2; … ; 104 104 8792; 104 104 2095], Bool[1 1 1; 1 1 1; … ; 0 0 1; 0 0 1]) + +julia> integer_ids +32×3 Matrix{Int32}: + 102 102 102 + 2 2 2 + 2055 2055 2024 + 2025 2080 2004 + 2318 20404 1038 + 7517 4522 2429 + 2007 1030 2147 + 20711 103 23033 + 2362 104 1013 + 20969 104 1046 + 1030 104 1006 + 103 104 1050 + 104 104 9970 + 104 104 2438 + 104 104 2024 + 104 104 2147 + 104 104 6500 + 104 104 2009 + 104 104 2018 + 104 104 2065 + 104 104 2942 + 104 104 2157 + 104 104 2009 + 104 104 2024 + 104 104 2004 + 104 104 2429 + 104 104 25450 + 104 104 2013 + 104 104 3591 + 104 104 19205 + 104 104 8792 + 104 104 2095 + +julia> bitmask +32×3 Matrix{Bool}: + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 0 1 + 1 0 1 + 1 0 1 + 1 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + +julia> TextEncoders.decode(tokenizer, integer_ids) +32×3 Matrix{String}: + "[CLS]" "[CLS]" "[CLS]" + "[unused0]" "[unused0]" "[unused0]" + "what" "what" "this" + "are" "do" "is" + "white" "rabbits" "a" + "spots" "eat" "really" + "on" "?" "long" + "ras" "[SEP]" "query" + "##p" "[MASK]" "." + "##berries" "[MASK]" "i" + "?" "[MASK]" "'" + "[SEP]" "[MASK]" "m" + "[MASK]" "[MASK]" "deliberately" + "[MASK]" "[MASK]" "making" + "[MASK]" "[MASK]" "this" + "[MASK]" "[MASK]" "long" + "[MASK]" "[MASK]" "##so" + "[MASK]" "[MASK]" "that" + "[MASK]" "[MASK]" "you" + "[MASK]" "[MASK]" "can" + "[MASK]" "[MASK]" "actually" + "[MASK]" "[MASK]" "see" + "[MASK]" "[MASK]" "that" + "[MASK]" "[MASK]" "this" + "[MASK]" "[MASK]" "is" + "[MASK]" "[MASK]" "really" + "[MASK]" "[MASK]" "truncated" + "[MASK]" "[MASK]" "at" + "[MASK]" "[MASK]" "32" + "[MASK]" "[MASK]" "token" + "[MASK]" "[MASK]" "##san" + "[MASK]" "[MASK]" "##d" ``` """ -function tensorize_queries(config::ColBERTConfig, +function tensorize_queries(query_token::String, attend_to_mask_tokens::Bool, tokenizer::TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}) - # placeholder for [Q] marker token - batch_text = [". " * query for query in batch_text] - - # getting the integer ids and masks - encoded_text = Transformers.TextEncoders.encode(tokenizer, batch_text) - ids, mask = encoded_text.token, encoded_text.attention_mask - integer_ids = reinterpret(Int32, ids) - integer_mask = NeuralAttentionlib.getmask(mask, ids)[1, :, :] - @assert isequal(size(integer_ids), size(integer_mask)) "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(size(integer_mask))" - @assert isequal( - size(integer_ids)[1], config.query_maxlen) "size(integer_ids): $(size(integer_ids)), query_maxlen: $(query_tokenizer.config.query_maxlen)" - @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" - @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" + # we assume that tokenizer is configured to have maxlen: query_maxlen - 1 + integer_ids, bitmask = _integer_ids_and_mask(tokenizer, batch_text) # adding the [Q] marker token ID and [MASK] augmentation Q_marker_token_id = TextEncodeBase.lookup( - tokenizer.vocab, config.query_token_id) - mask_token_id = TextEncodeBase.lookup(tokenizer.vocab, "[MASK]") - integer_ids[2, :] .= Q_marker_token_id - integer_ids[integer_ids .== 1] .= mask_token_id - - if config.attend_to_mask_tokens - integer_mask[integer_ids .== mask_token_id] .= 1 - @assert isequal(sum(integer_mask), prod(size(integer_mask))) "sum(integer_mask): $(sum(integer_mask)), prod(size(integer_mask)): $(prod(size(integer_mask)))" + tokenizer.vocab, query_token) |> Int32 + mask_token_id = TextEncodeBase.lookup(tokenizer.vocab, "[MASK]") |> Int32 + pad_token_id = TextEncodeBase.lookup( + tokenizer.vocab, tokenizer.config.padsym) |> Int32 + integer_ids = _add_marker_row(integer_ids, Q_marker_token_id) + bitmask = _add_marker_row(bitmask, true) + integer_ids[integer_ids .== pad_token_id] .= mask_token_id + + if attend_to_mask_tokens + bitmask[integer_ids .== mask_token_id] .= true + @assert isequal(sum(bitmask), prod(size(bitmask))) + "sum(integer_mask): $(sum(bitmask)), prod(size(integer_mask)): $(prod(size(bitmask)))" end - integer_ids, integer_mask + integer_ids, bitmask end diff --git a/src/modelling/tokenization/tokenizer_utils.jl b/src/modelling/tokenization/tokenizer_utils.jl new file mode 100644 index 0000000..fe60feb --- /dev/null +++ b/src/modelling/tokenization/tokenizer_utils.jl @@ -0,0 +1,143 @@ +""" + _integer_ids_and_mask( + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + batch_text::AbstractVector{String}) + +Run `batch_text` through `tokenizer` to get matrices of tokens and attention mask. + +# Arguments + + - `tokenizer`: The tokenizer to be used to tokenize the texts. + - `batch_text`: The list of texts to tokenize. + +# Returns + +A tuple `integer_ids, bitmask`, where `integer_ids` is a Matrix containing token IDs +and `bitmask` is the attention mask. + +# Examples + +```julia-repl +julia> using ColBERT: _integer_ids_and_mask, load_hgf_pretrained_local; + +julia> tokenizer = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/:tokenizer"); + +julia> batch_text = [ + "hello world", + "thank you!", + "a", + "this is some longer text, so length should be longer", + "this is an even longer document. this is some longer text, so length should be longer", +]; + +julia> integer_ids, bitmask = _integer_ids_and_mask(tokenizer, batch_text); + +julia> integer_ids +20×5 Matrix{Int32}: + 102 102 102 102 102 + 7593 4068 1038 2024 2024 + 2089 2018 103 2004 2004 + 103 1000 1 2071 2020 + 1 103 1 2937 2131 + 1 1 1 3794 2937 + 1 1 1 1011 6255 + 1 1 1 2062 1013 + 1 1 1 3092 2024 + 1 1 1 2324 2004 + 1 1 1 2023 2071 + 1 1 1 2937 2937 + 1 1 1 103 3794 + 1 1 1 1 1011 + 1 1 1 1 2062 + 1 1 1 1 3092 + 1 1 1 1 2324 + 1 1 1 1 2023 + 1 1 1 1 2937 + 1 1 1 1 103 + +julia> bitmask +20×5 BitMatrix: + 1 1 1 1 1 + 1 1 1 1 1 + 1 1 1 1 1 + 1 1 0 1 1 + 0 1 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 +``` +""" +function _integer_ids_and_mask( + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + batch_text::AbstractVector{String}) + encoded_text = TextEncoders.encode(tokenizer, batch_text) + ids, length_mask = encoded_text.token, encoded_text.attention_mask + integer_ids = reinterpret(Int32, ids) |> Matrix{Int32} + bitmask = length_mask .* trues(1, size(integer_ids)...) # (1, max_len, batch_size) + bitmask = reshape(bitmask, size(bitmask)[2:end]...) # (max_len, batch_size) + + @assert isequal(size(integer_ids), size(bitmask)) + "size(integer_ids): $(size(integer_ids)), size(bitmask): $(bitmask)" + @assert isequal(size(integer_ids, 2), length(batch_text)) + "size(integer_ids): $(size(integer_ids)), length(batch_text): $(length(batch_text))" + @assert integer_ids isa Matrix{Int32} "$(typeof(integer_ids))" + @assert bitmask isa BitMatrix "$(typeof(bitmask))" + "typeof(bitmask): $(typeof(bitmask))" + + integer_ids, bitmask +end + +""" + _add_marker_row(data::AbstractMatrix{T}, marker::T) where {T} + +Add row containing `marker` as the second row of `data`. + +# Arguments + + - `data`: The matrix in which the row is to be added. + - `marker`: The marker to be added. + +# Returns + +A matrix equal to `data`, with the second row being filled with `marker`. + +# Examples + +```julia-repl +julia> using ColBERT: _add_marker_row; + +julia> x = ones(Float32, 5, 5); +5×5 Matrix{Float32}: + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + +julia> _add_marker_row(x, zero(Float32)) +6×5 Matrix{Float32}: + 1.0 1.0 1.0 1.0 1.0 + 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 +``` + +""" +function _add_marker_row(data::AbstractMatrix{T}, marker::T) where {T} + [data[begin:min(1, size(data, 1)), :]; fill(marker, (1, size(data, 2))); + data[2:end, :]] +end diff --git a/src/savers.jl b/src/savers.jl index 7f1a3d5..bd31161 100644 --- a/src/savers.jl +++ b/src/savers.jl @@ -49,12 +49,9 @@ number of embeddings and the passage offsets are saved in a file named ` length(doclens), "num_embeddings" => length(codes) ), - 4 # indent + 4 ) end end @@ -121,3 +118,20 @@ function save(config::ColBERTConfig) ) end end + +function save_chunk_metadata_property( + index_path::String, property::String, properties::Vector{T}) where {T} + plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + @assert plan_metadata["num_chunks"] == length(properties) + for chunk_idx in 1:length(properties) + chunk_metadata = JSON.parsefile(joinpath( + index_path, "$(chunk_idx).metadata.json")) + chunk_metadata[property] = properties[chunk_idx] + open("$(chunk_idx).metadata.json", "w") do io + JSON.print(io, + chunk_metadata, + 4 + ) + end + end +end diff --git a/src/search/ranking.jl b/src/search/ranking.jl index 6cfe42f..ce1392c 100644 --- a/src/search/ranking.jl +++ b/src/search/ranking.jl @@ -1,29 +1,38 @@ """ -Return a candidate set of `pids` for the query matrix `Q`. This is done as follows: the nearest `nprobe` centroids for each query embedding are found. This list is then flattened and the unique set of these centroids is built. Using the `ivf`, the list of all unique embedding IDs contained in these centroids is computed. Finally, these embedding IDs are converted to `pids` using `emb2pid`. This list of `pids` is the final candidate set. + _cids_to_eids!(eids::Vector{Int}, centroid_ids::Vector{Int}, + ivf::Vector{Int}, ivf_lengths::Vector{Int}) + +Get the set of embedding IDs contained in `centroid_ids`. """ +function _cids_to_eids!(eids::Vector{Int}, centroid_ids::Vector{Int}, + ivf::Vector{Int}, ivf_lengths::Vector{Int}) + @assert length(eids) == sum(ivf_lengths[centroid_ids]) + centroid_ivf_offsets = cumsum([1; _head(ivf_lengths)]) + eid_offsets = cumsum([1; _head(ivf_lengths[centroid_ids])]) + for (idx, centroid_id) in enumerate(centroid_ids) + eid_offset = eid_offsets[idx] + batch_length = ivf_lengths[centroid_id] + ivf_offset = centroid_ivf_offsets[centroid_id] + eids[eid_offset:(eid_offset + batch_length - 1)] .= ivf[ivf_offset:(ivf_offset + batch_length - 1)] + end +end + function retrieve( - ivf::Vector{Int}, ivf_lengths::Vector{Int}, centroids::Matrix{Float32}, + ivf::Vector{Int}, ivf_lengths::Vector{Int}, centroids::AbstractMatrix{Float32}, emb2pid::Vector{Int}, nprobe::Int, Q::AbstractMatrix{Float32}) - # score of each query embedding with each centroid and take top nprobe centroids - cells = Flux.gpu(transpose(Q)) * Flux.gpu(centroids) |> Flux.cpu + # score each query against each centroid + cells = Q' * centroids # (num_query_embeddings, num_centroids) + # TODO: how to take topk entries using GPU code? - cells = mapslices( - row -> partialsortperm(row, 1:(nprobe), rev = true), - cells, dims = 2) # take top nprobe centroids for each query + cells = cells |> Flux.cpu + cells = _topk(cells, nprobe, dims = 2) # (num_query_embeddings, nprobe) centroid_ids = sort(unique(vec(cells))) # get all embedding IDs contained in centroid_ids using ivf - centroid_ivf_offsets = cat( - [1], 1 .+ cumsum(ivf_lengths)[1:end .!= end], dims = 1) - eids = Vector{Int}() - for centroid_id in centroid_ids - offset = centroid_ivf_offsets[centroid_id] - length = ivf_lengths[centroid_id] - append!(eids, ivf[offset:(offset + length - 1)]) - end - @assert isequal(length(eids), sum(ivf_lengths[centroid_ids])) - "length(eids): $(length(eids)), sum(ranker.ivf_lengths[centroid_ids]):" * - "$(sum(ivf_lengths[centroid_ids]))" + eids = Vector{Int}(undef, sum(ivf_lengths[centroid_ids])) # (sum(ivf_lengths[centroid_ids]), 1) + _cids_to_eids!(eids, centroid_ids, ivf, ivf_lengths) + + # get unique eids eids = sort(unique(eids)) # get pids from the emb2pid mapping @@ -34,56 +43,39 @@ end function _collect_compressed_embs_for_pids( doclens::Vector{Int}, codes::Vector{UInt32}, residuals::Matrix{UInt8}, pids::Vector{Int}) + # get offsets of pids in codes and residuals and the resultant arrays + pid_offsets = cumsum([1; _head(doclens)]) + offsets = cumsum([1; _head(doclens[pids])]) + + # collecting the codes and residuals for pids num_embeddings = sum(doclens[pids]) codes_packed = zeros(UInt32, num_embeddings) - residuals_packed = zeros(UInt8, size(residuals)[1], num_embeddings) - pid_offsets = cat([1], 1 .+ cumsum(doclens)[1:end .!= end], dims = 1) - offset = 1 - for pid in pids + residuals_packed = zeros(UInt8, size(residuals, 1), num_embeddings) + for (idx, pid) in enumerate(pids) + offset = offsets[idx] pid_offset = pid_offsets[pid] num_embs_pid = doclens[pid] - codes_packed[offset:(offset + num_embs_pid - 1)] = codes[pid_offset:(pid_offset + num_embs_pid - 1)] - residuals_packed[:, offset:(offset + num_embs_pid - 1)] = residuals[ + codes_packed[offset:(offset + num_embs_pid - 1)] .= codes[ + pid_offset:(pid_offset + num_embs_pid - 1)] + residuals_packed[:, offset:(offset + num_embs_pid - 1)] .= residuals[ :, pid_offset:(pid_offset + num_embs_pid - 1)] - offset += num_embs_pid end - @assert offset==num_embeddings + 1 "offset: $(offset), num_embs + 1: $(num_embeddings + 1)" codes_packed, residuals_packed end -function maxsim( - Q::Matrix{Float32}, D::Matrix{Float32}, pids::Vector{Int}, doclens::Vector{Int}) +function maxsim(Q::AbstractMatrix{Float32}, D::AbstractMatrix{Float32}, + pids::Vector{Int}, doclens::Vector{Int}) scores = zeros(Float32, length(pids)) num_embeddings = sum(doclens[pids]) - query_doc_scores = Flux.gpu(transpose(Q)) * Flux.gpu(D) # (num_query_tokens, num_embeddings) - offset = 1 + query_doc_scores = Q' * D + offsets = cumsum([1; _head(doclens[pids])]) for (idx, pid) in enumerate(pids) num_embs_pids = doclens[pid] + offset = offsets[idx] offset_end = min(num_embeddings, offset + num_embs_pids - 1) pid_scores = query_doc_scores[:, offset:offset_end] scores[idx] = sum(maximum(pid_scores, dims = 2)) offset += num_embs_pids end - @assert offset==num_embeddings + 1 "offset: $(offset), num_embs + 1: $(num_embeddings + 1)" scores end - -""" - - Get the decompressed embedding matrix for all embeddings in `pids`. Use `doclens` for this. -""" -function score_pids(config::ColBERTConfig, centroids::Matrix{Float32}, - bucket_weights::Vector{Float32}, doclens::Vector{Int}, codes::Vector{UInt32}, - residuals::Matrix{UInt8}, Q::Matrix{Float32}, pids::Vector{Int}) - codes_packed, residuals_packed = _collect_compressed_embs_for_pids( - doclens, codes, residuals, pids) - D_packed = decompress( - config.dim, config.nbits, centroids, bucket_weights, - codes_packed, residuals_packed) - @assert ndims(D_packed)==2 "ndims(D_packed): $(ndims(D_packed))" - @assert size(D_packed)[1] == config.dim - "size(D_packed): $(size(D_packed)), config.dim: $(config.dim)" - @assert size(D_packed)[2] == sum(doclens[pids]) - "size(D_packed): $(size(D_packed)), num_embs: $(sum(doclens[pids]))" - @assert D_packed isa AbstractMatrix{Float32} "$(typeof(D_packed))" - maxsim(Q, D_packed, pids, doclens) -end diff --git a/src/searching.jl b/src/searching.jl index 83c1bd3..c4046d6 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -1,26 +1,18 @@ struct Searcher config::ColBERTConfig - checkpoint::Checkpoint - centroids::Matrix{Float32} - bucket_cutoffs::Vector{Float32} - bucket_weights::Vector{Float32} + bert::HF.HGFBertModel + linear::Layers.Dense + tokenizer::TextEncoders.AbstractTransformerTextEncoder + centroids::AbstractMatrix{Float32} + bucket_cutoffs::AbstractVector{Float32} + bucket_weights::AbstractVector{Float32} ivf::Vector{Int} ivf_lengths::Vector{Int} doclens::Vector{Int} codes::Vector{UInt32} residuals::Matrix{UInt8} emb2pid::Vector{Int} -end - -function _build_emb2pid(doclens::Vector{Int}) - num_embeddings = sum(doclens) - emb2pid = zeros(Int, num_embeddings) - offset_doclens = 1 - for (pid, dlength) in enumerate(doclens) - emb2pid[offset_doclens:(offset_doclens + dlength - 1)] .= pid - offset_doclens += dlength - end - emb2pid + skiplist::Vector{Int} end function Searcher(index_path::String) @@ -28,21 +20,52 @@ function Searcher(index_path::String) error("Index at $(index_path) does not exist! Please build the index first and try again.") end + @info "Loading config from $(index_path)." config = load_config(index_path) - base_colbert = BaseColBERT(config.checkpoint) - checkpoint = Checkpoint(base_colbert, config) - @info "Loaded ColBERT layers from the $(config.checkpoint) HuggingFace checkpoint." + + @info "Loading ColBERT layers from the $(config.checkpoint) HuggingFace checkpoint." + tokenizer, bert, linear = load_hgf_pretrained_local(config.checkpoint) + bert = bert |> Flux.gpu + linear = linear |> Flux.gpu + + # configuring the tokenizer; with query_maxlen - 1 + process = tokenizer.process + truncpad_pipe = Pipeline{:token}( + TextEncodeBase.trunc_or_pad( + config.query_maxlen - 1, "[PAD]", :tail, :tail), + :token) + process = process[1:4] |> truncpad_pipe |> process[6:end] + tokenizer = TextEncoders.BertTextEncoder( + tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, + endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc) + + # loading the codec + @info "Loading codec." codec = load_codec(index_path) + codec["centroids"] = codec["centroids"] |> Flux.gpu + codec["bucket_cutoffs"] = codec["bucket_cutoffs"] |> Flux.gpu + codec["bucket_weights"] = codec["bucket_weights"] |> Flux.gpu + + # loading the ivf ivf = JLD2.load_object(joinpath(index_path, "ivf.jld2")) ivf_lengths = JLD2.load_object(joinpath(index_path, "ivf_lengths.jld2")) + + # loading the doclens and compressed embeddings doclens = load_doclens(index_path) codes, residuals = load_compressed_embs(index_path) + + # building emb2pid @info "Building the emb2pid mapping." emb2pid = _build_emb2pid(doclens) + # by default, only include the pad symbol in the skiplist + skiplist = [lookup(tokenizer.vocab, tokenizer.padsym)] + Searcher( config, - checkpoint, + bert, + linear, + tokenizer, codec["centroids"], codec["bucket_cutoffs"], codec["bucket_weights"], @@ -51,109 +74,55 @@ function Searcher(index_path::String) doclens, codes, residuals, - emb2pid + emb2pid, + skiplist ) end -""" - encode_query(searcher::Searcher, query::String) - -Encode a search query to a matrix of embeddings using the provided `searcher`. The encoded query can then be used to search the collection. - -# Arguments - - - `searcher`: A Searcher object that contains information about the collection and the index. - - `query`: The search query to encode. - -# Returns - -An array containing the embeddings for each token in the query. Also see [queryFromText](@ref) to see the size of the array. - -# Examples - -Here's an example using the `config` and `checkpoint` from the example for [`Checkpoint`](@ref). - -```julia-repl -julia> encode_query(config, checkpoint, "what are white spots on raspberries?") -128×32×1 Array{Float32, 3}: -[:, :, 1] = - 0.0158568 0.169676 0.092745 0.0798617 … 0.115938 0.112977 0.107919 - 0.220185 0.0304873 0.165348 0.150315 0.0168762 0.0178042 0.0200356 - -0.00790017 -0.0192251 -0.0852365 -0.0799609 -0.0777439 -0.0776733 -0.0830504 - -0.109909 -0.170906 -0.0138701 -0.0409766 -0.126037 -0.126829 -0.13149 - -0.0231787 0.0532214 0.0607473 0.0279048 0.117017 0.114073 0.108536 - 0.0620549 0.0465075 0.0821693 0.0606439 … 0.0150612 0.0133353 0.0126583 - -0.0290508 0.143255 0.0306142 0.0426579 -0.164401 -0.161857 -0.160327 - 0.0921475 0.058833 0.250449 0.234636 0.0664076 0.0659837 0.0711358 - 0.0279402 -0.0278357 0.144855 0.147958 0.154552 0.155525 0.163634 - -0.0768143 -0.00587302 0.00543038 0.00443376 -0.11757 -0.112495 -0.11112 - -0.0184337 0.00668561 -0.191863 -0.161345 … -0.107664 -0.107267 -0.114564 - 0.0112104 0.0214651 -0.0923963 -0.0823052 0.106261 0.105065 0.10409 - 0.110971 0.272576 0.148319 0.143233 0.109914 0.112652 0.108365 - -0.131066 0.0376254 -0.0164237 -0.000193318 -0.0969305 -0.0935498 -0.096145 - -0.0402605 0.0350559 0.0162864 0.0269105 -0.070679 -0.0655848 -0.0564059 - 0.0799973 0.0482302 0.0712078 0.0792903 … 0.00889943 0.00932721 0.00751066 - -0.137565 -0.0369116 -0.065728 -0.0664102 0.0297059 0.0278639 0.0257616 - 0.0479746 -0.102338 -0.0557072 -0.0833976 -0.0566325 -0.0568765 -0.0581378 - 0.0656851 0.0195639 0.0288789 0.0559219 0.0596156 0.0541802 0.0525933 - 0.0668634 -0.00400549 0.0297102 0.0505045 0.0361149 0.0325914 0.0260693 - -0.0691096 0.0348577 -0.000312685 0.0232462 … -0.132163 -0.129679 -0.131122 - -0.0273036 0.0653352 0.0332689 0.017918 0.0469949 0.0434268 0.0442646 - -0.0981665 -0.0296463 -0.0114686 -0.0348033 -0.0809244 -0.0823798 -0.081472 - -0.0262739 0.109895 0.0117273 0.0222689 0.0175875 0.013171 0.0195091 - 0.0861164 0.0799029 0.00381147 0.0170927 0.0209905 0.0230679 0.0221191 - ⋮ ⋱ ⋮ - -0.039636 -0.0837763 -0.0837142 -0.0597521 0.0313526 0.0316408 0.0309661 - 0.0755214 0.0960326 0.0858578 0.0614626 … 0.109034 0.107593 0.111863 - 0.0506199 0.00290888 0.047947 0.063503 0.033966 0.0327732 0.0261081 - -0.0288586 -0.150171 -0.0699125 -0.108002 -0.0697569 -0.0715358 -0.0683193 - -0.0646991 0.0724608 -0.00767811 -0.0184348 0.0649795 0.0697126 0.0808413 - 0.0445508 0.0296366 0.0325647 0.0521935 0.12324 0.120497 0.117703 - -0.127301 -0.0224252 -0.00579415 -0.00877803 … -0.0823464 -0.0803394 -0.0856279 - 0.0304881 0.0396951 0.0798097 0.0736797 0.0460205 0.0460111 0.0532082 - 0.0488798 0.252244 0.0866849 0.098552 -0.0395483 -0.0463498 -0.0494207 - -0.0296798 -0.0494761 0.00688248 0.0264166 -0.0404835 -0.0410673 -0.0367272 - 0.023548 -0.00147361 0.0629259 0.106951 -0.000107777 -0.000898423 0.00296315 - -0.0574151 -0.0875744 -0.103787 -0.114166 … -0.0687795 -0.070967 -0.0636385 - 0.0280373 0.149767 -0.0899733 -0.0732524 0.0201251 0.0197228 0.0219051 - -0.0617143 -0.0573989 -0.0973785 -0.0805046 0.107432 0.108591 0.109502 - -0.0859687 0.0623054 0.0974813 0.126841 0.0182794 0.0230548 0.031103 - 0.0392044 0.0162653 0.0926306 0.104054 0.0491496 0.0484319 0.0438133 - -0.0340362 -0.0278067 -0.0181035 -0.0282369 … -0.0617946 -0.0631367 -0.0675882 - 0.0131229 0.0565131 -0.0349061 -0.0464192 0.0724731 0.0780165 0.0746229 - -0.117425 0.162483 0.11039 0.136364 -0.00538224 -0.00685447 -0.00194357 - -0.0401157 -0.00450943 0.0539568 0.0689953 -0.00518066 -0.00600254 -0.0077147 - 0.0893984 0.0695061 -0.049941 -0.035411 0.0960931 0.0961892 0.103431 - -0.116265 -0.106331 -0.179832 -0.149728 … -0.0197172 -0.0220611 -0.018135 - -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0699094 -0.0684748 -0.0662903 - 0.100019 -0.0618588 0.106134 0.0989047 -0.055676 -0.0556784 -0.0595709 -``` -""" -function encode_query( - config::ColBERTConfig, checkpoint::Checkpoint, query::String) - queries = [query] - queryFromText(config, checkpoint, queries, config.index_bsize) +function _build_emb2pid(doclens::Vector{Int}) + num_embeddings = sum(doclens) + emb2pid = zeros(Int, num_embeddings) + embs2pid_offsets = cumsum([1; _head(doclens)]) + for (pid, dlength) in enumerate(doclens) + offset = embs2pid_offsets[pid] + emb2pid[offset:(offset + dlength - 1)] .= pid + end + @assert all(!=(0), emb2pid) + emb2pid end function search(searcher::Searcher, query::String, k::Int) - Q = encode_query(searcher.config, searcher.checkpoint, query) + Q = encode_queries(searcher.bert, searcher.linear, + searcher.tokenizer, [query], searcher.config.dim, + searcher.config.index_bsize, searcher.config.query_token, + searcher.config.attend_to_mask_tokens, searcher.skiplist) + @assert size(Q, 3)==1 "size(Q): $(size(Q))" + @assert(isequal(size(Q, 2), searcher.config.query_maxlen), + "size(Q): $(size(Q)), query_maxlen: $(searcher.config.query_maxlen)") + + # squeeze out last dim and move to gpu + Q = reshape(Q, size(Q)[1:end .!= end]...) |> Flux.gpu + + # get candidate pids + pids = retrieve(searcher.ivf, searcher.ivf_lengths, searcher.centroids, + searcher.emb2pid, searcher.config.nprobe, Q) - if size(Q)[3] > 1 - error("Only one query is supported at the moment!") - end - @assert size(Q)[3]==1 "size(Q): $(size(Q))" - @assert isequal(size(Q)[2], searcher.config.query_maxlen) - "size(Q): $(size(Q)), query_maxlen: $(searcher.config.query_maxlen)" # Q: (128, 32, 1) + # get compressed embeddings for the candidate pids + codes_packed, residuals_packed = _collect_compressed_embs_for_pids( + searcher.doclens, searcher.codes, searcher.residuals, pids) - Q = reshape(Q, size(Q)[1:end .!= end]...) # squeeze out the last dimension - @assert isequal(length(size(Q)), 2) "size(Q): $(size(Q))" + # decompress these embeddings and move to gpu + D_packed = decompress(searcher.config.dim, searcher.config.nbits, + Flux.cpu(searcher.centroids), Flux.cpu(searcher.bucket_weights), + codes_packed, residuals_packed) |> Flux.gpu + @assert(size(D_packed, 2)==sum(searcher.doclens[pids]), + "size(D_packed): $(size(D_packed)), num_embs: $(sum(searcher.doclens[pids]))") + @assert D_packed isa AbstractMatrix{Float32} "$(typeof(D_packed))" - pids = retrieve(searcher.ivf, searcher.ivf_lengths, searcher.centroids, - searcher.emb2pid, searcher.config.nprobe, Q) - scores = score_pids( - searcher.config, searcher.centroids, searcher.bucket_weights, - searcher.doclens, searcher.codes, searcher.residuals, Q, pids) + # get maxsim scores for the candidate pids + scores = maxsim(Q, D_packed, pids, searcher.doclens) + # sort scores and candidate pids, and return the top k indices = sortperm(scores, rev = true) pids, scores = pids[indices], scores[indices] pids[1:k], scores[1:k] diff --git a/src/utils.jl b/src/utils.jl index 58ddd71..88e12e2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,51 +1,56 @@ -""" - _sort_by_length( - integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) - -Sort sentences by number of attended tokens, if the number of sentences is larger than `bsize`. - -# Arguments - - - `integer_ids`: The token IDs of documents to be sorted. - - `integer_mask`: The attention masks of the documents to be sorted (attention masks are just bits). - - `bsize`: The size of batches to be considered. - -# Returns - -Depending upon `bsize`, the following are returned: - - - If the number of documents (second dimension of `integer_ids`) is atmost `bsize`, then the - `integer_ids` and `integer_mask` are returned unchanged. - - If the number of documents is larger than `bsize`, then the passages are first sorted - by the number of attended tokens (figured out from the `integer_mask`), and then the - sorted arrays `integer_ids`, `integer_mask` are returned, along with a list of - `reverse_indices`, i.e a mapping from the documents to their indices in the original - order. -""" -function _sort_by_length( - integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) - batch_size = size(integer_ids)[2] - if batch_size <= bsize - # if the number of passages fits the batch size, do nothing - integer_ids, integer_mask, Vector(1:batch_size) - end - - lengths = vec(sum(integer_mask; dims = 1)) # number of attended tokens in each passage - indices = sortperm(lengths) # get the indices which will sort lengths - reverse_indices = sortperm(indices) # invert the indices list - - integer_ids[:, indices], integer_mask[:, indices], reverse_indices -end +# """ +# _sort_by_length( +# integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) +# +# Sort sentences by number of attended tokens, if the number of sentences is larger than `bsize`. +# +# # Arguments +# +# - `integer_ids`: The token IDs of documents to be sorted. +# - `integer_mask`: The attention masks of the documents to be sorted (attention masks are just bits). +# - `bsize`: The size of batches to be considered. +# +# # Returns +# +# Depending upon `bsize`, the following are returned: +# +# - If the number of documents (second dimension of `integer_ids`) is atmost `bsize`, then the +# `integer_ids` and `integer_mask` are returned unchanged. +# - If the number of documents is larger than `bsize`, then the passages are first sorted +# by the number of attended tokens (figured out from the `integer_mask`), and then the +# sorted arrays `integer_ids`, `integer_mask` are returned, along with a list of +# `reverse_indices`, i.e a mapping from the documents to their indices in the original +# order. +# """ +# function _sort_by_length( +# integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}, batch_size::Int) +# size(integer_ids, 2) <= batch_size && +# return integer_ids, bitmask, Vector(1:size(integer_ids, 2)) +# lengths = vec(sum(bitmask; dims = 1)) # number of attended tokens in each passage +# indices = sortperm(lengths) # get the indices which will sort lengths +# reverse_indices = sortperm(indices) # invert the indices list +# @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" +# @assert bitmask isa BitMatrix "$(typeof(bitmask))" +# @assert reverse_indices isa Vector{Int} "$(typeof(reverse_indices))" +# integer_ids[:, indices], bitmask[:, indices], reverse_indices +# end function compute_distances_kernel!(batch_distances::AbstractMatrix{Float32}, batch_data::AbstractMatrix{Float32}, centroids::AbstractMatrix{Float32}) + isequal(size(batch_distances), (size(centroids, 2), size(batch_data, 2))) || + throw(DimensionMismatch("batch_distances should have size " * + "(num_centroids, point_bsize)!")) + isequal(size(batch_data, 1), size(centroids, 1)) || + throw(DimensionMismatch("batch_data and centroids should have " * + "the same embedding dimension!")) + batch_distances .= 0.0f0 # Compute squared distances: (a-b)^2 = a^2 + b^2 - 2ab # a^2 term - sum_sq_data = sum(batch_data .^ 2, dims = 1) # (1, point_bsize) + sum_sq_data = sum(batch_data .^ 2, dims = 1) # (1, point_bsize) # b^2 term - sum_sq_centroids = sum(centroids .^ 2, dims = 1)' # (num_centroids, 1) + sum_sq_centroids = sum(centroids .^ 2, dims = 1)' # (num_centroids, 1) # -2ab term mul!(batch_distances, centroids', batch_data, -2.0f0, 1.0f0) # (num_centroids, point_bsize) # Compute (a-b)^2 = a^2 + b^2 - 2ab @@ -56,20 +61,30 @@ end function update_centroids_kernel!(new_centroids::AbstractMatrix{Float32}, batch_data::AbstractMatrix{Float32}, batch_one_hot::AbstractMatrix{Float32}) + isequal( + size(new_centroids), (size(batch_data, 1), (size(batch_one_hot, 1)))) || + throw(DimensionMismatch("new_centroids should have the right shape " * + "for multiplying batch_data and batch_one_hot! ")) mul!(new_centroids, batch_data, batch_one_hot', 1.0f0, 1.0f0) end function assign_clusters_kernel!(batch_assignments::AbstractVector{Int32}, batch_distances::AbstractMatrix{Float32}) + length(batch_assignments) == size(batch_distances, 2) || + throw(DimensionMismatch("length(batch_assignments) " * + "should be equal to the point " * + "batch size of batch_distances!")) _, min_indices = findmin(batch_distances, dims = 1) batch_assignments .= getindex.(min_indices, 1) |> vec end function onehot_encode!(batch_one_hot::AbstractArray{Float32}, batch_assignments::AbstractVector{Int32}, k::Int) - # Create a range array for columns - col_indices = Vector(1:length(batch_assignments)) |> Flux.gpu - # Use broadcasting to set the appropriate elements to 1 + isequal(size(batch_one_hot), (k, length(batch_assignments))) || + throw(DimensionMismatch("batch_one_hot should have shape " * + "(k, length(batch_assignments))!")) + col_indices = similar(batch_assignments, length(batch_assignments)) # respects device + copyto!(col_indices, collect(1:length(batch_assignments))) batch_one_hot[batch_assignments .+ (col_indices .- 1) .* k] .= 1 end @@ -238,7 +253,14 @@ julia> centroids function kmeans_gpu_onehot!( data::AbstractMatrix{Float32}, centroids::AbstractMatrix{Float32}, k::Int; max_iters::Int = 10, tol::Float32 = 1.0f-4, point_bsize::Int = 1000) - @assert size(centroids)[2] == k + # TODO: move point_bsize to config? + size(centroids, 2) == k || + throw(DimensionMismatch("size(centroids, 2) must be k!")) + + # randomly initialize centroids + centroids .= data[:, randperm(size(data, 2))[1:k]] + + # allocations d, n = size(data) # dimension, number of inputs assignments = Vector{Int32}(undef, n) |> Flux.gpu distances = Matrix{Float32}(undef, k, point_bsize) |> Flux.gpu @@ -294,3 +316,21 @@ function kmeans_gpu_onehot!( Flux.cpu(assignments) end + +function _normalize_array!( + X::AbstractArray{T}; dims::Int = 1) where {T <: AbstractFloat} + norms = sqrt.(sum(abs2, X, dims = dims)) + epsilon = eps(T) + X ./= (norms .+ epsilon) +end + +function _topk(data::Matrix{T}, k::Int; dims::Int = 1) where {T <: Number} + # TODO: only works on CPU; make it work on GPUs? + # partialsortperm is not available in CUDA.jl + dims in [1, 2] || throw(DomainError("dims must be 1 or 2!")) + mapslices(v -> partialsortperm(v, 1:k, rev = true), data, dims = dims) +end + +function _head(v::Vector) + length(v) > 0 ? collect(take(v, length(v) - 1)) : similar(v, 0) +end diff --git a/test/Project.toml b/test/Project.toml index ab18642..29425ae 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,11 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" -Test = "1.6" \ No newline at end of file +Test = "1.6" diff --git a/test/indexing/codecs/residual.jl b/test/indexing/codecs/residual.jl new file mode 100644 index 0000000..56180c7 --- /dev/null +++ b/test/indexing/codecs/residual.jl @@ -0,0 +1,1007 @@ +using ColBERT: _normalize_array!, compress_into_codes!, _binarize, _unbinarize, + _bucket_indices, _packbits, _unpackbits, binarize, compress, + decompress_residuals, decompress + +@testset "compress_into_codes!" begin + # In most tests, we'll need unit vectors (so that dot-products become cosines) + # Test 1: Edge case, 1 centroid and 1 embedding + embs = rand(Float32, rand(1:128), 1) + centroids = embs + codes = zeros(UInt32, 1) + bsize = rand(1:(size(embs, 2) + 5)) + compress_into_codes!(codes, centroids, embs; bsize = bsize) + @test isequal(codes, UInt32[1]) + + # Test 2: Edge case, equal # of centroids and embedings + embs = rand(Float32, rand(1:128), rand(1:20)) + _normalize_array!(embs; dims = 1) + perm = randperm(size(embs, 2)) + centroids = embs[:, perm] + codes = zeros(UInt32, size(embs, 2)) + bsize = rand(1:(size(embs, 2) + 5)) + compress_into_codes!(codes, centroids, embs; bsize = bsize) + @test isequal(codes, sortperm(perm)) # sortperm(perm) -> inverse mapping + + # Test 3: sample centroids randomly from embeddings + embs = rand(Float32, rand(1:128), rand(1:20)) + _normalize_array!(embs; dims = 1) + perm = collect(take(randperm(size(embs, 2)), rand(1:size(embs, 2)))) + centroids = embs[:, perm] + codes = zeros(UInt32, size(embs, 2)) + bsize = rand(1:(size(embs, 2) + 5)) + compress_into_codes!(codes, centroids, embs) + @test all(in(1:size(centroids, 2)), codes) # in the right range + @test isequal(codes[perm], collect(1:length(perm))) # centroids have the right mappings + + # Test 4: Build embs by extending centroids + dim = rand(1:128) + tol = 1.0e-5 + scale = rand(2:5) # scaling factor + centroids = rand(Float32, dim, rand(1:20)) + _normalize_array!(centroids; dims = 1) + extension_mapping = rand(1:size(centroids, 2), scale * size(centroids, 2)) + embs = zeros(Float32, dim, length(extension_mapping)) + for (idx, col) in enumerate(eachcol(embs)) + # get some random noise + noise = -tol + 2 * tol * rand() + col .= centroids[:, extension_mapping[idx]] .+ noise + end + codes = zeros(UInt32, size(embs, 2)) + bsize = rand(1:(size(embs, 2) + 5)) + compress_into_codes!(codes, centroids, embs) + @test isequal(codes, extension_mapping) + + # Test 5: Check that an error is thrown if the lengths don't match + codes = zeros(UInt32, rand(1:(size(embs, 2) - 1))) + @test_throws DimensionMismatch compress_into_codes!(codes, centroids, embs) +end + +@testset "_binarize" begin + # defining the datapacks + datapacks = [ + ( + # Test 1: Basic functionality with a 2x2 matrix and 3 bits + data = [0 1; 2 3], + nbits = 3, + expected_output = reshape( + Bool[0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0], 3, 2, 2) + ), + ( + # Test 2: 2x2 matrix with 2 bits + data = [0 1; 2 3], + nbits = 2, + expected_output = reshape( + Bool[0, 0, 0, 1, 1, 0, 1, 1], 2, 2, 2) + ), + ( + # Test 3: Single value matrix + data = reshape([7], 1, 1), + nbits = 3, + expected_output = reshape(Bool[1, 1, 1], 3, 1, 1) + ), + ( + # Test 4: Edge case with nbits = 1 + data = [0 1; 0 1], + nbits = 1, + expected_output = reshape(Bool[0, 0, 1, 1], 1, 2, 2) + ) + ] + + for (data, nbits, expected_output) in datapacks + try + @test isequal(_binarize(data, nbits), expected_output) + catch + @show data, nbits, expected_output + end + end + + # Test 5: Invalid input with out-of-range values (should throw an error) + data = [0 1; 4 2] # 4 is out of range for 2 bits + nbits = 2 + @test_throws DomainError _binarize(data, nbits) + + # Test 6: Testing correct shapes and types + for int_type in INT_TYPES + nbits = rand(1:5) + dim = rand(1:500) + batch_size = rand(1:20) + data = map(Base.Fix1(convert, int_type), + rand(0:((1 << nbits) - 1), dim, batch_size)) + output = _binarize(data, nbits) + @test output isa Array{Bool, 3} + @test isequal(size(output), (nbits, dim, batch_size)) + end +end + +@testset "_unbinarize" begin + # Test 1: All bits are 0, should return zeros + nbits = rand(1:10) + data = falses(nbits, rand(1:20), rand(1:20)) + @test isequal(_unbinarize(data), zeros(Int, size(data, 2), size(data, 3))) + + # Test 2: All bits set to 1 + nbits = rand(1:10) + data = trues(nbits, rand(1:20), rand(1:20)) + @test isequal(_unbinarize(data), + (1 << nbits - 1) * ones(Int, size(data)[2], size(data)[3])) + + # Test 3: Edge case, single element + data = reshape(Bool[1, 0, 0, 1, 1], 5, 1, 1) + @test isequal(_unbinarize(data), reshape([25], 1, 1)) + + # Test 4: Multiple bits forming non-zero integers + # inputs matrix: [55 20; 49 24] + data = reshape( + Bool[1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, + 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], + 6, + 2, + 2) + @test isequal(_unbinarize(data), [55 20; 49 24]) + + # Test 4: Edge case with empty array + data = reshape(Bool[], 0, 0, 0) + @test isequal(_unbinarize(data), Matrix{Int}(undef, 0, 0)) + + # Test 5: Checking shapes and types + nbits = rand(1:20) + dim = rand(1:20) + batch_size = rand(1:20) + data = rand(Bool, nbits, dim, batch_size) + @test isequal(size(_unbinarize(data)), (dim, batch_size)) +end + +@testset "_unbinarize inverts _binarize" begin + # Take any random integer matrix, apply the ops and test equality with result + nbits = rand(1:20) + data = rand(0:((1 << nbits) - 1), rand(1:20), rand(1:20)) + binarized_data = _binarize(data, nbits) + unbinarzed_data = _unbinarize(binarized_data) + @test isequal(data, unbinarzed_data) +end + +@testset "_bucket_indices" begin + # defining datapacks + datapacks = [ + ( + # Test 1: Test with a matrix + data = [1 6; 3 12], + bucket_cutoffs = [0, 5, 10, 15], + expected = [1 2; 1 3] + ), + ( + + # Test 2: Edge case with empty data + data = Matrix{Float32}(undef, 0, 0), + bucket_cutoffs = [0, 10, 20], + expected = Matrix{Int}(undef, 0, 0) + ), + ( + # Test 3: Edge case with empty bucket_cutoffs + data = [5 15], + bucket_cutoffs = Float32[], + expected = [0 0] + ), + ( + # Test 4: with floats + data = [1.1 2.5 7.8], + bucket_cutoffs = [0.0, 2.0, 5.0, 10.0], + expected = [1 2 3] + ) + ] + + for (data, bucket_cutoffs, expected) in datapacks + try + @test isequal(_bucket_indices(data, bucket_cutoffs), expected) + catch + @show data, bucket_cutoffs, expected + end + end + + # Test 5: Check that range of indices is correct + data = rand(Float32, rand(1:20), rand(1:20)) + bucket_cutoffs = sort(rand(rand(1:100))) + @test all( + in(0:(length(bucket_cutoffs))), _bucket_indices(data, bucket_cutoffs)) + + # Test 6: Checking shapes, dimensions and types + for (T, S) in collect(product( + [INT_TYPES; FLOAT_TYPES], [INT_TYPES; FLOAT_TYPES])) + data = rand(T, rand(1:20), rand(1:20)) + bucket_cutoffs = sort(rand(rand(1:100))) + @test isequal(size(_bucket_indices(data, bucket_cutoffs)), size(data)) + end +end + +@testset "_packbits" begin + # In all tests, remember: BitArray.chunks reverses the endianess + # Test 1: Basic case with 1x64x1 array (one 64-bit block) + bitsarray = reshape( + Bool[1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, + 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, + 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0], + (1, 64, 1)) + expected = reshape( + UInt8[0b11011011; 0b01101010; 0b00111011; 0b11011010; 0b11100000; + 0b01011001; 0b11111011; 0b01010001], + 8, + 1) + @test isequal(_packbits(bitsarray), expected) + + # Test 2: All bits are 0s + bitsarray = falses(rand(1:20), 8 * rand(1:20), rand(1:20)) + expected = zeros( + UInt8, div(div(prod(size(bitsarray)), size(bitsarray, 3)), 8), + size(bitsarray, 3)) + @test isequal(_packbits(bitsarray), expected) + + # Test 3: All bits are 1s + bitsarray = trues(rand(1:20), 8 * rand(1:20), rand(1:20)) + expected = 0xff * ones( + UInt8, div(div(prod(size(bitsarray)), size(bitsarray, 3)), 8), + size(bitsarray, 3)) + @test isequal(_packbits(bitsarray), expected) + + # Test 4: Alternating bits; each byte is 0b10101010, or 0xaa + # Again, BitArray.chunks reverses endianess; so the byte is really 0b01010101 + # or 0x55 + bitsarray = trues(rand(1:20), 8 * rand(1:20), rand(1:20)) + bitsarray[collect(2:2:prod(size(bitsarray)))] .= 0 + expected = 0x55 * ones( + UInt8, div(div(prod(size(bitsarray)), size(bitsarray, 3)), 8), + size(bitsarray, 3)) + @test isequal(_packbits(bitsarray), expected) + + # Test 4: Edge case with an empty array (should be empty) + bitsarray = reshape(Bool[], (0, 0, 0)) + expected = Matrix{UInt8}(undef, 0, 0) + @test isequal(_packbits(bitsarray), expected) + + # Test 5: Ensure that proper errors are thrown + @test_throws DomainError _packbits(trues(3, 7, 5)) # dim not a multiple of 64 + + # Test 6: Test shapes and types + dim = 8 * rand(1:20) + nbits = rand(1:20) + batch_size = rand(1:20) + bitsarray = rand(Bool, nbits, dim, batch_size) + output = _packbits(bitsarray) + @test output isa Matrix{UInt8} + @test isequal(size(output), (div(dim * nbits, 8), batch_size)) +end + +@testset "_unpackbits" begin + # Again, remember: BitArray.chunks reverses the endianess + # Test 1: Basic case with 1x8 matrix, with nbits = 1 + nbits = 1 + packed_bits = reshape( + UInt8[0b00101010, 0b00010001, 0b11111111, 0b01000000, + 0b10000000, 0b11001000, 0b00100001, 0b01010111, + 0b00101010, 0b00010001, 0b11111111, 0b01000000, + 0b10000000, 0b11001000, 0b00100001, 0b01010111, + 0b00101010, 0b00010001, 0b11111111, 0b01000000, + 0b10000000, 0b11001000, 0b00100001, 0b01010111, + 0b00101010, 0b00010001, 0b11111111, 0b01000000, + 0b10000000, 0b11001000, 0b00100001, 0b01010111, + 0b00101010, 0b00010001, 0b11111111, 0b01000000, + 0b10000000, 0b11001000, 0b00100001, 0b01010111, + 0b00101010, 0b00010001, 0b11111111, 0b01000000, + 0b10000000, 0b11001000, 0b00100001, 0b01010111, + 0b00101010, 0b00010001, 0b11111111, 0b01000000, + 0b10000000, 0b11001000, 0b00100001, 0b01010111, + 0b00101010, 0b00010001, 0b11111111, 0b01000000, + 0b10000000, 0b11001000, 0b00100001, 0b01010111 + ], + 1, 64) + + expected = reshape( + Bool[ + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 0, + 1, + 0, + 1, + 0 + ], + nbits, + 8 * div(prod(size(packed_bits)), + nbits * size(packed_bits, 2)), size(packed_bits, 2)) + unpacked_bits = _unpackbits(packed_bits, nbits) + @test isequal(unpacked_bits, expected) + + # Test 2: All zeros + nbits = rand(1:10) + packed_bits = zeros(UInt8, nbits * rand(1:20), rand(1:20)) + expected = falses( + nbits, 8 * div(prod(size(packed_bits)), nbits * size(packed_bits, 2)), + size(packed_bits, 2)) + @test isequal(_unpackbits(packed_bits, nbits), expected) + + # Test 3: All ones + nbits = rand(1:10) + packed_bits = 0xff * ones(UInt8, nbits * rand(1:20), rand(1:20)) + expected = trues( + nbits, 8 * div(prod(size(packed_bits)), nbits * size(packed_bits, 2)), + size(packed_bits, 2)) + @test isequal(_unpackbits(packed_bits, nbits), expected) + + # Test 4: types and shapes + nbits = rand(1:10) + batch_size = rand(1:20) + packed_bits = rand(UInt8, nbits * rand(1:20), batch_size) + dim = 8 * div(prod(size(packed_bits)), nbits * batch_size) + output = _unpackbits(packed_bits, nbits) + @test output isa AbstractArray{Bool, 3} + @test isequal(size(output), (nbits, dim, batch_size)) +end + +@testset "_unpackbits inverts _packbits" begin + nbits = rand(1:10) + bitsarray = rand(Bool, nbits, 8 * rand(1:20), rand(1:20)) + packed_bits = _packbits(bitsarray) + unpacked_bits = _unpackbits(packed_bits, nbits) + @test isequal(unpacked_bits, bitsarray) +end + +@testset "binarize" begin + # Test 1: Checking the types and dimensions + dim = 8 * rand(1:20) + nbits = rand(1:20) + bucket_cutoffs = sort(rand(Float32, (1 << nbits) - 1)) + residuals = rand(Float32, dim, rand(1:100)) + binarized_residuals = binarize(dim, nbits, bucket_cutoffs, residuals) + @test binarized_residuals isa Matrix{UInt8} + @test isequal(size(residuals, 2), size(binarized_residuals, 2)) + @test isequal(size(binarized_residuals, 1), div(dim, 8) * nbits) + + # Test 2: Checking correct errors being thrown + @test_throws DomainError binarize( + 7, 7, sort(rand(Float32, 1 << 7 - 1)), rand(Float32, 7, 10)) # dim not multiple of 8 + @test_throws DomainError binarize( + 8, 8, sort(rand(Float32, 2^8 - 2)), rand(Float32, 8, 10)) # incorrect length for bucket_cutoffs +end + +@testset "compress" begin + # Test 1: Edge case, 1 centroid and 1 embedding + dim = 8 * rand(1:20) + nbits = rand(1:20) + bucket_cutoffs = sort(rand(Float32, (1 << nbits) - 1)) + embs = rand(Float32, dim, 1) + centroids = embs + bsize = rand(1:(size(embs, 2) + 5)) + codes, residuals = compress( + centroids, bucket_cutoffs, dim, nbits, embs; bsize = bsize) + @test isequal(codes, UInt32[1]) + @test all(isequal(zero(UInt8)), residuals) + + # Test 2: Edge case, equal # of centroids and embeddings + dim = 8 * rand(1:20) + nbits = rand(1:20) + bucket_cutoffs = sort(rand(Float32, (1 << nbits) - 1)) + embs = rand(Float32, dim, rand(1:20)) + _normalize_array!(embs; dims = 1) + perm = randperm(size(embs, 2)) + centroids = embs[:, perm] + bsize = rand(1:(size(embs, 2) + 5)) + codes, residuals = compress( + centroids, bucket_cutoffs, dim, nbits, embs; bsize = bsize) + @test isequal(codes, sortperm(perm)) # sortperm(perm) -> inverse mapping + @test all(isequal(zero(UInt8)), residuals) + + # Test 3: sample centroids randomly from embeddings + dim = 8 * rand(1:20) + nbits = rand(1:20) + bucket_cutoffs = sort(rand(Float32, (1 << nbits) - 1)) + embs = rand(Float32, dim, rand(1:20)) + _normalize_array!(embs; dims = 1) + perm = collect(take(randperm(size(embs, 2)), rand(1:size(embs, 2)))) + centroids = embs[:, perm] + bsize = rand(1:(size(embs, 2) + 5)) + codes, residuals = compress( + centroids, bucket_cutoffs, dim, nbits, embs; bsize = bsize) + @test all(in(1:size(centroids, 2)), codes) # in the right range + @test isequal(codes[perm], collect(1:length(perm))) # centroids have the right mappings + @test all(isequal(zero(UInt8)), residuals[:, perm]) # centroids have zero residuals + + # Test 4: Build embs by extending centroids + tol = 1.0e-5 + scale = rand(2:5) # scaling factor + dim = 8 * rand(1:20) + nbits = rand(1:20) + bucket_cutoffs = sort(rand(Float32, (1 << nbits) - 1)) + centroids = rand(Float32, dim, rand(1:20)) + _normalize_array!(centroids; dims = 1) + extension_mapping = rand(1:size(centroids, 2), scale * size(centroids, 2)) + embs = zeros(Float32, dim, length(extension_mapping)) + for (idx, col) in enumerate(eachcol(embs)) + # get some random noise + noise = -tol + 2 * tol * rand() + col .= centroids[:, extension_mapping[idx]] .+ noise + end + bsize = rand(1:(size(embs, 2) + 5)) + codes, residuals = compress( + centroids, bucket_cutoffs, dim, nbits, embs; bsize = bsize) + @test isequal(codes, extension_mapping) + # TODO somehow test that all the absmax of all residuals is atmost tol + + # Test 5: Shapes and types + dim = 8 * rand(1:20) + nbits = rand(1:20) + bucket_cutoffs = sort(rand(Float32, (1 << nbits) - 1)) + embs = rand(Float32, dim, rand(1:20)) + _normalize_array!(embs; dims = 1) + perm = collect(take(randperm(size(embs, 2)), rand(1:size(embs, 2)))) + centroids = embs[:, perm] + bsize = rand(1:(size(embs, 2) + 5)) + codes, residuals = compress( + centroids, bucket_cutoffs, dim, nbits, embs; bsize = bsize) + @test codes isa Vector{UInt32} + @test residuals isa Matrix{UInt8} + @test isequal(length(codes), size(embs, 2)) + @test isequal(size(residuals, 2), size(embs, 2)) + @test isequal(size(residuals, 1), div(dim, 8) * nbits) +end + +@testset "decompress_residuals" begin + # Test 1: Checking types and dimensions, and correct range of values + dim = 8 * rand(1:20) + nbits = rand(1:20) + bucket_weights = sort(rand(Float32, 1 << nbits)) + binarized_residuals = rand(UInt8, div(dim, 8) * nbits, rand(1:20)) + residuals = decompress_residuals( + dim, nbits, bucket_weights, binarized_residuals) + @test residuals isa Matrix{Float32} + @test isequal(size(residuals, 2), size(binarized_residuals, 2)) + @test isequal(size(residuals, 1), dim) + @test all(in(bucket_weights), residuals) + + # Test 2: Checking correct errors being thrown + @test_throws DomainError decompress_residuals( + 7, 7, sort(rand(Float32, 1 << 7)), + rand(UInt8, div(7, 8) * 7, rand(1:20))) # dim not a multiple of 8 + @test_throws DomainError decompress_residuals( + 8, 8, sort(rand(Float32, (1 << 8) - 1)), + rand(UInt8, div(8, 8) * 8, rand(1:20))) # bucket_weights not having correct length + @test_throws DomainError decompress_residuals( + 8, 8, sort(rand(Float32, 1 << 8)), rand(UInt8, 7, 64 * rand(1:100))) # binarized_residuals having an incorrect dim +end + +@testset "decompress_residuals inverts binarize" begin + # not exactly inverse, but a close inverse + dim = 8 * rand(1:20) + nbits = rand(1:20) + bucket_cutoffs = sort(rand(Float32, (1 << nbits) - 1)) + bucket_weights = sort(rand(Float32, 1 << nbits)) + residuals = rand(Float32, dim, rand(1:100)) + + # map each residual to it's expected weight + expected_indices = map( + Base.Fix1(searchsortedfirst, bucket_cutoffs), residuals) + expected = bucket_weights[expected_indices] + binarized_residuals = binarize(dim, nbits, bucket_cutoffs, residuals) + decompressed_residuals = decompress_residuals( + dim, nbits, bucket_weights, binarized_residuals) + @test isequal(expected, decompressed_residuals) +end + +@testset "decompress" begin + # Test 1: Types and shapes, and right range of values + dim = 8 * rand(1:20) + nbits = rand(1:20) + batch_size = rand(1:100) + bucket_weights = sort(rand(Float32, 1 << nbits)) + centroids = rand(Float32, dim, rand(1:100)) + codes = UInt32.(rand(1:size(centroids, 2), batch_size)) + binarized_residuals = rand(UInt8, div(dim, 8) * nbits, batch_size) + bsize = rand(1:(batch_size + 5)) + embeddings = decompress(dim, nbits, centroids, bucket_weights, + codes, binarized_residuals; bsize = bsize) + @test embeddings isa Matrix{Float32} + @test isequal(size(embeddings), (dim, length(codes))) +end diff --git a/test/indexing/collection_indexer.jl b/test/indexing/collection_indexer.jl new file mode 100644 index 0000000..e987e1c --- /dev/null +++ b/test/indexing/collection_indexer.jl @@ -0,0 +1,304 @@ +using ColBERT: _sample_pids, _heldout_split, setup, _bucket_cutoffs_and_weights, + _normalize_array!, _compute_avg_residuals!, train, + _check_all_files_are_saved, _collect_embedding_id_offset, + _build_ivf + +@testset "_sample_pids tests" begin + # Test 1: More pids than given can't be sampled + num_documents = rand(0:100000) + pids = _sample_pids(num_documents) + @test length(pids) <= num_documents + + # Test 2: Edge case, when + num_documents = rand(0:1) + pids = _sample_pids(num_documents) + @test length(pids) <= num_documents +end + +@testset "_heldout_split" begin + # Test 1: A basic test with a large size + sample = rand(Float32, rand(1:20), 100000) + for heldout_fraction in Float32.(collect(0.1:0.1:1.0)) + sample_train, sample_heldout = _heldout_split( + sample; heldout_fraction = heldout_fraction) + heldout_size = min(50000, Int(floor(100000 * heldout_fraction))) + @test size(sample_train, 2) == 100000 - heldout_size + @test size(sample_heldout, 2) == heldout_size + end + + # Test 2: Edge case with 1 column, should return empty train and full heldout + sample = rand(Float32, 3, 1) + heldout_fraction = 0.5f0 + sample_train, sample_heldout = _heldout_split( + sample; heldout_fraction = heldout_fraction) + @test size(sample_train, 2) == 0 # No columns should be left in the train set + @test size(sample_heldout, 2) == 1 # All columns in the heldout set +end + +@testset "setup" begin + # Test 1: Number of documents and chunksize should not be altered + collection = string.(rand('a':'z', rand(1:1000))) + avg_doclen_est = Float32(100 * rand()) + nranks = rand(1:10) + num_clustering_embs = rand(1:1000) + chunksize = rand(1:20) + plan_dict = setup( + collection, avg_doclen_est, num_clustering_embs, chunksize, nranks) + @test plan_dict["avg_doclen_est"] == avg_doclen_est + @test plan_dict["chunksize"] == chunksize + @test plan_dict["num_documents"] == length(collection) + @test plan_dict["num_embeddings_est"] == avg_doclen_est * length(collection) + + # Test 2: Tests for number of chunks + avg_doclen_est = 1.0f0 + nranks = rand(1:10) + num_clustering_embs = rand(1:1000) + + ## without remainders + chunksize = rand(1:20) + collection = string.(rand('a':'z', chunksize * rand(1:100))) + plan_dict = setup( + collection, avg_doclen_est, num_clustering_embs, chunksize, nranks) + @test plan_dict["num_chunks"] == div(length(collection), chunksize) + + ## with remainders + chunksize = rand(1:20) + 1 + collection = string.(rand( + 'a':'z', chunksize * rand(1:100) + rand(1:(chunksize - 1)))) + plan_dict = setup( + collection, avg_doclen_est, num_clustering_embs, chunksize, nranks) + @test plan_dict["num_chunks"] == div(length(collection), chunksize) + 1 + + # Test 3: Tests for number of clusters + collection = string.(rand('a':'z', rand(1:1000))) + avg_doclen_est = Float32(100 * rand()) + nranks = rand(1:10) + num_clustering_embs = rand(1:10000) + chunksize = rand(1:20) + plan_dict = setup( + collection, avg_doclen_est, num_clustering_embs, chunksize, nranks) + @test plan_dict["num_partitions"] <= num_clustering_embs + @test plan_dict["num_partitions"] <= + 16 * sqrt(avg_doclen_est * length(collection)) +end + +@testset "_bucket_cutoffs_and_weights" begin + # Test 1: Basic test with 2x2 matrix and nbits=2 + heldout_avg_residual = [0.0f0 0.2f0; 0.4f0 0.6f0; 0.8f0 1.0f0] + nbits = 2 + cutoffs, weights = _bucket_cutoffs_and_weights(nbits, heldout_avg_residual) + expected_cutoffs = Float32[0.25, 0.5, 0.75] + expected_weights = Float32[0.125, 0.375, 0.625, 0.875] + @test cutoffs ≈ expected_cutoffs + @test weights ≈ expected_weights + + # Test 2: Uniform values + value = rand(Float32) + heldout_avg_residual = value * ones(Float32, rand(1:20), rand(1:20)) + nbits = rand(1:10) + cutoffs, weights = _bucket_cutoffs_and_weights(nbits, heldout_avg_residual) + @test all(isequal(value), cutoffs) + @test all(isequal(value), weights) + + # Test 3: Shapes and types + heldout_avg_residual = rand(Float32, rand(1:20), rand(1:20)) + nbits = rand(1:10) + cutoffs, weights = _bucket_cutoffs_and_weights(nbits, heldout_avg_residual) + @test length(cutoffs) == (1 << nbits) - 1 + @test length(weights) == 1 << nbits + @test cutoffs isa Vector{Float32} + @test weights isa Vector{Float32} +end + +@testset "_compute_avg_residuals!" begin + # Test 1: centroids and heldout_avg_residual have the same columns with different perms + nbits = rand(1:20) + centroids = rand(Float32, rand(2:20), rand(1:20)) + _normalize_array!(centroids; dims = 1) + perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))] + heldout = centroids[:, perm] + codes = Vector{UInt32}(undef, size(heldout, 2)) + bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals!( + nbits, centroids, heldout, codes) + @test all(iszero, bucket_cutoffs) + @test all(iszero, bucket_weights) + @test iszero(avg_residual) + + # Test 2: some tolerance level + tol = 1e-5 + nbits = rand(1:20) + centroids = rand(Float32, rand(2:20), rand(1:20)) + _normalize_array!(centroids; dims = 1) + perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))] + heldout = centroids[:, perm] + for col in eachcol(heldout) + col .+= -tol + 2 * tol * rand() + end + codes = Vector{UInt32}(undef, size(heldout, 2)) + bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals!( + nbits, centroids, heldout, codes) + @test all(<=(tol), bucket_cutoffs) + @test all(<=(tol), bucket_weights) + @test avg_residual <= tol + + # Test 3: Shapes and types + nbits = rand(1:20) + dim = rand(1:20) + centroids = rand(Float32, dim, rand(1:20)) + heldout = rand(Float32, dim, rand(1:20)) + codes = Vector{UInt32}(undef, size(heldout, 2)) + bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals!( + nbits, centroids, heldout, codes) + @test length(bucket_cutoffs) == (1 << nbits) - 1 + @test length(bucket_weights) == 1 << nbits + @test bucket_cutoffs isa Vector{Float32} + @test bucket_weights isa Vector{Float32} + @test avg_residual isa Float32 + + # Test 4: Correct errors are thrown + nbits = 2 + centroids = Float32[1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0] # (3, 3) matrix + heldout = Float32[1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0] # (3, 3) matrix + codes = UInt32[0, 0] # Length is 2, but `heldout` has 3 columns + # Check for DimensionMismatch error + @test_throws DimensionMismatch _compute_avg_residuals!( + nbits, centroids, heldout, codes) +end + +@testset "train" begin + # Test 1: When all inputs are the same + testing shapes, types + dim = rand(2:20) + nbits = rand(1:5) + kmeans_niters = rand(1:5) + sample = ones(Float32, dim, rand(1:20)) + heldout = ones(Float32, dim, rand(1:size(sample, 2))) + num_partitions = rand(1:size(sample, 2)) + centroids, bucket_cutoffs, bucket_weights, avg_residual = train( + sample, heldout, num_partitions, nbits, kmeans_niters) + @test all(iszero(bucket_cutoffs)) + @test all(iszero(bucket_weights)) + @test iszero(avg_residual) + @test centroids isa Matrix{Float32} + @test bucket_cutoffs isa Vector{Float32} + @test bucket_weights isa Vector{Float32} + @test avg_residual isa Float32 + @test isequal(size(centroids), (dim, num_partitions)) + @test length(bucket_cutoffs) == (1 << nbits) - 1 + @test length(bucket_weights) == (1 << nbits) +end + +@testset "_check_all_files_are_saved" begin + temp_dir = mktempdir() + + # Create plan.json with required structure + plan_data = Dict( + "num_chunks" => 2, + "avg_doclen_est" => 10, + "num_documents" => 100, + "num_embeddings_est" => 200, + "num_embeddings" => 200, + "embeddings_offsets" => [0, 100], + "num_partitions" => 4, + "chunksize" => 50 + ) + open(joinpath(temp_dir, "plan.json"), "w") do f + JSON.print(f, plan_data) + end + + # Create non-chunk files + non_chunk_files = [ + "config.json", + "centroids.jld2", + "bucket_cutoffs.jld2", + "bucket_weights.jld2", + "avg_residual.jld2", + "ivf.jld2", + "ivf_lengths.jld2" + ] + for file in non_chunk_files + touch(joinpath(temp_dir, file)) + end + + # Create chunk files + for chunk_idx in 1:plan_data["num_chunks"] + chunk_metadata = Dict( + "num_passages" => 50, + "num_embeddings" => 100, + "passage_offset" => 0 + ) + open(joinpath(temp_dir, "$(chunk_idx).metadata.json"), "w") do f + JSON.print(f, chunk_metadata) + end + touch(joinpath(temp_dir, "$(chunk_idx).codes.jld2")) + touch(joinpath(temp_dir, "$(chunk_idx).residuals.jld2")) + touch(joinpath(temp_dir, "doclens.$(chunk_idx).jld2")) + end + + # Test 1: Check that all files exist + @test _check_all_files_are_saved(temp_dir) + + # Test 2: Remove one file at a time and check the function returns false + all_files = [ + "config.json", + non_chunk_files..., + "$(1).codes.jld2", "$(1).residuals.jld2", "doclens.1.jld2", "1.metadata.json", + "$(2).codes.jld2", "$(2).residuals.jld2", "doclens.2.jld2", "2.metadata.json" + ] + + for file in all_files + rm(joinpath(temp_dir, file)) + @test !_check_all_files_are_saved(temp_dir) + touch(joinpath(temp_dir, file)) # Recreate the file for the next iteration + end + rm(joinpath(temp_dir, "plan.json")) + @test !_check_all_files_are_saved(temp_dir) + + # Clean up + rm(temp_dir, recursive = true) +end + +@testset "_collect_embedding_id_offset" begin + # Test 1: Small test with fixed values + chunk_emb_counts = [3, 5, 2] + total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts) + @test total_sum == 10 + @test offsets == [1, 4, 9] + + # Test 2: Edge case with empty inputs + chunk_emb_counts = Int[] + total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts) + @test total_sum == 0 # No elements, so sum is 0 + @test offsets == [0] # When empty, it should return [0] + + # Test 3: All elements are ones + chunk_emb_counts = ones(Int, rand(1:20)) + total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts) + @test total_sum == length(chunk_emb_counts) + @test offsets == collect(1:length(chunk_emb_counts)) + + # Test 4: Type of outputs + chunk_emb_counts = rand(Int, rand(1:20)) + total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts) + @test total_sum isa Int + @test offsets isa Vector{Int} +end + +@testset "_build_ivf" begin + # Test 1: Typical input case + codes = UInt32[5, 3, 8, 2, 5, 5, 4, 2, 2, 1, 3] + num_partitions = 10 + ivf, ivf_lengths = _build_ivf(codes, num_partitions) + @test ivf == [10, 4, 8, 9, 2, 11, 7, 1, 5, 6, 3] + @test ivf_lengths == [1, 3, 2, 1, 3, 0, 0, 1, 0, 0] + + # Test 2: Testing types, shapes and range of vals + num_partitions = rand(1:1000) + codes = UInt32.(rand(1:num_partitions, 10000)) # Large array with random values + ivf, ivf_lengths = _build_ivf(codes, num_partitions) + @test length(ivf) == length(codes) + @test sum(ivf_lengths) == length(codes) + @test length(ivf_lengths) == num_partitions + @test all(in(ivf), codes) + @test ivf isa Vector{Int} + @test ivf_lengths isa Vector{Int} +end diff --git a/test/infra/config.jl b/test/infra/config.jl new file mode 100644 index 0000000..271db33 --- /dev/null +++ b/test/infra/config.jl @@ -0,0 +1,18 @@ +@testset "config.jl" begin + index_path = "./test_index" + config = ColBERTConfig(index_path = index_path) + key_vals = Dict([field => getproperty( + config, field) + for field in fieldnames(ColBERTConfig)]) + + ColBERT.save(config) + @test isfile(joinpath( + index_path, "config.json")) + + config = ColBERT.load_config(index_path) + @test config isa ColBERTConfig + for field in fieldnames(ColBERTConfig) + @test isequal( + getproperty(config, field), key_vals[field]) + end +end diff --git a/test/modelling/embedding_utils.jl b/test/modelling/embedding_utils.jl new file mode 100644 index 0000000..52ce3fa --- /dev/null +++ b/test/modelling/embedding_utils.jl @@ -0,0 +1,156 @@ +using ColBERT: mask_skiplist!, _clear_masked_embeddings!, _flatten_embeddings, + _remove_masked_tokens + +@testset "mask_skiplist!" begin + # Test Case 1: Simple case with no skips + mask = trues(3, 3) + integer_ids = Int32[1 2 3; 4 5 6; 7 8 9] + skiplist = Int[] + expected_mask = trues(3, 3) + mask_skiplist!(mask, integer_ids, skiplist) + @test mask == expected_mask + + # Test Case 2: Skip one value + mask = trues(3, 3) + integer_ids = Int32[1 2 3; 4 5 6; 7 8 9] + skiplist = [5] + expected_mask = [true true true; true false true; true true true] + mask_skiplist!(mask, integer_ids, skiplist) + @test mask == expected_mask + + # Test Case 3: Skip multiple values + mask = trues(3, 3) + integer_ids = Int32[1 2 3; 4 5 6; 7 8 9] + skiplist = [2, 6, 9] + expected_mask = [true false true; true true false; true true false] + mask_skiplist!(mask, integer_ids, skiplist) + @test mask == expected_mask + + # Test Case 4: All values in skiplist + mask = trues(3, 3) + integer_ids = Int32[1 2 3; 4 5 6; 7 8 9] + skiplist = [1, 2, 3, 4, 5, 6, 7, 8, 9] + expected_mask = falses(3, 3) + mask_skiplist!(mask, integer_ids, skiplist) + @test mask == expected_mask + + # Test Case 5: Empty integer_ids matrix + mask = trues(0, 0) + integer_ids = rand(Int32, 0, 0) + skiplist = [1] + expected_mask = trues(0, 0) + mask_skiplist!(mask, integer_ids, skiplist) + @test mask == expected_mask + + # Test Case 6: Skiplist with no matching values + mask = trues(3, 3) + integer_ids = Int32[1 2 3; 4 5 6; 7 8 9] + skiplist = [10, 11] + expected_mask = trues(3, 3) + mask_skiplist!(mask, integer_ids, skiplist) + @test mask == expected_mask +end + +@testset "_clear_masked_embeddings!" begin + # Test Case 1: No skiplist entries + dim, len, bsize = rand(1:20, 3) + D = rand(Float32, dim, len, bsize) + integer_ids = rand(Int32, len, bsize) + skiplist = Int[] + expected_D = copy(D) + _clear_masked_embeddings!(D, integer_ids, skiplist) + @test D == expected_D + + # Test Case 2: Single skiplist entry + dim, len, bsize = rand(1:20, 3) + D = rand(Float32, dim, len, bsize) + integer_ids = rand(Int32, len, bsize) + skiplist = Int[integer_ids[rand(1:(len * bsize))]] + expected_D = copy(D) + expected_D[:, findall(in(skiplist), integer_ids)] .= 0.0f0 + _clear_masked_embeddings!(D, integer_ids, skiplist) + @test D == expected_D + + # Test Case 3: Multiple skiplist entries + dim, len, bsize = rand(1:20, 3) + D = rand(Float32, dim, len, bsize) + integer_ids = rand(Int32, len, bsize) + skiplist = unique(Int.(rand(vec(integer_ids), rand(1:(len * bsize))))) + expected_D = copy(D) + expected_D[:, findall(in(skiplist), integer_ids)] .= 0.0f0 + _clear_masked_embeddings!(D, integer_ids, skiplist) + @test D == expected_D + + # Test Case 4: Skip all tokens + dim, len, bsize = rand(1:20, 3) + D = rand(Float32, dim, len, bsize) + integer_ids = rand(Int32, len, bsize) + skiplist = unique(Int.(vec(integer_ids))) + expected_D = similar(D) + expected_D .= 0.0f0 + _clear_masked_embeddings!(D, integer_ids, skiplist) + @test D == expected_D + + # Test Case 5: Skiplist with no matching tokens + dim, len, bsize = rand(1:20, 3) + D = rand(Float32, dim, len, bsize) + integer_ids = Int32.(rand(1:100, len, bsize)) + skiplist = unique(rand(101:1000, rand(1:20))) + expected_D = copy(D) + _clear_masked_embeddings!(D, integer_ids, skiplist) + @test D == expected_D + + # Test 6: Types and shapes + dim, len, bsize = rand(1:20, 3) + D = rand(Float32, dim, len, bsize) + integer_ids = rand(Int32, len, bsize) + skiplist = unique(rand(Int, rand(1:20))) + mask = _clear_masked_embeddings!(D, integer_ids, skiplist) + @test mask isa Array{Bool, 3} + @test isequal(size(mask), (1, size(D)[2:end]...)) +end + +@testset "_flatten_embeddings" begin + # Test Case 1: Generic case; len will correspond to a vector of constants + dim, len, bsize = rand(1:20, 3) + D = Array{Float32}(undef, dim, len, bsize) + for idx in 1:len + D[:, idx, :] .= idx + end + expected = Matrix{Float32}(undef, dim, len * bsize) + for idx in 1:len + expected[:, [idx + k * len for k in 0:(bsize - 1)]] .= idx + end + @test _flatten_embeddings(D) == expected + + # Test Case 2: Edge case with 0x3x2 array (should return 0x6 array) + D = Float32[] + D = reshape(D, 0, 3, 2) + expected_output = reshape(Float32[], 0, 6) + @test _flatten_embeddings(D) == expected_output +end + +@testset "_remove_masked_tokens" begin + # Test 1: Generic case; build a skiplist, and manually build the expected tensor + dim, len, bsize = rand(1:20, 3) + mask = trues(len, bsize) + skiplist = unique(rand(1:len, rand(1:len))) + for id in skiplist + mask[id, :] .= false + end + D = Matrix{Float32}(undef, dim, len * bsize) + for idx in 1:len + D[:, [idx + k * len for k in 0:(bsize - 1)]] .= idx + end + expected = rand(Float32, dim, 0) + for emb_id in 1:size(D, 2) + if !(D[1, emb_id] in skiplist) + expected = hcat(expected, D[:, emb_id]) + end + end + @test _remove_masked_tokens(D, mask) == expected + + # Test 2: Test for errors + @test_throws DimensionMismatch _remove_masked_tokens( + rand(Float32, 12, 20), rand(Bool, 4, 4)) +end diff --git a/test/modelling/tokenization/tokenizer_utils.jl b/test/modelling/tokenization/tokenizer_utils.jl new file mode 100644 index 0000000..d4572f7 --- /dev/null +++ b/test/modelling/tokenization/tokenizer_utils.jl @@ -0,0 +1,19 @@ +using ColBERT: _add_marker_row + +@testset "_add_marker_row" begin + for type in [INT_TYPES; FLOAT_TYPES] + # Test 1: Generic + num_rows, num_cols = rand(1:20), rand(1:20) + x = rand(type, num_rows, num_cols) + x = _add_marker_row(x, zero(type)) + @test isequal(size(x), (num_rows + 1, num_cols)) + @test isequal(x[2, :], repeat([zero(type)], num_cols)) + + # Test 2: Edge case, empty array + num_cols = rand(1:20) + x = rand(type, 0, num_cols) + x = _add_marker_row(x, zero(type)) + @test isequal(size(x), (1, num_cols)) + @test isequal(x[1, :], repeat([zero(type)], num_cols)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6e9e2c2..1a1af4e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,28 @@ using ColBERT +using .Iterators +using JSON +using LinearAlgebra +using Logging +using Random using Test -include("Aqua.jl") +# turn off logging +logger = NullLogger() +global_logger(logger) + +const INT_TYPES = [ + Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128] +const FLOAT_TYPES = [Float16, Float32, Float64] + +# include("Aqua.jl") + +# indexing operations +include("indexing/codecs/residual.jl") +include("indexing/collection_indexer.jl") + +# modelling operations +include("modelling/tokenization/tokenizer_utils.jl") +include("modelling/embedding_utils.jl") + +# utils +include("utils.jl") diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..a109e5c --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,213 @@ +using ColBERT: compute_distances_kernel!, update_centroids_kernel!, + assign_clusters_kernel!, onehot_encode!, kmeans_gpu_onehot!, + _normalize_array!, _topk, _head + +@testset "compute_distances_kernel!" begin + # Test 1: when all entries are the same + dim = rand(1:20) + batch_data = ones(Float32, dim, rand(1:20)) + centroids = ones(Float32, dim, rand(1:20)) + batch_distances = Matrix{Float32}( + undef, size(centroids, 2), size(batch_data, 2)) + compute_distances_kernel!(batch_distances, batch_data, centroids) + @test all(iszero, batch_distances) + + # Test 2: Edge case, single point and centroid + batch_data = reshape(Float32[1.0; 2.0], 2, 1) + centroids = reshape(Float32[2.0; 3.0], 2, 1) + batch_distances = Matrix{Float32}(undef, 1, 1) + compute_distances_kernel!(batch_distances, batch_data, centroids) + @test batch_distances ≈ Float32[2] + + # Test 3: Special case + dim = rand(1:20) + bsize = rand(1:20) + batch_data = ones(Float32, dim, bsize) + centroids = ones(Float32, dim, bsize) + for idx in 1:bsize + batch_data[:, idx] .*= idx + centroids[:, idx] .*= idx + end + expected_distances = ones(Float32, bsize, bsize) + for (i, j) in product(1:bsize, 1:bsize) + expected_distances[i, j] = dim * (i - j)^2 + end + batch_distances = Matrix{Float32}(undef, bsize, bsize) + compute_distances_kernel!(batch_distances, batch_data, centroids) + @test isequal(expected_distances, batch_distances) + + # Test 4: Correct errors are thrown + batch_data = Float32[1.0 2.0; 3.0 4.0] # 2x2 matrix + centroids = Float32[1.0 0.0; 0.0 1.0] # 2x2 matrix + batch_distances = zeros(Float32, 3, 2) # Incorrect size: should be 2x2 + @test_throws DimensionMismatch compute_distances_kernel!( + batch_distances, batch_data, centroids) + + batch_data = Float32[1.0 2.0; 3.0 4.0] # 2x2 matrix + centroids = Float32[1.0 0.0 1.0; 0.0 1.0 0.0; 1.0 1.0 1.0] # 3x3 matrix, different row count + batch_distances = zeros(Float32, 3, 2) # Should match 3x2, but embedding dim is wrong + @test_throws DimensionMismatch compute_distances_kernel!( + batch_distances, batch_data, centroids) +end + +@testset "update_centroids_kernel!" begin + # Test 1: Generic test to see if results are accumulated correctly + dim = rand(1:20) + num_centroids = rand(1:20) + num_points = rand(1:20) + point_to_centroid = rand(1:num_centroids, num_points) + new_centroids = ones(Float32, dim, num_centroids) + batch_data = ones(Float32, dim, num_points) + batch_one_hot = zeros(Float32, num_centroids, num_points) + for idx in 1:num_points + batch_one_hot[point_to_centroid[idx], idx] = 1.0f0 + end + expected = zeros(Float32, dim, num_centroids) + for centroid in point_to_centroid + expected[:, centroid] .+= 1.0f0 + end + update_centroids_kernel!(new_centroids, batch_data, batch_one_hot) + @test isequal(new_centroids, expected .+ 1.0f0) + + # Test 2: error, incorrect `new_centroids` size + batch_data = Float32[1.0 2.0; 3.0 4.0] # 2x2 matrix + batch_one_hot = Float32[1.0 0.0; 0.0 1.0] # 2x2 matrix (one-hot encoded) + new_centroids = zeros(Float32, 3, 2) # Incorrect size: should be 2x2 + @test_throws DimensionMismatch update_centroids_kernel!( + new_centroids, batch_data, batch_one_hot) + + # Test 3: error, incorrect `batch_one_hot` size + batch_data = Float32[1.0 2.0; 3.0 4.0] # 2x2 matrix + batch_one_hot = Float32[1.0 0.0 0.0; 0.0 1.0 0.0] # Incorrect size: should be 2x2, not 2x3 + new_centroids = zeros(Float32, 2, 2) # Correct size, but the error should be triggered by batch_one_hot + @test_throws DimensionMismatch update_centroids_kernel!( + new_centroids, batch_data, batch_one_hot) +end + +@testset "assign_clusters_kernel!" begin + # Test 1: testing the correct minimum assignment with random permutations + num_points = rand(1:100) + batch_assignments = Vector{Int32}(undef, num_points) + batch_distances = Matrix{Float32}(undef, rand(1:100), num_points) + expected_assignments = Vector{Int32}(undef, num_points) + for (idx, col) in enumerate(eachcol(batch_distances)) + perm = randperm(size(batch_distances, 1)) + col .= Float32.(perm) + expected_assignments[idx] = sortperm(perm)[1] + end + assign_clusters_kernel!(batch_assignments, batch_distances) + @test isequal(expected_assignments, batch_assignments) + + # Test 2: check DimensionMismatch error + batch_distances = Float32[1.0 2.0; + 4.0 5.0] + batch_assignments = Int32[0] + @test_throws DimensionMismatch assign_clusters_kernel!( + batch_assignments, batch_distances) +end + +@testset "onehot_encode!" begin + # Test 1: Basic functionality + k = rand(1:100) + batch_assignments = Int32.(collect(1:k)) + batch_one_hot = zeros(Float32, k, k) + onehot_encode!(batch_one_hot, batch_assignments, k) + @test isequal(batch_one_hot, I(k)) + + # Test 2: Slightly convoluted example + batch_assignments = Int32[4, 2, 3, 1] + batch_one_hot = zeros(Float32, 4, 4) + onehot_encode!(batch_one_hot, batch_assignments, 4) + @test batch_one_hot == Float32[0 0 0 1; + 0 1 0 0; + 0 0 1 0; + 1 0 0 0] + # Test 3: Edge case with k = 1 + batch_assignments = Int32[1, 1, 1] + batch_one_hot = zeros(Float32, 1, 3) + onehot_encode!(batch_one_hot, batch_assignments, 1) + @test batch_one_hot == Float32[1 1 1] + + # Test 4: Dimension mismatch error + batch_assignments = Int32[1, 2] + batch_one_hot = zeros(Float32, 3, 3) + @test_throws DimensionMismatch onehot_encode!( + batch_one_hot, batch_assignments, 3) +end + +@testset "kmeans_gpu_onehot!" begin + # Test 1: When all points are centroids + data = rand(Float32, rand(1:100), rand(1:100)) + centroids = similar(data) + point_bsize = rand(1:size(data, 2)) + cluster_ids = kmeans_gpu_onehot!(data, centroids, size(data, 2)) + @test isequal(centroids[:, cluster_ids], data) +end + +@testset "_normalize_array!" begin + # column normalization + X = rand(Float32, rand(1:100), rand(1:100)) + _normalize_array!(X, dims = 1) + for col in eachcol(X) + @test isapprox(norm(col), 1) + end + + # row normalization + X = rand(Float32, rand(1:100), rand(1:100)) + _normalize_array!(X, dims = 2) + for row in eachrow(X) + @test isapprox(norm(row), 1) + end +end + +@testset "_topk" begin + # Test 1: Basic functionality with k = 2, along dimension 1 (columns) + data = [3.0 1.0 4.0; + 1.0 5.0 9.0; + 2.0 6.0 5.0] + k = 2 + result = _topk(data, k, dims = 1) + @test result == [1 3 2; + 3 2 3] + + # Test 2: Basic functionality with k = 2, along dimension 2 (rows) + result = _topk(data, k, dims = 2) + @test result == [3 1; + 3 2; + 2 3] + + # Test 3: Check DomainError for invalid dims value + @test_throws DomainError _topk(data, k, dims = 3) +end + +@testset "_head" begin + # Test 1: Basic functionality with a non-empty vector + v = [1, 2, 3, 4] + result = _head(v) + @test result == [1, 2, 3] + + # Test 2: Edge case with a single-element vector + v = [10] + result = _head(v) + @test result == Int[] + + # Test 3: Edge case with an empty vector + v = Int[] + result = _head(v) + @test result == Int[] + + # Test 4: Test with a vector of strings + v = ["a", "b", "c"] + result = _head(v) + @test result == ["a", "b"] + + # Test 5: Test with a vector of floating-point numbers + v = [1.5, 2.5, 3.5] + result = _head(v) + @test result == [1.5, 2.5] + + # Test 6: Test with a vector of characters + v = ['a', 'b', 'c'] + result = _head(v) + @test result == ['a', 'b'] +end