Skip to content

Commit

Permalink
Adding implementation of docFromText, to convert passages to
Browse files Browse the repository at this point in the history
embeddings for each token across all passages.
  • Loading branch information
codetalker7 committed Jun 2, 2024
1 parent 6cc241b commit 9f5f587
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions src/modelling/checkpoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,35 @@ function doc(checkpoint::Checkpoint, integer_ids::AbstractArray, integer_mask::A
D = mapslices(v -> iszero(v) ? v : normalize(v), D, dims = 1) # normalize each embedding
return D, mask
end

function docFromText(checkpoint::Checkpoint, docs::Vector{String}, bsize::Union{Missing, Int})
if ismissing(bsize)
integer_ids, integer_mask = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize)
doc(checkpoint, integer_ids, integer_mask)
else
text_batches, reverse_indices = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize)
batches = [doc(checkpoint, integer_ids, integer_mask) for (integer_ids, integer_mask) in text_batches]

# aggregate all embeddings
D, mask = [], []
for (_D, _mask) in batches
push!(D, _D)
push!(mask, _mask)
end

# concat embeddings and masks, and put them in the original order
D, mask = cat(D..., dims = 3)[:, :, reverse_indices], cat(mask..., dims = 3)[:, :, reverse_indices]
mask = reshape(mask, size(mask)[2:end])

# get doclens, i.e number of attended tokens for each passage
doclens = sum(mask, dims = 1)

# flatten out embeddings, i.e get embeddings for each token in each passage
D = reshape(D, size(D)[1], prod(size(D)[2:end]))

# remove embeddings for masked tokens
D = D[:, reshape(mask, prod(size(mask)))]

D, doclens
end
end

0 comments on commit 9f5f587

Please sign in to comment.