diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 4bf2454..8a4e490 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -444,3 +444,20 @@ function query(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask: Q end +function queryFromText(checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{Missing, Int}) + if ismissing(bsize) + error("Currently bsize cannot be missing!") + end + + # configure the tokenizer to truncate or pad to query_maxlen + tokenizer = checkpoint.model.tokenizer + process = tokenizer.process + truncpad_pipe = Pipeline{:token}(TextEncodeBase.trunc_or_pad(checkpoint.colbert_config.query_settings.query_maxlen, "[PAD]", :tail, :tail), :token) + process = process[1:4] |> truncpad_pipe |> process[6:end] + tokenizer = Transformers.TextEncoders.BertTextEncoder(tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc) + + # get ids and masks, embeddings and returning the concatenated tensors + batches = tensorize(checkpoint.query_tokenizer, tokenizer, queries, bsize) + batches = [query(checkpoint, integer_ids, integer_mask) for (integer_ids, integer_mask) in batches] + cat(batches..., dims=3) +end