Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Searcher component. #14

Merged
merged 29 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a5c32c2
Setting up initial files and types for the searcher.
codetalker7 Jul 14, 2024
8319e15
Creating an example file for the searching component.
codetalker7 Jul 14, 2024
cd7b98d
Changing the default search settings; not doing PLAID optimizations
codetalker7 Jul 14, 2024
e2e3556
Adding a constructor for the searcher; we load and save the HF model to
codetalker7 Jul 14, 2024
55ac294
Adding a skeleton for the `search` method for a `Searcher`; will
codetalker7 Jul 14, 2024
cdc9e10
Fixing the constructor of `Searcher` to include the `QueryTokenizer`.
codetalker7 Jul 24, 2024
27ac472
Implementing the `encode_query` function to encode a query.
codetalker7 Jul 24, 2024
1c84c5a
Adding docstring for `encode_query`.
codetalker7 Jul 24, 2024
fc40fb9
Adding the `StridedTensor` type and it's constructor for efficient ivf
codetalker7 Jul 25, 2024
12f25ef
Adding docstring for `_select_strides`.
codetalker7 Jul 25, 2024
818b3dc
Adding a docstring for `_create_view`.
codetalker7 Jul 25, 2024
8445e30
Adding a method of `load_codec` to load a codec from an indexing path.
codetalker7 Jul 25, 2024
d8c03e6
Adding function to load residuals from a particular chunk idx.
codetalker7 Jul 26, 2024
5d2cac8
Adding more fields to the `IndexScorer`, and also adding a new
codetalker7 Jul 26, 2024
e2ad470
Loading the config inside the constructor of `Searcher` instead of
codetalker7 Jul 26, 2024
dac07c9
Not using strided tensors for now, and building an `emb2pid` mapping.
codetalker7 Jul 26, 2024
dbe5fa8
Some informative messages.
codetalker7 Jul 26, 2024
a8be3a9
Not striding the embeddings in the constructor of `IndexScorer`.
codetalker7 Jul 27, 2024
1abf201
Adding a `retrieve` function to get all the candidate pids for a given
codetalker7 Jul 27, 2024
70b6f12
For now changing the type to `AbstractFloat`; will focus on precision
codetalker7 Jul 27, 2024
7d6e1a4
Adding function to decompress residual embeddings.
codetalker7 Jul 28, 2024
5808d5c
Adding a function to decompress codes and residuals.
codetalker7 Jul 28, 2024
d96623a
Some helper comments.
codetalker7 Jul 28, 2024
b01c73f
Adding function `score_pids` to score the pids.
codetalker7 Jul 28, 2024
8281704
Completing the implementation of `rank`.
codetalker7 Jul 28, 2024
d770438
Exporing the `search` function.
codetalker7 Jul 28, 2024
aeb4003
Completing the implementation of `dense_search`.
codetalker7 Jul 28, 2024
9a73a82
Updating the search example.
codetalker7 Jul 28, 2024
7b5bd41
Saving the config in the indexing example.
codetalker7 Jul 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ config = ColBERTConfig(
# create and run the indexer
indexer = Indexer(config)
index(indexer)
ColBERT.save(config)
25 changes: 25 additions & 0 deletions examples/searching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using ColBERT

# create the config
dataroot = "downloads/lotte"
dataset = "lifestyle"
datasplit = "dev"
path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv")

nbits = 2 # encode each dimension with 2 bits

index_root = "experiments/notebook/indexes"
index_name = "short_$(dataset).$(datasplit).$(nbits)bits"
index_path = joinpath(index_root, index_name)

# build the searcher
searcher = Searcher(index_path)

# search for a query
query = "what are white spots on raspberries?"
pids, scores = search(searcher, query, 2)
print(searcher.config.resource_settings.collection.data[pids])

query = "are rabbits easy to housebreak?"
pids, scores = search(searcher, query, 9)
print(searcher.config.resource_settings.collection.data[pids])
6 changes: 6 additions & 0 deletions src/ColBERT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,10 @@ include("indexing/index_saver.jl")
include("indexing/collection_indexer.jl")
export Indexer, CollectionIndexer, index

# searcher
include("search/strided_tensor.jl")
include("search/index_storage.jl")
include("searching.jl")
export Searcher, search

end
80 changes: 80 additions & 0 deletions src/indexing/codecs/residual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ mutable struct ResidualCodec
bucket_weights::Vector{Float64}
end

"""

# Examples

```julia-repl
julia> codec = load_codec(index_path);
```
"""
function load_codec(index_path::String)
config = load(joinpath(index_path, "config.jld2"), "config")
centroids = load(joinpath(index_path, "centroids.jld2"), "centroids")
avg_residual = load(joinpath(index_path, "avg_residual.jld2"), "avg_residual")
buckets = load(joinpath(index_path, "buckets.jld2"))
ResidualCodec(config, centroids, avg_residual, buckets["bucket_cutoffs"], buckets["bucket_weights"])
end

"""
compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64})

Expand Down Expand Up @@ -130,6 +146,64 @@ function compress(codec::ResidualCodec, embs::Matrix{Float64})
codes, residuals
end

function decompress_residuals(codec::ResidualCodec, binary_residuals::Array{UInt8})
dim = codec.config.doc_settings.dim
nbits = codec.config.indexing_settings.nbits

@assert ndims(binary_residuals) == 2
@assert size(binary_residuals)[1] == (dim / 8) * nbits

# unpacking UInt8 into bits
unpacked_bits = BitVector()
for byte in vec(binary_residuals)
append!(unpacked_bits, [byte & (0x1<<n) != 0 for n in 0:7])
end

# reshaping into dims (nbits, dim, num_embeddings); inverse of what binarize does
unpacked_bits = reshape(unpacked_bits, nbits, dim, size(binary_residuals)[2])

# get decimal value for coordinate of the nbits-wide dimension; again, inverse of binarize
positionbits = fill(1, (nbits, 1, 1))
for i in 1:nbits
positionbits[i, :, :] .= 1 << (i - 1)
end

# multiply by 2^(i - 1) for the ith bit, and take sum to get the original bin back
unpacked_bits = unpacked_bits .* positionbits
unpacked_bits = sum(unpacked_bits, dims=1)
unpacked_bits = unpacked_bits .+ 1 # adding 1 to get correct bin indices

# reshaping to get rid of the nbits wide dimension
unpacked_bits = reshape(unpacked_bits, size(unpacked_bits)[2:end]...)
embeddings = codec.bucket_weights[unpacked_bits]
end

function decompress(codec::ResidualCodec, codes::Vector{Int}, residuals::Array{UInt8})
@assert ndims(codes) == 1
@assert ndims(residuals) == 2
@assert length(codes) == size(residuals)[2]

# decompress in batches
D = Vector{Array{<:AbstractFloat}}()
bsize = 1 << 15
batch_offset = 1
while batch_offset <= length(codes)
batch_codes = codes[batch_offset:min(batch_offset + bsize - 1, length(codes))]
batch_residuals = residuals[:, batch_offset:min(batch_offset + bsize - 1, length(codes))]

centroids_ = codec.centroids[:, batch_codes]
residuals_ = decompress_residuals(codec, batch_residuals)

batch_embeddings = centroids_ + residuals_
batch_embeddings = mapslices(v -> iszero(v) ? v : normalize(v), batch_embeddings, dims = 1)
push!(D, batch_embeddings)

batch_offset += bsize
end

cat(D..., dims = 2)
end

"""
load_codes(codec::ResidualCodec, chunk_idx::Int)

Expand All @@ -150,3 +224,9 @@ function load_codes(codec::ResidualCodec, chunk_idx::Int)
codes = JLD2.load(codes_path, "codes")
codes
end

function load_residuals(codec::ResidualCodec, chunk_idx::Int)
residual_path = joinpath(codec.config.indexing_settings.index_path, "$(chunk_idx).residuals.jld2")
residuals = JLD2.load(residual_path, "residuals")
residuals
end
5 changes: 2 additions & 3 deletions src/infra/settings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ Base.@kwdef struct IndexingSettings
end

Base.@kwdef struct SearchSettings
ncells::Union{Nothing, Int} = nothing
centroid_score_threshold::Union{Nothing, Float64} = nothing
ndocs::Union{Nothing, Int} = nothing
nprobe::Int = 2
ncandidates::Int = 8192
end
174 changes: 174 additions & 0 deletions src/search/index_storage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
struct IndexScorer
metadata::Dict
codec::ResidualCodec
ivf::Vector{Int}
ivf_lengths::Vector{Int}
doclens::Vector{Int}
codes::Vector{Int}
residuals::Matrix{UInt8}
emb2pid::Vector{Int}
end

"""

# Examples

```julia-repl
julia> IndexScorer(index_path)

```
"""
function IndexScorer(index_path::String)
@info "Loading the index from {index_path}."

# loading the config from the index path
config = JLD2.load(joinpath(index_path, "config.jld2"))["config"]

# the metadata
metadata_path = joinpath(index_path, "metadata.json")
metadata = JSON.parsefile(metadata_path)

# loading the codec
codec = load_codec(index_path)

# loading ivf into a StridedTensor
ivf_path = joinpath(index_path, "ivf.jld2")
ivf_dict = JLD2.load(ivf_path)
ivf, ivf_lengths = ivf_dict["ivf"], ivf_dict["ivf_lengths"]
# ivf = StridedTensor(ivf, ivf_lengths)

# loading all doclens
doclens = Vector{Int}()
for chunk_idx in 1:metadata["num_chunks"]
doclens_file = joinpath(index_path, "doclens.$(chunk_idx).jld2")
chunk_doclens = JLD2.load(doclens_file, "doclens")
append!(doclens, chunk_doclens)
end

# loading all embeddings
num_embeddings = metadata["num_embeddings"]
dim, nbits = config.doc_settings.dim, config.indexing_settings.nbits
@assert (dim * nbits) % 8 == 0
codes = zeros(Int, num_embeddings)
residuals = zeros(UInt8, Int((dim / 8) * nbits), num_embeddings)
codes_offset = 1
for chunk_idx in 1:metadata["num_chunks"]
chunk_codes = load_codes(codec, chunk_idx)
chunk_residuals = load_residuals(codec, chunk_idx)

codes_endpos = codes_offset + length(chunk_codes) - 1
codes[codes_offset:codes_endpos] = chunk_codes
residuals[:, codes_offset:codes_endpos] = chunk_residuals

codes_offset = codes_offset + length(chunk_codes)
end

# the emb2pid mapping
@info "Building the emb2pid mapping."
@assert isequal(sum(doclens), metadata["num_embeddings"])
emb2pid = zeros(Int, metadata["num_embeddings"])

offset_doclens = 1
for (pid, dlength) in enumerate(doclens)
emb2pid[offset_doclens:offset_doclens + dlength - 1] .= pid
offset_doclens += dlength
end

IndexScorer(
metadata,
codec,
ivf,
ivf_lengths,
doclens,
codes,
residuals,
emb2pid,
)
end

"""

Return a candidate set of `pids` for the query matrix `Q`. This is done as follows: the nearest `nprobe` centroids for each query embedding are found. This list is then flattened and the unique set of these centroids is built. Using the `ivf`, the list of all unique embedding IDs contained in these centroids is computed. Finally, these embedding IDs are converted to `pids` using `emb2pid`. This list of `pids` is the final candidate set.
"""
function retrieve(ranker::IndexScorer, config::ColBERTConfig, Q::Array{<:AbstractFloat})
@assert isequal(size(Q)[2], config.query_settings.query_maxlen) # Q: (128, 32, 1)

Q = reshape(Q, size(Q)[1:end .!= end]...) # squeeze out the last dimension
@assert isequal(length(size(Q)), 2)

# score of each query embedding with each centroid and take top nprobe centroids
cells = transpose(Q) * ranker.codec.centroids
cells = mapslices(row -> partialsortperm(row, 1:config.search_settings.nprobe, rev=true), cells, dims = 2) # take top nprobe centroids for each query
centroid_ids = sort(unique(vec(cells)))

# get all embedding IDs contained in centroid_ids using ivf
centroid_ivf_offsets = cat([1], 1 .+ cumsum(ranker.ivf_lengths)[1:end .!= end], dims = 1)
eids = Vector{Int}()
for centroid_id in centroid_ids
offset = centroid_ivf_offsets[centroid_id]
length = ranker.ivf_lengths[centroid_id]
append!(eids, ranker.ivf[offset:offset + length - 1])
end
@assert isequal(length(eids), sum(ranker.ivf_lengths[centroid_ids]))
eids = sort(unique(eids))

# get pids from the emb2pid mapping
pids = sort(unique(ranker.emb2pid[eids]))
pids
end

"""
- Get the decompressed embedding matrix for all embeddings in `pids`. Use `doclens` for this.
"""
function score_pids(ranker::IndexScorer, config::ColBERTConfig, Q::Array{<:AbstractFloat}, pids::Vector{Int})
# get codes and residuals for all embeddings across all pids
num_embs = sum(ranker.doclens[pids])
codes_packed = zeros(Int, num_embs)
residuals_packed = zeros(UInt8, size(ranker.residuals)[1], num_embs)
pid_offsets = cat([1], 1 .+ cumsum(ranker.doclens)[1:end .!= end], dims=1)

offset = 1
for pid in pids
pid_offset = pid_offsets[pid]
num_embs_pid = ranker.doclens[pid]
codes_packed[offset: offset + num_embs_pid - 1] = ranker.codes[pid_offset: pid_offset + num_embs_pid - 1]
residuals_packed[:, offset: offset + num_embs_pid - 1] = ranker.residuals[:, pid_offset: pid_offset + num_embs_pid - 1]
offset += num_embs_pid
end
@assert offset == num_embs + 1

# decompress these codes and residuals to get the original embeddings
D_packed = decompress(ranker.codec, codes_packed, residuals_packed)
@assert ndims(D_packed) == 2
@assert size(D_packed)[1] == config.doc_settings.dim
@assert size(D_packed)[2] == num_embs

# get the max-sim scores
if size(Q)[3] > 1
error("Only one query is supported at the moment!")
end
@assert size(Q)[3] == 1
Q = reshape(Q, size(Q)[1:2]...)

scores = Vector{Float64}()
query_doc_scores = transpose(Q) * D_packed # (num_query_tokens, num_embeddings)
offset = 1
for pid in pids
num_embs_pid = ranker.doclens[pid]
pid_scores = query_doc_scores[:, offset:min(num_embs, offset + num_embs_pid - 1)]
push!(scores, sum(maximum(pid_scores, dims = 2)))

offset += num_embs_pid
end
@assert offset == num_embs + 1

scores
end

function rank(ranker::IndexScorer, config::ColBERTConfig, Q::Array{<:AbstractFloat})
pids = retrieve(ranker, config, Q)
scores = score_pids(ranker, config, Q, pids)
indices = sortperm(scores, rev=true)

pids[indices], scores[indices]
end
Loading
Loading