From 8dda46f3984afc17d9af0612d1536cff699c5780 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sat, 10 Aug 2024 20:32:35 +0530 Subject: [PATCH 01/59] Moving all the config to just one struct + even better defaults. --- src/ColBERT.jl | 1 - src/infra/config.jl | 135 ++++++++++++++++++++------------------- src/infra/settings.jl | 144 ------------------------------------------ 3 files changed, 67 insertions(+), 213 deletions(-) delete mode 100644 src/infra/settings.jl diff --git a/src/ColBERT.jl b/src/ColBERT.jl index ebc7a58..f54a7af 100644 --- a/src/ColBERT.jl +++ b/src/ColBERT.jl @@ -22,7 +22,6 @@ include("data/queries.jl") export Collection, Queries # config and other infra -include("infra/settings.jl") include("infra/config.jl") export RunSettings, TokenizerSettings, ResourceSettings, DocSettings, QuerySettings, IndexingSettings, diff --git a/src/infra/config.jl b/src/infra/config.jl index 55a60a8..634f8cf 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -1,20 +1,34 @@ """ - ColBERTConfig(run_settings::RunSettings, tokenizer_settings::TokenizerSettings, - resource_settings::ResourceSettings, - doc_settings::DocSettings, query_settings::QuerySettings, - indexing_settings::IndexingSettings, search_settings::SearchSettings) + ColBERTConfig(; use_gpu::Bool, rank::Int, nranks::Int, query_token_id::String, + doc_token_id::String, query_token::String, doc_token::String, checkpoint::String, + collection::String, dim::Int, doc_maxlen::Int, mask_punctuation::Bool, + query_maxlen::Int, attend_to_mask_tokens::Bool, index_path::String, + index_bsize::Int, nbits::Int, kmeans_niters::Int, nprobe::Int, ncandidates::Int) Structure containing config for running and training various components. # Arguments - - `run_settings`: Sets the [`RunSettings`](@ref). - - `tokenizer_settings`: Sets the [`TokenizerSettings`](@ref). - - `resource_settings`: Sets the [`ResourceSettings`](@ref). - - `doc_settings`: Sets the [`DocSettings`](@ref). - - `query_settings`: Sets the [`QuerySettings`](@ref). - - `indexing_settings`: Sets the [`IndexingSettings`](@ref). - - `search_settings`: Sets the [`SearchSettings`](@ref). + - `use_gpu`: Whether to use a GPU or not. Default is `false`. + - `rank`: The index of the running GPU. Default is `0`. For now, the package only allows this to be `0`. + - `nranks`: The number of GPUs used in the run. Default is `1`. For now, the package only supports one GPU. + - `query_token_id`: Unique identifier for query tokens (defaults to `[unused0]`). + - `doc_token_id`: Unique identifier for document tokens (defaults to `[unused1]`). + - `query_token`: Token used to represent a query token (defaults to `[Q]`). + - `doc_token`: Token used to represent a document token (defaults to `[D]`). + - `checkpoint`: The path to the HuggingFace checkpoint of the underlying ColBERT model. Defaults to `"colbert-ir/colbertv2.0"`. + - `collection`: Path to the file containing the documents. Default is `""`. + - `dim`: The dimension of the document embedding space. Default is 128. + - `doc_maxlen`: The maximum length of a document before it is trimmed to fit. Default is 220. + - `mask_punctuation`: Whether or not to mask punctuation characters tokens in the document. Default is true. + - `query_maxlen`: The maximum length of queries after which they are trimmed. + - `attend_to_mask_tokens`: Whether or not to attend to mask tokens in the query. Default value is false. + - `index_path`: Path to save the index files. + - `index_bsize`: Batch size used for some parts of indexing. + - `nbits`: Number of bits used to compress residuals. + - `kmeans_niters`: Number of iterations used for k-means clustering. + - `nprobe`: The number of nearest centroids to fetch during a search. Default is `2`. Also see [`retrieve`](@ref). + - `ncandidates`: The number of candidates to get during candidate generation in search. Default is `8192`. Also see [`retrieve`](@ref). # Returns @@ -22,70 +36,55 @@ A [`ColBERTConfig`](@ref) object. # Examples -The relevant files for this example can be found in the `examples/` folder of the project root. +Most users will just want to use the defaults for most settings. Here's a minimal example: ```julia-repl -julia> dataroot = "downloads/lotte" - -julia> dataset = "lifestyle" - -julia> datasplit = "dev" - -julia> path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv") - -julia> collection = Collection(path) - -julia> length(collection.data) - -julia> nbits = 2 # encode each dimension with 2 bits - -julia> doc_maxlen = 300 # truncate passages at 300 tokens - -julia> checkpoint = "colbert-ir/colbertv2.0" # the HF checkpoint - -julia> index_root = "experiments/notebook/indexes" - -julia> index_name = "short_\$(dataset).\$(datasplit).\$(nbits)bits" - -julia> index_path = joinpath(index_root, index_name) - julia> config = ColBERTConfig( - RunSettings( - experiment = "notebook", - ), - TokenizerSettings(), - ResourceSettings( - checkpoint = checkpoint, - collection = collection, - index_name = index_name - ), - DocSettings( - doc_maxlen = doc_maxlen, - ), - QuerySettings(), - IndexingSettings( - index_path = index_path, - index_bsize = 3, - nbits = nbits, - kmeans_niters = 20 - ), - SearchSettings() + use_gpu=true, + collection="/home/codetalker7/documents", + index_path="./local_index" ); ``` """ Base.@kwdef struct ColBERTConfig - run_settings::RunSettings - tokenizer_settings::TokenizerSettings - resource_settings::ResourceSettings - doc_settings::DocSettings - query_settings::QuerySettings - indexing_settings::IndexingSettings - search_settings::SearchSettings + # run settings + use_gpu::Bool = false + rank::Int = 0 + nranks::Int = 1 + + # tokenization settings + query_token_id::String = "[unused0]" + doc_token_id::String = "[unused1]" + query_token::String = "[Q]" + doc_token::String = "[D]" + + # resource settings + checkpoint::String = "colbert-ir/colbertv2.0" + collection::String = "" + + # doc settings + dim::Int = 128 + doc_maxlen::Int = 220 + mask_punctuation::Bool = true + + # query settings + query_maxlen::Int = 32 + attend_to_mask_tokens::Bool = false + + # indexing settings + index_path::String = "" + index_bsize::Int = 64 + nbits::Int = 2 + kmeans_niters::Int = 20 + + # search settings + nprobe::Int = 2 + ncandidates::Int = 8192 end -# TODO: need to think of a better way to save the config later. -function save(config::ColBERTConfig) - config_path = joinpath(config.indexing_settings.index_path, "config.jld2") - JLD2.save(config_path, Dict("config" => config)) -end +# # TODO: need to think of a better way to save the config later. +# function save(config::ColBERTConfig) +# config_path = joinpath(config.indexing_settings.index_path, "config.jld2") +# JLD2.save(config_path, Dict("config" => config)) +# end diff --git a/src/infra/settings.jl b/src/infra/settings.jl deleted file mode 100644 index d251cdb..0000000 --- a/src/infra/settings.jl +++ /dev/null @@ -1,144 +0,0 @@ -""" - RunSettings([root, experiment, index_root, name, rank, nranks]) - -Structure holding all the settings necessary to describe the run environment. - -# Arguments - - - `root`: The root directory for the run. Default is an `"experiments"` folder in the current working directory. - - `experiment`: The name of the run. Default is `"default"`. - - `index_root`: The root directory for storing index. For now, there is no need to specify this as it is determined by the indexing component. - - `name`: The name of the run. Default is the current date and time. - - `use_gpu`: Whether to use a GPU or not. Default is `false`. - - `rank`: The index of the running GPU. Default is `0`. For now, the package only allows this to be `0`. - - `nranks`: The number of GPUs used in the run. Default is `1`. For now, the package only supports one GPU. - -# Returns - -A `RunSettings` object. -""" -Base.@kwdef struct RunSettings - root::String = joinpath(pwd(), "experiments") - experiment::String = "default" - index_root::Union{Nothing, String} = nothing - name::String = Dates.format(now(), "yyyy/mm/dd/HH.MM.SS") - use_gpu::Bool = false - rank::Int = 0 - nranks::Int = 1 -end - -""" - TokenizerSettings([query_token_id, doc_token_id, query_token, doc_token]) - -Structure to represent settings for the tokenization of queries and documents. - -# Arguments - - - `query_token_id`: Unique identifier for query tokens (defaults to `[unused0]`). - - `doc_token_id`: Unique identifier for document tokens (defaults to `[unused1]`). - - `query_token`: Token used to represent a query token (defaults to `[Q]`). - - `doc_token`: Token used to represent a document token (defaults to `[D]`). - -# Returns - -A `TokenizerSettings` object. -""" -Base.@kwdef struct TokenizerSettings - query_token_id::String = "[unused0]" - doc_token_id::String = "[unused1]" - query_token::String = "[Q]" - doc_token::String = "[D]" -end - -""" - ResourceSettings([checkpoint, collection, queries, index_name]) - -Structure to represent resource settings. - -# Arguments - - - `checkpoint`: The path to the HuggingFace checkpoint of the underlying ColBERT model. - - `collection`: The underlying collection of documents - - `queries`: The underlying collection of queries. - - `index_name`: The name of the index. - -# Returns - -A `ResourceSettings` object. -""" -Base.@kwdef struct ResourceSettings - checkpoint::Union{Nothing, String} = nothing - collection::Union{Nothing, Collection} = nothing - queries::Union{Nothing, String} = nothing - index_name::Union{Nothing, String} = nothing -end - -""" - DocSettings([dim, doc_maxlen, mask_punctuation]) - -Structure that defines the settings used for generating document embeddings. - -# Arguments - - - `dim`: The dimension of the document embedding space. Default is 128. - - `doc_maxlen`: The maximum length of a document before it is trimmed to fit. Default is 220. - - `mask_punctuation`: Whether or not to mask punctuation characters tokens in the document. Default is true. - -# Returns - -A `DocSettings` object. -""" -Base.@kwdef struct DocSettings - dim::Int = 128 - doc_maxlen::Int = 220 - mask_punctuation::Bool = true -end - -""" - QuerySettings([query_maxlen, attend_to_mask_tokens, interaction]) - -A structure representing the query settings used by the ColBERT model. - -# Arguments - - - `query_maxlen`: The maximum length of queries after which they are trimmed. - - `attend_to_mask_tokens`: Whether or not to attend to mask tokens in the query. Default value is false. - - `interaction`: The type of interaction used to compute the scores for the queries. Default value is "colbert". - -# Returns - -A `QuerySettings` object. -""" -Base.@kwdef struct QuerySettings - query_maxlen::Int = 32 - attend_to_mask_tokens::Bool = false - interaction::String = "colbert" -end - -""" - IndexingSettings([index_path, index_bsize, nbits, kmeans_niters]) - -Structure containing settings for indexing. - -# Arguments - - - `index_path`: Path to save the index files. - - `index_bsize::Int`: Batch size used for some parts of indexing. - - `nbits::Int`: Number of bits used to compress residuals. - - `kmeans_niters::Int`: Number of iterations used for k-means clustering. - -# Returns - -An `IndexingSettings` object. -""" -Base.@kwdef struct IndexingSettings - index_path::Union{Nothing, String} = nothing - index_bsize::Int = 64 - nbits::Int = 1 - kmeans_niters = 4 -end - -Base.@kwdef struct SearchSettings - nprobe::Int = 2 - ncandidates::Int = 8192 -end From e2e9274354dc8b2441e577b63a1a3b9cff75299a Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sat, 10 Aug 2024 22:52:08 +0530 Subject: [PATCH 02/59] Adding functions to load and save the config to JSON. --- src/infra/config.jl | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/infra/config.jl b/src/infra/config.jl index 634f8cf..f4287fb 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -17,7 +17,7 @@ Structure containing config for running and training various components. - `query_token`: Token used to represent a query token (defaults to `[Q]`). - `doc_token`: Token used to represent a document token (defaults to `[D]`). - `checkpoint`: The path to the HuggingFace checkpoint of the underlying ColBERT model. Defaults to `"colbert-ir/colbertv2.0"`. - - `collection`: Path to the file containing the documents. Default is `""`. + - `collection`: Path to the file containing the documents. Default is `""`. - `dim`: The dimension of the document embedding space. Default is 128. - `doc_maxlen`: The maximum length of a document before it is trimmed to fit. Default is 220. - `mask_punctuation`: Whether or not to mask punctuation characters tokens in the document. Default is true. @@ -40,9 +40,9 @@ Most users will just want to use the defaults for most settings. Here's a minima ```julia-repl julia> config = ColBERTConfig( - use_gpu=true, - collection="/home/codetalker7/documents", - index_path="./local_index" + use_gpu = true, + collection = "/home/codetalker7/documents", + index_path = "./local_index" ); ``` @@ -83,8 +83,21 @@ Base.@kwdef struct ColBERTConfig ncandidates::Int = 8192 end -# # TODO: need to think of a better way to save the config later. -# function save(config::ColBERTConfig) -# config_path = joinpath(config.indexing_settings.index_path, "config.jld2") -# JLD2.save(config_path, Dict("config" => config)) -# end +function save(config::ColBERTConfig) + properties = [Pair{String, Any}(string(field), getproperty(config, field)) + for field in fieldnames(ColBERTConfig)] + isdir(config.index_path) || mkdir(config.index_path) + open(joinpath(config.index_path, "config.json"), "w+") do io + JSON.print( + io, + Dict(properties), + 4 + ) + end +end + +function load(index_path::String) + config_dict = JSON.parsefile(joinpath(index_path, "config.json")) + key_vals = collect(zip(Symbol.(keys(config_dict)), values(config_dict))) + eval(:(ColBERTConfig($([Expr(:kw, :($key), :($val)) for (key, val) in key_vals]...)))) +end From a5213cdadf6d5491e2faff84358e73ab15029d5a Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sat, 10 Aug 2024 23:00:54 +0530 Subject: [PATCH 03/59] Changing name from `load` to `load_config` to avoid future ambiguitities. --- src/infra/config.jl | 56 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/src/infra/config.jl b/src/infra/config.jl index f4287fb..d243cba 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -38,7 +38,9 @@ A [`ColBERTConfig`](@ref) object. Most users will just want to use the defaults for most settings. Here's a minimal example: -```julia-repl +```jldoctest +julia> using ColBERT; + julia> config = ColBERTConfig( use_gpu = true, collection = "/home/codetalker7/documents", @@ -83,6 +85,30 @@ Base.@kwdef struct ColBERTConfig ncandidates::Int = 8192 end +""" + save(config::ColBERTConfig) + +Save a [`ColBERTConfig`](@ref) to disk in JSON. + +# Arguments + + - `config`: The [`ColBERTConfig`](@ref) to save. + +# Examples + +```jldoctest +julia> using ColBERT; + +julia> config = ColBERTConfig( + use_gpu = true, + collection = "/home/codetalker7/documents", + index_path = "./local_index" + ); + +julia> ColBERT.save(config); + +``` +""" function save(config::ColBERTConfig) properties = [Pair{String, Any}(string(field), getproperty(config, field)) for field in fieldnames(ColBERTConfig)] @@ -96,7 +122,33 @@ function save(config::ColBERTConfig) end end -function load(index_path::String) +""" + load_config(index_path::String) + +Load a [`ColBERTConfig`](@ref) from disk. + +# Arguments + + - `index_path`: The path of the directory where the config resides. + +# Examples + +```jldoctest +julia> using ColBERT; + +julia> config = ColBERTConfig( + use_gpu = true, + collection = "/home/codetalker7/documents", + index_path = "./local_index" + ); + +julia> ColBERT.save(config); + +julia> ColBERT.load_config("./local_index") +ColBERTConfig(true, 0, 1, "[unused0]", "[unused1]", "[Q]", "[D]", "colbert-ir/colbertv2.0", "/home/codetalker7/documents", 128, 220, true, 32, false, "./local_index", 64, 2, 20, 2, 8192) +``` +""" +function load_config(index_path::String) config_dict = JSON.parsefile(joinpath(index_path, "config.json")) key_vals = collect(zip(Symbol.(keys(config_dict)), values(config_dict))) eval(:(ColBERTConfig($([Expr(:kw, :($key), :($val)) for (key, val) in key_vals]...)))) From bdcbf514f8384b9c6f592fbd25bac3b608c0186d Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 12:50:18 +0530 Subject: [PATCH 04/59] Minor change in how config fields are accessed. --- src/indexing.jl | 4 +-- src/indexing/codecs/residual.jl | 14 ++++---- src/indexing/collection_encoder.jl | 4 +-- src/indexing/collection_indexer.jl | 34 +++++++++---------- src/indexing/index_saver.jl | 12 +++---- src/modelling/checkpoint.jl | 12 +++---- .../tokenization/doc_tokenization.jl | 2 +- .../tokenization/query_tokenization.jl | 8 ++--- src/search/index_storage.jl | 10 +++--- src/searching.jl | 4 +-- 10 files changed, 52 insertions(+), 52 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index 0b56f9d..29855c6 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -3,14 +3,14 @@ struct Indexer end function index(indexer::Indexer) - index_path = indexer.config.indexing_settings.index_path + index_path = indexer.config.index_path if isdir(index_path) @info "Index at $(index_path) already exists! Skipping indexing." return end config = indexer.config - checkpoint = config.resource_settings.checkpoint + checkpoint = config.checkpoint # loading the models @info "Loading ColBERT layers from HuggingFace." diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 287f2c7..1bae993 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -36,7 +36,7 @@ julia> codec = load_codec(index_path); ``` """ function load_codec(index_path::String) - config = load(joinpath(index_path, "config.jld2"), "config") + config = load_config(index_path) centroids = load(joinpath(index_path, "centroids.jld2"), "centroids") avg_residual = load(joinpath(index_path, "avg_residual.jld2"), "avg_residual") buckets = load(joinpath(index_path, "buckets.jld2")) @@ -95,8 +95,8 @@ Convert a matrix of residual vectors into a matrix of integer residual vector us A matrix of compressed integer residual vectors. """ function binarize(codec::ResidualCodec, residuals::AbstractMatrix{Float32}) - dim = codec.config.doc_settings.dim - nbits = codec.config.indexing_settings.nbits + dim = codec.config.dim + nbits = codec.config.nbits num_embeddings = size(residuals)[2] if dim % (nbits * 8) != 0 @@ -165,8 +165,8 @@ function compress(codec::ResidualCodec, embs::AbstractMatrix{Float32}) end function decompress_residuals(codec::ResidualCodec, binary_residuals::AbstractMatrix{UInt8}) - dim = codec.config.doc_settings.dim - nbits = codec.config.indexing_settings.nbits + dim = codec.config.dim + nbits = codec.config.nbits @assert ndims(binary_residuals)==2 "ndims(binary_residuals): $(ndims(binary_residuals))" @assert size(binary_residuals)[1]==(dim / 8) * nbits "size(binary_residuals): $(size(binary_residuals)), (dim / 8) * nbits: $((dim / 8) * nbits)" @@ -253,14 +253,14 @@ A vector of codes for the specified chunk. """ function load_codes(codec::ResidualCodec, chunk_idx::Int) codes_path = joinpath( - codec.config.indexing_settings.index_path, "$(chunk_idx).codes.jld2") + codec.config.index_path, "$(chunk_idx).codes.jld2") codes = JLD2.load(codes_path, "codes") codes end function load_residuals(codec::ResidualCodec, chunk_idx::Int) residual_path = joinpath( - codec.config.indexing_settings.index_path, "$(chunk_idx).residuals.jld2") + codec.config.index_path, "$(chunk_idx).residuals.jld2") residuals = JLD2.load(residual_path, "residuals") residuals end diff --git a/src/indexing/collection_encoder.jl b/src/indexing/collection_encoder.jl index 7d2f567..4b09903 100644 --- a/src/indexing/collection_encoder.jl +++ b/src/indexing/collection_encoder.jl @@ -46,9 +46,9 @@ function encode_passages(encoder::CollectionEncoder, passages::Vector{String}) embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}() # batching here to avoid storing intermediate embeddings on GPU # batching also occurs inside docFromText to do batch packing optimizations - for passages_batch in batch(passages, encoder.config.indexing_settings.index_bsize * 50) + for passages_batch in batch(passages, encoder.config.index_bsize * 50) embs_, doclens_ = docFromText(encoder.checkpoint, passages_batch, - encoder.config.indexing_settings.index_bsize) + encoder.config.index_bsize) push!(embs, embs_) append!(doclens, vec(doclens_)) end diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index d0051bb..0bd1d73 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -30,8 +30,8 @@ end function CollectionIndexer( config::ColBERTConfig, encoder::CollectionEncoder, saver::IndexSaver) - plan_path = joinpath(config.indexing_settings.index_path, "plan.json") - metadata_path = joinpath(config.indexing_settings.index_path, "metadata.json") + plan_path = joinpath(config.index_path, "plan.json") + metadata_path = joinpath(config.index_path, "metadata.json") CollectionIndexer( config, @@ -63,7 +63,7 @@ Sample PIDs from the collection to be used to compute clusters using a ``k``-mea A `Set` of `Int`s containing the sampled PIDs. """ function _sample_pids(indexer::CollectionIndexer) - num_passages = length(indexer.config.resource_settings.collection.data) + num_passages = length(indexer.config.collection.data) typical_doclen = 120 num_sampled_pids = 16 * sqrt(typical_doclen * num_passages) num_sampled_pids = Int(min(1 + floor(num_sampled_pids), num_passages)) @@ -93,7 +93,7 @@ The average document length (i.e number of attended tokens) computed from the sa """ function _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) # collect all passages with pids in sampled_pids - collection = indexer.config.resource_settings.collection + collection = indexer.config.collection sorted_sampled_pids = sort(collect(sampled_pids)) local_sample = collection.data[sorted_sampled_pids] @@ -105,7 +105,7 @@ function _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) indexer.avg_doclen_est = length(local_sample_doclens) > 0 ? sum(local_sample_doclens) / length(local_sample_doclens) : 0 - sample_path = joinpath(indexer.config.indexing_settings.index_path, "sample.jld2") + sample_path = joinpath(indexer.config.index_path, "sample.jld2") @info "avg_doclen_est = $(indexer.avg_doclen_est) \t length(local_sample) = $(length(local_sample))" @info "Saving sampled embeddings to $(sample_path)." JLD2.save(sample_path, Dict("local_sample_embs" => local_sample_embs)) @@ -152,16 +152,16 @@ The number of chunks into which the document embeddings will be stored (`indexer - `indexer::CollectionIndexer`: The indexer to be initialized. """ function setup(indexer::CollectionIndexer) - collection = indexer.config.resource_settings.collection + collection = indexer.config.collection indexer.num_chunks = Int(ceil(length(collection.data) / get_chunksize( - collection, indexer.config.run_settings.nranks))) + collection, indexer.config.nranks))) # sample passages for training centroids later sampled_pids = _sample_pids(indexer) avg_doclen_est = _sample_embeddings(indexer, sampled_pids) # computing the number of partitions, i.e clusters - num_passages = length(indexer.config.resource_settings.collection.data) + num_passages = length(indexer.config.collection.data) indexer.num_embeddings_est = num_passages * avg_doclen_est indexer.num_partitions = Int(floor(2^(floor(log2(16 * sqrt(indexer.num_embeddings_est)))))) @@ -189,7 +189,7 @@ The tuple `sample, sample_heldout`. """ function _concatenate_and_split_sample(indexer::CollectionIndexer) # load the sample embeddings - sample_path = joinpath(indexer.config.indexing_settings.index_path, "sample.jld2") + sample_path = joinpath(indexer.config.index_path, "sample.jld2") sample = JLD2.load(sample_path, "local_sample_embs") @debug "Original sample shape: $(size(sample))" @@ -236,7 +236,7 @@ function _compute_avg_residuals( avg_residual = mean(abs.(heldout_avg_residual), dims = 2) # for each dimension, take mean of absolute values of residuals # computing bucket weights and cutoffs - num_options = 2^indexer.config.indexing_settings.nbits + num_options = 2^indexer.config.nbits quantiles = Vector(0:(num_options - 1)) / num_options bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[2:end], quantiles .+ (0.5 / num_options) @@ -269,7 +269,7 @@ function train(indexer::CollectionIndexer) @assert heldout isa AbstractMatrix{Float32} "$(typeof(heldout))" centroids = kmeans(sample, indexer.num_partitions, - maxiter = indexer.config.indexing_settings.kmeans_niters, display = :iter).centers + maxiter = indexer.config.kmeans_niters, display = :iter).centers @assert size(centroids)[2]==indexer.num_partitions "size(centroids): $(size(centroids)), indexer.num_partitions: $(indexer.num_partitions)" @assert centroids isa AbstractMatrix{Float32} "$(typeof(centroids))" @@ -298,8 +298,8 @@ The documents are processed in batches of size `chunksize` (see [`enumerate_batc function index(indexer::CollectionIndexer; chunksize::Union{Int, Missing} = missing) load_codec!(indexer.saver) # load the codec objects batches = enumerate_batches( - indexer.config.resource_settings.collection, chunksize = chunksize, - nranks = indexer.config.run_settings.nranks) + indexer.config.collection, chunksize = chunksize, + nranks = indexer.config.nranks) for (chunk_idx, offset, passages) in batches # TODO: add functionality to not re-write chunks if they already exist! # TODO: add multiprocessing to this step! @@ -348,7 +348,7 @@ function _collect_embedding_id_offset(indexer::CollectionIndexer) embeddings_offsets = Vector{Int}() for chunk_idx in 1:(indexer.num_chunks) metadata_path = joinpath( - indexer.config.indexing_settings.index_path, "$(chunk_idx).metadata.json") + indexer.config.index_path, "$(chunk_idx).metadata.json") chunk_metadata = open(metadata_path, "r") do io chunk_metadata = JSON.parse(io) @@ -387,7 +387,7 @@ function _build_ivf(indexer::CollectionIndexer) ivf_lengths = counts(values, 1:(indexer.num_partitions)) @info "Saving the IVF." - ivf_path = joinpath(indexer.config.indexing_settings.index_path, "ivf.jld2") + ivf_path = joinpath(indexer.config.index_path, "ivf.jld2") JLD2.save(ivf_path, Dict( "ivf" => ivf, "ivf_lengths" => ivf_lengths @@ -396,7 +396,7 @@ end function _update_metadata(indexer::CollectionIndexer) @info "Saving the indexing metadata." - metadata_path = joinpath(indexer.config.indexing_settings.index_path, "metadata.json") + metadata_path = joinpath(indexer.config.index_path, "metadata.json") open(metadata_path, "w") do io JSON.print(io, @@ -406,7 +406,7 @@ function _update_metadata(indexer::CollectionIndexer) "num_partitions" => indexer.num_partitions, "num_embeddings" => indexer.num_embeddings, "avg_doclen" => Int(floor(indexer.num_embeddings / - length(indexer.config.resource_settings.collection.data))) + length(indexer.config.collection.data))) ), 4 ) diff --git a/src/indexing/index_saver.jl b/src/indexing/index_saver.jl index 0238632..fc343bd 100644 --- a/src/indexing/index_saver.jl +++ b/src/indexing/index_saver.jl @@ -25,7 +25,7 @@ The path of of the codec is inferred from the config stored in `saver`. - `saver`: An [`IndexSaver`](@ref) into which the codec is to be loaded. """ function load_codec!(saver::IndexSaver) - index_path = saver.config.indexing_settings.index_path + index_path = saver.config.index_path centroids = JLD2.load(joinpath(index_path, "centroids.jld2"), "centroids") avg_residual = JLD2.load(joinpath(index_path, "avg_residual.jld2"), "avg_residual") buckets = JLD2.load(joinpath(index_path, "buckets.jld2")) @@ -51,7 +51,7 @@ Also see [`train`](@ref). - `saver::IndexSaver`: The index saver to use. """ function save_codec(saver::IndexSaver) - index_path = saver.config.indexing_settings.index_path + index_path = saver.config.index_path centroids_path = joinpath(index_path, "centroids.jld2") avg_residual_path = joinpath(index_path, "avg_residual.jld2") buckets_path = joinpath(index_path, "buckets.jld2") @@ -87,7 +87,7 @@ The codes and compressed residuals for the chunk are saved in files named ` doclens)) # the metadata metadata_path = joinpath( - saver.config.indexing_settings.index_path, "$(chunk_idx).metadata.json") + saver.config.index_path, "$(chunk_idx).metadata.json") @info "Saving metadata to $(metadata_path)" open(metadata_path, "w") do io JSON.print(io, @@ -134,7 +134,7 @@ Check if the index chunk exists for the given `chunk_idx`. A boolean indicating whether all relevant files for the chunk exist. """ function check_chunk_exists(saver::IndexSaver, chunk_idx::Int) - index_path = saver.config.indexing_settings.index_path + index_path = saver.config.index_path path_prefix = joinpath(index_path, string(chunk_idx)) codes_path = "$(path_prefix).codes.jld2" residuals_path = "$(path_prefix).residuals.jld2" diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 9d78f29..6496a38 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -95,7 +95,7 @@ function BaseColBERT(checkpoint::String, config::ColBERTConfig) bert_model = HuggingFace.load_model( :bert, checkpoint, :model, bert_state_dict; config = bert_config) linear = HuggingFace._load_dense(bert_state_dict, "linear", bert_config.hidden_size, - config.doc_settings.dim, bert_config.initializer_range, true) + config.dim, bert_config.initializer_range, true) tokenizer = Transformers.load_tokenizer(checkpoint) bert_model = bert_model |> Flux.gpu @@ -173,7 +173,7 @@ end function Checkpoint(model::BaseColBERT, doc_tokenizer::DocTokenizer, query_tokenizer::QueryTokenizer, config::ColBERTConfig) - if config.doc_settings.mask_punctuation + if config.mask_punctuation punctuation_list = string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")) skiplist = [TextEncodeBase.lookup(model.tokenizer.vocab, punct) for punct in punctuation_list] @@ -276,7 +276,7 @@ julia> mask """ function doc(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}) - use_gpu = checkpoint.config.run_settings.use_gpu + use_gpu = checkpoint.config.use_gpu integer_ids = integer_ids |> Flux.gpu integer_mask = integer_mask |> Flux.gpu @@ -339,7 +339,7 @@ julia> docs = [ "this is some longer text, so length should be longer", ]; -julia> embs, doclens = docFromText(checkPoint, docs, config.indexing_settings.index_bsize) +julia> embs, doclens = docFromText(checkPoint, docs, config.index_bsize) (Float32[0.07590997 0.00056472444 … -0.09958261 -0.03259005; 0.08413661 -0.016337946 … -0.061889287 -0.017708546; … ; -0.11584533 0.016651645 … 0.0073241345 0.09233974; 0.043868616 0.084660925 … -0.0294838 -0.08536169], [5 5 4 13]) julia> embs @@ -478,7 +478,7 @@ julia> query(checkPoint, integer_ids, integer_mask) """ function query(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}) - use_gpu = checkpoint.config.run_settings.use_gpu + use_gpu = checkpoint.config.use_gpu integer_ids = integer_ids |> Flux.gpu integer_mask = integer_mask |> Flux.gpu @@ -579,7 +579,7 @@ function queryFromText( process = tokenizer.process truncpad_pipe = Pipeline{:token}( TextEncodeBase.trunc_or_pad( - checkpoint.config.query_settings.query_maxlen, "[PAD]", :tail, :tail), + checkpoint.config.query_maxlen, "[PAD]", :tail, :tail), :token) process = process[1:4] |> truncpad_pipe |> process[6:end] tokenizer = Transformers.TextEncoders.BertTextEncoder( diff --git a/src/modelling/tokenization/doc_tokenization.jl b/src/modelling/tokenization/doc_tokenization.jl index edbad76..8b052a0 100644 --- a/src/modelling/tokenization/doc_tokenization.jl +++ b/src/modelling/tokenization/doc_tokenization.jl @@ -21,7 +21,7 @@ end function DocTokenizer(tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, config::ColBERTConfig) D_marker_token_id = TextEncodeBase.lookup( - tokenizer.vocab, config.tokenizer_settings.doc_token_id) + tokenizer.vocab, config.doc_token_id) DocTokenizer(D_marker_token_id, config) end diff --git a/src/modelling/tokenization/query_tokenization.jl b/src/modelling/tokenization/query_tokenization.jl index 879128e..89ecd97 100644 --- a/src/modelling/tokenization/query_tokenization.jl +++ b/src/modelling/tokenization/query_tokenization.jl @@ -23,7 +23,7 @@ function QueryTokenizer( tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, config::ColBERTConfig) Q_marker_token_id = TextEncodeBase.lookup( - tokenizer.vocab, config.tokenizer_settings.query_token_id) + tokenizer.vocab, config.query_token_id) mask_token_id = TextEncodeBase.lookup(tokenizer.vocab, "[MASK]") QueryTokenizer(Q_marker_token_id, mask_token_id, config) end @@ -60,7 +60,7 @@ julia> tokenizer = base_colbert.tokenizer; julia> process = tokenizer.process; julia> truncpad_pipe = Pipeline{:token}( - TextEncodeBase.trunc_or_pad(config.query_settings.query_maxlen, "[PAD]", :tail, :tail), + TextEncodeBase.trunc_or_pad(config.query_maxlen, "[PAD]", :tail, :tail), :token); julia> process = process[1:4] |> truncpad_pipe |> process[6:end]; @@ -149,7 +149,7 @@ function tensorize(query_tokenizer::QueryTokenizer, integer_mask = NeuralAttentionlib.getmask(mask, ids)[1, :, :] @assert isequal(size(integer_ids), size(integer_mask)) "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(size(integer_mask))" @assert isequal( - size(integer_ids)[1], query_tokenizer.config.query_settings.query_maxlen) "size(integer_ids): $(size(integer_ids)), query_maxlen: $(query_tokenizer.config.query_settings.query_maxlen)" + size(integer_ids)[1], query_tokenizer.config.query_maxlen) "size(integer_ids): $(size(integer_ids)), query_maxlen: $(query_tokenizer.config.query_maxlen)" @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" @@ -157,7 +157,7 @@ function tensorize(query_tokenizer::QueryTokenizer, integer_ids[2, :] .= query_tokenizer.Q_marker_token_id integer_ids[integer_ids .== 1] .= query_tokenizer.mask_token_id - if query_tokenizer.config.query_settings.attend_to_mask_tokens + if query_tokenizer.config.attend_to_mask_tokens integer_mask[integer_ids .== query_tokenizer.mask_token_id] .= 1 @assert isequal(sum(integer_mask), prod(size(integer_mask))) "sum(integer_mask): $(sum(integer_mask)), prod(size(integer_mask)): $(prod(size(integer_mask)))" end diff --git a/src/search/index_storage.jl b/src/search/index_storage.jl index 82cb9db..fb207ef 100644 --- a/src/search/index_storage.jl +++ b/src/search/index_storage.jl @@ -21,7 +21,7 @@ function IndexScorer(index_path::String) @info "Loading the index from {index_path}." # loading the config from the index path - config = JLD2.load(joinpath(index_path, "config.jld2"))["config"] + config = load_config(index_path) # the metadata metadata_path = joinpath(index_path, "metadata.json") @@ -46,7 +46,7 @@ function IndexScorer(index_path::String) # loading all compressed embeddings num_embeddings = metadata["num_embeddings"] - dim, nbits = config.doc_settings.dim, config.indexing_settings.nbits + dim, nbits = config.dim, config.nbits @assert (dim * nbits) % 8==0 "(dim, nbits): $((dim, nbits))" codes = zeros(UInt32, num_embeddings) residuals = zeros(UInt8, Int((dim / 8) * nbits), num_embeddings) @@ -89,7 +89,7 @@ end Return a candidate set of `pids` for the query matrix `Q`. This is done as follows: the nearest `nprobe` centroids for each query embedding are found. This list is then flattened and the unique set of these centroids is built. Using the `ivf`, the list of all unique embedding IDs contained in these centroids is computed. Finally, these embedding IDs are converted to `pids` using `emb2pid`. This list of `pids` is the final candidate set. """ function retrieve(ranker::IndexScorer, config::ColBERTConfig, Q::AbstractArray{Float32}) - @assert isequal(size(Q)[2], config.query_settings.query_maxlen) "size(Q): $(size(Q)), query_maxlen: $(config.query_settings.query_maxlen)" # Q: (128, 32, 1) + @assert isequal(size(Q)[2], config.query_maxlen) "size(Q): $(size(Q)), query_maxlen: $(config.query_maxlen)" # Q: (128, 32, 1) Q = reshape(Q, size(Q)[1:end .!= end]...) # squeeze out the last dimension @assert isequal(length(size(Q)), 2) "size(Q): $(size(Q))" @@ -98,7 +98,7 @@ function retrieve(ranker::IndexScorer, config::ColBERTConfig, Q::AbstractArray{F cells = Flux.gpu(transpose(Q)) * Flux.gpu(ranker.codec.centroids) |> Flux.cpu # TODO: how to take topk entries using GPU code? cells = mapslices( - row -> partialsortperm(row, 1:(config.search_settings.nprobe), rev = true), + row -> partialsortperm(row, 1:(config.nprobe), rev = true), cells, dims = 2) # take top nprobe centroids for each query centroid_ids = sort(unique(vec(cells))) @@ -144,7 +144,7 @@ function score_pids(ranker::IndexScorer, config::ColBERTConfig, # decompress these codes and residuals to get the original embeddings D_packed = decompress(ranker.codec, codes_packed, residuals_packed) @assert ndims(D_packed)==2 "ndims(D_packed): $(ndims(D_packed))" - @assert size(D_packed)[1]==config.doc_settings.dim "size(D_packed): $(size(D_packed)), config.doc_settings.dim: $(config.doc_settings.dim)" + @assert size(D_packed)[1]==config.dim "size(D_packed): $(size(D_packed)), config.dim: $(config.dim)" @assert size(D_packed)[2]==num_embs "size(D_packed): $(size(D_packed)), num_embs: $(num_embs)" @assert D_packed isa AbstractMatrix{Float32} "$(typeof(D_packed))" diff --git a/src/searching.jl b/src/searching.jl index 08329db..3c93990 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -10,11 +10,11 @@ function Searcher(index_path::String) end # loading the config from the path - config = JLD2.load(joinpath(index_path, "config.jld2"))["config"] + config = load_config(index_path) # loading the model and saving it to prevent multiple loads @info "Loading ColBERT layers from HuggingFace." - base_colbert = BaseColBERT(config.resource_settings.checkpoint, config) + base_colbert = BaseColBERT(config.checkpoint, config) checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), QueryTokenizer(base_colbert.tokenizer, config), config) From 5fb34622185e4ab000ad72971ab31a1ea6da1760 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 13:37:20 +0530 Subject: [PATCH 05/59] Updating the `BaseColBERT` constructor to only take the config. Also updating the examples. --- src/modelling/checkpoint.jl | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 6496a38..455e729 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -17,10 +17,12 @@ A [`BaseColBERT`](@ref) object. # Examples -The `config` in the below example is taken from the example in [`ColBERTConfig`](@ref). - ```julia-repl -julia> base_colbert = BaseColBERT(checkpoint, config); +julia> using ColBERT, CUDA; + +julia> config = ColBERTConfig(use_gpu = true); + +julia> base_colbert = BaseColBERT(config); julia> base_colbert.bert HGFBertModel( @@ -52,23 +54,24 @@ HGFBertModel( ), LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters ), - ), # Total: 192 arrays, 85_054_464 parameters, 324.477 MiB. + ), # Total: 192 arrays, 85_054_464 parameters, 40.422 KiB. Branch{(:pooled,) = (:hidden_state,)}( BertPooler(Dense(σ = NNlib.tanh_fast, W = (768, 768), b = true)), # 590_592 parameters ), -) # Total: 199 arrays, 109_482_240 parameters, 417.664 MiB. +) # Total: 199 arrays, 109_482_240 parameters, 43.578 KiB. julia> base_colbert.linear Dense(W = (768, 128), b = true) # 98_432 parameters julia> base_colbert.tokenizer -BertTextEncoder( +TrfTextEncoder( ├─ TextTokenizer(MatchTokenization(WordPieceTokenization(bert_uncased_tokenizer, WordPiece(vocab_size = 30522, unk = [UNK], max_char = 100)), 5 patterns)), ├─ vocab = Vocab{String, SizedArray}(size = 30522, unk = [UNK], unki = 101), -├─ startsym = [CLS], -├─ endsym = [SEP], -├─ padsym = [PAD], -├─ trunc = 512, +├─ config = @NamedTuple{startsym::String, endsym::String, padsym::String, trunc::Union{Nothing, Int64}}(("[CLS]", "[SEP]", "[PAD]", 512)), +├─ annotate = annotate_strings, +├─ onehot = lookup_first, +├─ decode = nestedcall(remove_conti_prefix), +├─ textprocess = Pipelines(target[token] := join_text(source); target[token] := nestedcall(cleanup ∘ remove_prefix_space, target.token); target := (target.token)), └─ process = Pipelines: ╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source) ╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token) @@ -78,7 +81,9 @@ BertTextEncoder( ╰─ target[token] := TextEncodeBase.nested2batch(target.token) ╰─ target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment) ╰─ target[segment] := TextEncodeBase.nested2batch(target.segment) - ╰─ target := (target.token, target.segment, target.attention_mask) + ╰─ target[sequence_mask] := identity(target.attention_mask) + ╰─ target := (target.token, target.segment, target.attention_mask, target.sequence_mask) + ``` """ struct BaseColBERT @@ -87,9 +92,10 @@ struct BaseColBERT tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder end -function BaseColBERT(checkpoint::String, config::ColBERTConfig) +function BaseColBERT(config::ColBERTConfig) # since Transformers.jl doesn't support local loading - # we manually load the linear layer + # we manually load the linear layers + checkpoint = config.checkpoint bert_config = Transformers.load_config(checkpoint) bert_state_dict = HuggingFace.load_state_dict(checkpoint) bert_model = HuggingFace.load_model( From a4d3ea04fe3107176eac675091bbc950cb654772 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 14:03:34 +0530 Subject: [PATCH 06/59] Removing the `config` and `doc`/`query_tokenizers` from `Checkpoint`, as they aren't really needed. --- src/modelling/checkpoint.jl | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 455e729..6104202 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -3,7 +3,8 @@ bert::Transformers.HuggingFace.HGFBertModel, linear::Transformers.Layers.Dense, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder) -A struct representing the BERT model, linear layer, and the tokenizer used to compute embeddings for documents and queries. +A struct representing the BERT model, linear layer, and the tokenizer used to compute +embeddings for documents and queries. # Arguments @@ -111,17 +112,17 @@ function BaseColBERT(config::ColBERTConfig) end """ - Checkpoint(model::BaseColBERT, doc_tokenizer::DocTokenizer, config::ColBERTConfig) + Checkpoint(model::BaseColBERT, config::ColBERTConfig) -A wrapper for [`BaseColBERT`](@ref), which includes a [`ColBERTConfig`](@ref) and tokenization-specific functions via the [`DocTokenizer`](@ref) and [`QueryTokenizer`] types. +A wrapper for [`BaseColBERT`](@ref), containing information for generating embeddings +for docs and queries. -If the config's [`DocSettings`](@ref) are configured to mask punctuations, then the `skiplist` property of the created [`Checkpoint`](@ref) will be set to a list of token IDs of punctuations. +If the `config` is set to mask punctuations, then the `skiplist` property of the created +[`Checkpoint`](@ref) will be set to a list of token IDs of punctuations. Otherwise, it will be empty. # Arguments - `model`: The [`BaseColBERT`](@ref) to be wrapped. - - `doc_tokenizer`: A [`DocTokenizer`](@ref) used for functions related to document tokenization. - - `query_tokenizer`: A [`QueryTokenizer`](@ref) used for functions related to query tokenization. - `config`: The underlying [`ColBERTConfig`](@ref). # Returns @@ -133,10 +134,9 @@ The created [`Checkpoint`](@ref). Continuing from the example for [`BaseColBERT`](@ref): ```julia-repl -julia> checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), - QueryTokenizer(base_colbert.tokenizer, config), config) +julia> checkpoint = Checkpoint(base_colbert, config) -julia> checkPoint.skiplist # by default, all punctuations +julia> checkpoint.skiplist # by default, all punctuations 32-element Vector{Int64}: 1000 1001 @@ -171,22 +171,18 @@ julia> checkPoint.skiplist # by default, all punctuations """ struct Checkpoint model::BaseColBERT - doc_tokenizer::DocTokenizer - query_tokenizer::QueryTokenizer - config::ColBERTConfig - skiplist::Union{Missing, Vector{Int64}} + skiplist::Vector{Int64} end -function Checkpoint(model::BaseColBERT, doc_tokenizer::DocTokenizer, - query_tokenizer::QueryTokenizer, config::ColBERTConfig) +function Checkpoint(model::BaseColBERT, config::ColBERTConfig) if config.mask_punctuation punctuation_list = string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")) skiplist = [TextEncodeBase.lookup(model.tokenizer.vocab, punct) for punct in punctuation_list] else - skiplist = missing + skiplist = Vector{Int64}() end - Checkpoint(model, doc_tokenizer, query_tokenizer, config, skiplist) + Checkpoint(model, skiplist) end """ From 408a1cdf53325c0a356e1bcb4c3535c9dd5bfaa3 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 15:32:11 +0530 Subject: [PATCH 07/59] Removing the `DocTokenizer` type (it's unnecessary), and changing `tensorize` to `tensorize_docs`. Also moving util functions to `utils.jl`. --- .../tokenization/doc_tokenization.jl | 177 +++--------------- src/utils/utils.jl | 62 ++++++ 2 files changed, 85 insertions(+), 154 deletions(-) diff --git a/src/modelling/tokenization/doc_tokenization.jl b/src/modelling/tokenization/doc_tokenization.jl index 8b052a0..998c05f 100644 --- a/src/modelling/tokenization/doc_tokenization.jl +++ b/src/modelling/tokenization/doc_tokenization.jl @@ -1,57 +1,30 @@ """ - DocTokenizer(tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, - config::ColBERTConfig) - -Construct a `DocTokenizer` from a given tokenizer and configuration. The resulting structure supports functions to perform CoLBERT-style document operations on document texts. - -# Arguments - - - `tokenizer`: A tokenizer that has been trained on the BERT vocabulary. Fetched from HuggingFace. - - `config`: The underlying [`ColBERTConfig`](@ref). - -# Returns - -A `DocTokenizer` object. -""" -struct DocTokenizer - D_marker_token_id::Int - config::ColBERTConfig -end - -function DocTokenizer(tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, - config::ColBERTConfig) - D_marker_token_id = TextEncodeBase.lookup( - tokenizer.vocab, config.doc_token_id) - DocTokenizer(D_marker_token_id, config) -end - -""" - tensorize(doc_tokenizer::DocTokenizer, + tensorize_docs(config::ColBERTConfig, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}, bsize::Union{Missing, Int}) Convert a collection of documents to tensors in the ColBERT format. -This function adds the document marker token at the beginning of each document and then converts the text data into integer IDs and masks using the `tokenizer`. The returned objects are determined by the `bsize` argument. More specifically: - -- If `bsize` is missing, then a tuple `integer_ids, integer_mask` is returned, where `integer_ids` is an `Array` of token IDs for the modified documents, and `integer_mask` is an `Array` of attention masks for each document. -- If `bsize` is not missing, then more optimizing operations are performed on the documents. First, the arrays of token IDs and attention masks are sorted by document lengths (this is for more efficient use of GPUs on the batches; see [`_sort_by_length`](@ref)), and a list `reverse_indices` is computed, which remembers the original order of the documents (to reorder them later). The arrays of token IDs and attention masks are then batched into batches of size `bsize` (see [`_split_into_batches`](@ref)). Finally, the batches along with the list of `reverse_indices` are returned. +This function adds the document marker token at the beginning of each document +and then converts the text data into integer IDs and masks using the `tokenizer`. +Some optimizing operations are performed on the documents. First, the arrays of +token IDs and attention masks are sorted by document lengths (this is for more +efficient use of GPUs on the batches; see [`_sort_by_length`](@ref)), and a list +`reverse_indices` is computed, which remembers the original order of the documents +(to reorder them later). The arrays of token IDs and attention masks are then +batched into batches of size `bsize` (see [`_split_into_batches`](@ref)). +Finally, the batches along with the list of `reverse_indices` are returned. # Arguments -- `doc_tokenizer`: An instance of the `DocTokenizer` type. This object contains information about the document marker token ID. +- `config`: The `ColBERTConfig` to be used to fetch the document marker token ID. - `tokenizer`: The tokenizer which is used to convert text data into integer IDs. - `batch_text`: A document texts that will be converted into tensors of token IDs. -- `bsize`: The size of the batches to split the `batch_text` into. Can also be `missing`. +- `bsize`: The size of the batches to split the `batch_text` into. # Returns -If `bsize` is `missing`, then a tuple is returned, which contains: - -- `integer_ids`: An `Array` of integer IDs representing the token IDs of the documents in the input collection. It has shape `(L, N)`, where `L` is the length of the largest document in `batch_text` (i.e the document with the largest number of tokens), and `N` is the number of documents in the batch. -- `integer_mask`: An `Array` of bits representing the attention mask for each document. It has shape `(L, N)`, the same as `integer_ids`. - -If `bsize` is not `missing`, then a tuple containing the following is returned: +A tuple containing the following is returned: - `batches`: A `Vector` of tuples of arrays of token IDs and masks, sorted in the order of document lengths. Each array in each tuple has shape `(L, N)`, where `L` is the length of the largest document in `batch_text`, and `N` is the number of documents in the batch being considered. - `reverse_indices`: A `Vector` containing the indices of the documents in their original order. @@ -59,54 +32,11 @@ If `bsize` is not `missing`, then a tuple containing the following is returned: # Examples ```julia-repl -julia> base_colbert = BaseColBERT("colbert-ir/colbertv2.0", config); +julia> using ColBERT, Transformers; -julia> tokenizer = base_colbert.tokenizer; +julia> config = ColBERTConfig(); -julia> doc_tokenizer = DocTokenizer(tokenizer, config); - -julia> batch_text = [ - "hello world", - "thank you!", - "a", - "this is some longer text, so length should be longer", -]; - -julia> integer_ids, integer_mask = tensorize(doc_tokenizer, tokenizer, batch_text, missing); # no batching - -julia> integer_ids -14×4 reinterpret(Int32, ::Matrix{PrimitiveOneHot.OneHot{0x0000773a}}): - 102 102 102 102 - 3 3 3 3 - 7593 4068 1038 2024 - 2089 2018 103 2004 - 103 1000 1 2071 - 1 103 1 2937 - 1 1 1 3794 - 1 1 1 1011 - 1 1 1 2062 - 1 1 1 3092 - 1 1 1 2324 - 1 1 1 2023 - 1 1 1 2937 - 1 1 1 103 - -julia> integer_mask -14×4 Matrix{Bool}: - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 1 0 1 - 0 1 0 1 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 +julia> tokenizer = Transformers.load_tokenizer(config.checkpoint); julia> batch_text = [ "hello world", @@ -116,12 +46,10 @@ julia> batch_text = [ "this is an even longer document. this is some longer text, so length should be longer", ]; -julia> batches, reverse_indices = tensorize(doc_tokenizer, tokenizer, batch_text, 3) -2-element Vector{Tuple{AbstractArray, AbstractMatrix}}: - (Int32[102 102 102; 3 3 3; … ; 1 1 1; 1 1 1], Bool[1 1 1; 1 1 1; … ; 0 0 0; 0 0 0]) - (Int32[102 102; 3 3; … ; 1 2937; 1 103], Bool[1 1; 1 1; … ; 0 1; 0 1]) +julia> batches, reverse_indices = ColBERT.tensorize_docs(config, tokenizer, batch_text, 3) +(Tuple{AbstractMatrix{Int32}, AbstractMatrix{Bool}}[([102 102 102; 3 3 3; … ; 1 1 1; 1 1 1], [1 1 1; 1 1 1; … ; 0 0 0; 0 0 0]), ([102 102; 3 3; … ; 1 2937; 1 103], [1 1; 1 1; … ; 0 1; 0 1])], [2, 3, 1, 4, 5]) -julia> batches[1][1] # this time they are sorted by length +julia> batches[1][1] # sorted by length 21×3 Matrix{Int32}: 102 102 102 3 3 3 @@ -155,7 +83,7 @@ julia> reverse_indices # the original order ``` """ -function tensorize(doc_tokenizer::DocTokenizer, +function tensorize_docs(config::ColBERTConfig, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}, bsize::Union{Missing, Int}) # placeholder for [D] marker token @@ -171,7 +99,9 @@ function tensorize(doc_tokenizer::DocTokenizer, @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" # adding the [D] marker token ID - integer_ids[2, :] .= doc_tokenizer.D_marker_token_id + D_marker_token_id = TextEncodeBase.lookup( + tokenizer.vocab, config.doc_token_id) + integer_ids[2, :] .= D_marker_token_id if ismissing(bsize) error("Currently bsize can't be missing!") @@ -190,64 +120,3 @@ function tensorize(doc_tokenizer::DocTokenizer, batches, reverse_indices end end - -""" - _sort_by_length( - integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) - -Sort sentences by number of attended tokens, if the number of sentences is larger than `bsize`. - -# Arguments - - - `integer_ids`: The token IDs of documents to be sorted. - - `integer_mask`: The attention masks of the documents to be sorted (attention masks are just bits). - - `bsize`: The size of batches to be considered. - -# Returns - -Depending upon `bsize`, the following are returned: - - - If the number of documents (second dimension of `integer_ids`) is atmost `bsize`, then the `integer_ids` and `integer_mask` are returned unchanged. - - If the number of documents is larger than `bsize`, then the passages are first sorted by the number of attended tokens (figured out from the `integer_mask`), and then the sorted arrays `integer_ids`, `integer_mask` are returned, along with a list of `reverse_indices`, i.e a mapping from the documents to their indices in the original order. -""" -function _sort_by_length( - integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) - batch_size = size(integer_ids)[2] - if batch_size <= bsize - # if the number of passages fits the batch size, do nothing - integer_ids, integer_mask, Vector(1:batch_size) - end - - lengths = vec(sum(integer_mask; dims = 1)) # number of attended tokens in each passage - indices = sortperm(lengths) # get the indices which will sort lengths - reverse_indices = sortperm(indices) # invert the indices list - - integer_ids[:, indices], integer_mask[:, indices], reverse_indices -end - -""" - _split_into_batches( - integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) - -Split the given `integer_ids` and `integer_mask` into batches of size `bsize`. - -# Arguments - - - `integer_ids`: The array of token IDs to batch. - - `integer_mask`: The array of attention masks to batch. - -# Returns - -Batches of token IDs and attention masks, with each batch having size `bsize` (with the possibility of the last batch being smaller). -""" -function _split_into_batches( - integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) - batch_size = size(integer_ids)[2] - batches = Vector{Tuple{AbstractMatrix{Int32}, AbstractMatrix{Bool}}}() - for offset in 1:bsize:batch_size - push!(batches, - (integer_ids[:, offset:min(batch_size, offset + bsize - 1)], - integer_mask[:, offset:min(batch_size, offset + bsize - 1)])) - end - batches -end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index ee1d100..190bee5 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -31,3 +31,65 @@ function batch(group::Vector, bsize::Int; provide_offset::Bool = false) end batches end + + +""" + _sort_by_length( + integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) + +Sort sentences by number of attended tokens, if the number of sentences is larger than `bsize`. + +# Arguments + + - `integer_ids`: The token IDs of documents to be sorted. + - `integer_mask`: The attention masks of the documents to be sorted (attention masks are just bits). + - `bsize`: The size of batches to be considered. + +# Returns + +Depending upon `bsize`, the following are returned: + + - If the number of documents (second dimension of `integer_ids`) is atmost `bsize`, then the `integer_ids` and `integer_mask` are returned unchanged. + - If the number of documents is larger than `bsize`, then the passages are first sorted by the number of attended tokens (figured out from the `integer_mask`), and then the sorted arrays `integer_ids`, `integer_mask` are returned, along with a list of `reverse_indices`, i.e a mapping from the documents to their indices in the original order. +""" +function _sort_by_length( + integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) + batch_size = size(integer_ids)[2] + if batch_size <= bsize + # if the number of passages fits the batch size, do nothing + integer_ids, integer_mask, Vector(1:batch_size) + end + + lengths = vec(sum(integer_mask; dims = 1)) # number of attended tokens in each passage + indices = sortperm(lengths) # get the indices which will sort lengths + reverse_indices = sortperm(indices) # invert the indices list + + integer_ids[:, indices], integer_mask[:, indices], reverse_indices +end + +""" + _split_into_batches( + integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) + +Split the given `integer_ids` and `integer_mask` into batches of size `bsize`. + +# Arguments + + - `integer_ids`: The array of token IDs to batch. + - `integer_mask`: The array of attention masks to batch. + +# Returns + +Batches of token IDs and attention masks, with each batch having size `bsize` (with the possibility of the last batch being smaller). +""" +function _split_into_batches( + integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) + batch_size = size(integer_ids)[2] + batches = Vector{Tuple{AbstractMatrix{Int32}, AbstractMatrix{Bool}}}() + for offset in 1:bsize:batch_size + push!(batches, + (integer_ids[:, offset:min(batch_size, offset + bsize - 1)], + integer_mask[:, offset:min(batch_size, offset + bsize - 1)])) + end + batches +end From 6f1cc7ee928599ce54f79aca3608598638d1d104 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 15:33:24 +0530 Subject: [PATCH 08/59] Adding `config` as an argument to `docs` and `docFromText`; this is a better interface. --- src/modelling/checkpoint.jl | 160 +++++++++++++++++++----------------- 1 file changed, 86 insertions(+), 74 deletions(-) diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 6104202..42384b6 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -84,7 +84,6 @@ TrfTextEncoder( ╰─ target[segment] := TextEncodeBase.nested2batch(target.segment) ╰─ target[sequence_mask] := identity(target.attention_mask) ╰─ target := (target.token, target.segment, target.attention_mask, target.sequence_mask) - ``` """ struct BaseColBERT @@ -205,25 +204,34 @@ An array of booleans indicating whether the corresponding token ID is included i # Examples -Continuing with the example for [`tensorize`](@ref) and the `skiplist` from the example in [`Checkpoint`](@ref). +Continuing with the example for [`tensorize_docs`](@ref) and the `skiplist` from the example in [`Checkpoint`](@ref). ```julia-repl -julia> mask_skiplist(checkPoint.model.tokenizer, integer_ids, checkPoint.skiplist) -14×4 BitMatrix: - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 0 0 1 - 0 1 0 1 - 0 0 0 1 - 0 0 0 0 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 - 0 0 0 1 +julia> integer_ids = batches[1][1]; + +julia> ColBERT.mask_skiplist(checkpoint.model.tokenizer, integer_ids, checkpoint.skiplist) +21×3 BitMatrix: + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 0 1 0 + 0 0 1 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 + 0 0 0 ``` """ function mask_skiplist(tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, @@ -236,13 +244,15 @@ function mask_skiplist(tokenizer::Transformers.TextEncoders.AbstractTransformerT end """ - doc(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, + doc( + config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}) Compute the hidden state of the BERT and linear layers of ColBERT for documents. # Arguments + - `config`: The [`ColBERTConfig`](@ref) being used. - `checkpoint`: The [`Checkpoint`](@ref) containing the layers to compute the embeddings. - `integer_ids`: An array of token IDs to be fed into the BERT model. - `integer_mask`: An array of corresponding attention masks. Should have the same shape as `integer_ids`. @@ -256,30 +266,31 @@ A tuple `D, mask`, where: # Examples -Continuing from the example in [`tensorize`](@ref) and [`Checkpoint`](@ref): +Continuing from the example in [`tensorize_docs`](@ref) and [`Checkpoint`](@ref): ```julia-repl -julia> D, mask = doc(checkPoint, integer_ids, integer_mask); +julia> integer_ids, integer_mask = batches[1] + +julia> D, mask = ColBERT.doc(config, checkpoint, integer_ids, integer_mask); + +julia> typeof(D), size(D) +(CuArray{Float32, 3, CUDA.DeviceMemory}, (128, 21, 3)) julia> mask -1×14×4 BitArray{3}: +1×21×3 CuArray{Bool, 3, CUDA.DeviceMemory}: [:, :, 1] = - 1 1 1 1 1 0 0 0 0 0 0 0 0 0 + 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 [:, :, 2] = - 1 1 1 1 0 1 0 0 0 0 0 0 0 0 + 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 [:, :, 3] = - 1 1 1 1 0 0 0 0 0 0 0 0 0 0 - -[:, :, 4] = - 1 1 1 1 1 1 1 0 1 1 1 1 1 1 + 1 1 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ``` """ -function doc(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, +function doc( + config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}) - use_gpu = checkpoint.config.use_gpu - integer_ids = integer_ids |> Flux.gpu integer_mask = integer_mask |> Flux.gpu @@ -294,7 +305,7 @@ function doc(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, D = D .* mask # clear out embeddings of masked tokens - if !use_gpu + if !config.use_gpu # doing this because normalize gives exact results D = mapslices(v -> iszero(v) ? v : normalize(v), D, dims = 1) # normalize each embedding else @@ -312,8 +323,8 @@ function doc(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, end """ - docFromText( - checkpoint::Checkpoint, docs::Vector{String}, bsize::Union{Missing, Int}) + docFromText(config::ColBERTConfig, checkpoint::Checkpoint, + docs::Vector{String}, bsize::Union{Missing, Int}) Get ColBERT embeddings for `docs` using `checkpoint`. @@ -321,6 +332,7 @@ This function also applies ColBERT-style document pre-processing for each docume # Arguments +- `config`: The [`ColBERTConfig`](@ref) being used. - `checkpoint`: A [`Checkpoint`](@ref) to be used to compute embeddings. - `docs`: A list of documents to get the embeddings for. - `bsize`: A batch size for processing documents in batches. @@ -341,44 +353,45 @@ julia> docs = [ "this is some longer text, so length should be longer", ]; -julia> embs, doclens = docFromText(checkPoint, docs, config.index_bsize) +julia> embs, doclens = ColBERT.docFromText(config, checkpoint, docs, config.index_bsize) (Float32[0.07590997 0.00056472444 … -0.09958261 -0.03259005; 0.08413661 -0.016337946 … -0.061889287 -0.017708546; … ; -0.11584533 0.016651645 … 0.0073241345 0.09233974; 0.043868616 0.084660925 … -0.0294838 -0.08536169], [5 5 4 13]) julia> embs 128×27 Matrix{Float32}: - 0.07591 0.000564724 … -0.0811892 -0.0995826 -0.0325901 - 0.0841366 -0.0163379 -0.0118506 -0.0618893 -0.0177085 - -0.0301104 -0.0128125 0.0138397 -0.0573847 0.177861 - 0.0375673 0.216562 -0.110819 0.00425483 -0.00131543 - 0.0252677 0.151702 -0.0272065 0.0350983 -0.0381015 - 0.00608629 -0.0415363 … 0.122848 0.0747104 0.0836627 - -0.185256 -0.106582 0.0352982 -0.0405874 -0.064156 - -0.0816655 -0.142809 0.0565001 -0.134649 0.00380807 - 0.00471224 0.00444499 0.0112827 0.0253297 0.0665076 - -0.121564 -0.189994 0.0151938 -0.119054 -0.0980481 - 0.157599 0.0919844 … 0.0330667 0.0205288 0.0184296 - 0.0132481 -0.0430333 0.0404867 0.0575921 0.101702 - 0.0695787 0.0281928 -0.0378472 -0.053183 -0.123457 - -0.0933986 -0.0390347 0.0279156 0.0309749 0.00298161 - 0.0458561 0.0729707 0.103661 0.00905471 0.127777 - 0.00452597 0.05959 … 0.148845 0.0569492 0.293592 - ⋮ ⋱ ⋮ - 0.0510929 -0.138272 -0.00646483 -0.0171806 -0.0618908 - 0.128495 0.181198 -0.00408871 0.0274591 0.0343185 - -0.0961544 -0.0223997 0.0117907 -0.0813832 0.038232 - 0.0285498 0.0556695 … -0.0139291 -0.14533 -0.0176019 - 0.011212 -0.164717 0.071643 -0.0662124 0.164667 - -0.00178153 0.0600864 0.120243 0.0490749 0.0562548 - -0.0261783 0.0343851 0.0469064 0.040038 -0.0536367 - -0.0696538 -0.020624 0.0441996 0.0842775 0.0567261 - -0.0940356 -0.106123 … 0.00334512 0.00795235 -0.0439883 - 0.0567849 -0.0312434 -0.113022 0.0616158 -0.0738149 - -0.0143086 0.105833 -0.142671 -0.0430241 -0.0831739 - 0.044704 0.0783603 -0.0413787 0.0315282 -0.171445 - 0.129225 0.112544 0.120684 0.107231 0.119762 - 0.000207455 -0.124472 … -0.0930788 -0.0519733 0.0837618 - -0.115845 0.0166516 0.0577464 0.00732413 0.0923397 - 0.0438686 0.0846609 -0.0967041 -0.0294838 -0.0853617 + 0.0759101 0.00056477 -0.0256841 0.0847256 … 0.0321216 -0.0811892 -0.0995827 -0.03259 + 0.0841366 -0.0163379 -0.0573766 0.0125381 0.0838632 -0.0118507 -0.0618893 -0.0177087 + -0.0301104 -0.0128124 0.0137095 0.00290062 0.0347227 0.0138398 -0.0573847 0.177861 + 0.0375674 0.216562 0.220287 -0.011 -0.0213431 -0.110819 0.00425487 -0.00131534 + 0.0252677 0.151702 0.189658 -0.104252 -0.0654913 -0.0272064 0.0350983 -0.0381015 + 0.00608619 -0.0415363 -0.0479571 0.00884466 … 0.00207629 0.122848 0.0747105 0.0836628 + -0.185256 -0.106582 -0.0394912 -0.119268 0.163837 0.0352982 -0.0405874 -0.064156 + -0.0816655 -0.142809 -0.15595 -0.109608 0.0882721 0.0565001 -0.134649 0.00380792 + 0.00471225 0.00444501 0.0144707 0.0682628 0.0386771 0.0112827 0.0253297 0.0665075 + -0.121564 -0.189994 -0.173724 -0.0678208 -0.0832335 0.0151939 -0.119054 -0.0980481 + 0.157599 0.0919844 0.0748075 -0.122389 … 0.0599421 0.0330669 0.0205288 0.0184296 + 0.0132481 -0.0430333 -0.0679477 0.0918445 0.14166 0.0404866 0.0575921 0.101701 + 0.0695786 0.0281928 0.000234582 0.0570102 -0.137199 -0.0378472 -0.0531831 -0.123457 + -0.0933987 -0.0390347 -0.0274184 -0.0452961 0.14876 0.0279156 0.0309748 0.00298152 + 0.0458562 0.0729707 0.0336343 0.189599 0.0570071 0.103661 0.00905471 0.127777 + 0.00452595 0.05959 0.0768679 -0.036913 … 0.0768966 0.148845 0.0569493 0.293592 + -0.0385804 -0.00754613 0.0375564 0.00207589 -0.0161775 0.133667 0.266788 0.0394272 + ⋮ ⋱ ⋮ + 0.0510928 -0.138272 -0.111771 -0.192081 -0.0312752 -0.00646487 -0.0171807 -0.0618908 + 0.128495 0.181198 0.131882 -0.064132 -0.00662879 -0.00408871 0.027459 0.0343185 + -0.0961544 -0.0223997 0.025595 -0.12089 0.0042998 0.0117906 -0.0813832 0.0382321 + 0.0285496 0.0556695 0.0805605 -0.0728611 … 0.138845 -0.0139292 -0.14533 -0.017602 + 0.0112119 -0.164717 -0.188169 0.0315999 0.112653 0.071643 -0.0662124 0.164667 + -0.0017815 0.0600865 0.0858722 0.00955078 -0.0506793 0.120243 0.0490749 0.0562548 + -0.0261784 0.0343851 0.0447504 -0.105545 -0.0713677 0.0469064 0.040038 -0.0536368 + -0.0696538 -0.020624 -0.0465219 -0.121079 -0.0636235 0.0441996 0.0842775 0.0567261 + -0.0940355 -0.106123 -0.0424943 0.0650131 … 0.00190927 0.00334517 0.00795241 -0.0439884 + 0.0567849 -0.0312434 -0.0715616 0.136271 -0.0648593 -0.113022 0.0616157 -0.0738149 + -0.0143086 0.105833 0.0762297 0.0102708 -0.162572 -0.142671 -0.0430241 -0.0831737 + 0.0447039 0.0783602 0.0957613 0.0603179 0.0415507 -0.0413788 0.0315282 -0.171445 + 0.129225 0.112544 0.0815397 -0.00357054 0.097503 0.120684 0.107231 0.119762 + 0.00020747 -0.124472 -0.120445 -0.0102294 … -0.24173 -0.0930788 -0.0519734 0.0837617 + -0.115845 0.0166517 0.0199255 -0.044735 -0.0353863 0.0577463 0.00732411 0.0923398 + 0.0438687 0.0846609 0.0960215 0.112225 -0.178799 -0.096704 -0.0294837 -0.0853618 julia> doclens 4-element Vector{Int64}: @@ -386,19 +399,18 @@ julia> doclens 5 4 13 - ``` """ -function docFromText( - checkpoint::Checkpoint, docs::Vector{String}, bsize::Union{Missing, Int}) +function docFromText(config::ColBERTConfig, checkpoint::Checkpoint, + docs::Vector{String}, bsize::Union{Missing, Int}) if ismissing(bsize) # integer_ids, integer_mask = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize) # doc(checkpoint, integer_ids, integer_mask) error("Currently bsize cannot be missing!") else - text_batches, reverse_indices = tensorize( - checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize) - batches = [doc(checkpoint, integer_ids, integer_mask) + text_batches, reverse_indices = tensorize_docs( + config, checkpoint.model.tokenizer, docs, bsize) + batches = [doc(config, checkpoint, integer_ids, integer_mask) for (integer_ids, integer_mask) in text_batches] # aggregate all embeddings From 91e251bf1882a72d1f010febd38e5a712be28a25 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 17:24:57 +0530 Subject: [PATCH 09/59] Minor edit in docs. --- src/modelling/tokenization/doc_tokenization.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/modelling/tokenization/doc_tokenization.jl b/src/modelling/tokenization/doc_tokenization.jl index 998c05f..544c1a9 100644 --- a/src/modelling/tokenization/doc_tokenization.jl +++ b/src/modelling/tokenization/doc_tokenization.jl @@ -26,7 +26,10 @@ Finally, the batches along with the list of `reverse_indices` are returned. A tuple containing the following is returned: -- `batches`: A `Vector` of tuples of arrays of token IDs and masks, sorted in the order of document lengths. Each array in each tuple has shape `(L, N)`, where `L` is the length of the largest document in `batch_text`, and `N` is the number of documents in the batch being considered. +- `batches`: A `Vector` of tuples of arrays of token IDs and masks, sorted in the order + of document lengths. Each array in each tuple has shape `(L, N)`, where `L` is the length + of the largest document in `batch_text`, and `N` is the number of documents in the batch + being considered. - `reverse_indices`: A `Vector` containing the indices of the documents in their original order. # Examples From c1d7fbf83b1c6fcd4cd6363d2ddd51799d66e218 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 17:25:30 +0530 Subject: [PATCH 10/59] Removing the `QueryTokenizer` struct, and renaming `tensorize` to `tensorize_queries`. --- .../tokenization/query_tokenization.jl | 74 +++++++------------ 1 file changed, 28 insertions(+), 46 deletions(-) diff --git a/src/modelling/tokenization/query_tokenization.jl b/src/modelling/tokenization/query_tokenization.jl index 89ecd97..eee1e1d 100644 --- a/src/modelling/tokenization/query_tokenization.jl +++ b/src/modelling/tokenization/query_tokenization.jl @@ -1,61 +1,42 @@ """ - QueryTokenizer(tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, - config::ColBERTConfig) - -Construct a `QueryTokenizer` from a given tokenizer and configuration. The resulting structure supports functions to perform CoLBERT-style query operations on query texts, including addition of the query marker token (`"[Q]"`) and the `"[MASK]"` token augmentation. - -# Arguments - - - `tokenizer`: A tokenizer that has been trained on the BERT vocabulary. Fetched from HuggingFace. This tokenizer should be configured to truncate or pad a sequence to the maximum allowed query length given by the config (see [`QuerySettings`](@ref)). - - `config`: The underlying [`ColBERTConfig`](@ref). - -# Returns - -A `QueryTokenizer` object. -""" -struct QueryTokenizer - Q_marker_token_id::Int - mask_token_id::Int - config::ColBERTConfig -end - -function QueryTokenizer( - tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, - config::ColBERTConfig) - Q_marker_token_id = TextEncodeBase.lookup( - tokenizer.vocab, config.query_token_id) - mask_token_id = TextEncodeBase.lookup(tokenizer.vocab, "[MASK]") - QueryTokenizer(Q_marker_token_id, mask_token_id, config) -end - -""" - tensorize(query_tokenizer::DocTokenizer, + tensorize_queries(config::ColBERTConfig, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}, bsize::Union{Missing, Int}) Convert a collection of queries to tensors in the ColBERT format. -This function adds the query marker token at the beginning of each query text and then converts the text data into integer IDs and masks using the `tokenizer`. The returned tensors are batched into sizes given by the `bsize` argument. +This function adds the query marker token at the beginning of each query text +and then converts the text data into integer IDs and masks using the `tokenizer`. +The returned tensors are batched into sizes given by the `bsize` argument. # Arguments - - `query_tokenizer`: An instance of the [`QueryTokenizer`](@ref) type. This object contains information about the query marker token ID and the mask token ID. + - `config`: The [`ColBERTConfig`](@ref) to be used to figure out the query marker token ID. - `tokenizer`: The tokenizer which is used to convert text data into integer IDs. - `batch_text`: A document texts that will be converted into tensors of token IDs. - `bsize`: The size of the batches to split the `batch_text` into. # Returns -`batches`, A `Vector` of tuples of arrays of token IDs and masks corresponding to the query texts. Each array in each tuple has shape `(L, N)`, where `L` is the maximum query length specified by the config (see [`QuerySettings`](@ref)), and `N` is the number of queries in the batch being considered. +`batches`, A `Vector` of tuples of arrays of token IDs and masks corresponding to +the query texts. Each array in each tuple has shape `(L, N)`, where `L` is the +maximum query length specified by the config (see [`ColBERTConfig`](@ref)), and `N` +is the number of queries in the batch being considered. # Examples -In this example, we first fetch the tokenizer from HuggingFace, and then configure the tokenizer to truncate or pad each sequence to the maximum query length specified by the config. Note that, at the time of writing this package, configuring tokenizers in [`Transformers.jl`](https://github.com/chengchingwen/Transformers.jl) doesn't have a clean interface; so, we have to manually configure the tokenizer. The `config` used is the same as in the example for [`ColBERTConfig`](@ref). +In this example, we first fetch the tokenizer from HuggingFace, and then configure the +tokenizer to truncate or pad each sequence to the maximum query length specified by the +config. Note that, at the time of writing this package, configuring tokenizers in +[`Transformers.jl`](https://github.com/chengchingwen/Transformers.jl) doesn't have a +clean interface; so, we have to manually configure the tokenizer. ```julia-repl -julia> base_colbert = BaseColBERT("colbert-ir/colbertv2.0", config); +julia> using ColBERT, Transformers, TextEncodeBase; -julia> tokenizer = base_colbert.tokenizer; +julia> config = ColBERTConfig(); + +julia> tokenizer = Transformers.load_tokenizer(config.checkpoint); julia> process = tokenizer.process; @@ -69,11 +50,9 @@ julia> tokenizer = Transformers.TextEncoders.BertTextEncoder( tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); -julia> query_tokenizer = QueryTokenizer(tokenizer, config); - julia> queries = ["what are white spots on raspberries?"]; -julia> batches = tensorize(query_tokenizer, tokenizer, queries, 128); +julia> batches = ColBERT.tensorize_queries(config, tokenizer, queries, 128); julia> integer_ids, integer_mask = batches[1][1], batches[1][2]; @@ -132,7 +111,7 @@ julia> integer_mask 0 ``` """ -function tensorize(query_tokenizer::QueryTokenizer, +function tensorize_queries(config::ColBERTConfig, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}, bsize::Union{Missing, Int}) if ismissing(bsize) @@ -149,16 +128,19 @@ function tensorize(query_tokenizer::QueryTokenizer, integer_mask = NeuralAttentionlib.getmask(mask, ids)[1, :, :] @assert isequal(size(integer_ids), size(integer_mask)) "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(size(integer_mask))" @assert isequal( - size(integer_ids)[1], query_tokenizer.config.query_maxlen) "size(integer_ids): $(size(integer_ids)), query_maxlen: $(query_tokenizer.config.query_maxlen)" + size(integer_ids)[1], config.query_maxlen) "size(integer_ids): $(size(integer_ids)), query_maxlen: $(query_tokenizer.config.query_maxlen)" @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" # adding the [Q] marker token ID and [MASK] augmentation - integer_ids[2, :] .= query_tokenizer.Q_marker_token_id - integer_ids[integer_ids .== 1] .= query_tokenizer.mask_token_id + Q_marker_token_id = TextEncodeBase.lookup( + tokenizer.vocab, config.query_token_id) + mask_token_id = TextEncodeBase.lookup(tokenizer.vocab, "[MASK]") + integer_ids[2, :] .= Q_marker_token_id + integer_ids[integer_ids .== 1] .= mask_token_id - if query_tokenizer.config.attend_to_mask_tokens - integer_mask[integer_ids .== query_tokenizer.mask_token_id] .= 1 + if config.attend_to_mask_tokens + integer_mask[integer_ids .== mask_token_id] .= 1 @assert isequal(sum(integer_mask), prod(size(integer_mask))) "sum(integer_mask): $(sum(integer_mask)), prod(size(integer_mask)): $(prod(size(integer_mask)))" end From 48720bd04f60ac3f6630ead2b043b755a15495a2 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 17:27:12 +0530 Subject: [PATCH 11/59] Adding the `config` as an argument to the `query` and `queryFromText` functions. --- src/modelling/checkpoint.jl | 204 ++++++++++++++++++++++++------------ 1 file changed, 139 insertions(+), 65 deletions(-) diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 42384b6..928c8ca 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -200,11 +200,15 @@ Otherwise, all tokens are included in the mask. # Returns -An array of booleans indicating whether the corresponding token ID is included in the mask or not. The array has the same shape as `integer_ids`, i.e `(L, N)`, where `L` is the maximum length of any document in `integer_ids` and `N` is the number of documents. +An array of booleans indicating whether the corresponding token ID +is included in the mask or not. The array has the same shape as +`integer_ids`, i.e `(L, N)`, where `L` is the maximum length of +any document in `integer_ids` and `N` is the number of documents. # Examples -Continuing with the example for [`tensorize_docs`](@ref) and the `skiplist` from the example in [`Checkpoint`](@ref). +Continuing with the example for [`tensorize_docs`](@ref) and the +`skiplist` from the example in [`Checkpoint`](@ref). ```julia-repl julia> integer_ids = batches[1][1]; @@ -252,7 +256,7 @@ Compute the hidden state of the BERT and linear layers of ColBERT for documents. # Arguments - - `config`: The [`ColBERTConfig`](@ref) being used. + - `config`: The [`ColBERTConfig`](@ref) being used. - `checkpoint`: The [`Checkpoint`](@ref) containing the layers to compute the embeddings. - `integer_ids`: An array of token IDs to be fed into the BERT model. - `integer_mask`: An array of corresponding attention masks. Should have the same shape as `integer_ids`. @@ -261,8 +265,13 @@ Compute the hidden state of the BERT and linear layers of ColBERT for documents. A tuple `D, mask`, where: - - `D` is an array containing the normalized embeddings for each token in each document. It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of any document and `N` is the total number of documents. - - `mask` is an array containing attention masks for all documents, after masking out any tokens in the `skiplist` of `checkpoint`. It has shape `(1, L, N)`, where `(L, N)` is the same as described above. + - `D` is an array containing the normalized embeddings for each token in each document. + It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer + of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of + any document and `N` is the total number of documents. + - `mask` is an array containing attention masks for all documents, after masking out any + tokens in the `skiplist` of `checkpoint`. It has shape `(1, L, N)`, where `(L, N)` + is the same as described above. # Examples @@ -339,7 +348,10 @@ This function also applies ColBERT-style document pre-processing for each docume # Returns -A tuple `embs, doclens`, where `embs` is an array of embeddings and `doclens` is a `Vector` of document lengths. The array `embs` has shape `(D, N)`, where `D` is the embedding dimension (`128` for ColBERT's linear layer) and `N` is the total number of embeddings across all documents in `docs`. +A tuple `embs, doclens`, where `embs` is an array of embeddings and `doclens` is a `Vector` +of document lengths. The array `embs` has shape `(D, N)`, where `D` is the embedding +dimension (`128` for ColBERT's linear layer) and `N` is the total number of embeddings +across all documents in `docs`. # Examples @@ -444,56 +456,85 @@ function docFromText(config::ColBERTConfig, checkpoint::Checkpoint, end """ - query(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, + query( + config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}) Compute the hidden state of the BERT and linear layers of ColBERT for queries. # Arguments + - `config`: The [`ColBERTConfig`](@ref) to be used. - `checkpoint`: The [`Checkpoint`](@ref) containing the layers to compute the embeddings. - `integer_ids`: An array of token IDs to be fed into the BERT model. - `integer_mask`: An array of corresponding attention masks. Should have the same shape as `integer_ids`. # Returns -`Q`, where `Q` is an array containing the normalized embeddings for each token in the query matrix. It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of any query and `N` is the total number of queries. +`Q`, where `Q` is an array containing the normalized embeddings for each token in the query matrix. +It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer of ColBERT), +and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of any query and `N` is +the total number of queries. # Examples -Continuing from the queries example for [`tensorize`](@ref) and [`Checkpoint`](@ref): +Continuing from the queries example for [`tensorize_queries`](@ref) and [`Checkpoint`](@ref): ```julia-repl -julia> query(checkPoint, integer_ids, integer_mask) -128×32×1 Array{Float32, 3}: +julia> ColBERT.query(config, checkpoint, integer_ids, integer_mask) +128×32×1 CuArray{Float32, 3, CUDA.DeviceMemory}: [:, :, 1] = - 0.0158567 0.169676 0.092745 0.0798617 … 0.115938 0.112977 0.107919 - 0.220185 0.0304873 0.165348 0.150315 0.0168762 0.0178042 0.0200357 - -0.00790007 -0.0192251 -0.0852364 -0.0799609 -0.0777439 -0.0776733 -0.0830504 - -0.109909 -0.170906 -0.0138702 -0.0409767 -0.126037 -0.126829 -0.13149 - -0.0231786 0.0532214 0.0607473 0.0279048 0.117017 0.114073 0.108536 - 0.0620549 0.0465075 0.0821693 0.0606439 … 0.0150612 0.0133353 0.0126583 - -0.0290509 0.143255 0.0306142 0.042658 -0.164401 -0.161857 -0.160327 - 0.0921477 0.0588331 0.250449 0.234636 0.0664076 0.0659837 0.0711357 - 0.0279402 -0.0278357 0.144855 0.147958 0.154552 0.155525 0.163634 - -0.0768143 -0.00587305 0.00543038 0.00443374 -0.11757 -0.112495 -0.11112 - ⋮ ⋱ ⋮ - -0.0859686 0.0623054 0.0974813 0.126841 0.0182795 0.0230549 0.031103 - 0.0392043 0.0162653 0.0926306 0.104053 0.0491495 0.0484318 0.0438132 - -0.0340363 -0.0278066 -0.0181035 -0.0282369 … -0.0617945 -0.0631367 -0.0675882 - 0.013123 0.0565132 -0.0349061 -0.0464192 0.0724731 0.0780166 0.074623 - -0.117425 0.162483 0.11039 0.136364 -0.00538225 -0.00685449 -0.0019436 - -0.0401158 -0.0045094 0.0539569 0.0689953 -0.00518063 -0.00600252 -0.00771469 - 0.0893983 0.0695061 -0.0499409 -0.035411 0.0960932 0.0961893 0.103431 - -0.116265 -0.106331 -0.179832 -0.149728 … -0.0197172 -0.022061 -0.018135 - -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0699095 -0.0684749 -0.0662904 - 0.100019 -0.0618588 0.106134 0.0989047 -0.0556761 -0.0556784 -0.059571 + 0.0158568 0.169676 0.092745 0.0798617 0.153079 … 0.117006 0.115806 0.115938 0.112977 0.107919 + 0.220185 0.0304873 0.165348 0.150315 -0.0116249 0.0173332 0.0165187 0.0168762 0.0178042 0.0200356 + -0.00790017 -0.0192251 -0.0852365 -0.0799609 -0.0465292 -0.0693319 -0.0737462 -0.0777439 -0.0776733 -0.0830504 + -0.109909 -0.170906 -0.0138701 -0.0409766 -0.177391 -0.113141 -0.118738 -0.126037 -0.126829 -0.13149 + -0.0231787 0.0532214 0.0607473 0.0279048 0.0634681 0.112296 0.111831 0.117017 0.114073 0.108536 + 0.0620549 0.0465075 0.0821693 0.0606439 0.0592031 … 0.0167847 0.0148605 0.0150612 0.0133353 0.0126583 + -0.0290508 0.143255 0.0306142 0.0426579 0.129972 -0.17502 -0.169493 -0.164401 -0.161857 -0.160327 + 0.0921475 0.058833 0.250449 0.234636 0.0412965 0.0590262 0.0642577 0.0664076 0.0659837 0.0711358 + 0.0279402 -0.0278357 0.144855 0.147958 -0.0268559 0.161106 0.157629 0.154552 0.155525 0.163634 + -0.0768143 -0.00587302 0.00543038 0.00443376 -0.0134111 -0.126912 -0.123969 -0.11757 -0.112495 -0.11112 + -0.0184337 0.00668561 -0.191863 -0.161345 0.0222466 … -0.103246 -0.10374 -0.107664 -0.107267 -0.114564 + 0.0112104 0.0214651 -0.0923963 -0.0823052 0.0600248 0.103589 0.103387 0.106261 0.105065 0.10409 + 0.110971 0.272576 0.148319 0.143233 0.239578 0.11224 0.107913 0.109914 0.112652 0.108365 + -0.131066 0.0376254 -0.0164237 -0.000193318 0.00344707 -0.0893371 -0.0919217 -0.0969305 -0.0935498 -0.096145 + -0.0402605 0.0350559 0.0162864 0.0269105 0.00968855 -0.0623393 -0.0670097 -0.070679 -0.0655848 -0.0564059 + 0.0799973 0.0482302 0.0712078 0.0792903 0.0108783 … 0.00820444 0.00854873 0.00889943 0.00932721 0.00751066 + -0.137565 -0.0369116 -0.065728 -0.0664102 -0.0238012 0.029041 0.0292468 0.0297059 0.0278639 0.0257616 + 0.0479746 -0.102338 -0.0557072 -0.0833976 -0.0979401 -0.057629 -0.053911 -0.0566325 -0.0568765 -0.0581378 + 0.0656851 0.0195639 0.0288789 0.0559219 0.0315515 0.0472323 0.054771 0.0596156 0.0541802 0.0525933 + 0.0668634 -0.00400549 0.0297102 0.0505045 -0.00082792 0.0414113 0.0400276 0.0361149 0.0325914 0.0260693 + -0.0691096 0.0348577 -0.000312685 0.0232462 -0.00250495 … -0.141874 -0.142026 -0.132163 -0.129679 -0.131122 + -0.0273036 0.0653352 0.0332689 0.017918 0.0875479 0.0500921 0.0471914 0.0469949 0.0434268 0.0442646 + -0.0981665 -0.0296463 -0.0114686 -0.0348033 -0.0468719 -0.0772672 -0.0805913 -0.0809244 -0.0823798 -0.081472 + ⋮ ⋱ ⋮ + 0.0506199 0.00290888 0.047947 0.063503 -0.0072114 0.0360347 0.0326486 0.033966 0.0327732 0.0261081 + -0.0288586 -0.150171 -0.0699125 -0.108002 -0.142865 -0.0775934 -0.072192 -0.0697569 -0.0715358 -0.0683193 + -0.0646991 0.0724608 -0.00767811 -0.0184348 0.0524162 0.0457267 0.0532778 0.0649795 0.0697126 0.0808413 + 0.0445508 0.0296366 0.0325647 0.0521935 0.0436496 0.129031 0.126605 0.12324 0.120497 0.117703 + -0.127301 -0.0224252 -0.00579415 -0.00877803 -0.0140665 … -0.080026 -0.080839 -0.0823464 -0.0803394 -0.0856279 + 0.0304881 0.0396951 0.0798097 0.0736797 0.0800866 0.0426674 0.0411406 0.0460205 0.0460111 0.0532082 + 0.0488798 0.252244 0.0866849 0.098552 0.251561 -0.0236942 -0.035116 -0.0395483 -0.0463498 -0.0494207 + -0.0296798 -0.0494761 0.00688248 0.0264166 -0.0352487 -0.0476357 -0.0435388 -0.0404835 -0.0410673 -0.0367272 + 0.023548 -0.00147361 0.0629259 0.106951 0.0406627 0.00627022 0.00403014 -0.000107777 -0.000898423 0.00296315 + -0.0574151 -0.0875744 -0.103787 -0.114166 -0.103979 … -0.0708782 -0.0700138 -0.0687795 -0.070967 -0.0636385 + 0.0280373 0.149767 -0.0899733 -0.0732524 0.162316 0.022177 0.0183834 0.0201251 0.0197228 0.0219051 + -0.0617143 -0.0573989 -0.0973785 -0.0805046 -0.0525925 0.0997715 0.102691 0.107432 0.108591 0.109502 + -0.0859687 0.0623054 0.0974813 0.126841 0.0595557 0.0187937 0.0191363 0.0182794 0.0230548 0.031103 + 0.0392044 0.0162653 0.0926306 0.104054 0.0509464 0.0559883 0.0553617 0.0491496 0.0484319 0.0438133 + -0.0340362 -0.0278067 -0.0181035 -0.0282369 -0.0490531 … -0.0564175 -0.0562518 -0.0617946 -0.0631367 -0.0675882 + 0.0131229 0.0565131 -0.0349061 -0.0464192 0.0456515 0.0676478 0.0698765 0.0724731 0.0780165 0.0746229 + -0.117425 0.162483 0.11039 0.136364 0.135339 -0.00432259 -0.00508357 -0.00538224 -0.00685447 -0.00194357 + -0.0401157 -0.00450943 0.0539568 0.0689953 -0.00295334 -0.00671544 -0.00322498 -0.00518066 -0.00600254 -0.0077147 + 0.0893984 0.0695061 -0.049941 -0.035411 0.0767663 0.0913505 0.0964841 0.0960931 0.0961892 0.103431 + -0.116265 -0.106331 -0.179832 -0.149728 -0.0913282 … -0.0287848 -0.0275017 -0.0197172 -0.0220611 -0.018135 + -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.180245 -0.0780865 -0.073571 -0.0699094 -0.0684748 -0.0662903 + 0.100019 -0.0618588 0.106134 0.0989047 -0.0885639 -0.0547317 -0.0553563 -0.055676 -0.0556784 -0.0595709 ``` """ -function query(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, +function query( + config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}) - use_gpu = checkpoint.config.use_gpu - integer_ids = integer_ids |> Flux.gpu integer_mask = integer_mask |> Flux.gpu @@ -509,7 +550,7 @@ function query(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, Q = Q .* mask - if !use_gpu + if !config.use_gpu # doing this because normalize gives exact results Q = mapslices(v -> iszero(v) ? v : normalize(v), Q, dims = 1) # normalize each embedding else @@ -531,7 +572,7 @@ function query(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, end """ - queryFromText( + queryFromText(config::ColBERTConfig, checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{Missing, Int}) Get ColBERT embeddings for `queries` using `checkpoint`. @@ -540,13 +581,16 @@ This function also applies ColBERT-style query pre-processing for each query in # Arguments + - `config`: The [`ColBERTConfig`](@ref) to be used. - `checkpoint`: A [`Checkpoint`](@ref) to be used to compute embeddings. - `queries`: A list of queries to get the embeddings for. - `bsize`: A batch size for processing queries in batches. # Returns -`embs`, where `embs` is an array of embeddings. The array `embs` has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for ColBERT's linear layer), `L` is the maximum length of any query in the batch, and `N` is the total number of queries in `queries`. +`embs`, where `embs` is an array of embeddings. The array `embs` has shape `(D, L, N)`, +where `D` is the embedding dimension (`128` for ColBERT's linear layer), `L` is the +maximum length of any query in the batch, and `N` is the total number of queries in `queries`. # Examples @@ -555,34 +599,64 @@ Continuing from the example in [`Checkpoint`](@ref): ```julia-repl julia> queries = ["what are white spots on raspberries?"]; -julia> queryFromText(checkPoint, queries, 128) +julia> ColBERT.queryFromText(config, checkpoint, queries, 128) 128×32×1 Array{Float32, 3}: [:, :, 1] = - 0.0158567 0.169676 0.092745 0.0798617 … 0.115806 0.115938 0.112977 0.107919 - 0.220185 0.0304873 0.165348 0.150315 0.0165188 0.0168762 0.0178042 0.0200357 - -0.00790007 -0.0192251 -0.0852364 -0.0799609 -0.0737461 -0.0777439 -0.0776733 -0.0830504 - -0.109909 -0.170906 -0.0138702 -0.0409767 -0.118738 -0.126037 -0.126829 -0.13149 - -0.0231786 0.0532214 0.0607473 0.0279048 0.111831 0.117017 0.114073 0.108536 - 0.0620549 0.0465075 0.0821693 0.0606439 … 0.0148605 0.0150612 0.0133353 0.0126583 - -0.0290509 0.143255 0.0306142 0.042658 -0.169493 -0.164401 -0.161857 -0.160327 - 0.0921477 0.0588331 0.250449 0.234636 0.0642578 0.0664076 0.0659837 0.0711357 - 0.0279402 -0.0278357 0.144855 0.147958 0.157629 0.154552 0.155525 0.163634 - -0.0768143 -0.00587305 0.00543038 0.00443374 -0.123969 -0.11757 -0.112495 -0.11112 - -0.0184338 0.00668557 -0.191863 -0.161345 … -0.10374 -0.107664 -0.107267 -0.114564 - ⋮ ⋱ ⋮ - -0.0859686 0.0623054 0.0974813 0.126841 0.0191363 0.0182795 0.0230549 0.031103 - 0.0392043 0.0162653 0.0926306 0.104053 0.0553615 0.0491495 0.0484318 0.0438132 - -0.0340363 -0.0278066 -0.0181035 -0.0282369 … -0.0562518 -0.0617945 -0.0631367 -0.0675882 - 0.013123 0.0565132 -0.0349061 -0.0464192 0.0698766 0.0724731 0.0780166 0.074623 - -0.117425 0.162483 0.11039 0.136364 -0.0050836 -0.00538225 -0.00685449 -0.0019436 - -0.0401158 -0.0045094 0.0539569 0.0689953 -0.00322497 -0.00518063 -0.00600252 -0.00771469 - 0.0893983 0.0695061 -0.0499409 -0.035411 0.0964842 0.0960932 0.0961893 0.103431 - -0.116265 -0.106331 -0.179832 -0.149728 … -0.0275017 -0.0197172 -0.022061 -0.018135 - -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0735711 -0.0699095 -0.0684749 -0.0662904 - 0.100019 -0.0618588 0.106134 0.0989047 -0.0553564 -0.0556761 -0.0556784 -0.059571 + 0.0158568 0.169676 0.092745 0.0798617 0.153079 … 0.117734 0.117006 0.115806 0.115938 0.112977 0.107919 + 0.220185 0.0304873 0.165348 0.150315 -0.0116249 0.0181126 0.0173332 0.0165187 0.0168762 0.0178042 0.0200356 + -0.00790017 -0.0192251 -0.0852365 -0.0799609 -0.0465292 -0.0672796 -0.0693319 -0.0737462 -0.0777439 -0.0776733 -0.0830504 + -0.109909 -0.170906 -0.0138701 -0.0409766 -0.177391 -0.10489 -0.113141 -0.118738 -0.126037 -0.126829 -0.13149 + -0.0231787 0.0532214 0.0607473 0.0279048 0.0634681 0.113961 0.112296 0.111831 0.117017 0.114073 0.108536 + 0.0620549 0.0465075 0.0821693 0.0606439 0.0592031 … 0.0174852 0.0167847 0.0148605 0.0150612 0.0133353 0.0126583 + -0.0290508 0.143255 0.0306142 0.0426579 0.129972 -0.175238 -0.17502 -0.169493 -0.164401 -0.161857 -0.160327 + 0.0921475 0.058833 0.250449 0.234636 0.0412965 0.0555153 0.0590262 0.0642577 0.0664076 0.0659837 0.0711358 + 0.0279402 -0.0278357 0.144855 0.147958 -0.0268559 0.162062 0.161106 0.157629 0.154552 0.155525 0.163634 + -0.0768143 -0.00587302 0.00543038 0.00443376 -0.0134111 -0.128129 -0.126912 -0.123969 -0.11757 -0.112495 -0.11112 + -0.0184337 0.00668561 -0.191863 -0.161345 0.0222466 … -0.102283 -0.103246 -0.10374 -0.107664 -0.107267 -0.114564 + 0.0112104 0.0214651 -0.0923963 -0.0823052 0.0600248 0.103233 0.103589 0.103387 0.106261 0.105065 0.10409 + 0.110971 0.272576 0.148319 0.143233 0.239578 0.109759 0.11224 0.107913 0.109914 0.112652 0.108365 + -0.131066 0.0376254 -0.0164237 -0.000193318 0.00344707 -0.0862689 -0.0893371 -0.0919217 -0.0969305 -0.0935498 -0.096145 + -0.0402605 0.0350559 0.0162864 0.0269105 0.00968855 -0.0587467 -0.0623393 -0.0670097 -0.070679 -0.0655848 -0.0564059 + 0.0799973 0.0482302 0.0712078 0.0792903 0.0108783 … 0.00501423 0.00820444 0.00854873 0.00889943 0.00932721 0.00751066 + -0.137565 -0.0369116 -0.065728 -0.0664102 -0.0238012 0.0250844 0.029041 0.0292468 0.0297059 0.0278639 0.0257616 + 0.0479746 -0.102338 -0.0557072 -0.0833976 -0.0979401 -0.0583169 -0.057629 -0.053911 -0.0566325 -0.0568765 -0.0581378 + 0.0656851 0.0195639 0.0288789 0.0559219 0.0315515 0.03907 0.0472323 0.054771 0.0596156 0.0541802 0.0525933 + 0.0668634 -0.00400549 0.0297102 0.0505045 -0.00082792 0.0399623 0.0414113 0.0400276 0.0361149 0.0325914 0.0260693 + -0.0691096 0.0348577 -0.000312685 0.0232462 -0.00250495 … -0.146082 -0.141874 -0.142026 -0.132163 -0.129679 -0.131122 + -0.0273036 0.0653352 0.0332689 0.017918 0.0875479 0.0535029 0.0500921 0.0471914 0.0469949 0.0434268 0.0442646 + -0.0981665 -0.0296463 -0.0114686 -0.0348033 -0.0468719 -0.0741133 -0.0772672 -0.0805913 -0.0809244 -0.0823798 -0.081472 + -0.0262739 0.109895 0.0117273 0.0222689 0.100869 0.0119844 0.0132486 0.012956 0.0175875 0.013171 0.0195091 + 0.0861164 0.0799029 0.00381147 0.0170927 0.103322 0.0238912 0.0209658 0.0226638 0.0209905 0.0230679 0.0221191 + 0.125112 0.0880232 0.0351989 0.022897 0.0862715 … -0.0219898 -0.0238914 -0.0207844 -0.0229276 -0.0238033 -0.0236367 + ⋮ ⋱ ⋮ + -0.158838 0.0415251 -0.0584126 -0.0373528 0.0819274 -0.212757 -0.214835 -0.213414 -0.212899 -0.215478 -0.210674 + -0.039636 -0.0837763 -0.0837142 -0.0597521 -0.0868467 0.0309127 0.0339911 0.03399 0.0313526 0.0316408 0.0309661 + 0.0755214 0.0960326 0.0858578 0.0614626 0.111979 … 0.102411 0.101302 0.108277 0.109034 0.107593 0.111863 + 0.0506199 0.00290888 0.047947 0.063503 -0.0072114 0.0388324 0.0360347 0.0326486 0.033966 0.0327732 0.0261081 + -0.0288586 -0.150171 -0.0699125 -0.108002 -0.142865 -0.0811611 -0.0775934 -0.072192 -0.0697569 -0.0715358 -0.0683193 + -0.0646991 0.0724608 -0.00767811 -0.0184348 0.0524162 0.046386 0.0457267 0.0532778 0.0649795 0.0697126 0.0808413 + 0.0445508 0.0296366 0.0325647 0.0521935 0.0436496 0.125633 0.129031 0.126605 0.12324 0.120497 0.117703 + -0.127301 -0.0224252 -0.00579415 -0.00877803 -0.0140665 … -0.0826691 -0.080026 -0.080839 -0.0823464 -0.0803394 -0.0856279 + 0.0304881 0.0396951 0.0798097 0.0736797 0.0800866 0.0448139 0.0426674 0.0411406 0.0460205 0.0460111 0.0532082 + 0.0488798 0.252244 0.0866849 0.098552 0.251561 -0.0212669 -0.0236942 -0.035116 -0.0395483 -0.0463498 -0.0494207 + -0.0296798 -0.0494761 0.00688248 0.0264166 -0.0352487 -0.0486577 -0.0476357 -0.0435388 -0.0404835 -0.0410673 -0.0367272 + 0.023548 -0.00147361 0.0629259 0.106951 0.0406627 0.00599323 0.00627022 0.00403014 -0.000107777 -0.000898423 0.00296315 + -0.0574151 -0.0875744 -0.103787 -0.114166 -0.103979 … -0.0697383 -0.0708782 -0.0700138 -0.0687795 -0.070967 -0.0636385 + 0.0280373 0.149767 -0.0899733 -0.0732524 0.162316 0.0233808 0.022177 0.0183834 0.0201251 0.0197228 0.0219051 + -0.0617143 -0.0573989 -0.0973785 -0.0805046 -0.0525925 0.0936075 0.0997715 0.102691 0.107432 0.108591 0.109502 + -0.0859687 0.0623054 0.0974813 0.126841 0.0595557 0.0244905 0.0187937 0.0191363 0.0182794 0.0230548 0.031103 + 0.0392044 0.0162653 0.0926306 0.104054 0.0509464 0.0516558 0.0559883 0.0553617 0.0491496 0.0484319 0.0438133 + -0.0340362 -0.0278067 -0.0181035 -0.0282369 -0.0490531 … -0.0528032 -0.0564175 -0.0562518 -0.0617946 -0.0631367 -0.0675882 + 0.0131229 0.0565131 -0.0349061 -0.0464192 0.0456515 0.0670016 0.0676478 0.0698765 0.0724731 0.0780165 0.0746229 + -0.117425 0.162483 0.11039 0.136364 0.135339 -0.00589512 -0.00432259 -0.00508357 -0.00538224 -0.00685447 -0.00194357 + -0.0401157 -0.00450943 0.0539568 0.0689953 -0.00295334 -0.0122461 -0.00671544 -0.00322498 -0.00518066 -0.00600254 -0.0077147 + 0.0893984 0.0695061 -0.049941 -0.035411 0.0767663 0.0880484 0.0913505 0.0964841 0.0960931 0.0961892 0.103431 + -0.116265 -0.106331 -0.179832 -0.149728 -0.0913282 … -0.0318565 -0.0287848 -0.0275017 -0.0197172 -0.0220611 -0.018135 + -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.180245 -0.0800835 -0.0780865 -0.073571 -0.0699094 -0.0684748 -0.0662903 + 0.100019 -0.0618588 0.106134 0.0989047 -0.0885639 -0.0577217 -0.0547317 -0.0553563 -0.055676 -0.0556784 -0.0595709 ``` """ -function queryFromText( +function queryFromText(config::ColBERTConfig, checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{Missing, Int}) if ismissing(bsize) error("Currently bsize cannot be missing!") @@ -593,7 +667,7 @@ function queryFromText( process = tokenizer.process truncpad_pipe = Pipeline{:token}( TextEncodeBase.trunc_or_pad( - checkpoint.config.query_maxlen, "[PAD]", :tail, :tail), + config.query_maxlen, "[PAD]", :tail, :tail), :token) process = process[1:4] |> truncpad_pipe |> process[6:end] tokenizer = Transformers.TextEncoders.BertTextEncoder( @@ -601,8 +675,8 @@ function queryFromText( endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc) # get ids and masks, embeddings and returning the concatenated tensors - batches = tensorize(checkpoint.query_tokenizer, tokenizer, queries, bsize) - batches = [query(checkpoint, integer_ids, integer_mask) + batches = tensorize_queries(config, tokenizer, queries, bsize) + batches = [query(config, checkpoint, integer_ids, integer_mask) for (integer_ids, integer_mask) in batches] Q = cat(batches..., dims = 3) From 20f0402f79795c1264407aefd77d259dca255e76 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 21:43:52 +0530 Subject: [PATCH 12/59] Simplyfying the signatures of the setup functions; only using primitive types and the most important types for ColBERT (i.e `Checkpoint` and `ColBERTConfig`). This makes testing easier. --- src/indexing/collection_indexer.jl | 179 ++++++++++++++++++----------- 1 file changed, 112 insertions(+), 67 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 0bd1d73..4f54f7a 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -50,126 +50,171 @@ function CollectionIndexer( end """ - _sample_pids(indexer::CollectionIndexer) + encode_passages( + config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) + +Encode a list of passages using `checkpoint`. + +The given `passages` are run through the underlying BERT model and the linear layer to +generate the embeddings, after doing relevant document-specific preprocessing. +See [`docFromText`](@ref) for more details. + +# Arguments + + - `config`: The [`ColBERTConfig`](@ref) to be used. + - `checkpoint`: The [`Checkpoint`](@ref) used to encode the passages. + - `passages`: A list of strings representing the passages to be encoded. + +# Returns + +A tuple `embs, doclens` where: + + - `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`, + where `D` is the embedding dimension and `N` is the total number of embeddings across all the passages. + - `doclens::AbstractVector{Int}`: A vector of document lengths for each passage, + i.e the total number of attended tokens for each document passage. +""" +function encode_passages( + config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) + @info "Encoding $(length(passages)) passages." + + if length(passages) == 0 + error("The list of passages to encode is empty!") + end + + embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}() + # batching here to avoid storing intermediate embeddings on GPU + # batching also occurs inside docFromText to do batch packing optimizations + for passages_batch in batch(passages, config.index_bsize * 50) + embs_, doclens_ = docFromText(config, checkpoint, passages_batch, + config.index_bsize) + push!(embs, embs_) + append!(doclens, vec(doclens_)) + end + embs = cat(embs..., dims = 2) + embs, doclens +end + +""" + _sample_pids(num_documents::Int) Sample PIDs from the collection to be used to compute clusters using a ``k``-means clustering algorithm. # Arguments - - `indexer`: The collection indexer object containing the collection of passages to be indexed. + - `num_documents`: The total number of documents in the collection. It is assumed that each document has an ID + (aka PID) in the range of integers between `1` and `num_documents` (both inclusive). # Returns A `Set` of `Int`s containing the sampled PIDs. """ -function _sample_pids(indexer::CollectionIndexer) - num_passages = length(indexer.config.collection.data) +function _sample_pids(num_documents::Int) typical_doclen = 120 - num_sampled_pids = 16 * sqrt(typical_doclen * num_passages) - num_sampled_pids = Int(min(1 + floor(num_sampled_pids), num_passages)) - - sampled_pids = Set(sample(1:num_passages, num_sampled_pids)) + num_sampled_pids = 16 * sqrt(typical_doclen * num_documents) + num_sampled_pids = Int(min(1 + floor(num_sampled_pids), num_documents)) + sampled_pids = Set(sample(1:num_documents, num_sampled_pids)) @info "# of sampled PIDs = $(length(sampled_pids))" sampled_pids end """ - _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) - -Compute embeddings for the PIDs sampled by [`_sample_pids`](@ref), compute the average document length using the embeddings, and save the sampled embeddings to disk. + _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, + collection::Vector{String}, sampled_pids::Set{Int}) -The embeddings for the sampled documents are saved in a file named `sample.jld2` with it's path specified by the indexing directory. This embedding array has shape `(D, N)`, where `D` is the embedding dimension (`128`, after applying the linear layer of the ColBERT model) and `N` is the total number of embeddings over all documents. +Compute embeddings for the PIDs sampled by [`_sample_pids`](@ref). -Sample the passages with `pid` in `sampled_pids` from the `collection` and compute the average passage length. The function returns a tuple containing the embedded passages and the average passage length. +The embeddings for the sampled documents are saved in a file named `sample.jld2` with it's path +specified by the indexing directory. This embedding array has shape `(D, N)`, where `D` is the +embedding dimension (`128`, after applying the linear layer of the ColBERT model) and `N` is the +total number of embeddings over all documents. # Arguments - - `indexer`: An instance of `CollectionIndexer`. + - `config`: The [`ColBERTConfig`](@ref) to be used. + - `checkpoint`: The [`Checkpoint`] used to encode the passages. + - `collection`: The underlying collection of passages to get the samples from. - `sampled_pids`: Set of PIDs sampled by [`_sample_pids`](@ref). # Returns The average document length (i.e number of attended tokens) computed from the sampled documents. """ -function _sample_embeddings(indexer::CollectionIndexer, sampled_pids::Set{Int}) +function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, + collection::Vector{String}, sampled_pids::Set{Int}) # collect all passages with pids in sampled_pids - collection = indexer.config.collection sorted_sampled_pids = sort(collect(sampled_pids)) - local_sample = collection.data[sorted_sampled_pids] + local_sample = collection[sorted_sampled_pids] - local_sample_embs, local_sample_doclens = encode_passages(indexer.encoder, local_sample) + # get the local sample embeddings + local_sample_embs, local_sample_doclens = encode_passages( + config, checkpoint, local_sample) @debug "Local sample embeddings shape: $(size(local_sample_embs)), \t Local sample doclens: $(local_sample_doclens)" @assert size(local_sample_embs)[2]==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))" + @assert length(local_sample) == length(local_sample_doclens) - indexer.num_sample_embs = size(local_sample_embs)[2] - indexer.avg_doclen_est = length(local_sample_doclens) > 0 ? - sum(local_sample_doclens) / length(local_sample_doclens) : 0 + num_sample_embs = size(local_sample_embs)[2] + avg_doclen_est = length(local_sample_doclens) > 0 ? + sum(local_sample_doclens) / length(local_sample_doclens) : 0 - sample_path = joinpath(indexer.config.index_path, "sample.jld2") - @info "avg_doclen_est = $(indexer.avg_doclen_est) \t length(local_sample) = $(length(local_sample))" + sample_path = joinpath(config.index_path, "sample.jld2") + @info "avg_doclen_est = $(avg_doclen_est) \t length(local_sample) = $(length(local_sample))" @info "Saving sampled embeddings to $(sample_path)." JLD2.save(sample_path, Dict("local_sample_embs" => local_sample_embs)) - indexer.avg_doclen_est + avg_doclen_est end """ - _save_plan(indexer::CollectionIndexer) + setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) -Save the indexing plan to a JSON file. +Initialize the index by computing some indexing-specific estimates and save the indexing plan to disk. -Information about the number of chunks, number of clusters, estimated number of embeddings over all documents and the estimated average document length is saved to a file named `plan.json`, with directory specified by the indexing directory. +The number of chunks into which the document embeddings will be stored is simply computed using the +number of documents and the size of a chunk. A bunch of pids used for initializing the centroids for +the embedding clusters are sampled using the [`_sample_pids`](@ref) and [`_sample_embeddings`](@ref) +functions, and these samples are used to calculate the average document lengths and the estimated number +of embeddings which will be computed across all documents. Finally, the number of clusters to be used +for indexing is computed, and is proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``, +and the indexing plan is saved to `plan.json`, with the path being specified by the indexing directory. # Arguments - - `indexer`: The `CollectionIndexer` object that contains the index plan to be saved. + - `config`: The [`ColBERTConfig`](@ref) being used to set up the indexing. + - `checkpoint`: The [`Checkpoint`](@ref) used to compute embeddings. + - `collection`: The underlying collection of passages to initialize the index for. """ -function _save_plan(indexer::CollectionIndexer) - @info "Saving the index plan to $(indexer.plan_path)." - # TODO: export the config as json as well - open(indexer.plan_path, "w") do io +function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) + chunksize = min(25000, 1 + fld(length(collection), config.nranks)) + num_chunks = cld(length(collection), chunksize) + + # sample passages for training centroids later + sampled_pids = _sample_pids(length(collection)) + avg_doclen_est = _sample_embeddings(config, checkpoint, collection, sampled_pids) + + # computing the number of partitions, i.e clusters + num_passages = length(collection) + num_embeddings_est = num_passages * avg_doclen_est + num_partitions = Int(floor(2^(floor(log2(16 * sqrt(num_embeddings_est)))))) + + @info "Creating $(num_partitions) clusters." + @info "Estimated $(num_embeddings_est) embeddings." + + @info "Saving the index plan to $(joinpath(config.index_path, "plan.json"))." + open(joinpath(config.index_path, "plan.json"), "w") do io JSON.print(io, Dict( - "num_chunks" => indexer.num_chunks, - "num_partitions" => indexer.num_partitions, - "num_embeddings_est" => indexer.num_embeddings_est, - "avg_doclen_est" => indexer.avg_doclen_est + "num_chunks" => num_chunks, + "num_partitions" => num_partitions, + "num_embeddings_est" => num_embeddings_est, + "avg_doclen_est" => avg_doclen_est ), 4 # indent ) end -end - -""" - setup(indexer::CollectionIndexer) - -Initialize `indexer` by computing some indexing-specific estimates and save the indexing plan to disk. - -The number of chunks into which the document embeddings will be stored (`indexer.num_chunks`) is simply computed using the number of documents and the size of a chunk obtained from [`get_chunksize`](@ref). A bunch of pids used for initializing the centroids for the embedding clusters are sampled using the [`_sample_pids`](@ref) and [`_sample_embeddings`](@ref) functions, and these samples are used to calculate the average document lengths and the estimated number of embeddings which will be computed across all documents. Finally, the number of clusters (`indexer.num_partitions`) to be used for indexing is computed, and is proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``, and the indexing plan is saved to `plan.json` (see [`_save_plan`](@ref)) in the indexing directory. - -# Arguments - - - `indexer::CollectionIndexer`: The indexer to be initialized. -""" -function setup(indexer::CollectionIndexer) - collection = indexer.config.collection - indexer.num_chunks = Int(ceil(length(collection.data) / get_chunksize( - collection, indexer.config.nranks))) - - # sample passages for training centroids later - sampled_pids = _sample_pids(indexer) - avg_doclen_est = _sample_embeddings(indexer, sampled_pids) - - # computing the number of partitions, i.e clusters - num_passages = length(indexer.config.collection.data) - indexer.num_embeddings_est = num_passages * avg_doclen_est - indexer.num_partitions = Int(floor(2^(floor(log2(16 * - sqrt(indexer.num_embeddings_est)))))) - - @info "Creating $(indexer.num_partitions) clusters." - @info "Estimated $(indexer.num_embeddings_est) embeddings." - _save_plan(indexer) + @info "Saving the config to the indexing path." + ColBERT.save(config) end """ From bda3168240db6450797e8dc652bdabc36e863bc4 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 22:56:24 +0530 Subject: [PATCH 13/59] Using `JLD2.save_object` instead of `JLD2.save`. --- src/indexing/collection_indexer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 4f54f7a..2e44391 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -160,7 +160,7 @@ function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, sample_path = joinpath(config.index_path, "sample.jld2") @info "avg_doclen_est = $(avg_doclen_est) \t length(local_sample) = $(length(local_sample))" @info "Saving sampled embeddings to $(sample_path)." - JLD2.save(sample_path, Dict("local_sample_embs" => local_sample_embs)) + JLD2.save_object(sample_path, local_sample_embs) avg_doclen_est end From eca7babf35da93f12a6c2a8a3cfa4bb5e4755b40 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 23:26:18 +0530 Subject: [PATCH 14/59] Creating the index dir in the setup function if it doesn't exist. --- src/indexing/collection_indexer.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 2e44391..14ec2ec 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -185,6 +185,8 @@ and the indexing plan is saved to `plan.json`, with the path being specified by - `collection`: The underlying collection of passages to initialize the index for. """ function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) + isdir(config.index_path) || mkdir(config.index_path) + chunksize = min(25000, 1 + fld(length(collection), config.nranks)) num_chunks = cld(length(collection), chunksize) From b33b7bf27d4fbf9dc16c0621d6ac17d504afcbf2 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 23:27:01 +0530 Subject: [PATCH 15/59] Simplyfing all the clustering related code. --- src/indexing/collection_indexer.jl | 42 ++++++++++++++---------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 14ec2ec..34b5834 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -228,16 +228,16 @@ The sample embeddings saved by the [`setup`](@ref) function are loaded, shuffled # Arguments - - `indexer`: The [`CollectionIndexer`](@ref). + - `index_path`: The path of the index. # Returns The tuple `sample, sample_heldout`. """ -function _concatenate_and_split_sample(indexer::CollectionIndexer) +function _concatenate_and_split_sample(index_path::String) # load the sample embeddings - sample_path = joinpath(indexer.config.index_path, "sample.jld2") - sample = JLD2.load(sample_path, "local_sample_embs") + sample_path = joinpath(index_path, "sample.jld2") + sample = JLD2.load_object(sample_path) @debug "Original sample shape: $(size(sample))" # randomly shuffle embeddings @@ -271,19 +271,17 @@ Compute the average residuals and other statistics of the held-out sample embedd A tuple `bucket_cutoffs, bucket_weights, avg_residual`. """ function _compute_avg_residuals( - indexer::CollectionIndexer, centroids::AbstractMatrix{Float32}, + nbits::Int, centroids::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32}) - compressor = ResidualCodec( - indexer.config, centroids, 0.0, Vector{Float32}(), Vector{Float32}()) - codes = compress_into_codes(compressor, heldout) # get centroid codes + codes = compress_into_codes(centroids, heldout) # get centroid codes @assert codes isa AbstractVector{UInt32} "$(typeof(codes))" - heldout_reconstruct = Flux.gpu(compressor.centroids[:, codes]) # get corresponding centroids - heldout_avg_residual = Flux.gpu(heldout) - heldout_reconstruct # compute the residual + heldout_reconstruct = Flux.gpu(centroids[:, codes]) # get corresponding centroids + heldout_avg_residual = Flux.gpu(heldout) - heldout_reconstruct # compute the residual - avg_residual = mean(abs.(heldout_avg_residual), dims = 2) # for each dimension, take mean of absolute values of residuals + avg_residual = mean(abs.(heldout_avg_residual), dims = 2) # for each dimension, take mean of absolute values of residuals # computing bucket weights and cutoffs - num_options = 2^indexer.config.nbits + num_options = 2^nbits quantiles = Vector(0:(num_options - 1)) / num_options bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[2:end], quantiles .+ (0.5 / num_options) @@ -310,24 +308,24 @@ Average residuals and other compression data is computed via the [`_compute_avg_ - `indexer::CollectionIndexer`: The [`CollectionIndexer`](@ref) to be trained. """ -function train(indexer::CollectionIndexer) - sample, heldout = _concatenate_and_split_sample(indexer) +function train(config::ColBERTConfig) + sample, heldout = _concatenate_and_split_sample(config.index_path) @assert sample isa AbstractMatrix{Float32} "$(typeof(sample))" @assert heldout isa AbstractMatrix{Float32} "$(typeof(heldout))" - centroids = kmeans(sample, indexer.num_partitions, - maxiter = indexer.config.kmeans_niters, display = :iter).centers - @assert size(centroids)[2]==indexer.num_partitions "size(centroids): $(size(centroids)), indexer.num_partitions: $(indexer.num_partitions)" + # loading the indexing plan + plan_metadata = JSON.parsefile(joinpath(config.index_path, "plan.json")) + + centroids = kmeans(sample, plan_metadata["num_partitions"], + maxiter = config.kmeans_niters, display = :iter).centers + @assert size(centroids)[2]==plan_metadata["num_partitions"] "size(centroids): $(size(centroids)), num_partitions: $(plan_metadata["num_partitions"])" @assert centroids isa AbstractMatrix{Float32} "$(typeof(centroids))" bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals( - indexer, centroids, heldout) + config.nbits, centroids, heldout) @info "avg_residual = $(avg_residual)" - codec = ResidualCodec( - indexer.config, centroids, avg_residual, bucket_cutoffs, bucket_weights) - indexer.saver.codec = codec - save_codec(indexer.saver) + save_codec(config.index_path, centroids, bucket_cutoffs, bucket_weights, avg_residual) end """ From 314219b8c772f6955da6416e6e1105186355bbba Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 23:28:40 +0530 Subject: [PATCH 16/59] Simplyfing saving and loading of the codec to/from the index path. --- src/indexing/codecs/residual.jl | 48 ++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 1bae993..93a0ab0 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -36,12 +36,48 @@ julia> codec = load_codec(index_path); ``` """ function load_codec(index_path::String) - config = load_config(index_path) - centroids = load(joinpath(index_path, "centroids.jld2"), "centroids") - avg_residual = load(joinpath(index_path, "avg_residual.jld2"), "avg_residual") - buckets = load(joinpath(index_path, "buckets.jld2")) - ResidualCodec(config, centroids, avg_residual, - buckets["bucket_cutoffs"], buckets["bucket_weights"]) + centroids_path = joinpath(index_path, "centroids.jld2") + avg_residual_path = joinpath(index_path, "avg_residual.jld2") + bucket_cutoffs_path = joinpath(index_path, "bucket_cutoffs.jld2") + bucket_weights_path = joinpath(index_path, "bucket_weights.jld2") + @info "Loading codec from $(centroids_path), $(avg_residual_path), $(bucket_cutoffs_path) and $(bucket_weights_path)." + + centroids = JLD2.load_object(centroids_path) + avg_residual = JLD2.load_object(avg_residual_path) + bucket_cutoffs = JLD2.load_object(bucket_cutoffs_path) + bucket_weights = JLD2.load_object(bucket_weights_path) + + centroids, avg_residual, bucket_cutoffs, bucket_weights +end + +""" + save_codec(saver::IndexSaver) + +Save the codec used by the `saver` to disk. + +This will create three files in the directory specified by the indexing path: + + - `centroids.jld2` containing the centroids. + - `avg_residual.jld2` containing the average residual. + - `buckets.jld2` containing the bucket cutoffs and weights. + +Also see [`train`](@ref). + +# Arguments + + - `saver::IndexSaver`: The index saver to use. +""" +function save_codec(index_path::String, centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, bucket_weights::Vector{Float32}, avg_residual::Float32) + centroids_path = joinpath(index_path, "centroids.jld2") + avg_residual_path = joinpath(index_path, "avg_residual.jld2") + bucket_cutoffs_path = joinpath(index_path, "bucket_cutoffs.jld2") + bucket_weights_path = joinpath(index_path, "bucket_weights.jld2") + @info "Saving codec to $(centroids_path), $(avg_residual_path), $(bucket_cutoffs_path) and $(bucket_weights_path)." + + JLD2.save_object(centroids_path, centroids) + JLD2.save_object(avg_residual_path, avg_residual) + JLD2.save_object(bucket_cutoffs_path, bucket_cutoffs) + JLD2.save_object(bucket_weights_path, bucket_weights) end """ From 7ad3d1c265a46cc9fc23fdb3a9d417482ccdb71c Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 23:29:12 +0530 Subject: [PATCH 17/59] Minor change in `compress_into_codes`; using `centroids` as an argument now. --- src/indexing/codecs/residual.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 93a0ab0..1926658 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -97,14 +97,14 @@ A `Vector{UInt32}` of codes, where each code corresponds to the nearest centroid ``` ``` """ -function compress_into_codes(codec::ResidualCodec, embs::AbstractMatrix{Float32}) +function compress_into_codes(centroids::AbstractMatrix{Float32}, embs::AbstractMatrix{Float32}) codes = Vector{UInt32}() - bsize = Int(floor((1 << 29) / size(codec.centroids)[2])) + bsize = Int(floor((1 << 29) / size(centroids)[2])) offset = 1 while (offset <= size(embs)[2]) # batch on the second dimension dot_products = transpose(Flux.gpu(embs[ - :, offset:min(size(embs)[2], offset + bsize - 1)])) * Flux.gpu(codec.centroids) + :, offset:min(size(embs)[2], offset + bsize - 1)])) * Flux.gpu(centroids) indices = (cartesian_index -> cartesian_index.I[2]).(argmax(dot_products, dims = 2)[ :, 1]) append!(codes, indices) From c5f39f341427054bc167e9726281392dcdc86d5f Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 10:30:24 +0530 Subject: [PATCH 18/59] Allowing a custom chunksize. --- src/indexing/collection_indexer.jl | 3 ++- src/infra/config.jl | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 34b5834..abfad58 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -187,7 +187,8 @@ and the indexing plan is saved to `plan.json`, with the path being specified by function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) isdir(config.index_path) || mkdir(config.index_path) - chunksize = min(25000, 1 + fld(length(collection), config.nranks)) + chunksize = 0 + chunksize = ismissing(config.chunksize) ? min(25000, 1 + fld(length(collection), config.nranks)) : config.chunksize num_chunks = cld(length(collection), chunksize) # sample passages for training centroids later diff --git a/src/infra/config.jl b/src/infra/config.jl index d243cba..0ec2365 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -25,6 +25,8 @@ Structure containing config for running and training various components. - `attend_to_mask_tokens`: Whether or not to attend to mask tokens in the query. Default value is false. - `index_path`: Path to save the index files. - `index_bsize`: Batch size used for some parts of indexing. + - `chunksize`: Custom size of a chunk, i.e the number of passages for which data is to be stored in one chunk. Default is `missing`, + in which case `chunksize` is determined from the size of the `collection` and `nranks`. - `nbits`: Number of bits used to compress residuals. - `kmeans_niters`: Number of iterations used for k-means clustering. - `nprobe`: The number of nearest centroids to fetch during a search. Default is `2`. Also see [`retrieve`](@ref). @@ -77,6 +79,7 @@ Base.@kwdef struct ColBERTConfig # indexing settings index_path::String = "" index_bsize::Int = 64 + chunksize::Union{Missing, Int} = missing nbits::Int = 2 kmeans_niters::Int = 20 From bebb4a5a734b92b93c3461f725c7db3b54a532fe Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 12:05:38 +0530 Subject: [PATCH 19/59] Returning a codec `Dict` from `load_codec`. --- src/indexing/codecs/residual.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 1926658..c41bfd7 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -47,7 +47,12 @@ function load_codec(index_path::String) bucket_cutoffs = JLD2.load_object(bucket_cutoffs_path) bucket_weights = JLD2.load_object(bucket_weights_path) - centroids, avg_residual, bucket_cutoffs, bucket_weights + Dict( + "centroids" => centroids, + "avg_residual" => avg_residual, + "bucket_cutoffs" => bucket_cutoffs, + "bucket_weights" => bucket_weights + ) end """ From 07c2b8afa86a437742dab4942b3a0bcfd1cdaf0b Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 12:06:35 +0530 Subject: [PATCH 20/59] Simplyfying the arguments of the `binarize` and `compress` functions; will help in testing. --- src/indexing/codecs/residual.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index c41bfd7..f6022cf 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -135,9 +135,7 @@ Convert a matrix of residual vectors into a matrix of integer residual vector us A matrix of compressed integer residual vectors. """ -function binarize(codec::ResidualCodec, residuals::AbstractMatrix{Float32}) - dim = codec.config.dim - nbits = codec.config.nbits +function binarize(dim::Int, nbits::Int, bucket_cutoffs::Vector{Float32}, residuals::AbstractMatrix{Float32}) num_embeddings = size(residuals)[2] if dim % (nbits * 8) != 0 @@ -145,7 +143,7 @@ function binarize(codec::ResidualCodec, residuals::AbstractMatrix{Float32}) end # need to subtract one here, to preserve the number of options (2 ^ nbits) - bucket_indices = (x -> searchsortedfirst(codec.bucket_cutoffs, x)).(residuals) .- 1 # torch.bucketize + bucket_indices = (x -> searchsortedfirst(bucket_cutoffs, x)).(residuals) .- 1 # torch.bucketize bucket_indices = stack([bucket_indices for i in 1:nbits], dims = 1) # add an nbits-wide extra dimension positionbits = fill(1, (nbits, 1, 1)) for i in 1:nbits @@ -179,18 +177,18 @@ All embeddings are compressed to their nearest centroid IDs and their quantized A tuple containing a vector of codes and the compressed residuals matrix. """ -function compress(codec::ResidualCodec, embs::AbstractMatrix{Float32}) +function compress(centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, dim::Int, nbits::Int, embs::AbstractMatrix{Float32}) codes, residuals = Vector{UInt32}(), Vector{Matrix{UInt8}}() offset = 1 bsize = 1 << 18 while (offset <= size(embs)[2]) # batch on second dimension batch = embs[:, offset:min(size(embs)[2], offset + bsize - 1)] - codes_ = compress_into_codes(codec, batch) # get centroid codes - centroids_ = codec.centroids[:, codes_] # get corresponding centroids + codes_ = compress_into_codes(centroids, batch) # get centroid codes + centroids_ = centroids[:, codes_] # get corresponding centroids residuals_ = batch - centroids_ append!(codes, codes_) - push!(residuals, binarize(codec, residuals_)) + push!(residuals, binarize(dim, nbits, bucket_cutoffs, residuals_)) offset += bsize end residuals = cat(residuals..., dims = 2) From a2870a5d1514401b41019492d416226336b68011 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 12:07:19 +0530 Subject: [PATCH 21/59] Saving the chunksize in the indexing plan metadata. --- src/indexing/collection_indexer.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index abfad58..381e404 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -207,6 +207,7 @@ function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector open(joinpath(config.index_path, "plan.json"), "w") do io JSON.print(io, Dict( + "chunksize" => chunksize, "num_chunks" => num_chunks, "num_partitions" => num_partitions, "num_embeddings_est" => num_embeddings_est, From 4f9353d236affde53dca1c20b735dbd827c2a6d0 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 12:08:09 +0530 Subject: [PATCH 22/59] Simplyfing the `index` and `save_chunk` functions; using mostly primitive types. --- src/indexing/collection_indexer.jl | 73 +++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 11 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 381e404..254ea2e 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -330,6 +330,57 @@ function train(config::ColBERTConfig) save_codec(config.index_path, centroids, bucket_cutoffs, bucket_weights, avg_residual) end +""" + save_chunk(saver::IndexSaver, chunk_idx::Int, offset::Int, + embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) + +Save a single chunk of compressed embeddings and their relevant metadata to disk. + +The codes and compressed residuals for the chunk are saved in files named `.codec.jld2`. The document lengths are saved in a file named `doclens..jld2`. Relevant metadata, including number of documents in the chunk, number of embeddings and the passage offsets are saved in a file named `.metadata.json`. + +# Arguments + + - `saver`: The [`IndexSaver`](@ref) containing relevant information to save the chunk. + - `chunk_idx`: The index of the current chunk being saved. + - `offset`: The offset in the original document collection where this chunk starts. + - `embs`: The embeddings matrix for the current chunk. + - `doclens`: The document lengths vector for the current chunk. +""" +function save_chunk(config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_offset::Int, + embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) + codes, residuals = compress(codec["centroids"], codec["bucket_cutoffs"], config.dim, config.nbits, embs) + path_prefix = joinpath(config.index_path, string(chunk_idx)) + @assert length(codes)==size(embs)[2] "length(codes): $(length(codes)), size(embs): $(size(embs))" + + # saving the compressed embeddings + codes_path = "$(path_prefix).codes.jld2" + residuals_path = "$(path_prefix).residuals.jld2" + @info "Saving compressed codes to $(codes_path) and residuals to $(residuals_path)" + JLD2.save_object(codes_path, codes) + JLD2.save_object(residuals_path, residuals) + + # saving doclens + doclens_path = joinpath( + config.index_path, "doclens.$(chunk_idx).jld2") + @info "Saving doclens to $(doclens_path)" + JLD2.save_object(doclens_path, doclens) + + # the metadata + metadata_path = joinpath( + config.index_path, "$(chunk_idx).metadata.json") + @info "Saving metadata to $(metadata_path)" + open(metadata_path, "w") do io + JSON.print(io, + Dict( + "passage_offset" => passage_offset, + "num_passages" => length(doclens), + "num_embeddings" => length(codes) + ), + 4 # indent + ) + end +end + """ index(indexer::CollectionIndexer; chunksize::Union{Int, Missing} = missing) @@ -342,20 +393,20 @@ The documents are processed in batches of size `chunksize` (see [`enumerate_batc - `indexer`: The [`CollectionIndexer`](@ref) used to build the index. - `chunksize`: Size of a chunk into which the index is to be stored. """ -function index(indexer::CollectionIndexer; chunksize::Union{Int, Missing} = missing) - load_codec!(indexer.saver) # load the codec objects - batches = enumerate_batches( - indexer.config.collection, chunksize = chunksize, - nranks = indexer.config.nranks) - for (chunk_idx, offset, passages) in batches - # TODO: add functionality to not re-write chunks if they already exist! - # TODO: add multiprocessing to this step! - embs, doclens = encode_passages(indexer.encoder, passages) +function index(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) + codec = load_codec(config.index_path) + plan_metadata = JSON.parsefile(joinpath(config.index_path, "plan.json")) + passage_offset = 1 + for chunk_idx in 1:plan_metadata["num_chunks"] + passage_end_offset = min(length(collection), passage_offset + plan_metadata["chunksize"] - 1) + embs, doclens = encode_passages(config, checkpoint, collection[passage_offset:passage_end_offset]) @assert embs isa AbstractMatrix{Float32} "$(typeof(embs))" @assert doclens isa AbstractVector{Int} "$(typeof(doclens))" - @info "Saving chunk $(chunk_idx): \t $(length(passages)) passages and $(size(embs)[2]) embeddings. From offset #$(offset) onward." - save_chunk(indexer.saver, chunk_idx, offset, embs, doclens) + @info "Saving chunk $(chunk_idx): \t $(passage_end_offset - passage_offset + 1) passages and $(size(embs)[2]) embeddings. From passage #$(passage_offset) onward." + save_chunk(config, codec, chunk_idx, passage_offset, embs, doclens) + + passage_offset += plan_metadata["chunksize"] end end From f6c7308ffe6b8d51705ed7ed875d3231448d4e62 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 14:25:50 +0530 Subject: [PATCH 23/59] Removing unused exports. --- src/ColBERT.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/ColBERT.jl b/src/ColBERT.jl index f54a7af..fae58a6 100644 --- a/src/ColBERT.jl +++ b/src/ColBERT.jl @@ -23,23 +23,19 @@ export Collection, Queries # config and other infra include("infra/config.jl") -export RunSettings, TokenizerSettings, ResourceSettings, - DocSettings, QuerySettings, IndexingSettings, - SearchSettings, ColBERTConfig +export ColBERTConfig # models, document/query tokenizers include("modelling/tokenization/doc_tokenization.jl") include("modelling/tokenization/query_tokenization.jl") include("modelling/checkpoint.jl") -export BaseColBERT, Checkpoint, DocTokenizer, QueryTokenizer +export BaseColBERT, Checkpoint # indexer include("indexing/codecs/residual.jl") include("indexing.jl") -include("indexing/collection_encoder.jl") -include("indexing/index_saver.jl") include("indexing/collection_indexer.jl") -export Indexer, CollectionIndexer, index +export Indexer, index # searcher include("search/strided_tensor.jl") From fb5c0c30145a84c749299ac87f703e7d9b54ebb6 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 14:26:28 +0530 Subject: [PATCH 24/59] Simplyfying `load_codes.` --- src/indexing/codecs/residual.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index f6022cf..223dc6a 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -290,11 +290,10 @@ Load the codes from disk for a given chunk index. The codes are stored in the fi A vector of codes for the specified chunk. """ -function load_codes(codec::ResidualCodec, chunk_idx::Int) +function load_codes(index_path::String, chunk_idx::Int) codes_path = joinpath( - codec.config.index_path, "$(chunk_idx).codes.jld2") - codes = JLD2.load(codes_path, "codes") - codes + index_path, "$(chunk_idx).codes.jld2") + JLD2.load_object(codes_path) end function load_residuals(codec::ResidualCodec, chunk_idx::Int) From 0b93025d308bbb89651510cca6102bd4baab99cb Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 14:27:16 +0530 Subject: [PATCH 25/59] Simplyfying the `finalize` functions. --- src/indexing/collection_indexer.jl | 164 ++++++++++++----------------- 1 file changed, 70 insertions(+), 94 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 254ea2e..3d0c816 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -1,54 +1,3 @@ -""" - CollectionIndexer(config::ColBERTConfig, encoder::CollectionEncoder, saver::IndexSaver) - -Structure which performs all the index-building operations, including sampling initial centroids, clustering, computing document embeddings, compressing and building the `ivf`. - -# Arguments - - - `config`: The [`ColBERTConfig`](@ref) used to build the model. - - `encoder`: The [`CollectionEncoder`](@ref) to be used for encoding documents. - - `saver`: The [`IndexSaver`](@ref), responsible for saving the index to disk. - -# Returns - -A [`CollectionIndexer`](@ref) object, containing all indexing-related information. See the [`setup`](@ref), [`train`](@ref), [`index`](@ref) and [`finalize`](@ref) functions for building the index. -""" -mutable struct CollectionIndexer - config::ColBERTConfig - encoder::CollectionEncoder - saver::IndexSaver - plan_path::String - num_chunks::Int - num_embeddings_est::Float64 - num_partitions::Int - num_sample_embs::Int - avg_doclen_est::Float64 - embeddings_offsets::Vector{Int} - num_embeddings::Int - metadata_path::String -end - -function CollectionIndexer( - config::ColBERTConfig, encoder::CollectionEncoder, saver::IndexSaver) - plan_path = joinpath(config.index_path, "plan.json") - metadata_path = joinpath(config.index_path, "metadata.json") - - CollectionIndexer( - config, - encoder, - saver, - plan_path, - 0, # num_chunks - 0.0, # num_embeddings_est - 0, # num_partitions - 0, # num_sample_embs - 0.0, # avg_doclen_est - [], # embeddings_offsets - 0, # num_embeddings - metadata_path - ) -end - """ encode_passages( config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) @@ -411,7 +360,7 @@ function index(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector end """ - finalize(indexer::CollectionIndexer) + finalize(indexer) Finalize the indexing process by saving all files, collecting embedding ID offsets, building IVF, and updating metadata. @@ -421,32 +370,65 @@ See [`_check_all_files_are_saved`](@ref), [`_collect_embedding_id_offset`](@ref) - `indexer::CollectionIndexer`: The [`CollectionIndexer`](@ref) used to finalize the indexing process. """ -function finalize(indexer::CollectionIndexer) - _check_all_files_are_saved(indexer) - _collect_embedding_id_offset(indexer) - _build_ivf(indexer) - _update_metadata(indexer) +function finalize(index_path::String) + _check_all_files_are_saved(index_path) + _collect_embedding_id_offset(index_path) + _build_ivf(index_path) +end + +""" + check_chunk_exists(saver::IndexSaver, chunk_idx::Int) + +Check if the index chunk exists for the given `chunk_idx`. + +# Arguments + + - `saver`: The `IndexSaver` object that contains the indexing settings. + - `chunk_idx`: The index of the chunk to check. + +# Returns + +A boolean indicating whether all relevant files for the chunk exist. +""" +function check_chunk_exists(index_path::String, chunk_idx::Int) + path_prefix = joinpath(index_path, string(chunk_idx)) + codes_path = "$(path_prefix).codes.jld2" + residuals_path = "$(path_prefix).residuals.jld2" + doclens_path = joinpath(index_path, "doclens.$(chunk_idx).jld2") + metadata_path = joinpath(index_path, "$(chunk_idx).metadata.json") + + for file in [codes_path, residuals_path, doclens_path, metadata_path] + if !isfile(file) + return false + end + end + + true end -function _check_all_files_are_saved(indexer::CollectionIndexer) +function _check_all_files_are_saved(index_path::String) + plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + @info "Checking if all files are saved." - for chunk_idx in 1:(indexer.num_chunks) - if !(check_chunk_exists(indexer.saver, chunk_idx)) - @error "Could not find chunk $(chunk_idx)!" + for chunk_idx in 1:(plan_metadata["num_chunks"]) + if !(check_chunk_exists(index_path, chunk_idx)) + @error "Some files for chunk $(chunk_idx) are missing!" end end @info "Found all files!" end -function _collect_embedding_id_offset(indexer::CollectionIndexer) +function _collect_embedding_id_offset(index_path::String) + plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + @info "Collecting embedding ID offsets." passage_offset = 1 embedding_offset = 1 embeddings_offsets = Vector{Int}() - for chunk_idx in 1:(indexer.num_chunks) + for chunk_idx in 1:(plan_metadata["num_chunks"]) metadata_path = joinpath( - indexer.config.index_path, "$(chunk_idx).metadata.json") + index_path, "$(chunk_idx).metadata.json") chunk_metadata = open(metadata_path, "r") do io chunk_metadata = JSON.parse(io) @@ -462,18 +444,32 @@ function _collect_embedding_id_offset(indexer::CollectionIndexer) JSON.print(io, chunk_metadata, 4) end end + num_embeddings = embedding_offset - 1 + @assert length(embeddings_offsets) == plan_metadata["num_chunks"] - indexer.num_embeddings = embedding_offset - 1 - indexer.embeddings_offsets = embeddings_offsets + @info "Saving the indexing metadata." + metadata_path = joinpath(index_path, "metadata.json") + open(metadata_path, "w") do io + JSON.print(io, + # TODO: export the config here as well! + Dict( + "num_embeddings" => num_embeddings, + "embeddings_offsets" => embeddings_offsets + ), + 4 + ) + end end -function _build_ivf(indexer::CollectionIndexer) +function _build_ivf(index_path::String) + plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + @info "Building the centroid to embedding IVF." codes = Vector{UInt32}() @info "Loading codes for each embedding." - for chunk_idx in 1:(indexer.num_chunks) - chunk_codes = load_codes(indexer.saver.codec, chunk_idx) + for chunk_idx in 1:(plan_metadata["num_chunks"]) + chunk_codes = load_codes(index_path, chunk_idx) append!(codes, chunk_codes) end @assert codes isa AbstractVector{UInt32} "$(typeof(codes))" @@ -482,31 +478,11 @@ function _build_ivf(indexer::CollectionIndexer) ivf, values = sortperm(codes), sort(codes) @info "Getting unique codes and their counts." - ivf_lengths = counts(values, 1:(indexer.num_partitions)) + ivf_lengths = counts(values, 1:(plan_metadata["num_partitions"])) @info "Saving the IVF." - ivf_path = joinpath(indexer.config.index_path, "ivf.jld2") - JLD2.save(ivf_path, Dict( - "ivf" => ivf, - "ivf_lengths" => ivf_lengths - )) -end - -function _update_metadata(indexer::CollectionIndexer) - @info "Saving the indexing metadata." - metadata_path = joinpath(indexer.config.index_path, "metadata.json") - - open(metadata_path, "w") do io - JSON.print(io, - # TODO: export the config here as well! - Dict( - "num_chunks" => indexer.num_chunks, - "num_partitions" => indexer.num_partitions, - "num_embeddings" => indexer.num_embeddings, - "avg_doclen" => Int(floor(indexer.num_embeddings / - length(indexer.config.collection.data))) - ), - 4 - ) - end + ivf_path = joinpath(index_path, "ivf.jld2") + ivf_lengths_path = joinpath(index_path, "ivf_lengths.jld2") + JLD2.save_object(ivf_path, ivf) + JLD2.save_object(ivf_lengths_path, ivf_lengths) end From 01d6a959b9eeb877d804a49c6bb1a7b724371a0b Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 14:40:27 +0530 Subject: [PATCH 26/59] Updating docstrings for indexing functions. --- src/indexing/collection_indexer.jl | 86 +++++++++++++++++++----------- 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 3d0c816..49d039d 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -136,8 +136,9 @@ and the indexing plan is saved to `plan.json`, with the path being specified by function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) isdir(config.index_path) || mkdir(config.index_path) - chunksize = 0 - chunksize = ismissing(config.chunksize) ? min(25000, 1 + fld(length(collection), config.nranks)) : config.chunksize + chunksize = 0 + chunksize = ismissing(config.chunksize) ? + min(25000, 1 + fld(length(collection), config.nranks)) : config.chunksize num_chunks = cld(length(collection), chunksize) # sample passages for training centroids later @@ -171,15 +172,17 @@ function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector end """ - _concatenate_and_split_sample(indexer::CollectionIndexer) + _concatenate_and_split_sample(index_path::String) Randomly shuffle and split the sampled embeddings. -The sample embeddings saved by the [`setup`](@ref) function are loaded, shuffled randomly, and then split into a `sample` and a `sample_heldout` set, with `sample_heldout` containing a `0.05` fraction of the original sampled embeddings. +The sample embeddings saved by the [`setup`](@ref) function are loaded, shuffled randomly, +and then split into a `sample` and a `sample_heldout` set, with `sample_heldout` containing +a `0.05` fraction of the original sampled embeddings. # Arguments - - `index_path`: The path of the index. + - `index_path`: The path of the index. # Returns @@ -206,20 +209,25 @@ function _concatenate_and_split_sample(index_path::String) end """ - _compute_avg_residuals(indexer::CollectionIndexer, centroids::AbstractMatrix{Float32}, + _compute_avg_residuals( + nbits::Int, centroids::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32}) Compute the average residuals and other statistics of the held-out sample embeddings. # Arguments - - `indexer`: The underlying [`CollectionIndexer`](@ref). - - `centroids`: A matrix containing the centroids of the computed using a ``k``-means clustering algorithm on the sampled embeddings. Has shape `(D, indexer.num_partitions)`, where `D` is the embedding dimension (`128`) and `indexer.num_partitions` is the number of clusters. - - `heldout`: A matrix containing the held-out embeddings, computed using [`_concatenate_and_split_sample`](@ref). + - `nbits`: The number of bits used to compress the residuals. + - `centroids`: A matrix containing the centroids of the computed using a ``k``-means + clustering algorithm on the sampled embeddings. Has shape `(D, indexer.num_partitions)`, + where `D` is the embedding dimension (`128`) and `indexer.num_partitions` is the number + of clusters. + - `heldout`: A matrix containing the held-out embeddings, computed using + [`_concatenate_and_split_sample`](@ref). # Returns -A tuple `bucket_cutoffs, bucket_weights, avg_residual`. +A tuple `bucket_cutoffs, bucket_weights, avg_residual`, which will be used in compression/decompression of residuals. """ function _compute_avg_residuals( nbits::Int, centroids::AbstractMatrix{Float32}, @@ -249,15 +257,17 @@ function _compute_avg_residuals( end """ - train(indexer::CollectionIndexer) + train(config::ColBERTConfig) -Train a [`CollectionIndexer`](@ref) by computing centroids using a ``k``-means clustering algorithn, and store the compression information on disk. +Compute centroids using a ``k``-means clustering algorithn, and store the compression information +on disk. -Average residuals and other compression data is computed via the [`_compute_avg_residuals`](@ref) function, and the codec is saved on disk using [`save_codec`](@ref). +Average residuals and other compression data is computed via the [`_compute_avg_residuals`](@ref) +function, and the codec is saved on disk using [`save_codec`](@ref). # Arguments - - `indexer::CollectionIndexer`: The [`CollectionIndexer`](@ref) to be trained. + - `config`: The [`ColBERTConfig`](@ref) used to train the indexer. """ function train(config::ColBERTConfig) sample, heldout = _concatenate_and_split_sample(config.index_path) @@ -280,24 +290,30 @@ function train(config::ColBERTConfig) end """ - save_chunk(saver::IndexSaver, chunk_idx::Int, offset::Int, + save_chunk( + config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_offset::Int, embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) Save a single chunk of compressed embeddings and their relevant metadata to disk. -The codes and compressed residuals for the chunk are saved in files named `.codec.jld2`. The document lengths are saved in a file named `doclens..jld2`. Relevant metadata, including number of documents in the chunk, number of embeddings and the passage offsets are saved in a file named `.metadata.json`. +The codes and compressed residuals for the chunk are saved in files named `.codes.jld2`. +and `.residuals.jld2` respectively. The document lengths are saved in a file named +`doclens..jld2`. Relevant metadata, including number of documents in the chunk, +number of embeddings and the passage offsets are saved in a file named `.metadata.json`. # Arguments - - `saver`: The [`IndexSaver`](@ref) containing relevant information to save the chunk. + - `config`: The [`ColBERTConfig`](@ref) being used. - `chunk_idx`: The index of the current chunk being saved. - - `offset`: The offset in the original document collection where this chunk starts. + - `passage_offset`: The index of the first passage in the chunk. - `embs`: The embeddings matrix for the current chunk. - `doclens`: The document lengths vector for the current chunk. """ -function save_chunk(config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_offset::Int, +function save_chunk( + config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_offset::Int, embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) - codes, residuals = compress(codec["centroids"], codec["bucket_cutoffs"], config.dim, config.nbits, embs) + codes, residuals = compress( + codec["centroids"], codec["bucket_cutoffs"], config.dim, config.nbits, embs) path_prefix = joinpath(config.index_path, string(chunk_idx)) @assert length(codes)==size(embs)[2] "length(codes): $(length(codes)), size(embs): $(size(embs))" @@ -331,24 +347,30 @@ function save_chunk(config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_ end """ - index(indexer::CollectionIndexer; chunksize::Union{Int, Missing} = missing) + index(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) Build the index using `indexer`. -The documents are processed in batches of size `chunksize` (see [`enumerate_batches`](@ref)). Embeddings and document lengths are computed for each batch (see [`encode_passages`](@ref)), and they are saved to disk along with relevant metadata (see [`save_chunk`](@ref)). +The documents are processed in batches of size `chunksize`, determined by the config +(see [`ColBERTConfig`](@ref) and [`setup`](@ref)). Embeddings and document lengths are +computed for each batch (see [`encode_passages`](@ref)), and they are saved to disk +along with relevant metadata (see [`save_chunk`](@ref)). # Arguments - - `indexer`: The [`CollectionIndexer`](@ref) used to build the index. - - `chunksize`: Size of a chunk into which the index is to be stored. + - `config`: The [`ColBERTConfig`](@ref) being used. + - `checkpoint`: The [`Checkpoint`](@ref) to compute embeddings. + - `collection`: The collection to index. """ function index(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) codec = load_codec(config.index_path) plan_metadata = JSON.parsefile(joinpath(config.index_path, "plan.json")) passage_offset = 1 for chunk_idx in 1:plan_metadata["num_chunks"] - passage_end_offset = min(length(collection), passage_offset + plan_metadata["chunksize"] - 1) - embs, doclens = encode_passages(config, checkpoint, collection[passage_offset:passage_end_offset]) + passage_end_offset = min( + length(collection), passage_offset + plan_metadata["chunksize"] - 1) + embs, doclens = encode_passages( + config, checkpoint, collection[passage_offset:passage_end_offset]) @assert embs isa AbstractMatrix{Float32} "$(typeof(embs))" @assert doclens isa AbstractVector{Int} "$(typeof(doclens))" @@ -360,15 +382,17 @@ function index(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector end """ - finalize(indexer) + finalize(index_path::String) -Finalize the indexing process by saving all files, collecting embedding ID offsets, building IVF, and updating metadata. +Finalize the indexing process by saving all files, collecting embedding ID offsets, +and building the IVF. -See [`_check_all_files_are_saved`](@ref), [`_collect_embedding_id_offset`](@ref), [`_build_ivf`](@ref) and [`_update_metadata`](@ref) for more details. +See [`_check_all_files_are_saved`](@ref), [`_collect_embedding_id_offset`](@ref), +[`_build_ivf`](@ref) for more details. # Arguments - - `indexer::CollectionIndexer`: The [`CollectionIndexer`](@ref) used to finalize the indexing process. + - `index_path`: The path of the index. """ function finalize(index_path::String) _check_all_files_are_saved(index_path) @@ -445,7 +469,7 @@ function _collect_embedding_id_offset(index_path::String) end end num_embeddings = embedding_offset - 1 - @assert length(embeddings_offsets) == plan_metadata["num_chunks"] + @assert length(embeddings_offsets) == plan_metadata["num_chunks"] @info "Saving the indexing metadata." metadata_path = joinpath(index_path, "metadata.json") From b53801c6f2ccb92480011ec3f979d8e12a08a963 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 14:50:43 +0530 Subject: [PATCH 27/59] Updating docstrings of the functions in `residual.jl`. --- src/indexing/codecs/residual.jl | 83 +++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 223dc6a..654d3e1 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -28,12 +28,13 @@ mutable struct ResidualCodec end """ -# Examples + load_codec(index_path::String) -```julia-repl -julia> codec = load_codec(index_path); +Load compression/decompression information from the index path. -``` +# Arguments + + - `index_path`: The path of the index. """ function load_codec(index_path::String) centroids_path = joinpath(index_path, "centroids.jld2") @@ -56,23 +57,23 @@ function load_codec(index_path::String) end """ - save_codec(saver::IndexSaver) - -Save the codec used by the `saver` to disk. + save_codec( + index_path::String, centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, + bucket_weights::Vector{Float32}, avg_residual::Float32) -This will create three files in the directory specified by the indexing path: - - - `centroids.jld2` containing the centroids. - - `avg_residual.jld2` containing the average residual. - - `buckets.jld2` containing the bucket cutoffs and weights. - -Also see [`train`](@ref). +Save compression/decompression information from the index path. # Arguments - - `saver::IndexSaver`: The index saver to use. + - `index_path`: The path of the index. + - `centroids`: The matrix of centroids of the index. + - `bucket_cutoffs`: Cutoffs used to determine buckets during residual compression. + - `bucket_weights`: Weights used to determine the decompressed values during decompression. + - `avg_residual`: The average residual value, computed from the heldout set (see [`_compute_avg_residuals`](@ref)). """ -function save_codec(index_path::String, centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, bucket_weights::Vector{Float32}, avg_residual::Float32) +function save_codec( + index_path::String, centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, + bucket_weights::Vector{Float32}, avg_residual::Float32) centroids_path = joinpath(index_path, "centroids.jld2") avg_residual_path = joinpath(index_path, "avg_residual.jld2") bucket_cutoffs_path = joinpath(index_path, "bucket_cutoffs.jld2") @@ -86,23 +87,23 @@ function save_codec(index_path::String, centroids::Matrix{Float32}, bucket_cutof end """ - compress_into_codes(codec::ResidualCodec, embs::AbstractMatrix{Float32}) + compress_into_codes( + centroids::AbstractMatrix{Float32}, embs::AbstractMatrix{Float32}) -Compresses a matrix of embeddings into a vector of codes using the given [`ResidualCodec`](@ref), where the code for each embedding is its nearest centroid ID. +Compresses a matrix of embeddings into a vector of codes using the given `centroids`, +where the code for each embedding is its nearest centroid ID. # Arguments - - `codec`: The [`ResidualCodec`](@ref) used to compress the embeddings. + - `centroids`: The matrix of centroids. - `embs`: The matrix of embeddings to be compressed. # Returns A `Vector{UInt32}` of codes, where each code corresponds to the nearest centroid ID for the embedding. - -``` -``` """ -function compress_into_codes(centroids::AbstractMatrix{Float32}, embs::AbstractMatrix{Float32}) +function compress_into_codes( + centroids::AbstractMatrix{Float32}, embs::AbstractMatrix{Float32}) codes = Vector{UInt32}() bsize = Int(floor((1 << 29) / size(centroids)[2])) @@ -122,20 +123,25 @@ function compress_into_codes(centroids::AbstractMatrix{Float32}, embs::AbstractM end """ - binarize(codec::ResidualCodec, residuals::AbstractMatrix{Float32}) + binarize(dim::Int, nbits::Int, bucket_cutoffs::Vector{Float32}, + residuals::AbstractMatrix{Float32}) -Convert a matrix of residual vectors into a matrix of integer residual vector using `nbits` bits (specified by the underlying `config`). +Convert a matrix of residual vectors into a matrix of integer residual vector +using `nbits` bits. # Arguments - - `codec`: A [`ResidualCodec`](@ref) object containing the compression information. - - `residuals`: The matrix of residuals to be converted. + - `dim`: The embedding dimension (see [`ColBERTConfig`](@ref)). + - `nbits`: Number of bits to compress the residuals into. + - `bucket_cutoffs`: Cutoffs used to determine residual buckets. + - `residuals`: The matrix of residuals ot be compressed. # Returns -A matrix of compressed integer residual vectors. +A `AbstractMatrix{UInt8}` of compressed integer residual vectors. """ -function binarize(dim::Int, nbits::Int, bucket_cutoffs::Vector{Float32}, residuals::AbstractMatrix{Float32}) +function binarize(dim::Int, nbits::Int, bucket_cutoffs::Vector{Float32}, + residuals::AbstractMatrix{Float32}) num_embeddings = size(residuals)[2] if dim % (nbits * 8) != 0 @@ -162,22 +168,31 @@ function binarize(dim::Int, nbits::Int, bucket_cutoffs::Vector{Float32}, residua end """ - compress(codec::ResidualCodec, embs::AbstractMatrix{Float32}) + compress(centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, + dim::Int, nbits::Int, embs::AbstractMatrix{Float32}) -Compress a matrix of embeddings into a compact representation using the specified [`ResidualCodec`](@ref). +Compress a matrix of embeddings into a compact representation. -All embeddings are compressed to their nearest centroid IDs and their quantized residual vectors (where the quantization is done in `nbits` bits, specified by the `config` of `codec`). If `emb` denotes an embedding and `centroid` is is nearest centroid, the residual vector is defined to be `emb - centroid`. +All embeddings are compressed to their nearest centroid IDs and +their quantized residual vectors (where the quantization is done +in `nbits` bits). If `emb` denotes an embedding and `centroid` +is is nearest centroid, the residual vector is defined to be +`emb - centroid`. # Arguments - - `codec`: A [`ResidualCodec`](@ref) object containing the centroids and other parameters for the compression algorithm. + - `centroids`: The matrix of centroids. + - `bucket_cutoffs`: Cutoffs used to determine residual buckets. + - `dim`: The embedding dimension (see [`ColBERTConfig`](@ref)). + - `nbits`: Number of bits to compress the residuals into. - `embs`: The input embeddings to be compressed. # Returns A tuple containing a vector of codes and the compressed residuals matrix. """ -function compress(centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, dim::Int, nbits::Int, embs::AbstractMatrix{Float32}) +function compress(centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, + dim::Int, nbits::Int, embs::AbstractMatrix{Float32}) codes, residuals = Vector{UInt32}(), Vector{Matrix{UInt8}}() offset = 1 From 92a46aec10f556c359eefe7c5b47e564018d52c7 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 14:52:24 +0530 Subject: [PATCH 28/59] Removing unnecessary files. --- src/indexing/collection_encoder.jl | 57 ----------- src/indexing/index_saver.jl | 151 ----------------------------- 2 files changed, 208 deletions(-) delete mode 100644 src/indexing/collection_encoder.jl delete mode 100644 src/indexing/index_saver.jl diff --git a/src/indexing/collection_encoder.jl b/src/indexing/collection_encoder.jl deleted file mode 100644 index 4b09903..0000000 --- a/src/indexing/collection_encoder.jl +++ /dev/null @@ -1,57 +0,0 @@ -""" - CollectionEncoder(config::ColBERTConfig, checkpoint::Checkpoint) - -Structure to represent an encoder used to encode document passages to their corresponding embeddings. - -# Arguments - - - `config`: The underlying [`ColBERTConfig`](@ref). - - `checkpoint`: The [`Checkpoint`](@ref) used by the model. - -# Returns - -A [`CollectionEncoder`](@ref). -""" -struct CollectionEncoder - config::ColBERTConfig - checkpoint::Checkpoint -end - -""" - encode_passages(encoder::CollectionEncoder, passages::Vector{String}) - -Encode a list of passages using `encoder`. - -The given `passages` are run through the underlying BERT model and the linear layer to generate the embeddings, after doing relevant document-specific preprocessing. See [`docFromText`](@ref) for more details. - -# Arguments - - - `encoder`: The encoder used to encode the passages. - - `passages`: A list of strings representing the passages to be encoded. - -# Returns - -A tuple `embs, doclens` where: - - - `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`, where `D` is the embedding dimension and `N` is the total number of embeddings across all the passages. - - `doclens::AbstractVector{Int}`: A vector of document lengths for each passage, i.e the total number of attended tokens for each document passage. -""" -function encode_passages(encoder::CollectionEncoder, passages::Vector{String}) - @info "Encoding $(length(passages)) passages." - - if length(passages) == 0 - error("The list of passages to encode is empty!") - end - - embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}() - # batching here to avoid storing intermediate embeddings on GPU - # batching also occurs inside docFromText to do batch packing optimizations - for passages_batch in batch(passages, encoder.config.index_bsize * 50) - embs_, doclens_ = docFromText(encoder.checkpoint, passages_batch, - encoder.config.index_bsize) - push!(embs, embs_) - append!(doclens, vec(doclens_)) - end - embs = cat(embs..., dims = 2) - embs, doclens -end diff --git a/src/indexing/index_saver.jl b/src/indexing/index_saver.jl deleted file mode 100644 index fc343bd..0000000 --- a/src/indexing/index_saver.jl +++ /dev/null @@ -1,151 +0,0 @@ -""" - IndexSaver(config::ColBERTConfig, codec::Union{Missing, ResidualCodec} = missing) - -A structure to load/save various indexing components. - -# Arguments - - - `config`: A [`ColBERTConfig`](@ref). - - `codec`: A codec to encode and decode the embeddings. -""" -Base.@kwdef mutable struct IndexSaver - config::ColBERTConfig - codec::Union{Missing, ResidualCodec} = missing -end - -""" - load_codec!(saver::IndexSaver) - -Load a codec from disk into `saver`. - -The path of of the codec is inferred from the config stored in `saver`. - -# Arguments - - - `saver`: An [`IndexSaver`](@ref) into which the codec is to be loaded. -""" -function load_codec!(saver::IndexSaver) - index_path = saver.config.index_path - centroids = JLD2.load(joinpath(index_path, "centroids.jld2"), "centroids") - avg_residual = JLD2.load(joinpath(index_path, "avg_residual.jld2"), "avg_residual") - buckets = JLD2.load(joinpath(index_path, "buckets.jld2")) - saver.codec = ResidualCodec(saver.config, centroids, avg_residual, - buckets["bucket_cutoffs"], buckets["bucket_weights"]) -end - -""" - save_codec(saver::IndexSaver) - -Save the codec used by the `saver` to disk. - -This will create three files in the directory specified by the indexing path: - - - `centroids.jld2` containing the centroids. - - `avg_residual.jld2` containing the average residual. - - `buckets.jld2` containing the bucket cutoffs and weights. - -Also see [`train`](@ref). - -# Arguments - - - `saver::IndexSaver`: The index saver to use. -""" -function save_codec(saver::IndexSaver) - index_path = saver.config.index_path - centroids_path = joinpath(index_path, "centroids.jld2") - avg_residual_path = joinpath(index_path, "avg_residual.jld2") - buckets_path = joinpath(index_path, "buckets.jld2") - @info "Saving codec to $(centroids_path), $(avg_residual_path) and $(buckets_path)" - - JLD2.save(centroids_path, Dict("centroids" => saver.codec.centroids)) - JLD2.save(avg_residual_path, Dict("avg_residual" => saver.codec.avg_residual)) - JLD2.save( - buckets_path, - Dict( - "bucket_cutoffs" => saver.codec.bucket_cutoffs, - "bucket_weights" => saver.codec.bucket_weights - ) - ) -end - -""" - save_chunk(saver::IndexSaver, chunk_idx::Int, offset::Int, - embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) - -Save a single chunk of compressed embeddings and their relevant metadata to disk. - -The codes and compressed residuals for the chunk are saved in files named `.codec.jld2`. The document lengths are saved in a file named `doclens..jld2`. Relevant metadata, including number of documents in the chunk, number of embeddings and the passage offsets are saved in a file named `.metadata.json`. - -# Arguments - - - `saver`: The [`IndexSaver`](@ref) containing relevant information to save the chunk. - - `chunk_idx`: The index of the current chunk being saved. - - `offset`: The offset in the original document collection where this chunk starts. - - `embs`: The embeddings matrix for the current chunk. - - `doclens`: The document lengths vector for the current chunk. -""" -function save_chunk(saver::IndexSaver, chunk_idx::Int, offset::Int, - embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) - codes, residuals = compress(saver.codec, embs) - path_prefix = joinpath(saver.config.index_path, string(chunk_idx)) - @assert length(codes)==size(embs)[2] "length(codes): $(length(codes)), size(embs): $(size(embs))" - - # saving the compressed embeddings - codes_path = "$(path_prefix).codes.jld2" - residuals_path = "$(path_prefix).residuals.jld2" - @info "Saving compressed codes to $(codes_path) and residuals to $(residuals_path)" - JLD2.save(codes_path, Dict("codes" => codes)) - JLD2.save(residuals_path, Dict("residuals" => residuals)) - - # saving doclens - doclens_path = joinpath( - saver.config.index_path, "doclens.$(chunk_idx).jld2") - @info "Saving doclens to $(doclens_path)" - JLD2.save(doclens_path, Dict("doclens" => doclens)) - - # the metadata - metadata_path = joinpath( - saver.config.index_path, "$(chunk_idx).metadata.json") - @info "Saving metadata to $(metadata_path)" - open(metadata_path, "w") do io - JSON.print(io, - Dict( - "passage_offset" => offset, - "num_passages" => length(doclens), - "num_embeddings" => length(codes) - ), - 4 # indent - ) - end -end - -""" - check_chunk_exists(saver::IndexSaver, chunk_idx::Int) - -Check if the index chunk exists for the given `chunk_idx`. - -# Arguments - - - `saver`: The `IndexSaver` object that contains the indexing settings. - - `chunk_idx`: The index of the chunk to check. - -# Returns - -A boolean indicating whether all relevant files for the chunk exist. -""" -function check_chunk_exists(saver::IndexSaver, chunk_idx::Int) - index_path = saver.config.index_path - path_prefix = joinpath(index_path, string(chunk_idx)) - codes_path = "$(path_prefix).codes.jld2" - residuals_path = "$(path_prefix).residuals.jld2" - doclens_path = joinpath(index_path, "doclens.$(chunk_idx).jld2") - metadata_path = joinpath(index_path, "$(chunk_idx).metadata.json") - - for file in [codes_path, residuals_path, doclens_path, metadata_path] - if !isfile(file) - return false - end - end - - true -end From 5a5be9c5b373d2bb4abc26de154814980d20896e Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 14:55:49 +0530 Subject: [PATCH 29/59] Updating the examples file. --- examples/indexing.jl | 52 ++++++++++---------------------------------- 1 file changed, 12 insertions(+), 40 deletions(-) diff --git a/examples/indexing.jl b/examples/indexing.jl index 1765d2d..ede2bb3 100644 --- a/examples/indexing.jl +++ b/examples/indexing.jl @@ -6,47 +6,19 @@ using Random # set the global seed Random.seed!(0) -# create the config -dataroot = "downloads/lotte" -dataset = "lifestyle" -datasplit = "dev" -path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv") - -collection = Collection(path) -length(collection.data) - -nbits = 2 # encode each dimension with 2 bits -doc_maxlen = 300 # truncate passages at 300 tokens - -checkpoint = "colbert-ir/colbertv2.0" # the HF checkpoint -index_root = "experiments/notebook/indexes" -index_name = "short_$(dataset).$(datasplit).$(nbits)bits" -index_path = joinpath(index_root, index_name) - config = ColBERTConfig( - RunSettings( - experiment = "notebook", - use_gpu = true - ), - TokenizerSettings(), - ResourceSettings( - checkpoint = checkpoint, - collection = collection, - index_name = index_name - ), - DocSettings( - doc_maxlen = doc_maxlen, - ), - QuerySettings(), - IndexingSettings( - index_path = index_path, - index_bsize = 3, - nbits = nbits, - kmeans_niters = 20 - ), - SearchSettings() + use_gpu = true, + collection = "./cityofaustin", + doc_maxlen = 300, + index_path = "./cityofaustin_index/", + chunksize = 500, ) indexer = Indexer(config) -index(indexer) -ColBERT.save(config) +checkpoint = indexer.checkpoint +collection = indexer.collection + +@time ColBERT.setup(config, checkpoint, collection) +@time ColBERT.train(config) +@time ColBERT.index(config, checkpoint, collection) +@time ColBERT.finalize(config.index_path) From 716eb7672c5989485b9ec8810c51b0c3ac9f094f Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 17:13:36 +0530 Subject: [PATCH 30/59] Some minor optimizations in the code for `docFromText`. --- src/modelling/checkpoint.jl | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 928c8ca..cd49659 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -420,16 +420,26 @@ function docFromText(config::ColBERTConfig, checkpoint::Checkpoint, # doc(checkpoint, integer_ids, integer_mask) error("Currently bsize cannot be missing!") else - text_batches, reverse_indices = tensorize_docs( - config, checkpoint.model.tokenizer, docs, bsize) - batches = [doc(config, checkpoint, integer_ids, integer_mask) - for (integer_ids, integer_mask) in text_batches] + integer_ids, integer_mask = tensorize_docs(config, checkpoint.model.tokenizer, docs) + + # we sort passages by length to do batch packing for more efficient use of the GPU + integer_ids, integer_mask, reverse_indices = _sort_by_length( + integer_ids, integer_mask, bsize) + + @assert length(reverse_indices)==length(docs) "length(reverse_indices): $(length(reverse_indices)), length(batch_text): $(length(docs))" + @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" + @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" + @assert reverse_indices isa Vector{Int64} "$(typeof(reverse_indices))" # aggregate all embeddings D, mask = Vector{AbstractArray{Float32}}(), Vector{AbstractArray{Bool}}() - for (_D, _mask) in batches - push!(D, _D) - push!(mask, _mask) + passage_offset = 1 + while(passage_offset <= length(docs)) + passage_end_offset = min(length(docs), passage_offset + bsize - 1) + D_, mask_ = doc(config, checkpoint, integer_ids[:, passage_offset:passage_end_offset], integer_mask[:, passage_offset:passage_end_offset]) + push!(D, D_) + push!(mask, mask_) + passage_offset += bsize end # concat embeddings and masks, and put them in the original order From ddf22cbd7d54aec3bc6ce0650001ae9d540a1c05 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 17:14:06 +0530 Subject: [PATCH 31/59] Making `tensorize_docs` return only the ids and mask. --- .../tokenization/doc_tokenization.jl | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/src/modelling/tokenization/doc_tokenization.jl b/src/modelling/tokenization/doc_tokenization.jl index 544c1a9..18de140 100644 --- a/src/modelling/tokenization/doc_tokenization.jl +++ b/src/modelling/tokenization/doc_tokenization.jl @@ -88,7 +88,7 @@ julia> reverse_indices # the original order """ function tensorize_docs(config::ColBERTConfig, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, - batch_text::Vector{String}, bsize::Union{Missing, Int}) + batch_text::Vector{String}) # placeholder for [D] marker token batch_text = [". " * doc for doc in batch_text] @@ -97,29 +97,16 @@ function tensorize_docs(config::ColBERTConfig, ids, mask = encoded_text.token, encoded_text.attention_mask integer_ids = reinterpret(Int32, ids) integer_mask = NeuralAttentionlib.getmask(mask, ids)[1, :, :] - @assert isequal(size(integer_ids), size(integer_mask)) "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(integer_mask)" - @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" - @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" # adding the [D] marker token ID D_marker_token_id = TextEncodeBase.lookup( tokenizer.vocab, config.doc_token_id) integer_ids[2, :] .= D_marker_token_id + + @assert isequal(size(integer_ids), size(integer_mask)) "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(integer_mask)" + @assert isequal(size(integer_ids)[2], length(batch_text)) + @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" + @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" - if ismissing(bsize) - error("Currently bsize can't be missing!") - else - # we sort passages by length to do batch packing for more efficient use of the GPU - integer_ids, integer_mask, reverse_indices = _sort_by_length( - integer_ids, integer_mask, bsize) - @assert length(reverse_indices)==length(batch_text) "length(reverse_indices): $(length(reverse_indices)), length(batch_text): $(length(batch_text))" - @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" - @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" - @assert reverse_indices isa Vector{Int64} "$(typeof(reverse_indices))" - - batches = _split_into_batches(integer_ids, integer_mask, bsize) - @assert batches isa Vector{Tuple{AbstractMatrix{Int32}, AbstractMatrix{Bool}}} "$(typeof(batches))" - - batches, reverse_indices - end + integer_ids, integer_mask end From 262b7c7bf224f01b4f2a009e4d49e51a69d2a123 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 17:14:44 +0530 Subject: [PATCH 32/59] Processing `config.passages_batch_size` passages in `encode_passages`, and adding this to config with 300 as default. --- src/indexing/collection_indexer.jl | 8 +++++--- src/infra/config.jl | 2 ++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 49d039d..fcec3ac 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -34,11 +34,13 @@ function encode_passages( embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}() # batching here to avoid storing intermediate embeddings on GPU # batching also occurs inside docFromText to do batch packing optimizations - for passages_batch in batch(passages, config.index_bsize * 50) - embs_, doclens_ = docFromText(config, checkpoint, passages_batch, - config.index_bsize) + passage_offset = 1 + while passage_offset <= length(passages) + passage_end_offset = min(length(passages), passage_offset + config.passages_batch_size - 1) + embs_, doclens_ = docFromText(config, checkpoint, passages[passage_offset:passage_end_offset], config.index_bsize) push!(embs, embs_) append!(doclens, vec(doclens_)) + passage_offset += config.passages_batch_size end embs = cat(embs..., dims = 2) embs, doclens diff --git a/src/infra/config.jl b/src/infra/config.jl index 0ec2365..034289f 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -27,6 +27,7 @@ Structure containing config for running and training various components. - `index_bsize`: Batch size used for some parts of indexing. - `chunksize`: Custom size of a chunk, i.e the number of passages for which data is to be stored in one chunk. Default is `missing`, in which case `chunksize` is determined from the size of the `collection` and `nranks`. + - `passages_batch_size`: The number of passages sent as a batch to encoding functions. Default is `300`. - `nbits`: Number of bits used to compress residuals. - `kmeans_niters`: Number of iterations used for k-means clustering. - `nprobe`: The number of nearest centroids to fetch during a search. Default is `2`. Also see [`retrieve`](@ref). @@ -80,6 +81,7 @@ Base.@kwdef struct ColBERTConfig index_path::String = "" index_bsize::Int = 64 chunksize::Union{Missing, Int} = missing + passages_batch_size::Int = 300 nbits::Int = 2 kmeans_niters::Int = 20 From f0b941af94888216860920e770a923c77f0ba329 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 17:15:27 +0530 Subject: [PATCH 33/59] Removing unnecessary util functions. --- src/utils/utils.jl | 62 ---------------------------------------------- 1 file changed, 62 deletions(-) diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 190bee5..b94779a 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,38 +1,3 @@ -""" - batch(group::Vector, bsize::Int; [provide_offset::Bool = false]) - -Create batches of data from `group`. - -Each batch is a subvector of `group` with length equal to `bsize`. If `provide_offset` is true, each batch will be a tuple containing both the offset and the subvector, otherwise only the subvector will be returned. - -# Arguments - - - `group::Vector`: The input vector from which to create batches. - - `bsize::Int`: The size of each batch. - - `provide_offset::Bool = false`: Whether to include the offset in the output batches. Defaults to `false`. - -# Returns - -A vector of tuples, where each tuple contains an offset and a subvector, or just a vector containing subvectors, depending on the value of `provide_offset`. -""" -function batch(group::Vector, bsize::Int; provide_offset::Bool = false) - vtype = provide_offset ? - Vector{Tuple{Int, typeof(group)}} : - Vector{typeof(group)} - batches = vtype() - offset = 1 - while offset <= length(group) - if provide_offset - push!(batches, (offset, group[offset:min(length(group), offset + bsize - 1)])) - else - push!(batches, group[offset:min(length(group), offset + bsize - 1)]) - end - offset += bsize - end - batches -end - - """ _sort_by_length( integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) @@ -66,30 +31,3 @@ function _sort_by_length( integer_ids[:, indices], integer_mask[:, indices], reverse_indices end - -""" - _split_into_batches( - integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) - -Split the given `integer_ids` and `integer_mask` into batches of size `bsize`. - -# Arguments - - - `integer_ids`: The array of token IDs to batch. - - `integer_mask`: The array of attention masks to batch. - -# Returns - -Batches of token IDs and attention masks, with each batch having size `bsize` (with the possibility of the last batch being smaller). -""" -function _split_into_batches( - integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) - batch_size = size(integer_ids)[2] - batches = Vector{Tuple{AbstractMatrix{Int32}, AbstractMatrix{Bool}}}() - for offset in 1:bsize:batch_size - push!(batches, - (integer_ids[:, offset:min(batch_size, offset + bsize - 1)], - integer_mask[:, offset:min(batch_size, offset + bsize - 1)])) - end - batches -end From 528c1ec79fe2d755320c78d0aa08356f571d1d62 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 17:23:58 +0530 Subject: [PATCH 34/59] Applying format; changing the default `index_bsize` to `32`, and updating docstring of `tensorize_docs`. --- src/indexing/collection_indexer.jl | 7 +- src/infra/config.jl | 2 +- src/modelling/checkpoint.jl | 16 +-- .../tokenization/doc_tokenization.jl | 101 ++++++++++-------- 4 files changed, 69 insertions(+), 57 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index fcec3ac..b4eaeb2 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -36,8 +36,11 @@ function encode_passages( # batching also occurs inside docFromText to do batch packing optimizations passage_offset = 1 while passage_offset <= length(passages) - passage_end_offset = min(length(passages), passage_offset + config.passages_batch_size - 1) - embs_, doclens_ = docFromText(config, checkpoint, passages[passage_offset:passage_end_offset], config.index_bsize) + passage_end_offset = min( + length(passages), passage_offset + config.passages_batch_size - 1) + embs_, doclens_ = docFromText( + config, checkpoint, passages[passage_offset:passage_end_offset], + config.index_bsize) push!(embs, embs_) append!(doclens, vec(doclens_)) passage_offset += config.passages_batch_size diff --git a/src/infra/config.jl b/src/infra/config.jl index 034289f..2d728c8 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -79,7 +79,7 @@ Base.@kwdef struct ColBERTConfig # indexing settings index_path::String = "" - index_bsize::Int = 64 + index_bsize::Int = 32 chunksize::Union{Missing, Int} = missing passages_batch_size::Int = 300 nbits::Int = 2 diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index cd49659..81fe6aa 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -266,12 +266,12 @@ Compute the hidden state of the BERT and linear layers of ColBERT for documents. A tuple `D, mask`, where: - `D` is an array containing the normalized embeddings for each token in each document. - It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer - of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of - any document and `N` is the total number of documents. + It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer + of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of + any document and `N` is the total number of documents. - `mask` is an array containing attention masks for all documents, after masking out any - tokens in the `skiplist` of `checkpoint`. It has shape `(1, L, N)`, where `(L, N)` - is the same as described above. + tokens in the `skiplist` of `checkpoint`. It has shape `(1, L, N)`, where `(L, N)` + is the same as described above. # Examples @@ -434,9 +434,11 @@ function docFromText(config::ColBERTConfig, checkpoint::Checkpoint, # aggregate all embeddings D, mask = Vector{AbstractArray{Float32}}(), Vector{AbstractArray{Bool}}() passage_offset = 1 - while(passage_offset <= length(docs)) + while (passage_offset <= length(docs)) passage_end_offset = min(length(docs), passage_offset + bsize - 1) - D_, mask_ = doc(config, checkpoint, integer_ids[:, passage_offset:passage_end_offset], integer_mask[:, passage_offset:passage_end_offset]) + D_, mask_ = doc( + config, checkpoint, integer_ids[:, passage_offset:passage_end_offset], + integer_mask[:, passage_offset:passage_end_offset]) push!(D, D_) push!(mask, mask_) passage_offset += bsize diff --git a/src/modelling/tokenization/doc_tokenization.jl b/src/modelling/tokenization/doc_tokenization.jl index 18de140..5edd4e1 100644 --- a/src/modelling/tokenization/doc_tokenization.jl +++ b/src/modelling/tokenization/doc_tokenization.jl @@ -1,36 +1,27 @@ """ tensorize_docs(config::ColBERTConfig, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, - batch_text::Vector{String}, bsize::Union{Missing, Int}) + batch_text::Vector{String}) Convert a collection of documents to tensors in the ColBERT format. This function adds the document marker token at the beginning of each document and then converts the text data into integer IDs and masks using the `tokenizer`. -Some optimizing operations are performed on the documents. First, the arrays of -token IDs and attention masks are sorted by document lengths (this is for more -efficient use of GPUs on the batches; see [`_sort_by_length`](@ref)), and a list -`reverse_indices` is computed, which remembers the original order of the documents -(to reorder them later). The arrays of token IDs and attention masks are then -batched into batches of size `bsize` (see [`_split_into_batches`](@ref)). -Finally, the batches along with the list of `reverse_indices` are returned. # Arguments - `config`: The `ColBERTConfig` to be used to fetch the document marker token ID. - `tokenizer`: The tokenizer which is used to convert text data into integer IDs. - `batch_text`: A document texts that will be converted into tensors of token IDs. -- `bsize`: The size of the batches to split the `batch_text` into. # Returns A tuple containing the following is returned: -- `batches`: A `Vector` of tuples of arrays of token IDs and masks, sorted in the order - of document lengths. Each array in each tuple has shape `(L, N)`, where `L` is the length +- `integer_ids`: A `Matrix` of token IDs of shape `(L, N)`, where `L` is the length of the largest document in `batch_text`, and `N` is the number of documents in the batch being considered. -- `reverse_indices`: A `Vector` containing the indices of the documents in their original order. +- `integer_mask`: A `Matrix` of attention masks, of the same shape as `integer_ids`. # Examples @@ -49,40 +40,56 @@ julia> batch_text = [ "this is an even longer document. this is some longer text, so length should be longer", ]; -julia> batches, reverse_indices = ColBERT.tensorize_docs(config, tokenizer, batch_text, 3) -(Tuple{AbstractMatrix{Int32}, AbstractMatrix{Bool}}[([102 102 102; 3 3 3; … ; 1 1 1; 1 1 1], [1 1 1; 1 1 1; … ; 0 0 0; 0 0 0]), ([102 102; 3 3; … ; 1 2937; 1 103], [1 1; 1 1; … ; 0 1; 0 1])], [2, 3, 1, 4, 5]) - -julia> batches[1][1] # sorted by length -21×3 Matrix{Int32}: - 102 102 102 - 3 3 3 - 1038 7593 4068 - 103 2089 2018 - 1 103 1000 - 1 1 103 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - -julia> reverse_indices # the original order -5-element Vector{Int64}: - 2 - 3 - 1 - 4 - 5 +julia> integer_ids, integer_mask = ColBERT.tensorize_docs(config, tokenizer, batch_text) +(Int32[102 102 … 102 102; 3 3 … 3 3; … ; 1 1 … 1 2937; 1 1 … 1 103], Bool[1 1 … 1 1; 1 1 … 1 1; … ; 0 0 … 0 1; 0 0 … 0 1]) + +julia> integer_ids +21×5 reinterpret(Int32, ::Matrix{PrimitiveOneHot.OneHot{0x0000773a}}): + 102 102 102 102 102 + 3 3 3 3 3 + 7593 4068 1038 2024 2024 + 2089 2018 103 2004 2004 + 103 1000 1 2071 2020 + 1 103 1 2937 2131 + 1 1 1 3794 2937 + 1 1 1 1011 6255 + 1 1 1 2062 1013 + 1 1 1 3092 2024 + 1 1 1 2324 2004 + 1 1 1 2023 2071 + 1 1 1 2937 2937 + 1 1 1 103 3794 + 1 1 1 1 1011 + 1 1 1 1 2062 + 1 1 1 1 3092 + 1 1 1 1 2324 + 1 1 1 1 2023 + 1 1 1 1 2937 + 1 1 1 1 103 + +julia> integer_mask +21×5 Matrix{Bool}: + 1 1 1 1 1 + 1 1 1 1 1 + 1 1 1 1 1 + 1 1 1 1 1 + 1 1 0 1 1 + 0 1 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 1 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 + 0 0 0 0 1 ``` """ @@ -102,7 +109,7 @@ function tensorize_docs(config::ColBERTConfig, D_marker_token_id = TextEncodeBase.lookup( tokenizer.vocab, config.doc_token_id) integer_ids[2, :] .= D_marker_token_id - + @assert isequal(size(integer_ids), size(integer_mask)) "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(integer_mask)" @assert isequal(size(integer_ids)[2], length(batch_text)) @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" From ea617757338a5ceb2d65f299fcc85368ee88a53b Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 12 Aug 2024 17:30:12 +0530 Subject: [PATCH 35/59] Adding some asserts. --- src/indexing/collection_indexer.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index b4eaeb2..4cab824 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -41,6 +41,8 @@ function encode_passages( embs_, doclens_ = docFromText( config, checkpoint, passages[passage_offset:passage_end_offset], config.index_bsize) + @assert embs_ isa Matrix{Float32} + @assert doclens_ isa Vector{Int} push!(embs, embs_) append!(doclens, vec(doclens_)) passage_offset += config.passages_batch_size From 951522659ad49f0043933428f1a45272665f70f5 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Wed, 14 Aug 2024 14:31:00 +0530 Subject: [PATCH 36/59] Changing the internals of `Indexer`, and adding a new constructor. --- src/indexing.jl | 62 +++++++++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index 29855c6..4e15434 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -1,32 +1,44 @@ struct Indexer config::ColBERTConfig + checkpoint::Checkpoint + collection::Vector{String} end -function index(indexer::Indexer) - index_path = indexer.config.index_path - if isdir(index_path) - @info "Index at $(index_path) already exists! Skipping indexing." - return - end +function Indexer(config::ColBERTConfig) + base_colbert = BaseColBERT(config) + checkpoint = Checkpoint(base_colbert, config) + collection = readlines(config.collection) - config = indexer.config - checkpoint = config.checkpoint + @info "Loaded ColBERT layers from the $(checkpoint) HuggingFace checkpoint." + @info "Loaded $(length(collection)) documents from $(config.collection)." - # loading the models - @info "Loading ColBERT layers from HuggingFace." - base_colbert = BaseColBERT(checkpoint, config) - checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), - QueryTokenizer(base_colbert.tokenizer, config), config) - - # creating the encoder, saver and indexer - encoder = CollectionEncoder(config, checkPoint) - saver = IndexSaver(config = config) - collection_indexer = CollectionIndexer(config, encoder, saver) - - # building the index - @info "Building the index." - setup(collection_indexer) - train(collection_indexer) - index(collection_indexer) - finalize(collection_indexer) + Indexer(config, checkpoint, collection) end + +# function index(indexer::Indexer) +# if isdir(indexer.config.index_path) +# @info "Index at $(indexer.config.index_path) already exists! Skipping indexing." +# return +# end +# +# config = indexer.config +# checkpoint = config.checkpoint +# +# # loading the models +# @info "Loading ColBERT layers from HuggingFace." +# base_colbert = BaseColBERT(checkpoint, config) +# checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), +# QueryTokenizer(base_colbert.tokenizer, config), config) +# +# # creating the encoder, saver and indexer +# encoder = CollectionEncoder(config, checkPoint) +# saver = IndexSaver(config = config) +# collection_indexer = CollectionIndexer(config, encoder, saver) +# +# # building the index +# @info "Building the index." +# setup(collection_indexer) +# train(collection_indexer) +# index(collection_indexer) +# finalize(collection_indexer) +# end From 78dc814de594e16c65d234e7e7fc05f6039f5ed1 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Wed, 14 Aug 2024 14:32:13 +0530 Subject: [PATCH 37/59] Keeping the index function; will change it later. --- src/indexing.jl | 54 ++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index 4e15434..bf02160 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -15,30 +15,30 @@ function Indexer(config::ColBERTConfig) Indexer(config, checkpoint, collection) end -# function index(indexer::Indexer) -# if isdir(indexer.config.index_path) -# @info "Index at $(indexer.config.index_path) already exists! Skipping indexing." -# return -# end -# -# config = indexer.config -# checkpoint = config.checkpoint -# -# # loading the models -# @info "Loading ColBERT layers from HuggingFace." -# base_colbert = BaseColBERT(checkpoint, config) -# checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), -# QueryTokenizer(base_colbert.tokenizer, config), config) -# -# # creating the encoder, saver and indexer -# encoder = CollectionEncoder(config, checkPoint) -# saver = IndexSaver(config = config) -# collection_indexer = CollectionIndexer(config, encoder, saver) -# -# # building the index -# @info "Building the index." -# setup(collection_indexer) -# train(collection_indexer) -# index(collection_indexer) -# finalize(collection_indexer) -# end +function index(indexer::Indexer) + if isdir(indexer.config.index_path) + @info "Index at $(indexer.config.index_path) already exists! Skipping indexing." + return + end + + config = indexer.config + checkpoint = config.checkpoint + + # loading the models + @info "Loading ColBERT layers from HuggingFace." + base_colbert = BaseColBERT(checkpoint, config) + checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), + QueryTokenizer(base_colbert.tokenizer, config), config) + + # creating the encoder, saver and indexer + encoder = CollectionEncoder(config, checkPoint) + saver = IndexSaver(config = config) + collection_indexer = CollectionIndexer(config, encoder, saver) + + # building the index + @info "Building the index." + setup(collection_indexer) + train(collection_indexer) + index(collection_indexer) + finalize(collection_indexer) +end From 23cc67b5cac09c5c7ce06cef904d31daf41f6b56 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Thu, 15 Aug 2024 22:53:19 +0530 Subject: [PATCH 38/59] Adding some type checks. --- src/indexing/codecs/residual.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 654d3e1..1ead085 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -48,6 +48,11 @@ function load_codec(index_path::String) bucket_cutoffs = JLD2.load_object(bucket_cutoffs_path) bucket_weights = JLD2.load_object(bucket_weights_path) + @assert centroids isa Matrix{Float32} + @assert avg_residual isa Float32 + @assert bucket_cutoffs isa Vector{Float32} + @assert bucket_weights isa Vector{Float32} + Dict( "centroids" => centroids, "avg_residual" => avg_residual, From 6c6383468313b07ce39505bc6189cf5e66482b6d Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Thu, 15 Aug 2024 22:57:04 +0530 Subject: [PATCH 39/59] Making the `setup`, `train` and `_sample_embeddings` functions test friendly; moving file saving and loading to a higher level `index` function. --- src/indexing/collection_indexer.jl | 64 ++++++++++++------------------ 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 4cab824..18ff08e 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -109,16 +109,15 @@ function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, @assert size(local_sample_embs)[2]==sum(local_sample_doclens) "size(local_sample_embs): $(size(local_sample_embs)), sum(local_sample_doclens): $(sum(local_sample_doclens))" @assert length(local_sample) == length(local_sample_doclens) - num_sample_embs = size(local_sample_embs)[2] avg_doclen_est = length(local_sample_doclens) > 0 ? sum(local_sample_doclens) / length(local_sample_doclens) : 0 - sample_path = joinpath(config.index_path, "sample.jld2") @info "avg_doclen_est = $(avg_doclen_est) \t length(local_sample) = $(length(local_sample))" - @info "Saving sampled embeddings to $(sample_path)." - JLD2.save_object(sample_path, local_sample_embs) - avg_doclen_est + Dict( + "avg_doclen_est" => avg_doclen_est, + "local_sample_embs" => local_sample_embs + ) end """ @@ -141,8 +140,6 @@ and the indexing plan is saved to `plan.json`, with the path being specified by - `collection`: The underlying collection of passages to initialize the index for. """ function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) - isdir(config.index_path) || mkdir(config.index_path) - chunksize = 0 chunksize = ismissing(config.chunksize) ? min(25000, 1 + fld(length(collection), config.nranks)) : config.chunksize @@ -150,32 +147,24 @@ function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector # sample passages for training centroids later sampled_pids = _sample_pids(length(collection)) - avg_doclen_est = _sample_embeddings(config, checkpoint, collection, sampled_pids) + local_sample_dict = _sample_embeddings(config, checkpoint, collection, sampled_pids) # computing the number of partitions, i.e clusters num_passages = length(collection) - num_embeddings_est = num_passages * avg_doclen_est + num_embeddings_est = num_passages * local_sample_dict["avg_doclen_est"] num_partitions = Int(floor(2^(floor(log2(16 * sqrt(num_embeddings_est)))))) @info "Creating $(num_partitions) clusters." @info "Estimated $(num_embeddings_est) embeddings." - @info "Saving the index plan to $(joinpath(config.index_path, "plan.json"))." - open(joinpath(config.index_path, "plan.json"), "w") do io - JSON.print(io, - Dict( - "chunksize" => chunksize, - "num_chunks" => num_chunks, - "num_partitions" => num_partitions, - "num_embeddings_est" => num_embeddings_est, - "avg_doclen_est" => avg_doclen_est - ), - 4 # indent - ) - end - - @info "Saving the config to the indexing path." - ColBERT.save(config) + Dict( + "chunksize" => chunksize, + "num_chunks" => num_chunks, + "num_partitions" => num_partitions, + "num_embeddings_est" => num_embeddings_est, + "avg_doclen_est" => local_sample_dict["avg_doclen_est"], + "local_sample_embs" => local_sample_dict["local_sample_embs"] + ) end """ @@ -276,24 +265,23 @@ function, and the codec is saved on disk using [`save_codec`](@ref). - `config`: The [`ColBERTConfig`](@ref) used to train the indexer. """ -function train(config::ColBERTConfig) - sample, heldout = _concatenate_and_split_sample(config.index_path) - @assert sample isa AbstractMatrix{Float32} "$(typeof(sample))" - @assert heldout isa AbstractMatrix{Float32} "$(typeof(heldout))" - - # loading the indexing plan - plan_metadata = JSON.parsefile(joinpath(config.index_path, "plan.json")) - - centroids = kmeans(sample, plan_metadata["num_partitions"], - maxiter = config.kmeans_niters, display = :iter).centers - @assert size(centroids)[2]==plan_metadata["num_partitions"] "size(centroids): $(size(centroids)), num_partitions: $(plan_metadata["num_partitions"])" +function train(sample::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32}, num_partitions::Int, nbits::Int, kmeans_niters::Int) + centroids = kmeans(sample, num_partitions, + maxiter = kmeans_niters, display = :iter).centers + @assert size(centroids)[2]==num_partitions + "size(centroids): $(size(centroids)), num_partitions: $(num_partitions)" @assert centroids isa AbstractMatrix{Float32} "$(typeof(centroids))" bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals( - config.nbits, centroids, heldout) + nbits, centroids, heldout) @info "avg_residual = $(avg_residual)" - save_codec(config.index_path, centroids, bucket_cutoffs, bucket_weights, avg_residual) + Dict( + "centroids" => centroids, + "bucket_cutoffs" => bucket_cutoffs, + "bucket_weights" => bucket_weights, + "avg_residual" => avg_residual, + ) end """ From 28244b724fd7b36146de9875b260a1d9b76a47ec Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Thu, 15 Aug 2024 23:06:46 +0530 Subject: [PATCH 40/59] Updating docstrings. --- src/indexing/collection_indexer.jl | 72 +++++++++++++++++++----------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 18ff08e..2f40d02 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -19,7 +19,8 @@ See [`docFromText`](@ref) for more details. A tuple `embs, doclens` where: - `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`, - where `D` is the embedding dimension and `N` is the total number of embeddings across all the passages. + where `D` is the embedding dimension and `N` is the total number of embeddings + across all the passages. - `doclens::AbstractVector{Int}`: A vector of document lengths for each passage, i.e the total number of attended tokens for each document passage. """ @@ -54,12 +55,14 @@ end """ _sample_pids(num_documents::Int) -Sample PIDs from the collection to be used to compute clusters using a ``k``-means clustering algorithm. +Sample PIDs from the collection to be used to compute clusters using a ``k``-means clustering +algorithm. # Arguments - - `num_documents`: The total number of documents in the collection. It is assumed that each document has an ID - (aka PID) in the range of integers between `1` and `num_documents` (both inclusive). + - `num_documents`: The total number of documents in the collection. It is assumed that each + document has an ID (aka PID) in the range of integers between `1` and `num_documents` + (both inclusive). # Returns @@ -80,8 +83,7 @@ end Compute embeddings for the PIDs sampled by [`_sample_pids`](@ref). -The embeddings for the sampled documents are saved in a file named `sample.jld2` with it's path -specified by the indexing directory. This embedding array has shape `(D, N)`, where `D` is the +The embedding array has shape `(D, N)`, where `D` is the embedding dimension (`128`, after applying the linear layer of the ColBERT model) and `N` is the total number of embeddings over all documents. @@ -94,7 +96,10 @@ total number of embeddings over all documents. # Returns -The average document length (i.e number of attended tokens) computed from the sampled documents. +A `Dict` containing the average document length (i.e number of attended tokens) computed +from the sampled documents, and the embedding matrix for the local samples. The matrix has +shape `(D, N)`, where `D` is the embedding dimension (`128`) and `N` is the total number +of embeddings over all the sampled passages. """ function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}, sampled_pids::Set{Int}) @@ -116,28 +121,33 @@ function _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, Dict( "avg_doclen_est" => avg_doclen_est, - "local_sample_embs" => local_sample_embs + "local_sample_embs" => local_sample_embs ) end """ setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) -Initialize the index by computing some indexing-specific estimates and save the indexing plan to disk. +Initialize the index by computing some indexing-specific estimates and save the indexing plan +to disk. -The number of chunks into which the document embeddings will be stored is simply computed using the -number of documents and the size of a chunk. A bunch of pids used for initializing the centroids for -the embedding clusters are sampled using the [`_sample_pids`](@ref) and [`_sample_embeddings`](@ref) -functions, and these samples are used to calculate the average document lengths and the estimated number -of embeddings which will be computed across all documents. Finally, the number of clusters to be used -for indexing is computed, and is proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``, -and the indexing plan is saved to `plan.json`, with the path being specified by the indexing directory. +The number of chunks into which the document embeddings will be stored is simply computed using +the number of documents and the size of a chunk. A bunch of pids used for initializing the +centroids for the embedding clusters are sampled using the [`_sample_pids`](@ref) +and [`_sample_embeddings`](@ref) functions, and these samples are used to calculate the +average document lengths and the estimated number of embeddings which will be computed across +all documents. Finally, the number of clusters to be used for indexing is computed, and is +proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``. # Arguments - `config`: The [`ColBERTConfig`](@ref) being used to set up the indexing. - `checkpoint`: The [`Checkpoint`](@ref) used to compute embeddings. - `collection`: The underlying collection of passages to initialize the index for. + +# Returns + +A `Dict` containing the indexing plan. """ function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) chunksize = 0 @@ -164,7 +174,7 @@ function setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector "num_embeddings_est" => num_embeddings_est, "avg_doclen_est" => local_sample_dict["avg_doclen_est"], "local_sample_embs" => local_sample_dict["local_sample_embs"] - ) + ) end """ @@ -223,7 +233,8 @@ Compute the average residuals and other statistics of the held-out sample embedd # Returns -A tuple `bucket_cutoffs, bucket_weights, avg_residual`, which will be used in compression/decompression of residuals. +A tuple `bucket_cutoffs, bucket_weights, avg_residual`, which will be used in +compression/decompression of residuals. """ function _compute_avg_residuals( nbits::Int, centroids::AbstractMatrix{Float32}, @@ -253,23 +264,32 @@ function _compute_avg_residuals( end """ - train(config::ColBERTConfig) + train(sample::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32}, + num_partitions::Int, nbits::Int, kmeans_niters::Int) Compute centroids using a ``k``-means clustering algorithn, and store the compression information on disk. Average residuals and other compression data is computed via the [`_compute_avg_residuals`](@ref) -function, and the codec is saved on disk using [`save_codec`](@ref). - +function. # Arguments - - `config`: The [`ColBERTConfig`](@ref) used to train the indexer. + - `sample`: The matrix of sampled embeddings used to compute clusters. + - `heldout`: The matrix of sample embeddings used to compute the residual information. + - `num_partitions`: The number of clusters to compute. + - `nbits`: The number of bits used to encode the residuals. + - `kmeans_niters`: The maximum number of iterations in the ``k``-means algorithm. + +# Returns + +A `Dict` containing the residual codec, i.e information used to compress/decompress residuals. """ -function train(sample::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32}, num_partitions::Int, nbits::Int, kmeans_niters::Int) +function train(sample::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32}, + num_partitions::Int, nbits::Int, kmeans_niters::Int) centroids = kmeans(sample, num_partitions, maxiter = kmeans_niters, display = :iter).centers - @assert size(centroids)[2]==num_partitions - "size(centroids): $(size(centroids)), num_partitions: $(num_partitions)" + @assert size(centroids)[2] == num_partitions + "size(centroids): $(size(centroids)), num_partitions: $(num_partitions)" @assert centroids isa AbstractMatrix{Float32} "$(typeof(centroids))" bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals( @@ -280,7 +300,7 @@ function train(sample::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32} "centroids" => centroids, "bucket_cutoffs" => bucket_cutoffs, "bucket_weights" => bucket_weights, - "avg_residual" => avg_residual, + "avg_residual" => avg_residual ) end From 15c525f67a1edf761c07a34a9cd1267717bc748b Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Thu, 15 Aug 2024 23:21:02 +0530 Subject: [PATCH 41/59] Completing the `Indexer` functions. --- src/indexing.jl | 85 +++++++++++++++++++++++------- src/indexing/collection_indexer.jl | 19 ------- 2 files changed, 65 insertions(+), 39 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index bf02160..823353d 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -4,6 +4,20 @@ struct Indexer collection::Vector{String} end +""" + Indexer(config::ColBERTConfig) + +Type representing an ColBERT indexer. + +# Arguments + + - `config`: The [`ColBERTConfig`](@ref) used to build the index. + +# Returns + +An [`Indexer`] wrapping a [`ColBERTConfig`](@ref), a [`Checkpoint`](@ref) and +a collection of documents to index. +""" function Indexer(config::ColBERTConfig) base_colbert = BaseColBERT(config) checkpoint = Checkpoint(base_colbert, config) @@ -15,30 +29,61 @@ function Indexer(config::ColBERTConfig) Indexer(config, checkpoint, collection) end +""" + index(indexer::Indexer) + +Build an index given the configuration stored in `indexer`. + +# Arguments + + - `indexer`: An `Indexer` which is used to build the index on disk. +""" function index(indexer::Indexer) if isdir(indexer.config.index_path) @info "Index at $(indexer.config.index_path) already exists! Skipping indexing." return end - config = indexer.config - checkpoint = config.checkpoint - - # loading the models - @info "Loading ColBERT layers from HuggingFace." - base_colbert = BaseColBERT(checkpoint, config) - checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), - QueryTokenizer(base_colbert.tokenizer, config), config) - - # creating the encoder, saver and indexer - encoder = CollectionEncoder(config, checkPoint) - saver = IndexSaver(config = config) - collection_indexer = CollectionIndexer(config, encoder, saver) - - # building the index - @info "Building the index." - setup(collection_indexer) - train(collection_indexer) - index(collection_indexer) - finalize(collection_indexer) + # getting and saving the indexing plan + isdir(indexer.config.index_path) || mkdir(indexer.config.index_path) + plan_dict = setup(indexer.config, indexer.checkpoint, indexer.collection) + + sample_path = joinpath(indexer.config.index_path, "sample.jld2") + @info "Saving sampled embeddings to $(sample_path)." + JLD2.save_object(sample_path, plan_dict["local_sample_embs"]) + + @info "Saving the index plan to $(joinpath(indexer.config.index_path, "plan.json"))." + open(joinpath(indexer.config.index_path, "plan.json"), "w") do io + JSON.print(io, + Dict( + "chunksize" => plan_dict["chunksize"], + "num_chunks" => plan_dict["num_chunks"], + "num_partitions" => plan_dict["num_partitions"], + "num_embeddings_est" => plan_dict["num_embeddings_est"], + "avg_doclen_est" => plan_dict["avg_doclen_est"] + ), + 4 # indent + ) + end + + @info "Saving the config to the indexing path." + ColBERT.save(indexer.config) + + # training/clustering + sample, heldout = _concatenate_and_split_sample(indexer.config.index_path) + @assert sample isa AbstractMatrix{Float32} "$(typeof(sample))" + @assert heldout isa AbstractMatrix{Float32} "$(typeof(heldout))" + + codec = train(sample, heldout, plan_dict["num_partitions"], + indexer.config.nbits, indexer.config.kmeans_niters) + save_codec(indexer.config.index_path, codec["centroids"], codec["bucket_cutoffs"], + codec["bucket_weights"], codec["avg_residual"]) + + # indexing + index(indexer.config, indexer.checkpoint, indexer.collection) + + # finalizing + _check_all_files_are_saved(indexer.config.index_path) + _collect_embedding_id_offset(indexer.config.index_path) + _build_ivf(indexer.config.index_path) end diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 2f40d02..f112b79 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -396,25 +396,6 @@ function index(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector end end -""" - finalize(index_path::String) - -Finalize the indexing process by saving all files, collecting embedding ID offsets, -and building the IVF. - -See [`_check_all_files_are_saved`](@ref), [`_collect_embedding_id_offset`](@ref), -[`_build_ivf`](@ref) for more details. - -# Arguments - - - `index_path`: The path of the index. -""" -function finalize(index_path::String) - _check_all_files_are_saved(index_path) - _collect_embedding_id_offset(index_path) - _build_ivf(index_path) -end - """ check_chunk_exists(saver::IndexSaver, chunk_idx::Int) From 7c9894068a1a12c0514b96e0d7e2d7472f2dee5a Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Thu, 15 Aug 2024 23:23:19 +0530 Subject: [PATCH 42/59] Updating the example indexing script. --- examples/indexing.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/examples/indexing.jl b/examples/indexing.jl index ede2bb3..f1d098f 100644 --- a/examples/indexing.jl +++ b/examples/indexing.jl @@ -11,14 +11,8 @@ config = ColBERTConfig( collection = "./cityofaustin", doc_maxlen = 300, index_path = "./cityofaustin_index/", - chunksize = 500, + chunksize = 500 ) indexer = Indexer(config) -checkpoint = indexer.checkpoint -collection = indexer.collection - -@time ColBERT.setup(config, checkpoint, collection) -@time ColBERT.train(config) -@time ColBERT.index(config, checkpoint, collection) -@time ColBERT.finalize(config.index_path) +index(indexer) From de2db685c72bf2c06f9c15647f403d48ec929bc5 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 10:49:18 +0530 Subject: [PATCH 43/59] Simplifying modelling functions for queries; making them more test friendly. --- src/modelling/checkpoint.jl | 21 ++- .../tokenization/query_tokenization.jl | 156 +++++++++--------- 2 files changed, 98 insertions(+), 79 deletions(-) diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 81fe6aa..fd9ffa5 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -493,7 +493,7 @@ the total number of queries. Continuing from the queries example for [`tensorize_queries`](@ref) and [`Checkpoint`](@ref): ```julia-repl -julia> ColBERT.query(config, checkpoint, integer_ids, integer_mask) +julia> ColBERT.query(config, checkpoint, integer_ids, integer_mask)[:, :, 1] 128×32×1 CuArray{Float32, 3, CUDA.DeviceMemory}: [:, :, 1] = 0.0158568 0.169676 0.092745 0.0798617 0.153079 … 0.117006 0.115806 0.115938 0.112977 0.107919 @@ -542,6 +542,7 @@ julia> ColBERT.query(config, checkpoint, integer_ids, integer_mask) -0.116265 -0.106331 -0.179832 -0.149728 -0.0913282 … -0.0287848 -0.0275017 -0.0197172 -0.0220611 -0.018135 -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.180245 -0.0780865 -0.073571 -0.0699094 -0.0684748 -0.0662903 0.100019 -0.0618588 0.106134 0.0989047 -0.0885639 -0.0547317 -0.0553563 -0.055676 -0.0556784 -0.0595709 + ``` """ function query( @@ -687,10 +688,20 @@ function queryFromText(config::ColBERTConfig, endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc) # get ids and masks, embeddings and returning the concatenated tensors - batches = tensorize_queries(config, tokenizer, queries, bsize) - batches = [query(config, checkpoint, integer_ids, integer_mask) - for (integer_ids, integer_mask) in batches] - Q = cat(batches..., dims = 3) + integer_ids, integer_mask = tensorize_queries(config, tokenizer, queries) + + # aggregate all embeddings + Q = Vector{AbstractArray{Float32}}() + query_offset = 1 + while (query_offset <= length(queries)) + query_end_offset = min(length(queries), query_offset + bsize - 1) + Q_ = query( + config, checkpoint, integer_ids[:, query_offset:query_end_offset], + integer_mask[:, query_offset:query_end_offset]) + push!(Q, Q_) + query_offset += bsize + end + Q = cat(Q..., dims = 3) @assert ndims(Q)==3 "ndims(Q): $(ndims(Q))" @assert Q isa AbstractArray{Float32} "$(typeof(Q))" diff --git a/src/modelling/tokenization/query_tokenization.jl b/src/modelling/tokenization/query_tokenization.jl index eee1e1d..c6b9217 100644 --- a/src/modelling/tokenization/query_tokenization.jl +++ b/src/modelling/tokenization/query_tokenization.jl @@ -1,27 +1,25 @@ """ tensorize_queries(config::ColBERTConfig, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, - batch_text::Vector{String}, bsize::Union{Missing, Int}) + batch_text::Vector{String}) -Convert a collection of queries to tensors in the ColBERT format. +Convert a collection of queries to tensors of token IDs and attention masks. This function adds the query marker token at the beginning of each query text and then converts the text data into integer IDs and masks using the `tokenizer`. -The returned tensors are batched into sizes given by the `bsize` argument. # Arguments - `config`: The [`ColBERTConfig`](@ref) to be used to figure out the query marker token ID. - `tokenizer`: The tokenizer which is used to convert text data into integer IDs. - `batch_text`: A document texts that will be converted into tensors of token IDs. - - `bsize`: The size of the batches to split the `batch_text` into. # Returns -`batches`, A `Vector` of tuples of arrays of token IDs and masks corresponding to -the query texts. Each array in each tuple has shape `(L, N)`, where `L` is the -maximum query length specified by the config (see [`ColBERTConfig`](@ref)), and `N` -is the number of queries in the batch being considered. +A tuple `integer_ids`, `integer_mask` containing the token IDs and the attention mask. Each +of these two matrices has shape `(L, N)`, where `L` is the maximum query length specified +by the `config` (see [`ColBERTConfig`](@ref)), and `N` is the number of queries in +`batch_text`. # Examples @@ -50,74 +48,87 @@ julia> tokenizer = Transformers.TextEncoders.BertTextEncoder( tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); -julia> queries = ["what are white spots on raspberries?"]; - -julia> batches = ColBERT.tensorize_queries(config, tokenizer, queries, 128); - -julia> integer_ids, integer_mask = batches[1][1], batches[1][2]; - -julia> integer_ids -32×1 Matrix{Int32}: - 102 - 2 - 2055 - 2025 - 2318 - 7517 - 2007 - 20711 - 2362 - 20969 - 1030 - 103 - ⋮ - 104 - 104 - 104 - 104 - 104 - 104 - 104 - 104 - 104 - 104 - 104 +julia> queries = [ + "what are white spots on raspberries?", + "what do rabbits eat?" +]; + +julia> integer_ids, integer_mask = ColBERT.tensorize_queries(config, tokenizer, queries); + +julia> 32×2 reinterpret(Int32, ::Matrix{OneHot{0x0000773a}}): + 102 102 + 2 2 + 2055 2055 + 2025 2080 + 2318 20404 + 7517 4522 + 2007 1030 + 20711 103 + 2362 104 + 20969 104 + 1030 104 + 103 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 + 104 104 julia> integer_mask -32×1 Matrix{Bool}: - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - ⋮ - 0 - 0 - 0 - 0 - 0 - 0 - 0 - 0 - 0 - 0 - 0 +32×2 Matrix{Bool}: + 1 1 + 1 1 + 1 1 + 1 1 + 1 1 + 1 1 + 1 1 + 1 1 + 1 0 + 1 0 + 1 0 + 1 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + 0 0 + ``` """ function tensorize_queries(config::ColBERTConfig, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder, - batch_text::Vector{String}, bsize::Union{Missing, Int}) - if ismissing(bsize) - error("Currently bsize cannot be missing!") - end - + batch_text::Vector{String}) # placeholder for [Q] marker token batch_text = [". " * query for query in batch_text] @@ -144,8 +155,5 @@ function tensorize_queries(config::ColBERTConfig, @assert isequal(sum(integer_mask), prod(size(integer_mask))) "sum(integer_mask): $(sum(integer_mask)), prod(size(integer_mask)): $(prod(size(integer_mask)))" end - batches = _split_into_batches(integer_ids, integer_mask, bsize) - @assert batches isa Vector{Tuple{AbstractMatrix{Int32}, AbstractMatrix{Bool}}} "$(typeof(batches))" - - batches + integer_ids, integer_mask end From 88c1fc5d90cab5b15925efae21adab6626173f9b Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 11:34:11 +0530 Subject: [PATCH 44/59] Saving all metadata in `plan.json`. --- src/indexing/collection_indexer.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index f112b79..3ed85aa 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -439,6 +439,7 @@ function _check_all_files_are_saved(index_path::String) end function _collect_embedding_id_offset(index_path::String) + @assert isfile(joinpath(index_path, "plan.json")) "Fatal: plan.json doesn't exist!" plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) @info "Collecting embedding ID offsets." @@ -468,14 +469,11 @@ function _collect_embedding_id_offset(index_path::String) @assert length(embeddings_offsets) == plan_metadata["num_chunks"] @info "Saving the indexing metadata." - metadata_path = joinpath(index_path, "metadata.json") - open(metadata_path, "w") do io + plan_metadata["num_embeddings"] = num_embeddings + plan_metadata["embeddings_offsets"] = embeddings_offsets + open(joinpath(index_path, "plan.json"), "w") do io JSON.print(io, - # TODO: export the config here as well! - Dict( - "num_embeddings" => num_embeddings, - "embeddings_offsets" => embeddings_offsets - ), + plan_metadata, 4 ) end From 9b6a01877e9d814f79471cd33d249a9cdf0e2059 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 11:43:40 +0530 Subject: [PATCH 45/59] Minor fix in loading codes. --- src/indexing/collection_indexer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 3ed85aa..e27c2c9 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -487,7 +487,7 @@ function _build_ivf(index_path::String) @info "Loading codes for each embedding." for chunk_idx in 1:(plan_metadata["num_chunks"]) - chunk_codes = load_codes(index_path, chunk_idx) + chunk_codes = JLD2.load_object(joinpath(index_path, "$(chunk_idx).codes.jld2")) append!(codes, chunk_codes) end @assert codes isa AbstractVector{UInt32} "$(typeof(codes))" From ac90533ac981ac7f9a1c5bcc9ee77639f1631fe6 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 11:48:04 +0530 Subject: [PATCH 46/59] Adding more fields to the `Searcher`, and updating it's constructor. --- src/searching.jl | 75 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 7 deletions(-) diff --git a/src/searching.jl b/src/searching.jl index 3c93990..5a06210 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -1,7 +1,15 @@ struct Searcher config::ColBERTConfig checkpoint::Checkpoint - ranker::IndexScorer + centroids::Matrix{Float32} + bucket_cutoffs::Vector{Float32} + bucket_weights::Vector{Float32} + ivf::Vector{Int} + ivf_lengths::Vector{Int} + doclens::Vector{Int} + codes::Vector{UInt32} + residuals::Matrix{UInt8} + emb2pid::Vector{Int} end function Searcher(index_path::String) @@ -10,15 +18,68 @@ function Searcher(index_path::String) end # loading the config from the path - config = load_config(index_path) + config = load_config(index_path) # loading the model and saving it to prevent multiple loads - @info "Loading ColBERT layers from HuggingFace." - base_colbert = BaseColBERT(config.checkpoint, config) - checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), - QueryTokenizer(base_colbert.tokenizer, config), config) + base_colbert = BaseColBERT(config) + checkpoint = Checkpoint(base_colbert, config) + @info "Loaded ColBERT layers from the $(config.checkpoint) HuggingFace checkpoint." + + plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + codec = load_codec(index_path) + ivf = JLD2.load_object(joinpath(index_path, "ivf.jld2")) + ivf_lengths = JLD2.load_object(joinpath(index_path, "ivf_lengths.jld2")) + + # loading all doclens + doclens = Vector{Int}() + for chunk_idx in 1:plan_metadata["num_chunks"] + doclens_file = joinpath(index_path, "doclens.$(chunk_idx).jld2") + chunk_doclens = JLD2.load_object(doclens_file) + append!(doclens, chunk_doclens) + end + + # loading all compressed embeddings + num_embeddings = plan_metadata["num_embeddings"] + dim, nbits = config.dim, config.nbits + @assert (dim * nbits) % 8==0 "(dim, nbits): $((dim, nbits))" + codes = zeros(UInt32, num_embeddings) + residuals = zeros(UInt8, Int((dim / 8) * nbits), num_embeddings) + codes_offset = 1 + for chunk_idx in 1:plan_metadata["num_chunks"] + chunk_codes = JLD2.load_object(joinpath(index_path, "$(chunk_idx).codes.jld2")) + chunk_residuals = JLD2.load_object(joinpath(index_path, "$(chunk_idx).residuals.jld2")) + + codes_endpos = codes_offset + length(chunk_codes) - 1 + codes[codes_offset:codes_endpos] = chunk_codes + residuals[:, codes_offset:codes_endpos] = chunk_residuals + + codes_offset = codes_offset + length(chunk_codes) + end - Searcher(config, checkPoint, IndexScorer(index_path)) + # the emb2pid mapping + @info "Building the emb2pid mapping." + @assert isequal(sum(doclens), plan_metadata["num_embeddings"]) "sum(doclens): $(sum(doclens)), num_embeddings: $(plan_metadata["num_embeddings"])" + emb2pid = zeros(Int, plan_metadata["num_embeddings"]) + + offset_doclens = 1 + for (pid, dlength) in enumerate(doclens) + emb2pid[offset_doclens:(offset_doclens + dlength - 1)] .= pid + offset_doclens += dlength + end + + Searcher( + config, + checkpoint, + codec["centroids"], + codec["bucket_cutoffs"], + codec["bucket_weights"], + ivf, + ivf_lengths, + doclens, + codes, + residuals, + emb2pid + ) end """ From cf69b8d5e03cee273d84d925306fb6f3c01f1872 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 11:48:34 +0530 Subject: [PATCH 47/59] Minor change to the `encode_query` function. --- src/searching.jl | 86 ++++++++++++++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 32 deletions(-) diff --git a/src/searching.jl b/src/searching.jl index 5a06210..92c8974 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -98,45 +98,67 @@ An array containing the embeddings for each token in the query. Also see [queryF # Examples -Here's an example using the config given in docs for [`ColBERTConfig`](@ref). +Here's an example using the `config` and `checkpoint` from the example for [`Checkpoint`](@ref). ```julia-repl -julia> searcher = Searcher(config); - -julia> encode_query(searcher, "what are white spots on raspberries?") +julia> encode_query(config, checkpoint, "what are white spots on raspberries?") 128×32×1 Array{Float32, 3}: [:, :, 1] = - 0.0158567 0.169676 0.092745 0.0798617 … 0.115938 0.112977 0.107919 - 0.220185 0.0304873 0.165348 0.150315 0.0168762 0.0178042 0.0200357 - -0.00790007 -0.0192251 -0.0852364 -0.0799609 -0.0777439 -0.0776733 -0.0830504 - -0.109909 -0.170906 -0.0138702 -0.0409767 -0.126037 -0.126829 -0.13149 - -0.0231786 0.0532214 0.0607473 0.0279048 0.117017 0.114073 0.108536 - 0.0620549 0.0465075 0.0821693 0.0606439 … 0.0150612 0.0133353 0.0126583 - -0.0290509 0.143255 0.0306142 0.042658 -0.164401 -0.161857 -0.160327 - 0.0921477 0.0588331 0.250449 0.234636 0.0664076 0.0659837 0.0711357 - 0.0279402 -0.0278357 0.144855 0.147958 0.154552 0.155525 0.163634 - -0.0768143 -0.00587305 0.00543038 0.00443374 -0.11757 -0.112495 -0.11112 - -0.0184338 0.00668557 -0.191863 -0.161345 … -0.107664 -0.107267 -0.114564 - 0.0112104 0.0214651 -0.0923963 -0.0823051 0.106261 0.105065 0.10409 - ⋮ ⋱ ⋮ - -0.0617142 -0.0573989 -0.0973785 -0.0805046 0.107432 0.108591 0.109501 - -0.0859686 0.0623054 0.0974813 0.126841 0.0182795 0.0230549 0.031103 - 0.0392043 0.0162653 0.0926306 0.104053 0.0491495 0.0484318 0.0438132 - -0.0340363 -0.0278066 -0.0181035 -0.0282369 … -0.0617945 -0.0631367 -0.0675882 - 0.013123 0.0565132 -0.0349061 -0.0464192 0.0724731 0.0780166 0.074623 - -0.117425 0.162483 0.11039 0.136364 -0.00538225 -0.00685449 -0.0019436 - -0.0401158 -0.0045094 0.0539569 0.0689953 -0.00518063 -0.00600252 -0.00771469 - 0.0893983 0.0695061 -0.0499409 -0.035411 0.0960932 0.0961893 0.103431 - -0.116265 -0.106331 -0.179832 -0.149728 … -0.0197172 -0.022061 -0.018135 - -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0699095 -0.0684749 -0.0662904 - 0.100019 -0.0618588 0.106134 0.0989047 -0.0556761 -0.0556784 -0.059571 + 0.0158568 0.169676 0.092745 0.0798617 … 0.115938 0.112977 0.107919 + 0.220185 0.0304873 0.165348 0.150315 0.0168762 0.0178042 0.0200356 + -0.00790017 -0.0192251 -0.0852365 -0.0799609 -0.0777439 -0.0776733 -0.0830504 + -0.109909 -0.170906 -0.0138701 -0.0409766 -0.126037 -0.126829 -0.13149 + -0.0231787 0.0532214 0.0607473 0.0279048 0.117017 0.114073 0.108536 + 0.0620549 0.0465075 0.0821693 0.0606439 … 0.0150612 0.0133353 0.0126583 + -0.0290508 0.143255 0.0306142 0.0426579 -0.164401 -0.161857 -0.160327 + 0.0921475 0.058833 0.250449 0.234636 0.0664076 0.0659837 0.0711358 + 0.0279402 -0.0278357 0.144855 0.147958 0.154552 0.155525 0.163634 + -0.0768143 -0.00587302 0.00543038 0.00443376 -0.11757 -0.112495 -0.11112 + -0.0184337 0.00668561 -0.191863 -0.161345 … -0.107664 -0.107267 -0.114564 + 0.0112104 0.0214651 -0.0923963 -0.0823052 0.106261 0.105065 0.10409 + 0.110971 0.272576 0.148319 0.143233 0.109914 0.112652 0.108365 + -0.131066 0.0376254 -0.0164237 -0.000193318 -0.0969305 -0.0935498 -0.096145 + -0.0402605 0.0350559 0.0162864 0.0269105 -0.070679 -0.0655848 -0.0564059 + 0.0799973 0.0482302 0.0712078 0.0792903 … 0.00889943 0.00932721 0.00751066 + -0.137565 -0.0369116 -0.065728 -0.0664102 0.0297059 0.0278639 0.0257616 + 0.0479746 -0.102338 -0.0557072 -0.0833976 -0.0566325 -0.0568765 -0.0581378 + 0.0656851 0.0195639 0.0288789 0.0559219 0.0596156 0.0541802 0.0525933 + 0.0668634 -0.00400549 0.0297102 0.0505045 0.0361149 0.0325914 0.0260693 + -0.0691096 0.0348577 -0.000312685 0.0232462 … -0.132163 -0.129679 -0.131122 + -0.0273036 0.0653352 0.0332689 0.017918 0.0469949 0.0434268 0.0442646 + -0.0981665 -0.0296463 -0.0114686 -0.0348033 -0.0809244 -0.0823798 -0.081472 + -0.0262739 0.109895 0.0117273 0.0222689 0.0175875 0.013171 0.0195091 + 0.0861164 0.0799029 0.00381147 0.0170927 0.0209905 0.0230679 0.0221191 + ⋮ ⋱ ⋮ + -0.039636 -0.0837763 -0.0837142 -0.0597521 0.0313526 0.0316408 0.0309661 + 0.0755214 0.0960326 0.0858578 0.0614626 … 0.109034 0.107593 0.111863 + 0.0506199 0.00290888 0.047947 0.063503 0.033966 0.0327732 0.0261081 + -0.0288586 -0.150171 -0.0699125 -0.108002 -0.0697569 -0.0715358 -0.0683193 + -0.0646991 0.0724608 -0.00767811 -0.0184348 0.0649795 0.0697126 0.0808413 + 0.0445508 0.0296366 0.0325647 0.0521935 0.12324 0.120497 0.117703 + -0.127301 -0.0224252 -0.00579415 -0.00877803 … -0.0823464 -0.0803394 -0.0856279 + 0.0304881 0.0396951 0.0798097 0.0736797 0.0460205 0.0460111 0.0532082 + 0.0488798 0.252244 0.0866849 0.098552 -0.0395483 -0.0463498 -0.0494207 + -0.0296798 -0.0494761 0.00688248 0.0264166 -0.0404835 -0.0410673 -0.0367272 + 0.023548 -0.00147361 0.0629259 0.106951 -0.000107777 -0.000898423 0.00296315 + -0.0574151 -0.0875744 -0.103787 -0.114166 … -0.0687795 -0.070967 -0.0636385 + 0.0280373 0.149767 -0.0899733 -0.0732524 0.0201251 0.0197228 0.0219051 + -0.0617143 -0.0573989 -0.0973785 -0.0805046 0.107432 0.108591 0.109502 + -0.0859687 0.0623054 0.0974813 0.126841 0.0182794 0.0230548 0.031103 + 0.0392044 0.0162653 0.0926306 0.104054 0.0491496 0.0484319 0.0438133 + -0.0340362 -0.0278067 -0.0181035 -0.0282369 … -0.0617946 -0.0631367 -0.0675882 + 0.0131229 0.0565131 -0.0349061 -0.0464192 0.0724731 0.0780165 0.0746229 + -0.117425 0.162483 0.11039 0.136364 -0.00538224 -0.00685447 -0.00194357 + -0.0401157 -0.00450943 0.0539568 0.0689953 -0.00518066 -0.00600254 -0.0077147 + 0.0893984 0.0695061 -0.049941 -0.035411 0.0960931 0.0961892 0.103431 + -0.116265 -0.106331 -0.179832 -0.149728 … -0.0197172 -0.0220611 -0.018135 + -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0699094 -0.0684748 -0.0662903 + 0.100019 -0.0618588 0.106134 0.0989047 -0.055676 -0.0556784 -0.0595709 ``` """ -function encode_query(searcher::Searcher, query::String) +function encode_query(config::ColBERTConfig, checkpoint::Checkpoint, query::String) queries = [query] - bsize = 128 - Q = queryFromText(searcher.checkpoint, queries, bsize) - Q + queryFromText(config, checkpoint, queries, config.index_bsize) end function search(searcher::Searcher, query::String, k::Int) From c661775c98463cdd6a69a209dda41bcb3570368c Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 11:49:14 +0530 Subject: [PATCH 48/59] Removing unnecessary functions. --- src/indexing/codecs/residual.jl | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 1ead085..9ad780a 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -294,31 +294,3 @@ function decompress( embeddings end - -""" - load_codes(codec::ResidualCodec, chunk_idx::Int) - -Load the codes from disk for a given chunk index. The codes are stored in the file `.codes.jld2` located inside the -`index_path` provided by the configuration. - -# Arguments - - - `codec`: The [`ResidualCodec`](@ref) object containing the compression information. - - `chunk_idx`: The chunk index for which the codes should be loaded. - -# Returns - -A vector of codes for the specified chunk. -""" -function load_codes(index_path::String, chunk_idx::Int) - codes_path = joinpath( - index_path, "$(chunk_idx).codes.jld2") - JLD2.load_object(codes_path) -end - -function load_residuals(codec::ResidualCodec, chunk_idx::Int) - residual_path = joinpath( - codec.config.index_path, "$(chunk_idx).residuals.jld2") - residuals = JLD2.load(residual_path, "residuals") - residuals -end From 1702ae3813bdac0407154738328aaa89509ab552 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 12:21:29 +0530 Subject: [PATCH 49/59] Simplyfying the decompression functions; making them test friendly. --- src/indexing/codecs/residual.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 9ad780a..14a1327 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -223,10 +223,8 @@ function compress(centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, codes, residuals end -function decompress_residuals(codec::ResidualCodec, binary_residuals::AbstractMatrix{UInt8}) - dim = codec.config.dim - nbits = codec.config.nbits - +function decompress_residuals(dim::Int, nbits::Int, bucket_weights::Vector{Float32}, + binary_residuals::AbstractMatrix{UInt8}) @assert ndims(binary_residuals)==2 "ndims(binary_residuals): $(ndims(binary_residuals))" @assert size(binary_residuals)[1]==(dim / 8) * nbits "size(binary_residuals): $(size(binary_residuals)), (dim / 8) * nbits: $((dim / 8) * nbits)" @@ -252,7 +250,7 @@ function decompress_residuals(codec::ResidualCodec, binary_residuals::AbstractMa # reshaping to get rid of the nbits wide dimension unpacked_bits = reshape(unpacked_bits, size(unpacked_bits)[2:end]...) - embeddings = codec.bucket_weights[unpacked_bits] + embeddings = bucket_weights[unpacked_bits] @assert ndims(embeddings)==2 "ndims(embeddings): $(ndims(embeddings))" @assert size(embeddings)[2]==size(binary_residuals)[2] "size(embeddings): $(size(embeddings)), size(binary_residuals): $(size(binary_residuals)) " @@ -262,7 +260,8 @@ function decompress_residuals(codec::ResidualCodec, binary_residuals::AbstractMa end function decompress( - codec::ResidualCodec, codes::Vector{UInt32}, residuals::AbstractMatrix{UInt8}) + dim::Int, nbits::Int, centroids::Matrix{Float32}, bucket_weights::Vector{Float32}, + codes::Vector{UInt32}, residuals::AbstractMatrix{UInt8}) @assert ndims(codes)==1 "ndims(codes): $(ndims(codes))" @assert ndims(residuals)==2 "ndims(residuals): $(ndims(residuals))" @assert length(codes)==size(residuals)[2] "length(codes): $(length(codes)), size(residuals): $(size(residuals))" @@ -276,8 +275,8 @@ function decompress( batch_residuals = residuals[ :, batch_offset:min(batch_offset + bsize - 1, length(codes))] - centroids_ = codec.centroids[:, batch_codes] - residuals_ = decompress_residuals(codec, batch_residuals) + centroids_ = centroids[:, batch_codes] + residuals_ = decompress_residuals(dim, nbits, bucket_weights, batch_residuals) batch_embeddings = centroids_ + residuals_ batch_embeddings = mapslices( From 65d8d52637954f2f8c5eec3b6fdf58a7a5ccc6e6 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 12:26:29 +0530 Subject: [PATCH 50/59] Simplifying the search code, and removing the `IndexScorer`. --- src/search/index_storage.jl | 143 ++++++------------------------------ src/searching.jl | 18 ++++- 2 files changed, 35 insertions(+), 126 deletions(-) diff --git a/src/search/index_storage.jl b/src/search/index_storage.jl index fb207ef..d5df58c 100644 --- a/src/search/index_storage.jl +++ b/src/search/index_storage.jl @@ -1,165 +1,72 @@ -struct IndexScorer - metadata::Dict - codec::ResidualCodec - ivf::Vector{Int} - ivf_lengths::Vector{Int} - doclens::Vector{Int} - codes::Vector{UInt32} - residuals::Matrix{UInt8} - emb2pid::Vector{Int} -end - -""" -# Examples - -```julia-repl -julia> IndexScorer(index_path) - -``` -""" -function IndexScorer(index_path::String) - @info "Loading the index from {index_path}." - - # loading the config from the index path - config = load_config(index_path) - - # the metadata - metadata_path = joinpath(index_path, "metadata.json") - metadata = JSON.parsefile(metadata_path) - - # loading the codec - codec = load_codec(index_path) - - # loading ivf into a StridedTensor - ivf_path = joinpath(index_path, "ivf.jld2") - ivf_dict = JLD2.load(ivf_path) - ivf, ivf_lengths = ivf_dict["ivf"], ivf_dict["ivf_lengths"] - # ivf = StridedTensor(ivf, ivf_lengths) - - # loading all doclens - doclens = Vector{Int}() - for chunk_idx in 1:metadata["num_chunks"] - doclens_file = joinpath(index_path, "doclens.$(chunk_idx).jld2") - chunk_doclens = JLD2.load(doclens_file, "doclens") - append!(doclens, chunk_doclens) - end - - # loading all compressed embeddings - num_embeddings = metadata["num_embeddings"] - dim, nbits = config.dim, config.nbits - @assert (dim * nbits) % 8==0 "(dim, nbits): $((dim, nbits))" - codes = zeros(UInt32, num_embeddings) - residuals = zeros(UInt8, Int((dim / 8) * nbits), num_embeddings) - codes_offset = 1 - for chunk_idx in 1:metadata["num_chunks"] - chunk_codes = load_codes(codec, chunk_idx) - chunk_residuals = load_residuals(codec, chunk_idx) - - codes_endpos = codes_offset + length(chunk_codes) - 1 - codes[codes_offset:codes_endpos] = chunk_codes - residuals[:, codes_offset:codes_endpos] = chunk_residuals - - codes_offset = codes_offset + length(chunk_codes) - end - - # the emb2pid mapping - @info "Building the emb2pid mapping." - @assert isequal(sum(doclens), metadata["num_embeddings"]) "sum(doclens): $(sum(doclens)), num_embeddings: $(metadata["num_embeddings"])" - emb2pid = zeros(Int, metadata["num_embeddings"]) - - offset_doclens = 1 - for (pid, dlength) in enumerate(doclens) - emb2pid[offset_doclens:(offset_doclens + dlength - 1)] .= pid - offset_doclens += dlength - end - - IndexScorer( - metadata, - codec, - ivf, - ivf_lengths, - doclens, - codes, - residuals, - emb2pid - ) -end - """ Return a candidate set of `pids` for the query matrix `Q`. This is done as follows: the nearest `nprobe` centroids for each query embedding are found. This list is then flattened and the unique set of these centroids is built. Using the `ivf`, the list of all unique embedding IDs contained in these centroids is computed. Finally, these embedding IDs are converted to `pids` using `emb2pid`. This list of `pids` is the final candidate set. """ -function retrieve(ranker::IndexScorer, config::ColBERTConfig, Q::AbstractArray{Float32}) - @assert isequal(size(Q)[2], config.query_maxlen) "size(Q): $(size(Q)), query_maxlen: $(config.query_maxlen)" # Q: (128, 32, 1) - - Q = reshape(Q, size(Q)[1:end .!= end]...) # squeeze out the last dimension - @assert isequal(length(size(Q)), 2) "size(Q): $(size(Q))" - +function retrieve(ivf::Vector{Int}, ivf_lengths::Vector{Int}, centroids::Matrix{Float32}, + emb2pid::Vector{Int}, nprobe::Int, Q::AbstractMatrix{Float32}) # score of each query embedding with each centroid and take top nprobe centroids - cells = Flux.gpu(transpose(Q)) * Flux.gpu(ranker.codec.centroids) |> Flux.cpu + cells = Flux.gpu(transpose(Q)) * Flux.gpu(centroids) |> Flux.cpu # TODO: how to take topk entries using GPU code? cells = mapslices( - row -> partialsortperm(row, 1:(config.nprobe), rev = true), + row -> partialsortperm(row, 1:(nprobe), rev = true), cells, dims = 2) # take top nprobe centroids for each query centroid_ids = sort(unique(vec(cells))) # get all embedding IDs contained in centroid_ids using ivf centroid_ivf_offsets = cat( - [1], 1 .+ cumsum(ranker.ivf_lengths)[1:end .!= end], dims = 1) + [1], 1 .+ cumsum(ivf_lengths)[1:end .!= end], dims = 1) eids = Vector{Int}() for centroid_id in centroid_ids offset = centroid_ivf_offsets[centroid_id] - length = ranker.ivf_lengths[centroid_id] - append!(eids, ranker.ivf[offset:(offset + length - 1)]) + length = ivf_lengths[centroid_id] + append!(eids, ivf[offset:(offset + length - 1)]) end - @assert isequal(length(eids), sum(ranker.ivf_lengths[centroid_ids])) "length(eids): $(length(eids)), sum(ranker.ivf_lengths[centroid_ids]): $(sum(ranker.ivf_lengths[centroid_ids]))" + @assert isequal(length(eids), sum(ivf_lengths[centroid_ids])) "length(eids): $(length(eids)), sum(ranker.ivf_lengths[centroid_ids]): $(sum(ivf_lengths[centroid_ids]))" eids = sort(unique(eids)) # get pids from the emb2pid mapping - pids = sort(unique(ranker.emb2pid[eids])) + pids = sort(unique(emb2pid[eids])) pids end """ - Get the decompressed embedding matrix for all embeddings in `pids`. Use `doclens` for this. """ -function score_pids(ranker::IndexScorer, config::ColBERTConfig, - Q::AbstractArray{Float32}, pids::Vector{Int}) +function score_pids(config::ColBERTConfig, centroids::Matrix{Float32}, + bucket_weights::Vector{Float32}, doclens::Vector{Int}, codes::Vector{UInt32}, + residuals::Matrix{UInt8}, Q::AbstractArray{Float32}, pids::Vector{Int}) # get codes and residuals for all embeddings across all pids - num_embs = sum(ranker.doclens[pids]) + num_embs = sum(doclens[pids]) codes_packed = zeros(UInt32, num_embs) - residuals_packed = zeros(UInt8, size(ranker.residuals)[1], num_embs) - pid_offsets = cat([1], 1 .+ cumsum(ranker.doclens)[1:end .!= end], dims = 1) + residuals_packed = zeros(UInt8, size(residuals)[1], num_embs) + pid_offsets = cat([1], 1 .+ cumsum(doclens)[1:end .!= end], dims = 1) offset = 1 for pid in pids pid_offset = pid_offsets[pid] - num_embs_pid = ranker.doclens[pid] - codes_packed[offset:(offset + num_embs_pid - 1)] = ranker.codes[pid_offset:(pid_offset + num_embs_pid - 1)] - residuals_packed[:, offset:(offset + num_embs_pid - 1)] = ranker.residuals[ + num_embs_pid = doclens[pid] + codes_packed[offset:(offset + num_embs_pid - 1)] = codes[pid_offset:(pid_offset + num_embs_pid - 1)] + residuals_packed[:, offset:(offset + num_embs_pid - 1)] = residuals[ :, pid_offset:(pid_offset + num_embs_pid - 1)] offset += num_embs_pid end @assert offset==num_embs + 1 "offset: $(offset), num_embs + 1: $(num_embs + 1)" # decompress these codes and residuals to get the original embeddings - D_packed = decompress(ranker.codec, codes_packed, residuals_packed) + D_packed = decompress( + config.dim, config.nbits, centroids, bucket_weights, codes_packed, residuals_packed) @assert ndims(D_packed)==2 "ndims(D_packed): $(ndims(D_packed))" @assert size(D_packed)[1]==config.dim "size(D_packed): $(size(D_packed)), config.dim: $(config.dim)" @assert size(D_packed)[2]==num_embs "size(D_packed): $(size(D_packed)), num_embs: $(num_embs)" @assert D_packed isa AbstractMatrix{Float32} "$(typeof(D_packed))" # get the max-sim scores - if size(Q)[3] > 1 - error("Only one query is supported at the moment!") - end - @assert size(Q)[3]==1 "size(Q): $(size(Q))" Q = reshape(Q, size(Q)[1:2]...) scores = Vector{Float32}() query_doc_scores = Flux.gpu(transpose(Q)) * Flux.gpu(D_packed) # (num_query_tokens, num_embeddings) offset = 1 for pid in pids - num_embs_pid = ranker.doclens[pid] + num_embs_pid = doclens[pid] pid_scores = query_doc_scores[:, offset:min(num_embs, offset + num_embs_pid - 1)] push!(scores, sum(maximum(pid_scores, dims = 2))) @@ -169,11 +76,3 @@ function score_pids(ranker::IndexScorer, config::ColBERTConfig, scores end - -function rank(ranker::IndexScorer, config::ColBERTConfig, Q::AbstractArray{Float32}) - pids = retrieve(ranker, config, Q) - scores = score_pids(ranker, config, Q, pids) - indices = sortperm(scores, rev = true) - - pids[indices], scores[indices] -end diff --git a/src/searching.jl b/src/searching.jl index 92c8974..04edd85 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -162,11 +162,21 @@ function encode_query(config::ColBERTConfig, checkpoint::Checkpoint, query::Stri end function search(searcher::Searcher, query::String, k::Int) - dense_search(searcher, encode_query(searcher, query), k) -end + Q = encode_query(searcher.config, searcher.checkpoint, query) + + if size(Q)[3] > 1 + error("Only one query is supported at the moment!") + end + @assert size(Q)[3]==1 "size(Q): $(size(Q))" + @assert isequal(size(Q)[2], searcher.config.query_maxlen) "size(Q): $(size(Q)), query_maxlen: $(searcher.config.query_maxlen)" # Q: (128, 32, 1) + + Q = reshape(Q, size(Q)[1:end .!= end]...) # squeeze out the last dimension + @assert isequal(length(size(Q)), 2) "size(Q): $(size(Q))" -function dense_search(searcher::Searcher, Q::AbstractArray{Float32}, k::Int) - pids, scores = rank(searcher.ranker, searcher.config, Q) + pids = retrieve(searcher.ivf, searcher.ivf_lengths, searcher.centroids, searcher.emb2pid, searcher.config.nprobe, Q) + scores = score_pids(searcher.config, searcher.centroids, searcher.bucket_weights, searcher.doclens, searcher.codes, searcher.residuals, Q, pids) + indices = sortperm(scores, rev = true) + pids, scores = pids[indices], scores[indices] pids[1:k], scores[1:k] end From d8078ebd56c21f312d729b7ad649013a88d4f573 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 12:33:49 +0530 Subject: [PATCH 51/59] File rename. --- src/ColBERT.jl | 3 +-- src/search/{index_storage.jl => ranking.jl} | 0 2 files changed, 1 insertion(+), 2 deletions(-) rename src/search/{index_storage.jl => ranking.jl} (100%) diff --git a/src/ColBERT.jl b/src/ColBERT.jl index fae58a6..4212fe4 100644 --- a/src/ColBERT.jl +++ b/src/ColBERT.jl @@ -38,8 +38,7 @@ include("indexing/collection_indexer.jl") export Indexer, index # searcher -include("search/strided_tensor.jl") -include("search/index_storage.jl") +include("search/ranking.jl") include("searching.jl") export Searcher, search diff --git a/src/search/index_storage.jl b/src/search/ranking.jl similarity index 100% rename from src/search/index_storage.jl rename to src/search/ranking.jl From affdb5eb40c4f3187803ec5a42e7c0b4b330762b Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 12:34:50 +0530 Subject: [PATCH 52/59] Removing `ResidalCodec`. --- src/indexing/codecs/residual.jl | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index 14a1327..c2305cb 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -1,32 +1,3 @@ -""" - ResidualCodec( - config::ColBERTConfig, centroids::AbstractMatrix{Float32}, avg_residual::Float32, - bucket_cutoffs::AbstractVector{Float32}, bucket_weights::AbstractVector{Float32}) - -A struct that represents a compressor for ColBERT embeddings. - -It stores information about the configuration of the model, the centroids used to quantize the residuals, the average residual value, and the cutoffs and weights used to determine which buckets each residual belongs to. - -# Arguments - - - `config`: A [`ColBERTConfig`](@ref), representing all configuration parameters related to various ColBERT components. - - `centroids`: A matrix of centroids used to quantize the residuals. Has shape `(D, N)`, where `D` is the embedding dimension and `N` is the number of clusters. - - `avg_residual`: The average residual value. - - `bucket_cutoffs`: A vector of cutoff values used to determine which buckets each residual belongs to. - - `bucket_weights`: A vector of weights used to determine the importance of each bucket. - -# Returns - -A `ResidualCodec` object. -""" -mutable struct ResidualCodec - config::ColBERTConfig - centroids::AbstractMatrix{Float32} - avg_residual::Float32 - bucket_cutoffs::AbstractVector{Float32} - bucket_weights::AbstractVector{Float32} -end - """ load_codec(index_path::String) From 5a49d9269178ad8bd5f0eac2cc202a8e038ba6a2 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 12:37:24 +0530 Subject: [PATCH 53/59] Removing data structs; they aren't really needed. --- src/ColBERT.jl | 5 -- src/data/collection.jl | 120 ----------------------------------------- src/data/queries.jl | 4 -- 3 files changed, 129 deletions(-) delete mode 100644 src/data/collection.jl delete mode 100644 src/data/queries.jl diff --git a/src/ColBERT.jl b/src/ColBERT.jl index 4212fe4..3e5e295 100644 --- a/src/ColBERT.jl +++ b/src/ColBERT.jl @@ -16,11 +16,6 @@ using Transformers # utils include("utils/utils.jl") -# datasets -include("data/collection.jl") -include("data/queries.jl") -export Collection, Queries - # config and other infra include("infra/config.jl") export ColBERTConfig diff --git a/src/data/collection.jl b/src/data/collection.jl deleted file mode 100644 index 04e7860..0000000 --- a/src/data/collection.jl +++ /dev/null @@ -1,120 +0,0 @@ -# TODO: implement on-disk collections, and the case where pids are not necessarily sorted and can be arbitrary -""" - Collection(path::String) - -A wrapper around a collection of documents, which stores the underlying collection as a `Vector{String}`. - -# Arguments - - - `path::String`: A path to the document dataset. It is assumed that `path` refers to a CSV file. Each line of the - the CSV file should be of the form `pid \\t document`, where `pid` is the integer index of the document. `pid`s should be in the range ``[1, N]``, where ``N`` is the number of documents, and should be sorted. - -# Examples - -Here's an example which loads a small subset of the LoTTe dataset defined in `short_collections.tsv` (see the `examples` folder in the package). - -```julia-repl -julia> using ColBERT; - -julia> dataroot = "downloads/lotte"; - -julia> dataset = "lifestyle"; - -julia> datasplit = "dev"; - -julia> path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv") -"downloads/lotte/lifestyle/dev/short_collection.tsv" - -julia> collection = Collection(path) -Collection at downloads/lotte/lifestyle/dev/short_collection.tsv with 10 passages. -``` -""" -struct Collection - path::String - data::Vector{String} -end - -function Collection(path::String) - file = CSV.File(path; delim = '\t', header = [:pid, :text], - types = Dict(:pid => Int, :text => String), debug = true, quoted = false) - @info "Loaded $(length(file.text)[1]) passages." - Collection(path, file.text) -end - -""" - get_chunksize(collection::Collection, nranks::Int) - -Determine the size of chunks used to store the index, based on the size of the `collection` and the number of available GPUs. - -# Arguments - - - `collection::Collection`: The underlying collection of documents. - - `nranks::Int`: Number of available GPUs to compute the index. At this point, the package only supports `nranks = 1`. - -# Examples - -Continuing from the example from the [`Collection`](@ref) constructor: - -```julia-repl -julia> get_chunksize(collection, 1) -11 -``` -""" -function get_chunksize(collection::Collection, nranks::Int) - Int(min(25000, 1 + floor(length(collection.data) / nranks))) -end - -""" - enumerate_batches(collection::Collection; [chunksize, nranks]) - -Batch the `collection` into chunks containing tuples of the form `(chunk_idx, offset, passages)`, where `chunk_idx` is the index of the chunk, `offset` is the index of the first passsage in the chunk, and `passages` is a `Vector{String}` containing the passages in the chunk. - -# Arguments - - - `collection::Collection`: The collection to batch. - - `chunksize::Union{Int, Missing}`: The chunksize to use to batch the collection. Default `missing`. If this is `missing`, then `chunksize` is determined using [`get_chunksize`](@ref) based on the `collection` and `nranks`. - - `nranks::Union{Int, Missing}`: The number of available GPUs. Default `missing`. Currently the package only supports `nranks = 1`. - -The `collection` is batched into chunks of uniform size (with the last chunk potentially having a smaller size). - -# Examples - -Continuing from the example in the [`Collection`](@ref) constructor. - -```julia-repl -julia> enumerate_batches(collection; nranks = 1); - -julia> enumerate_batches(collection; chunksize = 3); - -``` -""" -function enumerate_batches( - collection::Collection; chunksize::Union{Int, Missing} = missing, - nranks::Union{Int, Missing} = missing) - if ismissing(chunksize) - if ismissing(nranks) - error("Atleast one of the arguments chunksize or nranks must be specified!") - end - chunksize = get_chunksize(collection, nranks) - end - - num_passages = length(collection.data) - batches = Vector{Tuple{Int, Int, Vector{String}}}() - chunk_idx, offset = 1, 1 - while true - push!(batches, - (chunk_idx, offset, - collection.data[offset:min(offset + chunksize - 1, num_passages)])) - chunk_idx += 1 - offset += chunksize - - if offset > num_passages - break - end - end - batches -end - -function Base.show(io::IO, collection::Collection) - print(io, "Collection at $(collection.path) with $(length(collection.data)) passages.") -end diff --git a/src/data/queries.jl b/src/data/queries.jl deleted file mode 100644 index 36cb803..0000000 --- a/src/data/queries.jl +++ /dev/null @@ -1,4 +0,0 @@ -Base.@kwdef struct Queries - path::String - data::Vector{String} -end From 6fda1e38a8d20c5330d658c53dbca363d0d7f07f Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 12:58:29 +0530 Subject: [PATCH 54/59] Moving loading/saving functions to their own files. --- src/ColBERT.jl | 4 + src/indexing/codecs/residual.jl | 64 --------------- src/indexing/collection_indexer.jl | 57 ------------- src/infra/config.jl | 69 ---------------- src/loaders.jl | 108 +++++++++++++++++++++++++ src/savers.jl | 123 +++++++++++++++++++++++++++++ 6 files changed, 235 insertions(+), 190 deletions(-) create mode 100644 src/loaders.jl create mode 100644 src/savers.jl diff --git a/src/ColBERT.jl b/src/ColBERT.jl index 3e5e295..f3c1e22 100644 --- a/src/ColBERT.jl +++ b/src/ColBERT.jl @@ -37,4 +37,8 @@ include("search/ranking.jl") include("searching.jl") export Searcher, search +# loaders and savers +include("loaders.jl") +include("savers.jl") + end diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index c2305cb..799c68e 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -1,67 +1,3 @@ -""" - load_codec(index_path::String) - -Load compression/decompression information from the index path. - -# Arguments - - - `index_path`: The path of the index. -""" -function load_codec(index_path::String) - centroids_path = joinpath(index_path, "centroids.jld2") - avg_residual_path = joinpath(index_path, "avg_residual.jld2") - bucket_cutoffs_path = joinpath(index_path, "bucket_cutoffs.jld2") - bucket_weights_path = joinpath(index_path, "bucket_weights.jld2") - @info "Loading codec from $(centroids_path), $(avg_residual_path), $(bucket_cutoffs_path) and $(bucket_weights_path)." - - centroids = JLD2.load_object(centroids_path) - avg_residual = JLD2.load_object(avg_residual_path) - bucket_cutoffs = JLD2.load_object(bucket_cutoffs_path) - bucket_weights = JLD2.load_object(bucket_weights_path) - - @assert centroids isa Matrix{Float32} - @assert avg_residual isa Float32 - @assert bucket_cutoffs isa Vector{Float32} - @assert bucket_weights isa Vector{Float32} - - Dict( - "centroids" => centroids, - "avg_residual" => avg_residual, - "bucket_cutoffs" => bucket_cutoffs, - "bucket_weights" => bucket_weights - ) -end - -""" - save_codec( - index_path::String, centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, - bucket_weights::Vector{Float32}, avg_residual::Float32) - -Save compression/decompression information from the index path. - -# Arguments - - - `index_path`: The path of the index. - - `centroids`: The matrix of centroids of the index. - - `bucket_cutoffs`: Cutoffs used to determine buckets during residual compression. - - `bucket_weights`: Weights used to determine the decompressed values during decompression. - - `avg_residual`: The average residual value, computed from the heldout set (see [`_compute_avg_residuals`](@ref)). -""" -function save_codec( - index_path::String, centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, - bucket_weights::Vector{Float32}, avg_residual::Float32) - centroids_path = joinpath(index_path, "centroids.jld2") - avg_residual_path = joinpath(index_path, "avg_residual.jld2") - bucket_cutoffs_path = joinpath(index_path, "bucket_cutoffs.jld2") - bucket_weights_path = joinpath(index_path, "bucket_weights.jld2") - @info "Saving codec to $(centroids_path), $(avg_residual_path), $(bucket_cutoffs_path) and $(bucket_weights_path)." - - JLD2.save_object(centroids_path, centroids) - JLD2.save_object(avg_residual_path, avg_residual) - JLD2.save_object(bucket_cutoffs_path, bucket_cutoffs) - JLD2.save_object(bucket_weights_path, bucket_weights) -end - """ compress_into_codes( centroids::AbstractMatrix{Float32}, embs::AbstractMatrix{Float32}) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index e27c2c9..4204483 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -304,63 +304,6 @@ function train(sample::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32} ) end -""" - save_chunk( - config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_offset::Int, - embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) - -Save a single chunk of compressed embeddings and their relevant metadata to disk. - -The codes and compressed residuals for the chunk are saved in files named `.codes.jld2`. -and `.residuals.jld2` respectively. The document lengths are saved in a file named -`doclens..jld2`. Relevant metadata, including number of documents in the chunk, -number of embeddings and the passage offsets are saved in a file named `.metadata.json`. - -# Arguments - - - `config`: The [`ColBERTConfig`](@ref) being used. - - `chunk_idx`: The index of the current chunk being saved. - - `passage_offset`: The index of the first passage in the chunk. - - `embs`: The embeddings matrix for the current chunk. - - `doclens`: The document lengths vector for the current chunk. -""" -function save_chunk( - config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_offset::Int, - embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) - codes, residuals = compress( - codec["centroids"], codec["bucket_cutoffs"], config.dim, config.nbits, embs) - path_prefix = joinpath(config.index_path, string(chunk_idx)) - @assert length(codes)==size(embs)[2] "length(codes): $(length(codes)), size(embs): $(size(embs))" - - # saving the compressed embeddings - codes_path = "$(path_prefix).codes.jld2" - residuals_path = "$(path_prefix).residuals.jld2" - @info "Saving compressed codes to $(codes_path) and residuals to $(residuals_path)" - JLD2.save_object(codes_path, codes) - JLD2.save_object(residuals_path, residuals) - - # saving doclens - doclens_path = joinpath( - config.index_path, "doclens.$(chunk_idx).jld2") - @info "Saving doclens to $(doclens_path)" - JLD2.save_object(doclens_path, doclens) - - # the metadata - metadata_path = joinpath( - config.index_path, "$(chunk_idx).metadata.json") - @info "Saving metadata to $(metadata_path)" - open(metadata_path, "w") do io - JSON.print(io, - Dict( - "passage_offset" => passage_offset, - "num_passages" => length(doclens), - "num_embeddings" => length(codes) - ), - 4 # indent - ) - end -end - """ index(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) diff --git a/src/infra/config.jl b/src/infra/config.jl index 2d728c8..91b7140 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -89,72 +89,3 @@ Base.@kwdef struct ColBERTConfig nprobe::Int = 2 ncandidates::Int = 8192 end - -""" - save(config::ColBERTConfig) - -Save a [`ColBERTConfig`](@ref) to disk in JSON. - -# Arguments - - - `config`: The [`ColBERTConfig`](@ref) to save. - -# Examples - -```jldoctest -julia> using ColBERT; - -julia> config = ColBERTConfig( - use_gpu = true, - collection = "/home/codetalker7/documents", - index_path = "./local_index" - ); - -julia> ColBERT.save(config); - -``` -""" -function save(config::ColBERTConfig) - properties = [Pair{String, Any}(string(field), getproperty(config, field)) - for field in fieldnames(ColBERTConfig)] - isdir(config.index_path) || mkdir(config.index_path) - open(joinpath(config.index_path, "config.json"), "w+") do io - JSON.print( - io, - Dict(properties), - 4 - ) - end -end - -""" - load_config(index_path::String) - -Load a [`ColBERTConfig`](@ref) from disk. - -# Arguments - - - `index_path`: The path of the directory where the config resides. - -# Examples - -```jldoctest -julia> using ColBERT; - -julia> config = ColBERTConfig( - use_gpu = true, - collection = "/home/codetalker7/documents", - index_path = "./local_index" - ); - -julia> ColBERT.save(config); - -julia> ColBERT.load_config("./local_index") -ColBERTConfig(true, 0, 1, "[unused0]", "[unused1]", "[Q]", "[D]", "colbert-ir/colbertv2.0", "/home/codetalker7/documents", 128, 220, true, 32, false, "./local_index", 64, 2, 20, 2, 8192) -``` -""" -function load_config(index_path::String) - config_dict = JSON.parsefile(joinpath(index_path, "config.json")) - key_vals = collect(zip(Symbol.(keys(config_dict)), values(config_dict))) - eval(:(ColBERTConfig($([Expr(:kw, :($key), :($val)) for (key, val) in key_vals]...)))) -end diff --git a/src/loaders.jl b/src/loaders.jl new file mode 100644 index 0000000..d6777ed --- /dev/null +++ b/src/loaders.jl @@ -0,0 +1,108 @@ +""" + load_codec(index_path::String) + +Load compression/decompression information from the index path. + +# Arguments + + - `index_path`: The path of the index. +""" +function load_codec(index_path::String) + centroids_path = joinpath(index_path, "centroids.jld2") + avg_residual_path = joinpath(index_path, "avg_residual.jld2") + bucket_cutoffs_path = joinpath(index_path, "bucket_cutoffs.jld2") + bucket_weights_path = joinpath(index_path, "bucket_weights.jld2") + @info "Loading codec from $(centroids_path), $(avg_residual_path), "* + "$(bucket_cutoffs_path) and $(bucket_weights_path)." + + centroids = JLD2.load_object(centroids_path) + avg_residual = JLD2.load_object(avg_residual_path) + bucket_cutoffs = JLD2.load_object(bucket_cutoffs_path) + bucket_weights = JLD2.load_object(bucket_weights_path) + + @assert centroids isa Matrix{Float32} + @assert avg_residual isa Float32 + @assert bucket_cutoffs isa Vector{Float32} + @assert bucket_weights isa Vector{Float32} + + Dict( + "centroids" => centroids, + "avg_residual" => avg_residual, + "bucket_cutoffs" => bucket_cutoffs, + "bucket_weights" => bucket_weights + ) +end + +""" + load_config(index_path::String) + +Load a [`ColBERTConfig`](@ref) from disk. + +# Arguments + + - `index_path`: The path of the directory where the config resides. + +# Examples + +```jldoctest +julia> using ColBERT; + +julia> config = ColBERTConfig( + use_gpu = true, + collection = "/home/codetalker7/documents", + index_path = "./local_index" + ); + +julia> ColBERT.save(config); + +julia> ColBERT.load_config("./local_index") +ColBERTConfig(true, 0, 1, "[unused0]", "[unused1]", "[Q]", "[D]", "colbert-ir/colbertv2.0", "/home/codetalker7/documents", 128, 220, true, 32, false, "./local_index", 64, 2, 20, 2, 8192) +``` +""" +function load_config(index_path::String) + config_dict = JSON.parsefile(joinpath(index_path, "config.json")) + key_vals = collect(zip(Symbol.(keys(config_dict)), values(config_dict))) + eval(:(ColBERTConfig($([Expr(:kw, :($key), :($val)) for (key, val) in key_vals]...)))) +end + +function load_doclens(index_path::String) + plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + doclens = Vector{Int}() + for chunk_idx in 1:plan_metadata["num_chunks"] + doclens_file = joinpath(index_path, "doclens.$(chunk_idx).jld2") + chunk_doclens = JLD2.load_object(doclens_file) + append!(doclens, chunk_doclens) + end + @assert isequal(sum(doclens), plan_metadata["num_embeddings"]) + "sum(doclens): $(sum(doclens)), num_embeddings: $(plan_metadata["num_embeddings"])" + doclens +end + +function load_compressed_embs(index_path::String) + config = load_config(index_path) + plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) + @assert (config.dim * config.nbits) % 8==0 "(dim, nbits): $((config.dim, config.nbits))" + + codes = zeros(UInt32, plan_metadata["num_embeddings"]) + residuals = zeros(UInt8, Int((config.dim / 8) * config.nbits), plan_metadata["num_embeddings"]) + codes_offset = 1 + for chunk_idx in 1:plan_metadata["num_chunks"] + chunk_codes = JLD2.load_object(joinpath(index_path, "$(chunk_idx).codes.jld2")) + chunk_residuals = JLD2.load_object(joinpath(index_path, "$(chunk_idx).residuals.jld2")) + + codes_endpos = codes_offset + length(chunk_codes) - 1 + codes[codes_offset:codes_endpos] = chunk_codes + residuals[:, codes_offset:codes_endpos] = chunk_residuals + + codes_offset = codes_offset + length(chunk_codes) + end + @assert length(codes) == plan_metadata["num_embeddings"] + "length(codes): $(length(codes)), num_embeddings: $(plan_metadata["num_embeddings"])" + @assert ndims(residuals) == 2 + @assert size(residuals)[2] == plan_metadata["num_embeddings"] + "size(residuals): $(size(residuals)), num_embeddings: $(plan_metadata["num_embeddings"])" + @assert codes isa Vector{UInt32} + @assert residuals isa Matrix{UInt8} + + codes, residuals +end diff --git a/src/savers.jl b/src/savers.jl new file mode 100644 index 0000000..7f1a3d5 --- /dev/null +++ b/src/savers.jl @@ -0,0 +1,123 @@ +""" + save_codec( + index_path::String, centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, + bucket_weights::Vector{Float32}, avg_residual::Float32) + +Save compression/decompression information from the index path. + +# Arguments + + - `index_path`: The path of the index. + - `centroids`: The matrix of centroids of the index. + - `bucket_cutoffs`: Cutoffs used to determine buckets during residual compression. + - `bucket_weights`: Weights used to determine the decompressed values during decompression. + - `avg_residual`: The average residual value, computed from the heldout set (see [`_compute_avg_residuals`](@ref)). +""" +function save_codec( + index_path::String, centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, + bucket_weights::Vector{Float32}, avg_residual::Float32) + centroids_path = joinpath(index_path, "centroids.jld2") + avg_residual_path = joinpath(index_path, "avg_residual.jld2") + bucket_cutoffs_path = joinpath(index_path, "bucket_cutoffs.jld2") + bucket_weights_path = joinpath(index_path, "bucket_weights.jld2") + @info "Saving codec to $(centroids_path), $(avg_residual_path), $(bucket_cutoffs_path) and $(bucket_weights_path)." + + JLD2.save_object(centroids_path, centroids) + JLD2.save_object(avg_residual_path, avg_residual) + JLD2.save_object(bucket_cutoffs_path, bucket_cutoffs) + JLD2.save_object(bucket_weights_path, bucket_weights) +end + +""" + save_chunk( + config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_offset::Int, + embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) + +Save a single chunk of compressed embeddings and their relevant metadata to disk. + +The codes and compressed residuals for the chunk are saved in files named `.codes.jld2`. +and `.residuals.jld2` respectively. The document lengths are saved in a file named +`doclens..jld2`. Relevant metadata, including number of documents in the chunk, +number of embeddings and the passage offsets are saved in a file named `.metadata.json`. + +# Arguments + + - `config`: The [`ColBERTConfig`](@ref) being used. + - `chunk_idx`: The index of the current chunk being saved. + - `passage_offset`: The index of the first passage in the chunk. + - `embs`: The embeddings matrix for the current chunk. + - `doclens`: The document lengths vector for the current chunk. +""" +function save_chunk( + config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_offset::Int, + embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) + codes, residuals = compress( + codec["centroids"], codec["bucket_cutoffs"], config.dim, config.nbits, embs) + path_prefix = joinpath(config.index_path, string(chunk_idx)) + @assert length(codes)==size(embs)[2] "length(codes): $(length(codes)), size(embs): $(size(embs))" + + # saving the compressed embeddings + codes_path = "$(path_prefix).codes.jld2" + residuals_path = "$(path_prefix).residuals.jld2" + @info "Saving compressed codes to $(codes_path) and residuals to $(residuals_path)" + JLD2.save_object(codes_path, codes) + JLD2.save_object(residuals_path, residuals) + + # saving doclens + doclens_path = joinpath( + config.index_path, "doclens.$(chunk_idx).jld2") + @info "Saving doclens to $(doclens_path)" + JLD2.save_object(doclens_path, doclens) + + # the metadata + metadata_path = joinpath( + config.index_path, "$(chunk_idx).metadata.json") + @info "Saving metadata to $(metadata_path)" + open(metadata_path, "w") do io + JSON.print(io, + Dict( + "passage_offset" => passage_offset, + "num_passages" => length(doclens), + "num_embeddings" => length(codes) + ), + 4 # indent + ) + end +end + +""" + save(config::ColBERTConfig) + +Save a [`ColBERTConfig`](@ref) to disk in JSON. + +# Arguments + + - `config`: The [`ColBERTConfig`](@ref) to save. + +# Examples + +```jldoctest +julia> using ColBERT; + +julia> config = ColBERTConfig( + use_gpu = true, + collection = "/home/codetalker7/documents", + index_path = "./local_index" + ); + +julia> ColBERT.save(config); + +``` +""" +function save(config::ColBERTConfig) + properties = [Pair{String, Any}(string(field), getproperty(config, field)) + for field in fieldnames(ColBERTConfig)] + isdir(config.index_path) || mkdir(config.index_path) + open(joinpath(config.index_path, "config.json"), "w+") do io + JSON.print( + io, + Dict(properties), + 4 + ) + end +end From 87d8534c8a61635670c64deb6a04e31bb35f9635 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 13:36:35 +0530 Subject: [PATCH 55/59] Simplifying `Searcher` constructor. --- src/searching.jl | 55 ++++++++++++------------------------------------ 1 file changed, 14 insertions(+), 41 deletions(-) diff --git a/src/searching.jl b/src/searching.jl index 04edd85..3f1c69e 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -12,61 +12,34 @@ struct Searcher emb2pid::Vector{Int} end +function _build_emb2pid(doclens::Vector{Int}) + num_embeddings = sum(doclens) + emb2pid = zeros(Int, num_embeddings) + offset_doclens = 1 + for (pid, dlength) in enumerate(doclens) + emb2pid[offset_doclens:(offset_doclens + dlength - 1)] .= pid + offset_doclens += dlength + end + emb2pid +end + function Searcher(index_path::String) if !isdir(index_path) error("Index at $(index_path) does not exist! Please build the index first and try again.") end - # loading the config from the path config = load_config(index_path) - - # loading the model and saving it to prevent multiple loads base_colbert = BaseColBERT(config) checkpoint = Checkpoint(base_colbert, config) @info "Loaded ColBERT layers from the $(config.checkpoint) HuggingFace checkpoint." - - plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) codec = load_codec(index_path) ivf = JLD2.load_object(joinpath(index_path, "ivf.jld2")) ivf_lengths = JLD2.load_object(joinpath(index_path, "ivf_lengths.jld2")) - - # loading all doclens - doclens = Vector{Int}() - for chunk_idx in 1:plan_metadata["num_chunks"] - doclens_file = joinpath(index_path, "doclens.$(chunk_idx).jld2") - chunk_doclens = JLD2.load_object(doclens_file) - append!(doclens, chunk_doclens) - end - - # loading all compressed embeddings - num_embeddings = plan_metadata["num_embeddings"] - dim, nbits = config.dim, config.nbits - @assert (dim * nbits) % 8==0 "(dim, nbits): $((dim, nbits))" - codes = zeros(UInt32, num_embeddings) - residuals = zeros(UInt8, Int((dim / 8) * nbits), num_embeddings) - codes_offset = 1 - for chunk_idx in 1:plan_metadata["num_chunks"] - chunk_codes = JLD2.load_object(joinpath(index_path, "$(chunk_idx).codes.jld2")) - chunk_residuals = JLD2.load_object(joinpath(index_path, "$(chunk_idx).residuals.jld2")) - - codes_endpos = codes_offset + length(chunk_codes) - 1 - codes[codes_offset:codes_endpos] = chunk_codes - residuals[:, codes_offset:codes_endpos] = chunk_residuals - - codes_offset = codes_offset + length(chunk_codes) - end - - # the emb2pid mapping + doclens = load_doclens(index_path) + codes, residuals = load_compressed_embs(index_path) @info "Building the emb2pid mapping." - @assert isequal(sum(doclens), plan_metadata["num_embeddings"]) "sum(doclens): $(sum(doclens)), num_embeddings: $(plan_metadata["num_embeddings"])" - emb2pid = zeros(Int, plan_metadata["num_embeddings"]) + emb2pid = _build_emb2pid(doclens) - offset_doclens = 1 - for (pid, dlength) in enumerate(doclens) - emb2pid[offset_doclens:(offset_doclens + dlength - 1)] .= pid - offset_doclens += dlength - end - Searcher( config, checkpoint, From 2af3336cc8f847218543ff4b73e41a5daf632fc3 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 14:22:24 +0530 Subject: [PATCH 56/59] Refactoring the ranking code; making it more test friendly. --- src/search/ranking.jl | 74 ++++++++++++++++++++++++------------------- src/searching.jl | 9 ++++-- 2 files changed, 47 insertions(+), 36 deletions(-) diff --git a/src/search/ranking.jl b/src/search/ranking.jl index d5df58c..bb4ff0c 100644 --- a/src/search/ranking.jl +++ b/src/search/ranking.jl @@ -20,7 +20,9 @@ function retrieve(ivf::Vector{Int}, ivf_lengths::Vector{Int}, centroids::Matrix{ length = ivf_lengths[centroid_id] append!(eids, ivf[offset:(offset + length - 1)]) end - @assert isequal(length(eids), sum(ivf_lengths[centroid_ids])) "length(eids): $(length(eids)), sum(ranker.ivf_lengths[centroid_ids]): $(sum(ivf_lengths[centroid_ids]))" + @assert isequal(length(eids), sum(ivf_lengths[centroid_ids])) + "length(eids): $(length(eids)), sum(ranker.ivf_lengths[centroid_ids]):" * + "$(sum(ivf_lengths[centroid_ids]))" eids = sort(unique(eids)) # get pids from the emb2pid mapping @@ -28,18 +30,12 @@ function retrieve(ivf::Vector{Int}, ivf_lengths::Vector{Int}, centroids::Matrix{ pids end -""" - - Get the decompressed embedding matrix for all embeddings in `pids`. Use `doclens` for this. -""" -function score_pids(config::ColBERTConfig, centroids::Matrix{Float32}, - bucket_weights::Vector{Float32}, doclens::Vector{Int}, codes::Vector{UInt32}, - residuals::Matrix{UInt8}, Q::AbstractArray{Float32}, pids::Vector{Int}) - # get codes and residuals for all embeddings across all pids - num_embs = sum(doclens[pids]) - codes_packed = zeros(UInt32, num_embs) - residuals_packed = zeros(UInt8, size(residuals)[1], num_embs) +function _collect_compressed_embs_for_pids(doclens::Vector{Int}, codes::Vector{UInt32}, + residuals::Matrix{UInt8}, pids::Vector{Int}) + num_embeddings = sum(doclens[pids]) + codes_packed = zeros(UInt32, num_embeddings) + residuals_packed = zeros(UInt8, size(residuals)[1], num_embeddings) pid_offsets = cat([1], 1 .+ cumsum(doclens)[1:end .!= end], dims = 1) - offset = 1 for pid in pids pid_offset = pid_offsets[pid] @@ -49,30 +45,42 @@ function score_pids(config::ColBERTConfig, centroids::Matrix{Float32}, :, pid_offset:(pid_offset + num_embs_pid - 1)] offset += num_embs_pid end - @assert offset==num_embs + 1 "offset: $(offset), num_embs + 1: $(num_embs + 1)" + @assert offset==num_embeddings + 1 "offset: $(offset), num_embs + 1: $(num_embeddings + 1)" + codes_packed, residuals_packed +end + +function maxsim( + Q::Matrix{Float32}, D::Matrix{Float32}, pids::Vector{Int}, doclens::Vector{Int}) + scores = zeros(Float32, length(pids)) + num_embeddings = sum(doclens[pids]) + query_doc_scores = Flux.gpu(transpose(Q)) * Flux.gpu(D) # (num_query_tokens, num_embeddings) + offset = 1 + for (idx, pid) in enumerate(pids) + num_embs_pids = doclens[pid] + offset_end = min(num_embeddings, offset + num_embs_pids - 1) + pid_scores = query_doc_scores[:, offset:offset_end] + scores[idx] = sum(maximum(pid_scores, dims = 2)) + offset += num_embs_pids + end + @assert offset == num_embeddings + 1 "offset: $(offset), num_embs + 1: $(num_embeddings + 1)" + scores +end - # decompress these codes and residuals to get the original embeddings +""" + - Get the decompressed embedding matrix for all embeddings in `pids`. Use `doclens` for this. +""" +function score_pids(config::ColBERTConfig, centroids::Matrix{Float32}, + bucket_weights::Vector{Float32}, doclens::Vector{Int}, codes::Vector{UInt32}, + residuals::Matrix{UInt8}, Q::Matrix{Float32}, pids::Vector{Int}) + codes_packed, residuals_packed = _collect_compressed_embs_for_pids( + doclens, codes, residuals, pids) D_packed = decompress( config.dim, config.nbits, centroids, bucket_weights, codes_packed, residuals_packed) @assert ndims(D_packed)==2 "ndims(D_packed): $(ndims(D_packed))" - @assert size(D_packed)[1]==config.dim "size(D_packed): $(size(D_packed)), config.dim: $(config.dim)" - @assert size(D_packed)[2]==num_embs "size(D_packed): $(size(D_packed)), num_embs: $(num_embs)" + @assert size(D_packed)[1] == config.dim + "size(D_packed): $(size(D_packed)), config.dim: $(config.dim)" + @assert size(D_packed)[2] == sum(doclens[pids]) + "size(D_packed): $(size(D_packed)), num_embs: $(sum(doclens[pids]))" @assert D_packed isa AbstractMatrix{Float32} "$(typeof(D_packed))" - - # get the max-sim scores - Q = reshape(Q, size(Q)[1:2]...) - - scores = Vector{Float32}() - query_doc_scores = Flux.gpu(transpose(Q)) * Flux.gpu(D_packed) # (num_query_tokens, num_embeddings) - offset = 1 - for pid in pids - num_embs_pid = doclens[pid] - pid_scores = query_doc_scores[:, offset:min(num_embs, offset + num_embs_pid - 1)] - push!(scores, sum(maximum(pid_scores, dims = 2))) - - offset += num_embs_pid - end - @assert offset==num_embs + 1 "offset: $(offset), num_embs + 1: $(num_embs + 1)" - - scores + maxsim(Q, D_packed, pids, doclens) end diff --git a/src/searching.jl b/src/searching.jl index 3f1c69e..5b22edb 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -141,13 +141,16 @@ function search(searcher::Searcher, query::String, k::Int) error("Only one query is supported at the moment!") end @assert size(Q)[3]==1 "size(Q): $(size(Q))" - @assert isequal(size(Q)[2], searcher.config.query_maxlen) "size(Q): $(size(Q)), query_maxlen: $(searcher.config.query_maxlen)" # Q: (128, 32, 1) + @assert isequal(size(Q)[2], searcher.config.query_maxlen) + "size(Q): $(size(Q)), query_maxlen: $(searcher.config.query_maxlen)" # Q: (128, 32, 1) Q = reshape(Q, size(Q)[1:end .!= end]...) # squeeze out the last dimension @assert isequal(length(size(Q)), 2) "size(Q): $(size(Q))" - pids = retrieve(searcher.ivf, searcher.ivf_lengths, searcher.centroids, searcher.emb2pid, searcher.config.nprobe, Q) - scores = score_pids(searcher.config, searcher.centroids, searcher.bucket_weights, searcher.doclens, searcher.codes, searcher.residuals, Q, pids) + pids = retrieve(searcher.ivf, searcher.ivf_lengths, searcher.centroids, + searcher.emb2pid, searcher.config.nprobe, Q) + scores = score_pids(searcher.config, searcher.centroids, searcher.bucket_weights, + searcher.doclens, searcher.codes, searcher.residuals, Q, pids) indices = sortperm(scores, rev = true) pids, scores = pids[indices], scores[indices] From 60a4e3025272d773b0e93944a76be18d0bdc5d82 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 14:27:42 +0530 Subject: [PATCH 57/59] Updating the examples. --- examples/indexing.jl | 6 +++--- examples/searching.jl | 22 +++++++--------------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/examples/indexing.jl b/examples/indexing.jl index f1d098f..6c609e7 100644 --- a/examples/indexing.jl +++ b/examples/indexing.jl @@ -8,10 +8,10 @@ Random.seed!(0) config = ColBERTConfig( use_gpu = true, - collection = "./cityofaustin", + collection = "./short_collection", doc_maxlen = 300, - index_path = "./cityofaustin_index/", - chunksize = 500 + index_path = "./short_collection_index/", + chunksize = 3 ) indexer = Indexer(config) diff --git a/examples/searching.jl b/examples/searching.jl index bce7612..1bfecad 100644 --- a/examples/searching.jl +++ b/examples/searching.jl @@ -1,26 +1,18 @@ using ColBERT using CUDA -# create the config -dataroot = "downloads/lotte" -dataset = "lifestyle" -datasplit = "dev" -path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv") - -nbits = 2 # encode each dimension with 2 bits - -index_root = "experiments/notebook/indexes" -index_name = "short_$(dataset).$(datasplit).$(nbits)bits" -index_path = joinpath(index_root, index_name) - # build the searcher +index_path = "short_collection_index" searcher = Searcher(index_path) +# load the collection +collection = readlines(searcher.config.collection) + # search for a query query = "what are white spots on raspberries?" pids, scores = search(searcher, query, 2) -print(searcher.config.resource_settings.collection.data[pids]) +print(collection[pids]) query = "are rabbits easy to housebreak?" -pids, scores = search(searcher, query, 9) -print(searcher.config.resource_settings.collection.data[pids]) +pids, scores = search(searcher, query, 1) +print(collection[pids]) From 60e0d81c08c493ba517918ea22bdef9545dc1dd2 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 14:28:54 +0530 Subject: [PATCH 58/59] Removing strided tensors for now. --- src/search/strided_tensor.jl | 108 ----------------------------------- 1 file changed, 108 deletions(-) delete mode 100644 src/search/strided_tensor.jl diff --git a/src/search/strided_tensor.jl b/src/search/strided_tensor.jl deleted file mode 100644 index d8555fc..0000000 --- a/src/search/strided_tensor.jl +++ /dev/null @@ -1,108 +0,0 @@ -""" - StridedTensor(packed_tensor::Vector{Int}, lengths::Vector{Int}) - -Type to perform `ivf` operations efficiently. - -# Arguments - - - `packed_tensor`: The `ivf`, i.e the centroid to embedding map build during indexing. It is assumed that this map is stored as a `Vector`, wherein the embedding IDs are stored consecutively for each centroid ID. - - `lengths`: The total number of embeddings for a centroid ID, for each centroid. - -# Returns - -A [`StridedTensor`](@ref), which computes and stores all relevant data to lookup the `ivf` efficiently. - -# Examples - -```julia-repl -julia> using JLD2; - -julia> ivf_path = joinpath(index_path, "ivf.jld2"); - -julia> ivf_dict = load(ivf_path); - -julia> ivf, ivf_lengths = ivf_dict["ivf"], ivf_dict["ivf_lengths"]; - -julia> ivf = StridedTensor(ivf, ivf_lengths) - -``` -""" -struct StridedTensor - tensor::Vector{Int} - lengths::Vector{Int} - strides::Vector{Int} - offsets::Vector{Int} - views::Dict{Int, Array{Int}} -end - -function StridedTensor(packed_tensor::Vector{Int}, lengths::Vector{Int}) - tensor = packed_tensor - strides = cat( - _select_strides(lengths, [0.5, 0.75, 0.9, 0.95]), [max(lengths...)], dims = 1) - strides = Int.(trunc.(strides)) - offsets = cat([0], cumsum(lengths), dims = 1) - - if offsets[length(offsets) - 1] + max(lengths...) > length(tensor) - padding = zeros(Int, max(lengths...)) - tensor = cat(tensor, padding, dims = 1) - end - - views = Dict(stride => _create_view(tensor, stride) for stride in strides) - - StridedTensor( - tensor, - lengths, - strides, - offsets, - views - ) -end - -""" - _select_strides(lengths::Vector{Int}, quantiles::Vector{Float64}) - -Get candidate strides computed using `quantiles` from `lengths`. - -# Arguments - - - `lengths`: A vector of `ivf` lengths to select candidate stride lengths from. - - `quantiles`: The quantiles to be computed. - -# Returns - -A `Vector` containing the corresponding quantiles. -""" -function _select_strides(lengths::Vector{Int}, quantiles::Vector{Float64}) - if length(lengths) < 5000 - quantile(lengths, quantiles) - else - sample = rand(1:length(lengths), 2000) - quantile(lengths[sample], quantiles) - end -end - -""" - _create_view(tensor::Vector{Int}, stride::Int) - -Create a view into `tensor`, where each column of the view corresponds to a slice of size `stride` in the original tensor. - -# Arguments - - - `tensor`: The input `Vector` to create views of. - - `stride`: The number of elements to include in each slice of the output tensor. - -# Returns - -An array of shape `(stride, outdim)`, where each column is a slice of size `stride` from the original tensor, and `outdim = length(tensor) - stride + 1`. -""" -function _create_view(tensor::Vector{Int}, stride::Int) - outdim = length(tensor) - stride + 1 - size = (stride, outdim) - tensor_view = zeros(Int, size) - - for column in 1:outdim - tensor_view[:, column] = copy(tensor[column:(column + stride - 1)]) - end - - tensor_view -end From cb99f55ffb226ff9182d5e7ea503fbf1732a2b94 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Fri, 16 Aug 2024 14:35:16 +0530 Subject: [PATCH 59/59] Running the Julia formatter. --- examples/indexing.jl | 2 +- examples/load_colbert.jl | 216 +++++++++++++++++++++++++++++++++++++++ src/infra/config.jl | 2 +- src/loaders.jl | 18 ++-- src/search/ranking.jl | 8 +- src/searching.jl | 2 +- 6 files changed, 233 insertions(+), 15 deletions(-) create mode 100644 examples/load_colbert.jl diff --git a/examples/indexing.jl b/examples/indexing.jl index 6c609e7..53c6e1b 100644 --- a/examples/indexing.jl +++ b/examples/indexing.jl @@ -11,7 +11,7 @@ config = ColBERTConfig( collection = "./short_collection", doc_maxlen = 300, index_path = "./short_collection_index/", - chunksize = 3 + chunksize = 3 ) indexer = Indexer(config) diff --git a/examples/load_colbert.jl b/examples/load_colbert.jl new file mode 100644 index 0000000..c9cffc8 --- /dev/null +++ b/examples/load_colbert.jl @@ -0,0 +1,216 @@ +using Transformers +using JSON3 +using Transformers.HuggingFace +const HF = Transformers.HuggingFace + +""" + _load_tokenizer_config(path_config) + +Load tokenizer config locally. +""" +function _load_tokenizer_config(path_config::AbstractString) + @assert isfile(path_config) "Tokenizer config file not found: $path_config" + return JSON3.read(read(path_config)) +end + +""" + extract_tokenizer_type(tkr_type::AbstractString) + +Extract tokenizer type from config. +""" +function extract_tokenizer_type(tkr_type::AbstractString) + m = match(r"(\S+)Tokenizer(Fast)?", tkr_type) + isnothing(m) && + error("Unknown tokenizer: $tkr_type") + tkr_type = Symbol(lowercase(m.captures[1])) +end + +""" + _load_tokenizer(cfg::HF.HGFConfig; path_tokenizer_config::AbstractString, + path_special_tokens_map::AbstractString, path_tokenizer::AbstractString) + +Local tokenizer loader. +""" +function _load_tokenizer(cfg::HF.HGFConfig; + path_tokenizer_config::AbstractString, + path_special_tokens_map::AbstractString, path_tokenizer::AbstractString) + @assert isfile(path_tokenizer_config) "Tokenizer config file not found: $path_tokenizer_config" + @assert isfile(path_special_tokens_map) "Special tokens map file not found: $path_special_tokens_map" + @assert isfile(path_tokenizer) "Tokenizer file not found: $path_tokenizer" + ## load tokenizer config + tkr_cfg = _load_tokenizer_config(path_tokenizer_config) + tkr_type_sym = extract_tokenizer_type(tkr_cfg.tokenizer_class) + tkr_type = HF.tokenizer_type(tkr_type_sym) # eg, Val(:bert)() + ## load special tokens + special_tokens = HF.load_special_tokens_map(path_special_tokens_map) + ## load tokenizer + kwargs = HF.extract_fast_tkr_kwargs( + tkr_type, tkr_cfg, cfg, special_tokens) + tokenizer, vocab, process_config, decode, textprocess = HF.load_fast_tokenizer( + tkr_type, path_tokenizer, cfg) + for (k, v) in process_config + kwargs[k] = v + end + ## construct tokenizer and mutate the decode +textprocess pipelines + tkr = HF.encoder_construct( + tkr_type, tokenizer, vocab; kwargs...) + tkr = HF.setproperties!!( + tkr, (; decode, textprocess)) + return tkr +end + +""" + _load_model(cfg::HF.HGFConfig; path_model::AbstractString, + trainmode::Bool = false, lazy::Bool = false, mmap::Bool = true) + +Local model loader. +""" +function _load_model(cfg::HF.HGFConfig; + path_model::AbstractString, + trainmode::Bool = false, lazy::Bool = false, mmap::Bool = true) + @assert isfile(path_model) "Model file not found: $path_model" + @assert endswith(path_model, ".bin") "Model file must end with .bin (type torch `pickle`): $path_model" + ## Assume fixed + task = :model + + ## Load state dict + # We know we have pytorch_model.bin -> so format is :pickle and it's a single file + # status = HF.singlefilename(HF.WeightStatus{:pickle}) + status = HF.HasSingleFile{:pickle}(path_model) + state_dict = HF.load_state_dict_from( + status; lazy, mmap) + + ## + model_type = HF.get_model_type( + HF.getconfigname(cfg), task) + basekey = String(HF.basemodelkey(model_type)) + if HF.isbasemodel(model_type) + prefix = HF.haskeystartswith( + state_dict, basekey) ? basekey : "" + else + prefix = "" + if !HF.haskeystartswith( + state_dict, basekey) + new_state_dict = OrderedDict{ + Any, Any}() + for (key, val) in state_dict + new_state_dict[joinname(basekey, key)] = val + end + state_dict = new_state_dict + end + end + model = load_model( + model_type, cfg, state_dict, prefix) + trainmode || (model = Layers.testmode(model)) + return model +end + +""" + load_hgf_pretrained_local(dir_spec::AbstractString; + path_config::Union{Nothing, AbstractString} = nothing, + path_tokenizer_config::Union{Nothing, AbstractString} = nothing, + path_special_tokens_map::Union{Nothing, AbstractString} = nothing, + path_tokenizer::Union{Nothing, AbstractString} = nothing, + path_model::Union{Nothing, AbstractString} = nothing, + kwargs... + +) + +Local model loader. Honors the `load_hgf_pretrained` interface, where you can request +specific files to be loaded, eg, `my/dir/to/model:tokenizer` or `my/dir/to/model:config`. + +# Arguments + + - `dir_spec::AbstractString`: Directory specification (item specific after the colon is optional), eg, `my/dir/to/model` or `my/dir/to/model:tokenizer`. + - `path_config::Union{Nothing, AbstractString}`: Path to config file. + - `path_tokenizer_config::Union{Nothing, AbstractString}`: Path to tokenizer config file. + - `path_special_tokens_map::Union{Nothing, AbstractString}`: Path to special tokens map file. + - `path_tokenizer::Union{Nothing, AbstractString}`: Path to tokenizer file. + - `path_model::Union{Nothing, AbstractString}`: Path to model file. + - `kwargs...`: Additional keyword arguments for `_load_model` function like `mmap`, `lazy`, `trainmode`. +""" +function load_hgf_pretrained_local( + dir_spec::AbstractString; + path_config::Union{ + Nothing, AbstractString} = nothing, + path_tokenizer_config::Union{ + Nothing, AbstractString} = nothing, + path_special_tokens_map::Union{ + Nothing, AbstractString} = nothing, + path_tokenizer::Union{ + Nothing, AbstractString} = nothing, + path_model::Union{ + Nothing, AbstractString} = nothing, + kwargs... +) + + ## Extract if item was provided + name_item = rsplit(dir_spec, ':'; limit = 2) + all = length(name_item) == 1 + dir, item = if all + dir_spec, "model" + else + Iterators.map(String, name_item) + end + item = lowercase(item) + ## Set paths + @assert isdir(dir) "Local directory not found: $dir" + if isnothing(path_config) + path_config = joinpath(dir, "config.json") + end + if isnothing(path_tokenizer_config) + path_tokenizer_config = joinpath( + dir, "tokenizer_config.json") + end + if isnothing(path_special_tokens_map) + path_special_tokens_map = joinpath( + dir, "special_tokens_map.json") + end + if isnothing(path_tokenizer) + path_tokenizer = joinpath( + dir, "tokenizer.json") + end + if isnothing(path_model) + path_model = joinpath( + dir, "pytorch_model.bin") + end + ## Check if they exist + @assert isfile(path_config) "Config file not found: $path_config" + @assert isfile(path_tokenizer_config) "Tokenizer config file not found: $path_tokenizer_config" + @assert isfile(path_special_tokens_map) "Special tokens map file not found: $path_special_tokens_map" + @assert isfile(path_tokenizer) "Tokenizer file not found: $path_tokenizer" + @assert isfile(path_model) "Model file not found: $path_model" + + ## load config + cfg = HF._load_config(path_config) + item == "config" && return cfg + + ## load tokenizer + if item == "tokenizer" || all + tkr = _load_tokenizer( + cfg; path_tokenizer_config, + path_special_tokens_map, + path_tokenizer) + end + item == "tokenizer" && return tkr + + ## load model + model = _load_model( + cfg; path_model, kwargs...) + + if all + return tkr, model + else + return model + end +end + +## Example +using Transformers.TextEncoders + +# My files can be found in this directory +dir = "colbert-ir" +textenc, model = load_hgf_pretrained_local(dir) + +encoded = encode(textenc, "test it") +output = model(encoded) diff --git a/src/infra/config.jl b/src/infra/config.jl index 91b7140..332c810 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -79,7 +79,7 @@ Base.@kwdef struct ColBERTConfig # indexing settings index_path::String = "" - index_bsize::Int = 32 + index_bsize::Int = 32 chunksize::Union{Missing, Int} = missing passages_batch_size::Int = 300 nbits::Int = 2 diff --git a/src/loaders.jl b/src/loaders.jl index d6777ed..8a2aaed 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -12,8 +12,8 @@ function load_codec(index_path::String) avg_residual_path = joinpath(index_path, "avg_residual.jld2") bucket_cutoffs_path = joinpath(index_path, "bucket_cutoffs.jld2") bucket_weights_path = joinpath(index_path, "bucket_weights.jld2") - @info "Loading codec from $(centroids_path), $(avg_residual_path), "* - "$(bucket_cutoffs_path) and $(bucket_weights_path)." + @info "Loading codec from $(centroids_path), $(avg_residual_path), " * + "$(bucket_cutoffs_path) and $(bucket_weights_path)." centroids = JLD2.load_object(centroids_path) avg_residual = JLD2.load_object(avg_residual_path) @@ -74,7 +74,7 @@ function load_doclens(index_path::String) append!(doclens, chunk_doclens) end @assert isequal(sum(doclens), plan_metadata["num_embeddings"]) - "sum(doclens): $(sum(doclens)), num_embeddings: $(plan_metadata["num_embeddings"])" + "sum(doclens): $(sum(doclens)), num_embeddings: $(plan_metadata["num_embeddings"])" doclens end @@ -84,11 +84,13 @@ function load_compressed_embs(index_path::String) @assert (config.dim * config.nbits) % 8==0 "(dim, nbits): $((config.dim, config.nbits))" codes = zeros(UInt32, plan_metadata["num_embeddings"]) - residuals = zeros(UInt8, Int((config.dim / 8) * config.nbits), plan_metadata["num_embeddings"]) + residuals = zeros( + UInt8, Int((config.dim / 8) * config.nbits), plan_metadata["num_embeddings"]) codes_offset = 1 for chunk_idx in 1:plan_metadata["num_chunks"] chunk_codes = JLD2.load_object(joinpath(index_path, "$(chunk_idx).codes.jld2")) - chunk_residuals = JLD2.load_object(joinpath(index_path, "$(chunk_idx).residuals.jld2")) + chunk_residuals = JLD2.load_object(joinpath( + index_path, "$(chunk_idx).residuals.jld2")) codes_endpos = codes_offset + length(chunk_codes) - 1 codes[codes_offset:codes_endpos] = chunk_codes @@ -96,11 +98,11 @@ function load_compressed_embs(index_path::String) codes_offset = codes_offset + length(chunk_codes) end - @assert length(codes) == plan_metadata["num_embeddings"] - "length(codes): $(length(codes)), num_embeddings: $(plan_metadata["num_embeddings"])" + @assert length(codes) == plan_metadata["num_embeddings"] + "length(codes): $(length(codes)), num_embeddings: $(plan_metadata["num_embeddings"])" @assert ndims(residuals) == 2 @assert size(residuals)[2] == plan_metadata["num_embeddings"] - "size(residuals): $(size(residuals)), num_embeddings: $(plan_metadata["num_embeddings"])" + "size(residuals): $(size(residuals)), num_embeddings: $(plan_metadata["num_embeddings"])" @assert codes isa Vector{UInt32} @assert residuals isa Matrix{UInt8} diff --git a/src/search/ranking.jl b/src/search/ranking.jl index bb4ff0c..1ec1bf4 100644 --- a/src/search/ranking.jl +++ b/src/search/ranking.jl @@ -60,9 +60,9 @@ function maxsim( offset_end = min(num_embeddings, offset + num_embs_pids - 1) pid_scores = query_doc_scores[:, offset:offset_end] scores[idx] = sum(maximum(pid_scores, dims = 2)) - offset += num_embs_pids + offset += num_embs_pids end - @assert offset == num_embeddings + 1 "offset: $(offset), num_embs + 1: $(num_embeddings + 1)" + @assert offset==num_embeddings + 1 "offset: $(offset), num_embs + 1: $(num_embeddings + 1)" scores end @@ -78,9 +78,9 @@ function score_pids(config::ColBERTConfig, centroids::Matrix{Float32}, config.dim, config.nbits, centroids, bucket_weights, codes_packed, residuals_packed) @assert ndims(D_packed)==2 "ndims(D_packed): $(ndims(D_packed))" @assert size(D_packed)[1] == config.dim - "size(D_packed): $(size(D_packed)), config.dim: $(config.dim)" + "size(D_packed): $(size(D_packed)), config.dim: $(config.dim)" @assert size(D_packed)[2] == sum(doclens[pids]) - "size(D_packed): $(size(D_packed)), num_embs: $(sum(doclens[pids]))" + "size(D_packed): $(size(D_packed)), num_embs: $(sum(doclens[pids]))" @assert D_packed isa AbstractMatrix{Float32} "$(typeof(D_packed))" maxsim(Q, D_packed, pids, doclens) end diff --git a/src/searching.jl b/src/searching.jl index 5b22edb..a0f7e57 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -142,7 +142,7 @@ function search(searcher::Searcher, query::String, k::Int) end @assert size(Q)[3]==1 "size(Q): $(size(Q))" @assert isequal(size(Q)[2], searcher.config.query_maxlen) - "size(Q): $(size(Q)), query_maxlen: $(searcher.config.query_maxlen)" # Q: (128, 32, 1) + "size(Q): $(size(Q)), query_maxlen: $(searcher.config.query_maxlen)" # Q: (128, 32, 1) Q = reshape(Q, size(Q)[1:end .!= end]...) # squeeze out the last dimension @assert isequal(length(size(Q)), 2) "size(Q): $(size(Q))"