From bbcb151e5cd11f561db8b5c7dcfe988c25d5e474 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 26 Aug 2024 04:54:12 +0530 Subject: [PATCH] Structural changes to old functions, plus further refactoring; also removing unnecessary functions. --- src/indexing/collection_indexer.jl | 53 --- src/modelling/checkpoint.jl | 513 +++++++++++++++++------------ 2 files changed, 304 insertions(+), 262 deletions(-) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 752f12e..2dd3424 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -1,56 +1,3 @@ -""" - encode_passages( - config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) - -Encode a list of passages using `checkpoint`. - -The given `passages` are run through the underlying BERT model and the linear layer to -generate the embeddings, after doing relevant document-specific preprocessing. -See [`docFromText`](@ref) for more details. - -# Arguments - - - `config`: The [`ColBERTConfig`](@ref) to be used. - - `checkpoint`: The [`Checkpoint`](@ref) used to encode the passages. - - `passages`: A list of strings representing the passages to be encoded. - -# Returns - -A tuple `embs, doclens` where: - - - `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`, - where `D` is the embedding dimension and `N` is the total number of embeddings - across all the passages. - - `doclens::AbstractVector{Int}`: A vector of document lengths for each passage, - i.e the total number of attended tokens for each document passage. -""" -function encode_passages( - config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) - @info "Encoding $(length(passages)) passages." - - if length(passages) == 0 - error("The list of passages to encode is empty!") - end - - embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}() - # batching here to avoid storing intermediate embeddings on GPU - # batching also occurs inside docFromText to do batch packing optimizations - for passage_offset in 1:(config.passages_batch_size):length(passages) - passage_end_offset = min( - length(passages), passage_offset + config.passages_batch_size - 1) - embs_, doclens_ = docFromText( - config, checkpoint, passages[passage_offset:passage_end_offset], - config.index_bsize) - @assert embs_ isa Matrix{Float32} - @assert doclens_ isa Vector{Int} - push!(embs, embs_) - append!(doclens, vec(doclens_)) - embs_, doclens_ = nothing, nothing - end - embs = cat(embs..., dims = 2) - embs, doclens -end - """ _sample_pids(num_documents::Int) diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index ff9c57d..79cbda2 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -194,48 +194,159 @@ any document in `integer_ids` and `N` is the number of documents. # Examples -Continuing with the example for [`tensorize_docs`](@ref) and the -`skiplist` from the example in [`Checkpoint`](@ref). +In this example, we'll mask out all punctuations as well as the pad symbol +of a tokenizer. ```julia-repl -julia> integer_ids = batches[1][1]; - -julia> ColBERT.mask_skiplist( - checkpoint.model.tokenizer, integer_ids, checkpoint.skiplist) -21×3 BitMatrix: - 1 1 1 - 1 1 1 - 1 1 1 - 1 1 1 - 0 1 0 - 0 0 1 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 - 0 0 0 +julia> using ColBERT: mask_skiplist; + +julia> using TextEncodeBase + +julia> tokenizer = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/:tokenizer"); + +julia> punctuations_and_padsym = [string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")); + tokenizer.padsym]; + +julia> skiplist = [lookup(tokenizer.vocab, sym) + for sym in punctuations_and_padsym] +33-element Vector{Int64}: + 1000 + 1001 + 1002 + 1003 + 1004 + 1005 + 1006 + 1007 + 1008 + 1009 + 1010 + 1011 + 1012 + 1013 + 1014 + 1025 + 1026 + 1027 + 1028 + 1029 + 1030 + 1031 + 1032 + 1033 + 1034 + 1035 + 1036 + 1037 + 1064 + 1065 + 1066 + 1067 + 1 + +julia> batch_text = [ + "no punctuation text", + "this, batch,! of text contains puncts! but is larger so that? the other text contains pad symbol;" +]; + +julia> integer_ids, _ = tensorize_docs("[unused1]", tokenizer, batch_text) + +julia> integer_ids +27×2 Matrix{Int32}: + 102 102 + 3 3 + 2054 2024 + 26137 1011 + 6594 14109 + 14506 1011 + 3794 1000 + 103 1998 + 1 3794 + 1 3398 + 1 26137 + 1 16650 + 1 1000 + 1 2022 + 1 2004 + 1 3470 + 1 2062 + 1 2009 + 1 1030 + 1 1997 + 1 2061 + 1 3794 + 1 3398 + 1 11688 + 1 6455 + 1 1026 + 1 103 + +julia> decode(tokenizer, integer_ids) +27×2 Matrix{String}: + " [CLS]" " [CLS]" + " [unused1]" " [unused1]" + " no" " this" + " pun" " ," + "ct" " batch" + "uation" " ," + " text" " !" + " [SEP]" " of" + " [PAD]" " text" + " [PAD]" " contains" + " [PAD]" " pun" + " [PAD]" "cts" + " [PAD]" " !" + " [PAD]" " but" + " [PAD]" " is" + " [PAD]" " larger" + " [PAD]" " so" + " [PAD]" " that" + " [PAD]" " ?" + " [PAD]" " the" + " [PAD]" " other" + " [PAD]" " text" + " [PAD]" " contains" + " [PAD]" " pad" + " [PAD]" " symbol" + " [PAD]" " ;" + " [PAD]" " [SEP]" + +julia> mask_skiplist(integer_ids, skiplist) +27×2 BitMatrix: + 1 1 + 1 1 + 1 1 + 1 0 + 1 1 + 1 0 + 1 0 + 1 1 + 0 1 + 0 1 + 0 1 + 0 1 + 0 0 + 0 1 + 0 1 + 0 1 + 0 1 + 0 1 + 0 0 + 0 1 + 0 1 + 0 1 + 0 1 + 0 1 + 0 1 + 0 0 + 0 1 ``` """ -function mask_skiplist( - tokenizer::TextEncoders.AbstractTransformerTextEncoder, - integer_ids::AbstractMatrix{Int32}, skiplist::Union{ - Missing, Vector{Int64}}) - filter = integer_ids .!= - TextEncodeBase.lookup(tokenizer.vocab, tokenizer.padsym) +function mask_skiplist!(mask::AbstractMatrix{Bool}, + integer_ids::AbstractMatrix{Int32}, skiplist::Vector{Int64}) for token_id in skiplist - filter = filter .& (integer_ids .!= token_id) + mask .= mask .& (integer_ids .!= token_id) end - filter end """ @@ -288,179 +399,10 @@ julia> mask 1 1 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ``` """ -function doc( - config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, - integer_mask::AbstractMatrix{Bool}) - integer_ids = integer_ids |> Flux.gpu - integer_mask = integer_mask |> Flux.gpu - - D = checkpoint.model.bert((token = integer_ids, - attention_mask = NeuralAttentionlib.GenericSequenceMask(integer_mask))).hidden_state - D = checkpoint.model.linear(D) - - mask = mask_skiplist( - checkpoint.model.tokenizer, integer_ids, checkpoint.skiplist) - mask = reshape(mask, (1, size(mask)...)) # equivalent of unsqueeze - @assert isequal(size(mask)[2:end], size(D)[2:end]) - "size(mask): $(size(mask)), size(D): $(size(D))" - @assert mask isa AbstractArray{Bool} "$(typeof(mask))" - - D = D .* mask # clear out embeddings of masked tokens - - if !config.use_gpu - # doing this because normalize gives exact results - D = mapslices(v -> iszero(v) ? v : normalize(v), D, dims = 1) # normalize each embedding - else - # TODO: try to do some tests to see the gap between this and LinearAlgebra.normalize - # mapreduce doesn't give exact normalization - norms = map(sqrt, mapreduce(abs2, +, D, dims = 1)) - norms[norms .== 0] .= 1 # avoid division by 0 - @assert isequal(size(norms)[2:end], size(D)[2:end]) - @assert size(norms)[1] == 1 - - D = D ./ norms - end - - D, mask -end - -""" - docFromText(config::ColBERTConfig, checkpoint::Checkpoint, - docs::Vector{String}, bsize::Union{Missing, Int}) - -Get ColBERT embeddings for `docs` using `checkpoint`. - -This function also applies ColBERT-style document pre-processing for each document in `docs`. - -# Arguments - -- `config`: The [`ColBERTConfig`](@ref) being used. -- `checkpoint`: A [`Checkpoint`](@ref) to be used to compute embeddings. -- `docs`: A list of documents to get the embeddings for. -- `bsize`: A batch size for processing documents in batches. - -# Returns - -A tuple `embs, doclens`, where `embs` is an array of embeddings and `doclens` is a `Vector` -of document lengths. The array `embs` has shape `(D, N)`, where `D` is the embedding -dimension (`128` for ColBERT's linear layer) and `N` is the total number of embeddings -across all documents in `docs`. - -# Examples - -Continuing from the example in [`Checkpoint`](@ref): - -```julia-repl -julia> docs = [ - "hello world", - "thank you!", - "a", - "this is some longer text, so length should be longer", -]; - -julia> embs, doclens = ColBERT.docFromText(config, checkpoint, docs, config.index_bsize) -(Float32[0.07590997 0.00056472444 … -0.09958261 -0.03259005; 0.08413661 -0.016337946 … -0.061889287 -0.017708546; … ; -0.11584533 0.016651645 … 0.0073241345 0.09233974; 0.043868616 0.084660925 … -0.0294838 -0.08536169], [5 5 4 13]) - -julia> embs -128×27 Matrix{Float32}: - 0.0759101 0.00056477 -0.0256841 0.0847256 … 0.0321216 -0.0811892 -0.0995827 -0.03259 - 0.0841366 -0.0163379 -0.0573766 0.0125381 0.0838632 -0.0118507 -0.0618893 -0.0177087 - -0.0301104 -0.0128124 0.0137095 0.00290062 0.0347227 0.0138398 -0.0573847 0.177861 - 0.0375674 0.216562 0.220287 -0.011 -0.0213431 -0.110819 0.00425487 -0.00131534 - 0.0252677 0.151702 0.189658 -0.104252 -0.0654913 -0.0272064 0.0350983 -0.0381015 - 0.00608619 -0.0415363 -0.0479571 0.00884466 … 0.00207629 0.122848 0.0747105 0.0836628 - -0.185256 -0.106582 -0.0394912 -0.119268 0.163837 0.0352982 -0.0405874 -0.064156 - -0.0816655 -0.142809 -0.15595 -0.109608 0.0882721 0.0565001 -0.134649 0.00380792 - 0.00471225 0.00444501 0.0144707 0.0682628 0.0386771 0.0112827 0.0253297 0.0665075 - -0.121564 -0.189994 -0.173724 -0.0678208 -0.0832335 0.0151939 -0.119054 -0.0980481 - 0.157599 0.0919844 0.0748075 -0.122389 … 0.0599421 0.0330669 0.0205288 0.0184296 - 0.0132481 -0.0430333 -0.0679477 0.0918445 0.14166 0.0404866 0.0575921 0.101701 - 0.0695786 0.0281928 0.000234582 0.0570102 -0.137199 -0.0378472 -0.0531831 -0.123457 - -0.0933987 -0.0390347 -0.0274184 -0.0452961 0.14876 0.0279156 0.0309748 0.00298152 - 0.0458562 0.0729707 0.0336343 0.189599 0.0570071 0.103661 0.00905471 0.127777 - 0.00452595 0.05959 0.0768679 -0.036913 … 0.0768966 0.148845 0.0569493 0.293592 - -0.0385804 -0.00754613 0.0375564 0.00207589 -0.0161775 0.133667 0.266788 0.0394272 - ⋮ ⋱ ⋮ - 0.0510928 -0.138272 -0.111771 -0.192081 -0.0312752 -0.00646487 -0.0171807 -0.0618908 - 0.128495 0.181198 0.131882 -0.064132 -0.00662879 -0.00408871 0.027459 0.0343185 - -0.0961544 -0.0223997 0.025595 -0.12089 0.0042998 0.0117906 -0.0813832 0.0382321 - 0.0285496 0.0556695 0.0805605 -0.0728611 … 0.138845 -0.0139292 -0.14533 -0.017602 - 0.0112119 -0.164717 -0.188169 0.0315999 0.112653 0.071643 -0.0662124 0.164667 - -0.0017815 0.0600865 0.0858722 0.00955078 -0.0506793 0.120243 0.0490749 0.0562548 - -0.0261784 0.0343851 0.0447504 -0.105545 -0.0713677 0.0469064 0.040038 -0.0536368 - -0.0696538 -0.020624 -0.0465219 -0.121079 -0.0636235 0.0441996 0.0842775 0.0567261 - -0.0940355 -0.106123 -0.0424943 0.0650131 … 0.00190927 0.00334517 0.00795241 -0.0439884 - 0.0567849 -0.0312434 -0.0715616 0.136271 -0.0648593 -0.113022 0.0616157 -0.0738149 - -0.0143086 0.105833 0.0762297 0.0102708 -0.162572 -0.142671 -0.0430241 -0.0831737 - 0.0447039 0.0783602 0.0957613 0.0603179 0.0415507 -0.0413788 0.0315282 -0.171445 - 0.129225 0.112544 0.0815397 -0.00357054 0.097503 0.120684 0.107231 0.119762 - 0.00020747 -0.124472 -0.120445 -0.0102294 … -0.24173 -0.0930788 -0.0519734 0.0837617 - -0.115845 0.0166517 0.0199255 -0.044735 -0.0353863 0.0577463 0.00732411 0.0923398 - 0.0438687 0.0846609 0.0960215 0.112225 -0.178799 -0.096704 -0.0294837 -0.0853618 - -julia> doclens -4-element Vector{Int64}: - 5 - 5 - 4 - 13 -``` -""" -function docFromText(config::ColBERTConfig, checkpoint::Checkpoint, - docs::Vector{String}, bsize::Union{Missing, Int}) - if ismissing(bsize) - # integer_ids, integer_mask = tensorize(checkpoint.doc_tokenizer, checkpoint.model.tokenizer, docs, bsize) - # doc(checkpoint, integer_ids, integer_mask) - error("Currently bsize cannot be missing!") - else - integer_ids, integer_mask = tensorize_docs( - config, checkpoint.model.tokenizer, docs) - - # we sort passages by length to do batch packing for more efficient use of the GPU - integer_ids, integer_mask, reverse_indices = _sort_by_length( - integer_ids, integer_mask, bsize) - - @assert length(reverse_indices) == length(docs) - "length(reverse_indices): $(length(reverse_indices)), length(batch_text): $(length(docs))" - @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" - @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" - @assert reverse_indices isa Vector{Int64} "$(typeof(reverse_indices))" - - # aggregate all embeddings - D, mask = Vector{AbstractArray{Float32}}(), - Vector{AbstractArray{Bool}}() - for passage_offset in 1:bsize:length(docs) - passage_end_offset = min(length(docs), passage_offset + bsize - 1) - D_, mask_ = doc( - config, checkpoint, integer_ids[ - :, passage_offset:passage_end_offset], - integer_mask[:, passage_offset:passage_end_offset]) - push!(D, D_) - push!(mask, mask_) - D_, mask_ = nothing, nothing - end - - # concat embeddings and masks, and put them in the original order - D, mask = cat(D..., dims = 3)[:, :, reverse_indices], - cat(mask..., dims = 3)[:, :, reverse_indices] - mask = reshape(mask, size(mask)[2:end]) - - # get doclens, i.e number of attended tokens for each passage - doclens = vec(sum(mask, dims = 1)) - - # flatten out embeddings, i.e get embeddings for each token in each passage - D = reshape(D, size(D)[1], prod(size(D)[2:end])) - - # remove embeddings for masked tokens - D = D[:, reshape(mask, prod(size(mask)))] - - @assert ndims(D)==2 "ndims(D): $(ndims(D))" - @assert size(D)[2]==sum(doclens) "size(D): $(size(D)), sum(doclens): $(sum(doclens))" - @assert D isa AbstractMatrix{Float32} "$(typeof(D))" - @assert doclens isa AbstractVector{Int64} "$(typeof(doclens))" - - Flux.cpu(D), Flux.cpu(doclens) - end +function doc(bert::HF.HGFBertModel, linear::Layers.Dense, + integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}) + linear(bert((token = integer_ids, + attention_mask = NeuralAttentionlib.GenericSequenceMask(bitmask))).hidden_state) end """ @@ -706,3 +648,156 @@ function queryFromText(config::ColBERTConfig, Flux.cpu(Q) end + +function _clear_masked_embeddings!(D::AbstractArray{Float32, 3}, + integer_ids::AbstractMatrix{Int32}, skiplist::Vector{Int}) + @assert isequal(size(D)[2:end], size(integer_ids)) + "size(D): $(size(D)), size(integer_ids): $(size(integer_ids))" + + # set everything to true + mask = similar(integer_ids, Bool) # respects the device as well + mask .= true + mask_skiplist!(mask, integer_ids, skiplist) # (doc_maxlen, current_batch_size) + mask = reshape(mask, (1, size(mask)...)) # (1, doc_maxlen, current_batch_size) + + @assert isequal(size(mask)[2:end], size(D)[2:end]) + "size(mask): $(size(mask)), size(D): $(size(D))" + @assert mask isa AbstractArray{Bool} "$(typeof(mask))" + + D .= D .* mask # clear embeddings of masked tokens + mask +end + +function _flatten_embeddings(D::AbstractArray{Float32, 3}) + reshape(D, size(D, 1), prod(size(D)[2:end])) +end + +function _remove_masked_tokens( + D::AbstractMatrix{Float32}, mask::AbstractMatrix{Bool}) + D[:, reshape(mask, prod(size(mask)))] +end + +""" + encode_passages( + config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) + +Encode a list of passages using `checkpoint`. + +The given `passages` are run through the underlying BERT model and the linear layer to +generate the embeddings, after doing relevant document-specific preprocessing. +See [`docFromText`](@ref) for more details. + +# Arguments + + - `config`: The [`ColBERTConfig`](@ref) to be used. + - `checkpoint`: The [`Checkpoint`](@ref) used to encode the passages. + - `passages`: A list of strings representing the passages to be encoded. + +# Returns + +A tuple `embs, doclens` where: + + - `embs::AbstractMatrix{Float32}`: The full embedding matrix. Of shape `(D, N)`, + where `D` is the embedding dimension and `N` is the total number of embeddings + across all the passages. + - `doclens::AbstractVector{Int}`: A vector of document lengths for each passage, + i.e the total number of attended tokens for each document passage. + +# Examples + +```julia-repl +julia> using ColBERT: load_hgf_pretrained_local, ColBERTConfig, encode_passages; + +julia> using CUDA, Flux, Transformers, TextEncodeBase; + +julia> config = ColBERTConfig(); + +julia> dim = config.dim +128 + +julia> index_bsize = 128; # this is the batch size to be fed in the transformer + +julia> doc_maxlen = config.doc_maxlen +300 + +julia> doc_token = config.doc_token_id +"[unused1]" + +julia> tokenizer, bert, linear = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/"); + +julia> process = tokenizer.process; + +julia> truncpad_pipe = Pipeline{:token}( + TextEncodeBase.trunc_or_pad(doc_maxlen - 1, "[PAD]", :tail, :tail), + :token); + +julia> process = process[1:4] |> truncpad_pipe |> process[6:end]; + +julia> tokenizer = TextEncoders.BertTextEncoder( + tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, + endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); + +julia> bert = bert |> Flux.gpu; + +julia> linear = linear |> Flux.gpu; + +julia> passages = readlines("./downloads/lotte/lifestyle/dev/collection.tsv")[1:50000]; + +julia> punctuations_and_padsym = [string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")); + tokenizer.padsym]; + +julia> skiplist = [lookup(tokenizer.vocab, sym) + for sym in punctuations_and_padsym]; + +julia> @time encode_passages(bert, linear, tokenizer, passages, dim, index_bsize, doc_token, skiplist) +``` +""" +function encode_passages(bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + passages::Vector{String}, dim::Int, index_bsize::Int, + doc_token::String, skiplist::Vector{Int}) + @info "Encoding $(length(passages)) passages." + length(passages) == 0 && return rand(Float32, dim, 0), rand(Int, 0) + + # batching here to avoid storing intermediate embeddings on GPU + embs, doclens = Vector{AbstractMatrix{Float32}}(), Vector{Int}() + for passage_offset in 1:index_bsize:length(passages) + passage_end_offset = min( + length(passages), passage_offset + index_bsize - 1) + + # get the token IDs and attention mask + integer_ids, bitmask = tensorize_docs( + doc_token, tokenizer, passages[passage_offset:passage_end_offset]) + + integer_ids = integer_ids |> Flux.gpu + bitmask = bitmask |> Flux.gpu + + # run the tokens and attention mask through the transformer + # and mask the skiplist tokens + D = doc(bert, linear, integer_ids, bitmask) # (dim, doc_maxlen, current_batch_size) + mask = _clear_masked_embeddings!(D, integer_ids, skiplist) # (1, doc_maxlen, current_batch_size) + + # normalize each embedding in D; along dims = 1 + _normalize_array!(D, dims = 1) + + # get the doclens by unsqueezing the mask + mask = reshape(mask, size(mask)[2:end]) # (doc_maxlen, current_batch_size) + doclens_ = vec(sum(mask, dims = 1)) + + # flatten out embeddings, i.e get embeddings for each token in each passage + D = _flatten_embeddings(D) + + # remove embeddings for masked tokens + D = _remove_masked_tokens(D, mask) + + @assert ndims(D)==2 "ndims(D): $(ndims(D))" + @assert size(D, 2)==sum(doclens_) "size(D): $(size(D)), sum(doclens): $(sum(doclens_))" + @assert D isa AbstractMatrix{Float32} "$(typeof(D))" + @assert doclens_ isa AbstractVector{Int64} "$(typeof(doclens_))" + + push!(embs, Flux.cpu(D)) + append!(doclens, Flux.cpu(doclens_)) + end + embs = cat(embs..., dims = 2) + embs, doclens +end