Skip to content

Commit

Permalink
Adding tests for indexing functions + related changes to src.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed Sep 7, 2024
1 parent f0ba596 commit f48ccf2
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 230 deletions.
33 changes: 18 additions & 15 deletions src/indexing/collection_indexer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ function setup(collection::Vector{String}, avg_doclen_est::Float32,
)
end

function _bucket_cutoffs_and_weights(
nbits::Int, heldout_avg_residual::AbstractMatrix{Float32})
num_options = 1 << nbits
quantiles = collect(0:(num_options - 1)) / num_options
bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[2:end],
quantiles .+ (0.5 / num_options)
bucket_cutoffs = Float32.(quantile(
heldout_avg_residual, bucket_cutoffs_quantiles))
bucket_weights = Float32.(quantile(
heldout_avg_residual, bucket_weights_quantiles))
bucket_cutoffs, bucket_weights
end

"""
_compute_avg_residuals(
nbits::Int, centroids::AbstractMatrix{Float32},
Expand All @@ -159,30 +172,20 @@ compression/decompression of residuals.
function _compute_avg_residuals!(
nbits::Int, centroids::AbstractMatrix{Float32},
heldout::AbstractMatrix{Float32}, codes::AbstractVector{UInt32})
@assert length(codes) == size(heldout, 2)
length(codes) == size(heldout, 2) ||
throw(DimensionMismatch("length(codes) must be equal to the number " *
"of embeddings in heldout!"))

compress_into_codes!(codes, centroids, heldout) # get centroid codes
heldout_reconstruct = centroids[:, codes] # get corresponding centroids
heldout_avg_residual = heldout - heldout_reconstruct # compute the residual

avg_residual = mean(abs.(heldout_avg_residual), dims = 2) # for each dimension, take mean of absolute values of residuals

# computing bucket weights and cutoffs
num_options = 2^nbits
quantiles = Vector(0:(num_options - 1)) / num_options
bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[2:end],
quantiles .+ (0.5 / num_options)

bucket_cutoffs = Float32.(quantile(
heldout_avg_residual, bucket_cutoffs_quantiles))
bucket_weights = Float32.(quantile(
heldout_avg_residual, bucket_weights_quantiles))
@assert bucket_cutoffs isa AbstractVector{Float32} "$(typeof(bucket_cutoffs))"
@assert bucket_weights isa AbstractVector{Float32} "$(typeof(bucket_weights))"
bucket_cutoffs, bucket_weights = _bucket_cutoffs_and_weights(
nbits, heldout_avg_residual)

@info "Got bucket_cutoffs_quantiles = $(bucket_cutoffs_quantiles) and bucket_weights_quantiles = $(bucket_weights_quantiles)"
@info "Got bucket_cutoffs = $(bucket_cutoffs) and bucket_weights = $(bucket_weights)"

bucket_cutoffs, bucket_weights, mean(avg_residual)
end

Expand Down
173 changes: 0 additions & 173 deletions src/modelling/checkpoint.jl
Original file line number Diff line number Diff line change
@@ -1,176 +1,3 @@
"""
BaseColBERT(;
bert::HuggingFace.HGFBertModel, linear::Layers.Dense,
tokenizer::TextEncoders.AbstractTransformerTextEncoder)
A struct representing the BERT model, linear layer, and the tokenizer used to compute
embeddings for documents and queries.
# Arguments
- `bert`: The pre-trained BERT model used to generate the embeddings.
- `linear`: The linear layer used to project the embeddings to a specific dimension.
- `tokenizer`: The tokenizer to used by the BERT model.
# Returns
A [`BaseColBERT`](@ref) object.
# Examples
```julia-repl
julia> using ColBERT, CUDA;
julia> base_colbert = BaseColBERT("/home/codetalker7/models/colbertv2.0/");
julia> base_colbert.bert
HGFBertModel(
Chain(
CompositeEmbedding(
token = Embed(768, 30522), # 23_440_896 parameters
position = ApplyEmbed(.+, FixedLenPositionEmbed(768, 512)), # 393_216 parameters
segment = ApplyEmbed(.+, Embed(768, 2), Transformers.HuggingFace.bert_ones_like), # 1_536 parameters
),
DropoutLayer<nothing>(
LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters
),
),
Transformer<12>(
PostNormTransformerBlock(
DropoutLayer<nothing>(
SelfAttention(
MultiheadQKVAttenOp(head = 12, p = nothing),
Fork<3>(Dense(W = (768, 768), b = true)), # 1_771_776 parameters
Dense(W = (768, 768), b = true), # 590_592 parameters
),
),
LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters
DropoutLayer<nothing>(
Chain(
Dense(σ = NNlib.gelu, W = (768, 3072), b = true), # 2_362_368 parameters
Dense(W = (3072, 768), b = true), # 2_360_064 parameters
),
),
LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters
),
), # Total: 192 arrays, 85_054_464 parameters, 40.422 KiB.
Branch{(:pooled,) = (:hidden_state,)}(
BertPooler(Dense(σ = NNlib.tanh_fast, W = (768, 768), b = true)), # 590_592 parameters
),
) # Total: 199 arrays, 109_482_240 parameters, 43.578 KiB.
julia> base_colbert.linear
Dense(W = (768, 128), b = true) # 98_432 parameters
julia> base_colbert.tokenizer
TrfTextEncoder(
├─ TextTokenizer(MatchTokenization(WordPieceTokenization(bert_uncased_tokenizer, WordPiece(vocab_size = 30522, unk = [UNK], max_char = 100)), 5 patterns)),
├─ vocab = Vocab{String, SizedArray}(size = 30522, unk = [UNK], unki = 101),
├─ config = @NamedTuple{startsym::String, endsym::String, padsym::String, trunc::Union{Nothing, Int64}}(("[CLS]", "[SEP]", "[PAD]", 512)),
├─ annotate = annotate_strings,
├─ onehot = lookup_first,
├─ decode = nestedcall(remove_conti_prefix),
├─ textprocess = Pipelines(target[token] := join_text(source); target[token] := nestedcall(cleanup ∘ remove_prefix_space, target.token); target := (target.token)),
└─ process = Pipelines:
╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source)
╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token)
╰─ target[(token, segment)] := SequenceTemplate{String}([CLS]:<type=1> Input[1]:<type=1> [SEP]:<type=1> (Input[2]:<type=2> [SEP]:<type=2>)...)(target.token)
╰─ target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(512))(target.token)
╰─ target[token] := TextEncodeBase.trunc_and_pad(512, [PAD], tail, tail)(target.token)
╰─ target[token] := TextEncodeBase.nested2batch(target.token)
╰─ target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment)
╰─ target[segment] := TextEncodeBase.nested2batch(target.segment)
╰─ target[sequence_mask] := identity(target.attention_mask)
╰─ target := (target.token, target.segment, target.attention_mask, target.sequence_mask)
```
"""
struct BaseColBERT
bert::HF.HGFBertModel
linear::Layers.Dense
tokenizer::TextEncoders.AbstractTransformerTextEncoder
end

function BaseColBERT(modelpath::AbstractString)
tokenizer, bert_model, linear = load_hgf_pretrained_local(modelpath)
bert_model = bert_model |> Flux.gpu
linear = linear |> Flux.gpu
BaseColBERT(bert_model, linear, tokenizer)
end

"""
Checkpoint(model::BaseColBERT, config::ColBERTConfig)
A wrapper for [`BaseColBERT`](@ref), containing information for generating embeddings
for docs and queries.
If the `config` is set to mask punctuations, then the `skiplist` property of the created
[`Checkpoint`](@ref) will be set to a list of token IDs of punctuations. Otherwise, it will be empty.
# Arguments
- `model`: The [`BaseColBERT`](@ref) to be wrapped.
- `config`: The underlying [`ColBERTConfig`](@ref).
# Returns
The created [`Checkpoint`](@ref).
# Examples
Continuing from the example for [`BaseColBERT`](@ref):
```julia-repl
julia> checkpoint = Checkpoint(base_colbert, config)
julia> checkpoint.skiplist # by default, all punctuations
32-element Vector{Int64}:
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1064
1065
1066
1067
```
"""
struct Checkpoint
model::BaseColBERT
skiplist::Vector{Int64}
end

function Checkpoint(model::BaseColBERT, config::ColBERTConfig)
if config.mask_punctuation
punctuation_list = string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~"))
skiplist = [TextEncodeBase.lookup(model.tokenizer.vocab, punct)
for punct in punctuation_list]
else
skiplist = Vector{Int64}()
end
Checkpoint(model, skiplist)
end

"""
doc(
config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32},
Expand Down
7 changes: 4 additions & 3 deletions src/modelling/tokenization/tokenizer_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ A matrix equal to `data`, with the second row being filled with `marker`.
# Examples
```julia-repl
julia> using ColBERT: _add_marker_row;
julia> using ColBERT: _add_marker_row;
julia> x = ones(Float32, 5, 5);
julia> x = ones(Float32, 5, 5);
5×5 Matrix{Float32}:
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
Expand All @@ -138,5 +138,6 @@ julia> _add_marker_row(x, zero(Float32))
"""
function _add_marker_row(data::AbstractMatrix{T}, marker::T) where {T}
[data[begin:1, :]; fill(marker, (1, size(data, 2))); data[2:end, :]]
[data[begin:min(1, size(data, 1)), :]; fill(marker, (1, size(data, 2)));
data[2:end, :]]
end
76 changes: 38 additions & 38 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,49 @@
"""
_sort_by_length(
integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int)
Sort sentences by number of attended tokens, if the number of sentences is larger than `bsize`.
# Arguments
- `integer_ids`: The token IDs of documents to be sorted.
- `integer_mask`: The attention masks of the documents to be sorted (attention masks are just bits).
- `bsize`: The size of batches to be considered.
# Returns
Depending upon `bsize`, the following are returned:
- If the number of documents (second dimension of `integer_ids`) is atmost `bsize`, then the
`integer_ids` and `integer_mask` are returned unchanged.
- If the number of documents is larger than `bsize`, then the passages are first sorted
by the number of attended tokens (figured out from the `integer_mask`), and then the
sorted arrays `integer_ids`, `integer_mask` are returned, along with a list of
`reverse_indices`, i.e a mapping from the documents to their indices in the original
order.
"""
function _sort_by_length(
integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}, batch_size::Int)
size(integer_ids, 2) <= batch_size &&
return integer_ids, bitmask, Vector(1:size(integer_ids, 2))
lengths = vec(sum(bitmask; dims = 1)) # number of attended tokens in each passage
indices = sortperm(lengths) # get the indices which will sort lengths
reverse_indices = sortperm(indices) # invert the indices list
@assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))"
@assert bitmask isa BitMatrix "$(typeof(bitmask))"
@assert reverse_indices isa Vector{Int} "$(typeof(reverse_indices))"
integer_ids[:, indices], bitmask[:, indices], reverse_indices
end
# """
# _sort_by_length(
# integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int)
#
# Sort sentences by number of attended tokens, if the number of sentences is larger than `bsize`.
#
# # Arguments
#
# - `integer_ids`: The token IDs of documents to be sorted.
# - `integer_mask`: The attention masks of the documents to be sorted (attention masks are just bits).
# - `bsize`: The size of batches to be considered.
#
# # Returns
#
# Depending upon `bsize`, the following are returned:
#
# - If the number of documents (second dimension of `integer_ids`) is atmost `bsize`, then the
# `integer_ids` and `integer_mask` are returned unchanged.
# - If the number of documents is larger than `bsize`, then the passages are first sorted
# by the number of attended tokens (figured out from the `integer_mask`), and then the
# sorted arrays `integer_ids`, `integer_mask` are returned, along with a list of
# `reverse_indices`, i.e a mapping from the documents to their indices in the original
# order.
# """
# function _sort_by_length(
# integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}, batch_size::Int)
# size(integer_ids, 2) <= batch_size &&
# return integer_ids, bitmask, Vector(1:size(integer_ids, 2))
# lengths = vec(sum(bitmask; dims = 1)) # number of attended tokens in each passage
# indices = sortperm(lengths) # get the indices which will sort lengths
# reverse_indices = sortperm(indices) # invert the indices list
# @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))"
# @assert bitmask isa BitMatrix "$(typeof(bitmask))"
# @assert reverse_indices isa Vector{Int} "$(typeof(reverse_indices))"
# integer_ids[:, indices], bitmask[:, indices], reverse_indices
# end

function compute_distances_kernel!(batch_distances::AbstractMatrix{Float32},
batch_data::AbstractMatrix{Float32},
centroids::AbstractMatrix{Float32})
batch_distances .= 0.0f0
# Compute squared distances: (a-b)^2 = a^2 + b^2 - 2ab
# a^2 term
sum_sq_data = sum(batch_data .^ 2, dims = 1) # (1, point_bsize)
sum_sq_data = sum(batch_data .^ 2, dims = 1) # (1, point_bsize)
# b^2 term
sum_sq_centroids = sum(centroids .^ 2, dims = 1)' # (num_centroids, 1)
sum_sq_centroids = sum(centroids .^ 2, dims = 1)' # (num_centroids, 1)
# -2ab term
mul!(batch_distances, centroids', batch_data, -2.0f0, 1.0f0) # (num_centroids, point_bsize)
# Compute (a-b)^2 = a^2 + b^2 - 2ab
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
Loading

0 comments on commit f48ccf2

Please sign in to comment.