diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index df1f339..a537058 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -54,3 +54,35 @@ function doc(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask::A D = mapslices(v -> iszero(v) ? v : normalize(v), D, dims = 1) # normalize each embedding return D, mask end + +function docFromText(checkpoint::Checkpoint, docs::Vector{String}, bsize::Union{Missing, Int}) + if ismissing(bsize) + integer_ids, integer_mask = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize) + doc(checkpoint, integer_ids, integer_mask) + else + text_batches, reverse_indices = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize) + batches = [doc(checkpoint, integer_ids, integer_mask) for (integer_ids, integer_mask) in text_batches] + + # aggregate all embeddings + D, mask = [], [] + for (_D, _mask) in batches + push!(D, _D) + push!(mask, _mask) + end + + # concat embeddings and masks, and put them in the original order + D, mask = cat(D..., dims = 3)[:, :, reverse_indices], cat(mask..., dims = 3)[:, :, reverse_indices] + mask = reshape(mask, size(mask)[2:end]) + + # get doclens, i.e number of attended tokens for each passage + doclens = sum(mask, dims = 1) + + # flatten out embeddings, i.e get embeddings for each token in each passage + D = reshape(D, size(D)[1], prod(size(D)[2:end])) + + # remove embeddings for masked tokens + D = D[:, reshape(mask, prod(size(mask)))] + + D, doclens + end +end