Skip to content

Commit

Permalink
Structural changes to old functions, plus further refactoring; also
Browse files Browse the repository at this point in the history
removing unnecessary functions.
  • Loading branch information
codetalker7 committed Aug 25, 2024
1 parent 7147c73 commit bbcb151
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 262 deletions.
53 changes: 0 additions & 53 deletions src/indexing/collection_indexer.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,3 @@
"""
encode_passages(
config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String})
Encode a list of passages using `checkpoint`.
The given `passages` are run through the underlying BERT model and the linear layer to
generate the embeddings, after doing relevant document-specific preprocessing.
See [`docFromText`](@ref) for more details.
# Arguments
- `config`: The [`ColBERTConfig`](@ref) to be used.
- `checkpoint`: The [`Checkpoint`](@ref) used to encode the passages.
- `passages`: A list of strings representing the passages to be encoded.
# Returns
A tuple `embs, doclens` where:
- `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`,
where `D` is the embedding dimension and `N` is the total number of embeddings
across all the passages.
- `doclens::AbstractVector{Int}`: A vector of document lengths for each passage,
i.e the total number of attended tokens for each document passage.
"""
function encode_passages(
config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String})
@info "Encoding $(length(passages)) passages."

if length(passages) == 0
error("The list of passages to encode is empty!")
end

embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}()
# batching here to avoid storing intermediate embeddings on GPU
# batching also occurs inside docFromText to do batch packing optimizations
for passage_offset in 1:(config.passages_batch_size):length(passages)
passage_end_offset = min(
length(passages), passage_offset + config.passages_batch_size - 1)
embs_, doclens_ = docFromText(
config, checkpoint, passages[passage_offset:passage_end_offset],
config.index_bsize)
@assert embs_ isa Matrix{Float32}
@assert doclens_ isa Vector{Int}
push!(embs, embs_)
append!(doclens, vec(doclens_))
embs_, doclens_ = nothing, nothing
end
embs = cat(embs..., dims = 2)
embs, doclens
end

"""
_sample_pids(num_documents::Int)
Expand Down
Loading

0 comments on commit bbcb151

Please sign in to comment.