diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 0270bc1..fdabf72 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -227,6 +227,33 @@ function doc(bert::HF.HGFBertModel, linear::Layers.Dense, attention_mask = NeuralAttentionlib.GenericSequenceMask(bitmask))).hidden_state) end +function _doc_embeddings_and_doclens( + bert::HF.HGFBertModel, linear::Layers.Dense, skiplist::Vector{Int}, + integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}) + D = doc(bert, linear, integer_ids, bitmask) # (dim, doc_maxlen, current_batch_size) + mask = _clear_masked_embeddings!(D, integer_ids, skiplist) # (1, doc_maxlen, current_batch_size) + + # normalize each embedding in D; along dims = 1 + _normalize_array!(D, dims = 1) + + # get the doclens by unsqueezing the mask + mask = reshape(mask, size(mask)[2:end]) # (doc_maxlen, current_batch_size) + doclens = vec(sum(mask, dims = 1)) + + # flatten out embeddings, i.e get embeddings for each token in each passage + D = _flatten_embeddings(D) # (dim, total_num_embeddings) + + # remove embeddings for masked tokens + D = _remove_masked_tokens(D, mask) # (dim, total_num_masked_embeddings) + + @assert ndims(D)==2 "ndims(D): $(ndims(D))" + @assert size(D, 2)==sum(doclens) "size(D): $(size(D)), sum(doclens): $(sum(doclens))" + @assert D isa AbstractMatrix{Float32} "$(typeof(D))" + @assert doclens isa AbstractVector{Int64} "$(typeof(doclens))" + + D, doclens +end + """ query( config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, @@ -535,7 +562,7 @@ julia> bert = bert |> Flux.gpu; julia> linear = linear |> Flux.gpu; -julia> passages = readlines("./downloads/lotte/lifestyle/dev/collection.tsv")[1:50000]; +julia> passages = readlines("./downloads/lotte/lifestyle/dev/collection.tsv")[1:1000]; julia> punctuations_and_padsym = [string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")); tokenizer.padsym]; @@ -543,16 +570,12 @@ julia> punctuations_and_padsym = [string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\] julia> skiplist = [lookup(tokenizer.vocab, sym) for sym in punctuations_and_padsym]; -julia> @time encode_passages(bert, linear, tokenizer, passages, dim, index_bsize, doc_token, skiplist) - -julia> passages = [ - "hello world", - "thank you!", - "a", - "this is some longer text, so length should be longer", -]; +julia> @time embs, doclens = encode_passages( + bert, linear, tokenizer, passages, dim, index_bsize, doc_token, skiplist) # second run stats +[ Info: Encoding 1000 passages. + 25.247094 seconds (29.65 M allocations: 1.189 GiB, 37.26% gc time, 0.00% compilation time) +(Float32[-0.08001435 -0.10785186 … -0.08651956 -0.12118215; 0.07319974 0.06629379 … 0.0929825 0.13665271; … ; -0.037957724 -0.039623592 … 0.031274226 0.063107446; 0.15484622 0.16779025 … 0.11533891 0.11508792], [279, 117, 251, 105, 133, 170, 181, 115, 190, 132 … 76, 204, 199, 244, 256, 125, 251, 261, 262, 263]) -julia> @time embs, doclen = encode_passages(bert, linear, tokenizer, passages, dim, index_bsize, doc_token, skiplist) ``` """ function encode_passages(bert::HF.HGFBertModel, linear::Layers.Dense, @@ -577,27 +600,8 @@ function encode_passages(bert::HF.HGFBertModel, linear::Layers.Dense, # run the tokens and attention mask through the transformer # and mask the skiplist tokens - D = doc(bert, linear, integer_ids, bitmask) # (dim, doc_maxlen, current_batch_size) - mask = _clear_masked_embeddings!(D, integer_ids, skiplist) # (1, doc_maxlen, current_batch_size) - - # normalize each embedding in D; along dims = 1 - _normalize_array!(D, dims = 1) - - # get the doclens by unsqueezing the mask - mask = reshape(mask, size(mask)[2:end]) # (doc_maxlen, current_batch_size) - doclens_ = vec(sum(mask, dims = 1)) - - # flatten out embeddings, i.e get embeddings for each token in each passage - D = _flatten_embeddings(D) # (dim, total_num_embeddings) - - # remove embeddings for masked tokens - D = _remove_masked_tokens(D, mask) # (dim, total_num_masked_embeddings) - - @assert ndims(D)==2 "ndims(D): $(ndims(D))" - @assert size(D, 1) == dim "size(D): $(size(D)), dim: $(dim)" - @assert size(D, 2)==sum(doclens_) "size(D): $(size(D)), sum(doclens): $(sum(doclens_))" - @assert D isa AbstractMatrix{Float32} "$(typeof(D))" - @assert doclens_ isa AbstractVector{Int64} "$(typeof(doclens_))" + D, doclens_ = _doc_embeddings_and_doclens( + bert, linear, skiplist, integer_ids, bitmask) push!(embs, Flux.cpu(D)) append!(doclens, Flux.cpu(doclens_))