Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating query embeddings. #17

Merged
merged 7 commits into from
Jul 24, 2024
146 changes: 140 additions & 6 deletions src/modelling/checkpoint.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ..ColBERT: DocTokenizer, ColBERTConfig
using ..ColBERT: DocTokenizer, QueryTokenizer, ColBERTConfig

"""
BaseColBERT(; bert::Transformers.HuggingFace.HGFBertModel, linear::Transformers.Layers.Dense, tokenizer::Transformers.TextEncoders.AbstractTransformerTextEncoder)
Expand Down Expand Up @@ -101,13 +101,14 @@ end
"""
Checkpoint(model::BaseColBERT, doc_tokenizer::DocTokenizer, colbert_config::ColBERTConfig)

A wrapper for [`BaseColBERT`](@ref), which includes a [`ColBERTConfig`](@ref) and tokenization-specific functions via the [`DocTokenizer`](@ref) type.
A wrapper for [`BaseColBERT`](@ref), which includes a [`ColBERTConfig`](@ref) and tokenization-specific functions via the [`DocTokenizer`](@ref) and [`QueryTokenizer`] types.

If the config's [`DocSettings`](@ref) are configured to mask punctuations, then the `skiplist` property of the created [`Checkpoint`](@ref) will be set to a list of token IDs of punctuations.

# Arguments
- `model`: The [`BaseColBERT`](@ref) to be wrapped.
- `doc_tokenizer`: A [`DocTokenizer`](@ref) used for functions related to document tokenization.
- `query_tokenizer`: A [`QueryTokenizer`](@ref) used for functions related to query tokenization.
- `colbert_config`: The underlying [`ColBERTConfig`](@ref).

# Returns
Expand All @@ -118,7 +119,7 @@ The created [`Checkpoint`](@ref).
Continuing from the example for [`BaseColBERT`](@ref):

