Skip to content

Commit

Permalink
Adding the queryFromText function to get query embeddings from query
Browse files Browse the repository at this point in the history
texts.
  • Loading branch information
codetalker7 committed Jul 24, 2024
1 parent f666be7 commit 4e241d7
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/modelling/checkpoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,20 @@ function query(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask:
Q
end

function queryFromText(checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{Missing, Int})
if ismissing(bsize)
error("Currently bsize cannot be missing!")
end

# configure the tokenizer to truncate or pad to query_maxlen
tokenizer = checkpoint.model.tokenizer
process = tokenizer.process
truncpad_pipe = Pipeline{:token}(TextEncodeBase.trunc_or_pad(checkpoint.colbert_config.query_settings.query_maxlen, "[PAD]", :tail, :tail), :token)
process = process[1:4] |> truncpad_pipe |> process[6:end]
tokenizer = Transformers.TextEncoders.BertTextEncoder(tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc)

# get ids and masks, embeddings and returning the concatenated tensors
batches = tensorize(checkpoint.query_tokenizer, tokenizer, queries, bsize)
batches = [query(checkpoint, integer_ids, integer_mask) for (integer_ids, integer_mask) in batches]
cat(batches..., dims=3)
end

0 comments on commit 4e241d7

Please sign in to comment.