diff --git a/examples/indexing.jl b/examples/indexing.jl index 8266e21..515f021 100644 --- a/examples/indexing.jl +++ b/examples/indexing.jl @@ -48,3 +48,4 @@ config = ColBERTConfig( # create and run the indexer indexer = Indexer(config) index(indexer) +ColBERT.save(config) diff --git a/examples/searching.jl b/examples/searching.jl new file mode 100644 index 0000000..a299e3c --- /dev/null +++ b/examples/searching.jl @@ -0,0 +1,25 @@ +using ColBERT + +# create the config +dataroot = "downloads/lotte" +dataset = "lifestyle" +datasplit = "dev" +path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv") + +nbits = 2 # encode each dimension with 2 bits + +index_root = "experiments/notebook/indexes" +index_name = "short_$(dataset).$(datasplit).$(nbits)bits" +index_path = joinpath(index_root, index_name) + +# build the searcher +searcher = Searcher(index_path) + +# search for a query +query = "what are white spots on raspberries?" +pids, scores = search(searcher, query, 2) +print(searcher.config.resource_settings.collection.data[pids]) + +query = "are rabbits easy to housebreak?" +pids, scores = search(searcher, query, 9) +print(searcher.config.resource_settings.collection.data[pids]) diff --git a/src/ColBERT.jl b/src/ColBERT.jl index 0f027d9..a2181cc 100644 --- a/src/ColBERT.jl +++ b/src/ColBERT.jl @@ -41,4 +41,10 @@ include("indexing/index_saver.jl") include("indexing/collection_indexer.jl") export Indexer, CollectionIndexer, index +# searcher +include("search/strided_tensor.jl") +include("search/index_storage.jl") +include("searching.jl") +export Searcher, search + end diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index aa73b5a..0a1b38f 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -27,6 +27,22 @@ mutable struct ResidualCodec bucket_weights::Vector{Float64} end +""" + +# Examples + +```julia-repl +julia> codec = load_codec(index_path); +``` +""" +function load_codec(index_path::String) + config = load(joinpath(index_path, "config.jld2"), "config") + centroids = load(joinpath(index_path, "centroids.jld2"), "centroids") + avg_residual = load(joinpath(index_path, "avg_residual.jld2"), "avg_residual") + buckets = load(joinpath(index_path, "buckets.jld2")) + ResidualCodec(config, centroids, avg_residual, buckets["bucket_cutoffs"], buckets["bucket_weights"]) +end + """ compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64}) @@ -130,6 +146,64 @@ function compress(codec::ResidualCodec, embs::Matrix{Float64}) codes, residuals end +function decompress_residuals(codec::ResidualCodec, binary_residuals::Array{UInt8}) + dim = codec.config.doc_settings.dim + nbits = codec.config.indexing_settings.nbits + + @assert ndims(binary_residuals) == 2 + @assert size(binary_residuals)[1] == (dim / 8) * nbits + + # unpacking UInt8 into bits + unpacked_bits = BitVector() + for byte in vec(binary_residuals) + append!(unpacked_bits, [byte & (0x1< iszero(v) ? v : normalize(v), batch_embeddings, dims = 1) + push!(D, batch_embeddings) + + batch_offset += bsize + end + + cat(D..., dims = 2) +end + """ load_codes(codec::ResidualCodec, chunk_idx::Int) @@ -150,3 +224,9 @@ function load_codes(codec::ResidualCodec, chunk_idx::Int) codes = JLD2.load(codes_path, "codes") codes end + +function load_residuals(codec::ResidualCodec, chunk_idx::Int) + residual_path = joinpath(codec.config.indexing_settings.index_path, "$(chunk_idx).residuals.jld2") + residuals = JLD2.load(residual_path, "residuals") + residuals +end diff --git a/src/infra/settings.jl b/src/infra/settings.jl index b2954b8..7c766f7 100644 --- a/src/infra/settings.jl +++ b/src/infra/settings.jl @@ -138,7 +138,6 @@ Base.@kwdef struct IndexingSettings end Base.@kwdef struct SearchSettings - ncells::Union{Nothing, Int} = nothing - centroid_score_threshold::Union{Nothing, Float64} = nothing - ndocs::Union{Nothing, Int} = nothing + nprobe::Int = 2 + ncandidates::Int = 8192 end diff --git a/src/search/index_storage.jl b/src/search/index_storage.jl new file mode 100644 index 0000000..a3ffce7 --- /dev/null +++ b/src/search/index_storage.jl @@ -0,0 +1,174 @@ +struct IndexScorer + metadata::Dict + codec::ResidualCodec + ivf::Vector{Int} + ivf_lengths::Vector{Int} + doclens::Vector{Int} + codes::Vector{Int} + residuals::Matrix{UInt8} + emb2pid::Vector{Int} +end + +""" + +# Examples + +```julia-repl +julia> IndexScorer(index_path) + +``` +""" +function IndexScorer(index_path::String) + @info "Loading the index from {index_path}." + + # loading the config from the index path + config = JLD2.load(joinpath(index_path, "config.jld2"))["config"] + + # the metadata + metadata_path = joinpath(index_path, "metadata.json") + metadata = JSON.parsefile(metadata_path) + + # loading the codec + codec = load_codec(index_path) + + # loading ivf into a StridedTensor + ivf_path = joinpath(index_path, "ivf.jld2") + ivf_dict = JLD2.load(ivf_path) + ivf, ivf_lengths = ivf_dict["ivf"], ivf_dict["ivf_lengths"] + # ivf = StridedTensor(ivf, ivf_lengths) + + # loading all doclens + doclens = Vector{Int}() + for chunk_idx in 1:metadata["num_chunks"] + doclens_file = joinpath(index_path, "doclens.$(chunk_idx).jld2") + chunk_doclens = JLD2.load(doclens_file, "doclens") + append!(doclens, chunk_doclens) + end + + # loading all embeddings + num_embeddings = metadata["num_embeddings"] + dim, nbits = config.doc_settings.dim, config.indexing_settings.nbits + @assert (dim * nbits) % 8 == 0 + codes = zeros(Int, num_embeddings) + residuals = zeros(UInt8, Int((dim / 8) * nbits), num_embeddings) + codes_offset = 1 + for chunk_idx in 1:metadata["num_chunks"] + chunk_codes = load_codes(codec, chunk_idx) + chunk_residuals = load_residuals(codec, chunk_idx) + + codes_endpos = codes_offset + length(chunk_codes) - 1 + codes[codes_offset:codes_endpos] = chunk_codes + residuals[:, codes_offset:codes_endpos] = chunk_residuals + + codes_offset = codes_offset + length(chunk_codes) + end + + # the emb2pid mapping + @info "Building the emb2pid mapping." + @assert isequal(sum(doclens), metadata["num_embeddings"]) + emb2pid = zeros(Int, metadata["num_embeddings"]) + + offset_doclens = 1 + for (pid, dlength) in enumerate(doclens) + emb2pid[offset_doclens:offset_doclens + dlength - 1] .= pid + offset_doclens += dlength + end + + IndexScorer( + metadata, + codec, + ivf, + ivf_lengths, + doclens, + codes, + residuals, + emb2pid, + ) +end + +""" + +Return a candidate set of `pids` for the query matrix `Q`. This is done as follows: the nearest `nprobe` centroids for each query embedding are found. This list is then flattened and the unique set of these centroids is built. Using the `ivf`, the list of all unique embedding IDs contained in these centroids is computed. Finally, these embedding IDs are converted to `pids` using `emb2pid`. This list of `pids` is the final candidate set. +""" +function retrieve(ranker::IndexScorer, config::ColBERTConfig, Q::Array{<:AbstractFloat}) + @assert isequal(size(Q)[2], config.query_settings.query_maxlen) # Q: (128, 32, 1) + + Q = reshape(Q, size(Q)[1:end .!= end]...) # squeeze out the last dimension + @assert isequal(length(size(Q)), 2) + + # score of each query embedding with each centroid and take top nprobe centroids + cells = transpose(Q) * ranker.codec.centroids + cells = mapslices(row -> partialsortperm(row, 1:config.search_settings.nprobe, rev=true), cells, dims = 2) # take top nprobe centroids for each query + centroid_ids = sort(unique(vec(cells))) + + # get all embedding IDs contained in centroid_ids using ivf + centroid_ivf_offsets = cat([1], 1 .+ cumsum(ranker.ivf_lengths)[1:end .!= end], dims = 1) + eids = Vector{Int}() + for centroid_id in centroid_ids + offset = centroid_ivf_offsets[centroid_id] + length = ranker.ivf_lengths[centroid_id] + append!(eids, ranker.ivf[offset:offset + length - 1]) + end + @assert isequal(length(eids), sum(ranker.ivf_lengths[centroid_ids])) + eids = sort(unique(eids)) + + # get pids from the emb2pid mapping + pids = sort(unique(ranker.emb2pid[eids])) + pids +end + +""" +- Get the decompressed embedding matrix for all embeddings in `pids`. Use `doclens` for this. +""" +function score_pids(ranker::IndexScorer, config::ColBERTConfig, Q::Array{<:AbstractFloat}, pids::Vector{Int}) + # get codes and residuals for all embeddings across all pids + num_embs = sum(ranker.doclens[pids]) + codes_packed = zeros(Int, num_embs) + residuals_packed = zeros(UInt8, size(ranker.residuals)[1], num_embs) + pid_offsets = cat([1], 1 .+ cumsum(ranker.doclens)[1:end .!= end], dims=1) + + offset = 1 + for pid in pids + pid_offset = pid_offsets[pid] + num_embs_pid = ranker.doclens[pid] + codes_packed[offset: offset + num_embs_pid - 1] = ranker.codes[pid_offset: pid_offset + num_embs_pid - 1] + residuals_packed[:, offset: offset + num_embs_pid - 1] = ranker.residuals[:, pid_offset: pid_offset + num_embs_pid - 1] + offset += num_embs_pid + end + @assert offset == num_embs + 1 + + # decompress these codes and residuals to get the original embeddings + D_packed = decompress(ranker.codec, codes_packed, residuals_packed) + @assert ndims(D_packed) == 2 + @assert size(D_packed)[1] == config.doc_settings.dim + @assert size(D_packed)[2] == num_embs + + # get the max-sim scores + if size(Q)[3] > 1 + error("Only one query is supported at the moment!") + end + @assert size(Q)[3] == 1 + Q = reshape(Q, size(Q)[1:2]...) + + scores = Vector{Float64}() + query_doc_scores = transpose(Q) * D_packed # (num_query_tokens, num_embeddings) + offset = 1 + for pid in pids + num_embs_pid = ranker.doclens[pid] + pid_scores = query_doc_scores[:, offset:min(num_embs, offset + num_embs_pid - 1)] + push!(scores, sum(maximum(pid_scores, dims = 2))) + + offset += num_embs_pid + end + @assert offset == num_embs + 1 + + scores +end + +function rank(ranker::IndexScorer, config::ColBERTConfig, Q::Array{<:AbstractFloat}) + pids = retrieve(ranker, config, Q) + scores = score_pids(ranker, config, Q, pids) + indices = sortperm(scores, rev=true) + + pids[indices], scores[indices] +end diff --git a/src/search/strided_tensor.jl b/src/search/strided_tensor.jl new file mode 100644 index 0000000..faecde3 --- /dev/null +++ b/src/search/strided_tensor.jl @@ -0,0 +1,107 @@ +""" + StridedTensor(packed_tensor::Vector{Int}, lengths::Vector{Int}) + +Type to perform `ivf` operations efficiently. + +# Arguments + +- `packed_tensor`: The `ivf`, i.e the centroid to embedding map build during indexing. It is assumed that this map is stored as a `Vector`, wherein the embedding IDs are stored consecutively for each centroid ID. +- `lengths`: The total number of embeddings for a centroid ID, for each centroid. + +# Returns + +A [`StridedTensor`](@ref), which computes and stores all relevant data to lookup the `ivf` efficiently. + +# Examples + +```julia-repl + +julia> using JLD2; + +julia> ivf_path = joinpath(index_path, "ivf.jld2"); + +julia> ivf_dict = load(ivf_path); + +julia> ivf, ivf_lengths = ivf_dict["ivf"], ivf_dict["ivf_lengths"]; + +julia> ivf = StridedTensor(ivf, ivf_lengths) +``` +""" +struct StridedTensor + tensor::Vector{Int} + lengths::Vector{Int} + strides::Vector{Int} + offsets::Vector{Int} + views::Dict{Int, Array{Int}} +end + +function StridedTensor(packed_tensor::Vector{Int}, lengths::Vector{Int}) + tensor = packed_tensor + strides = cat(_select_strides(lengths, [.5, .75, .9, .95]), [max(lengths...)], dims = 1) + strides = Int.(trunc.(strides)) + offsets = cat([0], cumsum(lengths), dims = 1) + + if offsets[length(offsets) - 1] + max(lengths...) > length(tensor) + padding = zeros(Int, max(lengths...)) + tensor = cat(tensor, padding, dims = 1) + end + + views = Dict(stride => _create_view(tensor, stride) for stride in strides) + + StridedTensor( + tensor, + lengths, + strides, + offsets, + views + ) +end + +""" + _select_strides(lengths::Vector{Int}, quantiles::Vector{Float64}) + +Get candidate strides computed using `quantiles` from `lengths`. + +# Arguments + +- `lengths`: A vector of `ivf` lengths to select candidate stride lengths from. +- `quantiles`: The quantiles to be computed. + +# Returns + +A `Vector` containing the corresponding quantiles. +""" +function _select_strides(lengths::Vector{Int}, quantiles::Vector{Float64}) + if length(lengths) < 5000 + quantile(lengths, quantiles) + else + sample = rand(1:length(lengths), 2000) + quantile(lengths[sample], quantiles) + end +end + +""" + _create_view(tensor::Vector{Int}, stride::Int) + +Create a view into `tensor`, where each column of the view corresponds to a slice of size `stride` in the original tensor. + +# Arguments + +- `tensor`: The input `Vector` to create views of. +- `stride`: The number of elements to include in each slice of the output tensor. + +# Returns + +An array of shape `(stride, outdim)`, where each column is a slice of size `stride` from the original tensor, and `outdim = length(tensor) - stride + 1`. +""" +function _create_view(tensor::Vector{Int}, stride::Int) + outdim = length(tensor) - stride + 1 + size = (stride, outdim) + tensor_view = zeros(Int, size) + + for column in 1:outdim + tensor_view[:, column] = copy(tensor[column:column + stride - 1]) + end + + tensor_view +end diff --git a/src/searching.jl b/src/searching.jl new file mode 100644 index 0000000..40a3ef3 --- /dev/null +++ b/src/searching.jl @@ -0,0 +1,91 @@ +using .ColBERT: Checkpoint, ColBERTConfig, Collection, IndexScorer + +struct Searcher + config::ColBERTConfig + checkpoint::Checkpoint + ranker::IndexScorer +end + +function Searcher(index_path::String) + if !isdir(index_path) + error("Index at $(index_path) does not exist! Please build the index first and try again.") + end + + # loading the config from the path + config = JLD2.load(joinpath(index_path, "config.jld2"))["config"] + + # loading the model and saving it to prevent multiple loads + @info "Loading ColBERT layers from HuggingFace." + base_colbert = BaseColBERT(config.resource_settings.checkpoint, config) + checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), QueryTokenizer(base_colbert.tokenizer, config), config) + + Searcher(config, checkPoint, IndexScorer(index_path)) +end + +""" + encode_query(searcher::Searcher, query::String) + +Encode a search query to a matrix of embeddings using the provided `searcher`. The encoded query can then be used to search the collection. + +# Arguments + +- `searcher`: A Searcher object that contains information about the collection and the index. +- `query`: The search query to encode. + +# Returns + +An array containing the embeddings for each token in the query. Also see [queryFromText](@ref) to see the size of the array. + +# Examples + +Here's an example using the config given in docs for [`ColBERTConfig`](@ref). + +```julia-repl +julia> searcher = Searcher(config); + +julia> encode_query(searcher, "what are white spots on raspberries?") +128×32×1 Array{Float32, 3}: +[:, :, 1] = + 0.0158567 0.169676 0.092745 0.0798617 … 0.115938 0.112977 0.107919 + 0.220185 0.0304873 0.165348 0.150315 0.0168762 0.0178042 0.0200357 + -0.00790007 -0.0192251 -0.0852364 -0.0799609 -0.0777439 -0.0776733 -0.0830504 + -0.109909 -0.170906 -0.0138702 -0.0409767 -0.126037 -0.126829 -0.13149 + -0.0231786 0.0532214 0.0607473 0.0279048 0.117017 0.114073 0.108536 + 0.0620549 0.0465075 0.0821693 0.0606439 … 0.0150612 0.0133353 0.0126583 + -0.0290509 0.143255 0.0306142 0.042658 -0.164401 -0.161857 -0.160327 + 0.0921477 0.0588331 0.250449 0.234636 0.0664076 0.0659837 0.0711357 + 0.0279402 -0.0278357 0.144855 0.147958 0.154552 0.155525 0.163634 + -0.0768143 -0.00587305 0.00543038 0.00443374 -0.11757 -0.112495 -0.11112 + -0.0184338 0.00668557 -0.191863 -0.161345 … -0.107664 -0.107267 -0.114564 + 0.0112104 0.0214651 -0.0923963 -0.0823051 0.106261 0.105065 0.10409 + ⋮ ⋱ ⋮ + -0.0617142 -0.0573989 -0.0973785 -0.0805046 0.107432 0.108591 0.109501 + -0.0859686 0.0623054 0.0974813 0.126841 0.0182795 0.0230549 0.031103 + 0.0392043 0.0162653 0.0926306 0.104053 0.0491495 0.0484318 0.0438132 + -0.0340363 -0.0278066 -0.0181035 -0.0282369 … -0.0617945 -0.0631367 -0.0675882 + 0.013123 0.0565132 -0.0349061 -0.0464192 0.0724731 0.0780166 0.074623 + -0.117425 0.162483 0.11039 0.136364 -0.00538225 -0.00685449 -0.0019436 + -0.0401158 -0.0045094 0.0539569 0.0689953 -0.00518063 -0.00600252 -0.00771469 + 0.0893983 0.0695061 -0.0499409 -0.035411 0.0960932 0.0961893 0.103431 + -0.116265 -0.106331 -0.179832 -0.149728 … -0.0197172 -0.022061 -0.018135 + -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0699095 -0.0684749 -0.0662904 + 0.100019 -0.0618588 0.106134 0.0989047 -0.0556761 -0.0556784 -0.059571 + +``` +""" +function encode_query(searcher::Searcher, query::String) + queries = [query] + bsize = 128 + Q = queryFromText(searcher.checkpoint, queries, bsize) + Q +end + +function search(searcher::Searcher, query::String, k::Int) + dense_search(searcher, encode_query(searcher, query), k) +end + +function dense_search(searcher::Searcher, Q::Array{<:AbstractFloat}, k::Int) + pids, scores = rank(searcher.ranker, searcher.config, Q) + + pids[1:k], scores[1:k] +end