Skip to content

Commit

Permalink
Writing separate function to get doc embeddings and doclens, and
Browse files Browse the repository at this point in the history
updating the example.
  • Loading branch information
codetalker7 committed Aug 26, 2024
1 parent b8fc0ec commit b20857d
Showing 1 changed file with 35 additions and 31 deletions.
66 changes: 35 additions & 31 deletions src/modelling/checkpoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,33 @@ function doc(bert::HF.HGFBertModel, linear::Layers.Dense,
attention_mask = NeuralAttentionlib.GenericSequenceMask(bitmask))).hidden_state)
end

function _doc_embeddings_and_doclens(
bert::HF.HGFBertModel, linear::Layers.Dense, skiplist::Vector{Int},
integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool})
D = doc(bert, linear, integer_ids, bitmask) # (dim, doc_maxlen, current_batch_size)
mask = _clear_masked_embeddings!(D, integer_ids, skiplist) # (1, doc_maxlen, current_batch_size)

# normalize each embedding in D; along dims = 1
_normalize_array!(D, dims = 1)

# get the doclens by unsqueezing the mask
mask = reshape(mask, size(mask)[2:end]) # (doc_maxlen, current_batch_size)
doclens = vec(sum(mask, dims = 1))

# flatten out embeddings, i.e get embeddings for each token in each passage
D = _flatten_embeddings(D) # (dim, total_num_embeddings)

# remove embeddings for masked tokens
D = _remove_masked_tokens(D, mask) # (dim, total_num_masked_embeddings)

@assert ndims(D)==2 "ndims(D): $(ndims(D))"
@assert size(D, 2)==sum(doclens) "size(D): $(size(D)), sum(doclens): $(sum(doclens))"
@assert D isa AbstractMatrix{Float32} "$(typeof(D))"
@assert doclens isa AbstractVector{Int64} "$(typeof(doclens))"

D, doclens
end

"""
query(
config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32},
Expand Down Expand Up @@ -535,24 +562,20 @@ julia> bert = bert |> Flux.gpu;
julia> linear = linear |> Flux.gpu;
julia> passages = readlines("./downloads/lotte/lifestyle/dev/collection.tsv")[1:50000];
julia> passages = readlines("./downloads/lotte/lifestyle/dev/collection.tsv")[1:1000];
julia> punctuations_and_padsym = [string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~"));
tokenizer.padsym];
julia> skiplist = [lookup(tokenizer.vocab, sym)
for sym in punctuations_and_padsym];
julia> @time encode_passages(bert, linear, tokenizer, passages, dim, index_bsize, doc_token, skiplist)
julia> passages = [
"hello world",
"thank you!",
"a",
"this is some longer text, so length should be longer",
];
julia> @time embs, doclens = encode_passages(
bert, linear, tokenizer, passages, dim, index_bsize, doc_token, skiplist) # second run stats
[ Info: Encoding 1000 passages.
25.247094 seconds (29.65 M allocations: 1.189 GiB, 37.26% gc time, 0.00% compilation time)
(Float32[-0.08001435 -0.10785186 … -0.08651956 -0.12118215; 0.07319974 0.06629379 … 0.0929825 0.13665271; … ; -0.037957724 -0.039623592 … 0.031274226 0.063107446; 0.15484622 0.16779025 … 0.11533891 0.11508792], [279, 117, 251, 105, 133, 170, 181, 115, 190, 132 … 76, 204, 199, 244, 256, 125, 251, 261, 262, 263])
julia> @time embs, doclen = encode_passages(bert, linear, tokenizer, passages, dim, index_bsize, doc_token, skiplist)
```
"""
function encode_passages(bert::HF.HGFBertModel, linear::Layers.Dense,
Expand All @@ -577,27 +600,8 @@ function encode_passages(bert::HF.HGFBertModel, linear::Layers.Dense,

# run the tokens and attention mask through the transformer
# and mask the skiplist tokens
D = doc(bert, linear, integer_ids, bitmask) # (dim, doc_maxlen, current_batch_size)
mask = _clear_masked_embeddings!(D, integer_ids, skiplist) # (1, doc_maxlen, current_batch_size)

# normalize each embedding in D; along dims = 1
_normalize_array!(D, dims = 1)

# get the doclens by unsqueezing the mask
mask = reshape(mask, size(mask)[2:end]) # (doc_maxlen, current_batch_size)
doclens_ = vec(sum(mask, dims = 1))

# flatten out embeddings, i.e get embeddings for each token in each passage
D = _flatten_embeddings(D) # (dim, total_num_embeddings)

# remove embeddings for masked tokens
D = _remove_masked_tokens(D, mask) # (dim, total_num_masked_embeddings)

@assert ndims(D)==2 "ndims(D): $(ndims(D))"
@assert size(D, 1) == dim "size(D): $(size(D)), dim: $(dim)"
@assert size(D, 2)==sum(doclens_) "size(D): $(size(D)), sum(doclens): $(sum(doclens_))"
@assert D isa AbstractMatrix{Float32} "$(typeof(D))"
@assert doclens_ isa AbstractVector{Int64} "$(typeof(doclens_))"
D, doclens_ = _doc_embeddings_and_doclens(
bert, linear, skiplist, integer_ids, bitmask)

push!(embs, Flux.cpu(D))
append!(doclens, Flux.cpu(doclens_))
Expand Down

0 comments on commit b20857d

Please sign in to comment.