```julia-repl
julia> checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), config);
julia> checkPoint = Checkpoint(base_colbert, DocTokenizer(base_colbert.tokenizer, config), QueryTokenizer(base_colbert.tokenizer, config), config)

julia> checkPoint.skiplist # by default, all punctuations
32-element Vector{Int64}:
Expand Down Expand Up @@ -157,18 +158,19 @@ julia> checkPoint.skiplist # by default, all punctuations
struct Checkpoint
model::BaseColBERT
doc_tokenizer::DocTokenizer
query_tokenizer::QueryTokenizer
colbert_config::ColBERTConfig
skiplist::Union{Missing, Vector{Int}}
end

function Checkpoint(model::BaseColBERT, doc_tokenizer::DocTokenizer, colbert_config::ColBERTConfig)
function Checkpoint(model::BaseColBERT, doc_tokenizer::DocTokenizer, query_tokenizer::QueryTokenizer, colbert_config::ColBERTConfig)
if colbert_config.doc_settings.mask_punctuation
punctuation_list = string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~"))
skiplist = [TextEncodeBase.lookup(model.tokenizer.vocab, punct) for punct in punctuation_list]
else
skiplist = missing
end
Checkpoint(model, doc_tokenizer, colbert_config, skiplist)
Checkpoint(model, doc_tokenizer, query_tokenizer, colbert_config, skiplist)
end

"""
Expand Down Expand Up @@ -223,7 +225,7 @@ end
"""
doc(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask::AbstractArray)

Compute the hidden state of the BERT and linear layers of ColBERT.
Compute the hidden state of the BERT and linear layers of ColBERT for documents.

# Arguments

Expand Down Expand Up @@ -267,6 +269,7 @@ function doc(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask::A

mask = mask_skiplist(checkpoint.model.tokenizer, integer_ids, checkpoint.skiplist)
mask = reshape(mask, (1, size(mask)...)) # equivalent of unsqueeze
@assert isequal(size(mask)[2:end], size(D)[2:end])

D = D .* mask # clear out embeddings of masked tokens
D = mapslices(v -> iszero(v) ? v : normalize(v), D, dims = 1) # normalize each embedding
Expand Down Expand Up @@ -379,3 +382,134 @@ function docFromText(checkpoint::Checkpoint, docs::Vector{String}, bsize::Union{
D, doclens
end
end

"""
query(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask::AbstractArray)

Compute the hidden state of the BERT and linear layers of ColBERT for queries.

# Arguments

- `checkpoint`: The [`Checkpoint`](@ref) containing the layers to compute the embeddings.
- `integer_ids`: An array of token IDs to be fed into the BERT model.
- `integer_mask`: An array of corresponding attention masks. Should have the same shape as `integer_ids`.

# Returns

`Q`, where `Q` is an array containing the normalized embeddings for each token in the query matrix. It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of any query and `N` is the total number of queries.

# Examples

Continuing from the queries example for [`tensorize`](@ref) and [`Checkpoint`](@ref):

```julia-repl
julia> query(checkPoint, integer_ids, integer_mask)
128×32×1 Array{Float32, 3}:
[:, :, 1] =
0.0158567 0.169676 0.092745 0.0798617 … 0.115938 0.112977 0.107919
0.220185 0.0304873 0.165348 0.150315 0.0168762 0.0178042 0.0200357
-0.00790007 -0.0192251 -0.0852364 -0.0799609 -0.0777439 -0.0776733 -0.0830504
-0.109909 -0.170906 -0.0138702 -0.0409767 -0.126037 -0.126829 -0.13149
-0.0231786 0.0532214 0.0607473 0.0279048 0.117017 0.114073 0.108536
0.0620549 0.0465075 0.0821693 0.0606439 … 0.0150612 0.0133353 0.0126583
-0.0290509 0.143255 0.0306142 0.042658 -0.164401 -0.161857 -0.160327
0.0921477 0.0588331 0.250449 0.234636 0.0664076 0.0659837 0.0711357
0.0279402 -0.0278357 0.144855 0.147958 0.154552 0.155525 0.163634
-0.0768143 -0.00587305 0.00543038 0.00443374 -0.11757 -0.112495 -0.11112
⋮ ⋱ ⋮
-0.0859686 0.0623054 0.0974813 0.126841 0.0182795 0.0230549 0.031103
0.0392043 0.0162653 0.0926306 0.104053 0.0491495 0.0484318 0.0438132
-0.0340363 -0.0278066 -0.0181035 -0.0282369 … -0.0617945 -0.0631367 -0.0675882
0.013123 0.0565132 -0.0349061 -0.0464192 0.0724731 0.0780166 0.074623
-0.117425 0.162483 0.11039 0.136364 -0.00538225 -0.00685449 -0.0019436
-0.0401158 -0.0045094 0.0539569 0.0689953 -0.00518063 -0.00600252 -0.00771469
0.0893983 0.0695061 -0.0499409 -0.035411 0.0960932 0.0961893 0.103431
-0.116265 -0.106331 -0.179832 -0.149728 … -0.0197172 -0.022061 -0.018135
-0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0699095 -0.0684749 -0.0662904
0.100019 -0.0618588 0.106134 0.0989047 -0.0556761 -0.0556784 -0.059571

```
"""
function query(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask::AbstractArray)
Q = checkpoint.model.bert((token=integer_ids, attention_mask=NeuralAttentionlib.GenericSequenceMask(integer_mask))).hidden_state
Q = checkpoint.model.linear(Q)

# only skip the pad symbol, i.e an empty skiplist
mask = mask_skiplist(checkpoint.model.tokenizer, integer_ids, Vector{Int}())
mask = reshape(mask, (1, size(mask)...)) # equivalent of unsqueeze
@assert isequal(size(mask)[2:end], size(Q)[2:end])

Q = Q .* mask
Q = mapslices(v -> iszero(v) ? v : normalize(v), Q, dims = 1) # normalize each embedding
Q
end

"""
queryFromText(checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{Missing, Int})

Get ColBERT embeddings for `queries` using `checkpoint`.

This function also applies ColBERT-style query pre-processing for each query in `queries`.

# Arguments

- `checkpoint`: A [`Checkpoint`](@ref) to be used to compute embeddings.
- `queries`: A list of queries to get the embeddings for.
- `bsize`: A batch size for processing queries in batches.

# Returns

`embs`, where `embs` is an array of embeddings. The array `embs` has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for ColBERT's linear layer), `L` is the maximum length of any query in the batch, and `N` is the total number of queries in `queries`.

# Examples

Continuing from the example in [`Checkpoint`](@ref):

```julia-repl
julia> queries = ["what are white spots on raspberries?"];

julia> queryFromText(checkPoint, queries, 128)
128×32×1 Array{Float32, 3}:
[:, :, 1] =
0.0158567 0.169676 0.092745 0.0798617 … 0.115806 0.115938 0.112977 0.107919
0.220185 0.0304873 0.165348 0.150315 0.0165188 0.0168762 0.0178042 0.0200357
-0.00790007 -0.0192251 -0.0852364 -0.0799609 -0.0737461 -0.0777439 -0.0776733 -0.0830504
-0.109909 -0.170906 -0.0138702 -0.0409767 -0.118738 -0.126037 -0.126829 -0.13149
-0.0231786 0.0532214 0.0607473 0.0279048 0.111831 0.117017 0.114073 0.108536
0.0620549 0.0465075 0.0821693 0.0606439 … 0.0148605 0.0150612 0.0133353 0.0126583
-0.0290509 0.143255 0.0306142 0.042658 -0.169493 -0.164401 -0.161857 -0.160327
0.0921477 0.0588331 0.250449 0.234636 0.0642578 0.0664076 0.0659837 0.0711357
0.0279402 -0.0278357 0.144855 0.147958 0.157629 0.154552 0.155525 0.163634
-0.0768143 -0.00587305 0.00543038 0.00443374 -0.123969 -0.11757 -0.112495 -0.11112
-0.0184338 0.00668557 -0.191863 -0.161345 … -0.10374 -0.107664 -0.107267 -0.114564
⋮ ⋱ ⋮
-0.0859686 0.0623054 0.0974813 0.126841 0.0191363 0.0182795 0.0230549 0.031103
0.0392043 0.0162653 0.0926306 0.104053 0.0553615 0.0491495 0.0484318 0.0438132
-0.0340363 -0.0278066 -0.0181035 -0.0282369 … -0.0562518 -0.0617945 -0.0631367 -0.0675882
0.013123 0.0565132 -0.0349061 -0.0464192 0.0698766 0.0724731 0.0780166 0.074623
-0.117425 0.162483 0.11039 0.136364 -0.0050836 -0.00538225 -0.00685449 -0.0019436
-0.0401158 -0.0045094 0.0539569 0.0689953 -0.00322497 -0.00518063 -0.00600252 -0.00771469
0.0893983 0.0695061 -0.0499409 -0.035411 0.0964842 0.0960932 0.0961893 0.103431
-0.116265 -0.106331 -0.179832 -0.149728 … -0.0275017 -0.0197172 -0.022061 -0.018135
-0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0735711 -0.0699095 -0.0684749 -0.0662904
0.100019 -0.0618588 0.106134 0.0989047 -0.0553564 -0.0556761 -0.0556784 -0.059571

```
"""
function queryFromText(checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{Missing, Int})
if ismissing(bsize)
error("Currently bsize cannot be missing!")
end

# configure the tokenizer to truncate or pad to query_maxlen
tokenizer = checkpoint.model.tokenizer
process = tokenizer.process
truncpad_pipe = Pipeline{:token}(TextEncodeBase.trunc_or_pad(checkpoint.colbert_config.query_settings.query_maxlen, "[PAD]", :tail, :tail), :token)
process = process[1:4] |> truncpad_pipe |> process[6:end]
tokenizer = Transformers.TextEncoders.BertTextEncoder(tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc)

# get ids and masks, embeddings and returning the concatenated tensors
batches = tensorize(checkpoint.query_tokenizer, tokenizer, queries, bsize)
batches = [query(checkpoint, integer_ids, integer_mask) for (integer_ids, integer_mask) in batches]
cat(batches..., dims=3)
end
Loading