From f8288d9fe17403b7d994421faa211d1fa12229ae Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sat, 24 Aug 2024 23:31:05 +0530 Subject: [PATCH] Structural changes on the `tensorize` functions; making them independent of the config, and changing the way doc/query markers are added. Also updating examples. --- .../tokenization/doc_tokenization.jl | 99 ++++++-- .../tokenization/query_tokenization.jl | 238 +++++++++++------- 2 files changed, 218 insertions(+), 119 deletions(-) diff --git a/src/modelling/tokenization/doc_tokenization.jl b/src/modelling/tokenization/doc_tokenization.jl index 736be6c..583254d 100644 --- a/src/modelling/tokenization/doc_tokenization.jl +++ b/src/modelling/tokenization/doc_tokenization.jl @@ -1,5 +1,5 @@ """ - tensorize_docs(config::ColBERTConfig, + tensorize_docs(doc_token_id::String, tokenizer::TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}) @@ -26,11 +26,37 @@ A tuple containing the following is returned: # Examples ```julia-repl -julia> using ColBERT, Transformers; +julia> using ColBERT: tensorize_docs, load_hgf_pretrained_local; -julia> config = ColBERTConfig(); +julia> using Transformers, Transformers.TextEncoders, TextEncodeBase; -julia> tokenizer = Transformers.load_tokenizer(config.checkpoint); +julia> tokenizer = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/:tokenizer") + +# configure the tokenizers maxlen and padding/truncation +julia> doc_maxlen = 20; + +julia> process = tokenizer.process +Pipelines: + target[token] := TextEncodeBase.nestedcall(string_getvalue, source) + target[token] := Transformers.TextEncoders.grouping_sentence(target.token) + target[(token, segment)] := SequenceTemplate{String}([CLS]: Input[1]: [SEP]: (Input[2]: [SEP]:)...)(target.token) + target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(512))(target.token) + target[token] := TextEncodeBase.trunc_and_pad(512, [PAD], tail, tail)(target.token) + target[token] := TextEncodeBase.nested2batch(target.token) + target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment) + target[segment] := TextEncodeBase.nested2batch(target.segment) + target[sequence_mask] := identity(target.attention_mask) + target := (target.token, target.segment, target.attention_mask, target.sequence_mask) + +julia> truncpad_pipe = Pipeline{:token}( + TextEncodeBase.trunc_or_pad(doc_maxlen - 1, "[PAD]", :tail, :tail), + :token); + +julia> process = process[1:4] |> truncpad_pipe |> process[6:end]; + +julia> tokenizer = TextEncoders.BertTextEncoder( + tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, + endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); julia> batch_text = [ "hello world", @@ -40,11 +66,12 @@ julia> batch_text = [ "this is an even longer document. this is some longer text, so length should be longer", ]; -julia> integer_ids, integer_mask = ColBERT.tensorize_docs(config, tokenizer, batch_text) -(Int32[102 102 … 102 102; 3 3 … 3 3; … ; 1 1 … 1 2937; 1 1 … 1 103], Bool[1 1 … 1 1; 1 1 … 1 1; … ; 0 0 … 0 1; 0 0 … 0 1]) +julia> integer_ids, integer_mask = tensorize_docs( + "[unused1]", tokenizer, batch_text) +(Int32[102 102 … 102 102; 3 3 … 3 3; … ; 1 1 … 1 2023; 1 1 … 1 2937], Bool[1 1 … 1 1; 1 1 … 1 1; … ; 0 0 … 0 1; 0 0 … 0 1]) julia> integer_ids -21×5 reinterpret(Int32, ::Matrix{PrimitiveOneHot.OneHot{0x0000773a}}): +20×5 Matrix{Int32}: 102 102 102 102 102 3 3 3 3 3 7593 4068 1038 2024 2024 @@ -65,10 +92,9 @@ julia> integer_ids 1 1 1 1 2324 1 1 1 1 2023 1 1 1 1 2937 - 1 1 1 1 103 julia> integer_mask -21×5 Matrix{Bool}: +20×5 Matrix{Bool}: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 @@ -89,31 +115,56 @@ julia> integer_mask 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 - 0 0 0 0 1 +julia> TextEncoders.decode(tokenizer, integer_ids) +20×5 Matrix{String}: + "[CLS]" "[CLS]" "[CLS]" "[CLS]" "[CLS]" + "[unused1]" "[unused1]" "[unused1]" "[unused1]" "[unused1]" + "hello" "thank" "a" "this" "this" + "world" "you" "[SEP]" "is" "is" + "[SEP]" "!" "[PAD]" "some" "an" + "[PAD]" "[SEP]" "[PAD]" "longer" "even" + "[PAD]" "[PAD]" "[PAD]" "text" "longer" + "[PAD]" "[PAD]" "[PAD]" "," "document" + "[PAD]" "[PAD]" "[PAD]" "so" "." + "[PAD]" "[PAD]" "[PAD]" "length" "this" + "[PAD]" "[PAD]" "[PAD]" "should" "is" + "[PAD]" "[PAD]" "[PAD]" "be" "some" + "[PAD]" "[PAD]" "[PAD]" "longer" "longer" + "[PAD]" "[PAD]" "[PAD]" "[SEP]" "text" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "," + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "so" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "length" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "should" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "be" + "[PAD]" "[PAD]" "[PAD]" "[PAD]" "longer" ``` """ -function tensorize_docs(config::ColBERTConfig, +function tensorize_docs(doc_token_id::String, tokenizer::TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}) - # placeholder for [D] marker token - batch_text = [". " * doc for doc in batch_text] - + # we assume that tokenizer is configured to have maxlen: doc_maxlen - 1 # getting the integer ids and masks - encoded_text = Transformers.TextEncoders.encode(tokenizer, batch_text) + encoded_text = TextEncoders.encode(tokenizer, batch_text) ids, mask = encoded_text.token, encoded_text.attention_mask integer_ids = reinterpret(Int32, ids) integer_mask = NeuralAttentionlib.getmask(mask, ids)[1, :, :] - # adding the [D] marker token ID - D_marker_token_id = TextEncodeBase.lookup( - tokenizer.vocab, config.doc_token_id) - integer_ids[2, :] .= D_marker_token_id - - @assert isequal(size(integer_ids), size(integer_mask)) "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(integer_mask)" - @assert isequal(size(integer_ids)[2], length(batch_text)) - @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" - @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" + # adding the [D] marker token ID as the second token + # first one is always the "[CLS]" token + D_marker_token_id = lookup(tokenizer.vocab, doc_token_id) |> Int32 + integer_ids = [integer_ids[begin:1, :]; + fill(D_marker_token_id, (1, length(batch_text))); + integer_ids[2:end, :]] + integer_mask = [integer_mask[begin:1, :]; + fill(true, (1, length(batch_text))); integer_mask[2:end, :]] + + @assert isequal(size(integer_ids), size(integer_mask)) + "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(integer_mask)" + @assert isequal(size(integer_ids, 2), length(batch_text)) + "size(integer_ids): $(size(integer_ids)), length(batch_text): $(length(batch_text))" + @assert integer_ids isa Matrix{Int32} "$(typeof(integer_ids))" + @assert integer_mask isa Matrix{Bool} "$(typeof(integer_mask))" integer_ids, integer_mask end diff --git a/src/modelling/tokenization/query_tokenization.jl b/src/modelling/tokenization/query_tokenization.jl index efbfe56..83b7ab5 100644 --- a/src/modelling/tokenization/query_tokenization.jl +++ b/src/modelling/tokenization/query_tokenization.jl @@ -1,5 +1,5 @@ """ - tensorize_queries(config::ColBERTConfig, + tensorize_queries(query_token_id::String, attend_to_mask_tokens::Bool, tokenizer::TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}) @@ -30,130 +30,178 @@ config. Note that, at the time of writing this package, configuring tokenizers i clean interface; so, we have to manually configure the tokenizer. ```julia-repl -julia> using ColBERT, Transformers, TextEncodeBase; +julia> using ColBERT: tensorize_queries, load_hgf_pretrained_local; -julia> config = ColBERTConfig(); +julia> using Transformers, Transformers.TextEncoders, TextEncodeBase; -julia> tokenizer = Transformers.load_tokenizer(config.checkpoint); +julia> tokenizer = load_hgf_pretrained_local("/home/codetalker7/models/colbertv2.0/:tokenizer") + +# configure the tokenizers maxlen and padding/truncation +julia> query_maxlen = 32; julia> process = tokenizer.process; julia> truncpad_pipe = Pipeline{:token}( - TextEncodeBase.trunc_or_pad(config.query_maxlen, "[PAD]", :tail, :tail), - :token); + TextEncodeBase.trunc_or_pad(query_maxlen - 1, "[PAD]", :tail, :tail), + :token); julia> process = process[1:4] |> truncpad_pipe |> process[6:end]; julia> tokenizer = TextEncoders.BertTextEncoder( - tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, - endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); + tokenizer.tokenizer, tokenizer.vocab, process; startsym = tokenizer.startsym, + endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc); -julia> queries = [ +julia> batch_text = [ "what are white spots on raspberries?", - "what do rabbits eat?" + "what do rabbits eat?", + "this is a really long query. I'm deliberately making this long"* + "so that you can actually see that this is really truncated at 32 tokens"* + "and that the other two queries are padded to get 32 tokens."* + "this makes this a nice query as an example." ]; -julia> integer_ids, integer_mask = ColBERT.tensorize_queries(config, tokenizer, queries); - -julia> 32×2 reinterpret(Int32, ::Matrix{OneHot{0x0000773a}}): - 102 102 - 2 2 - 2055 2055 - 2025 2080 - 2318 20404 - 7517 4522 - 2007 1030 - 20711 103 - 2362 104 - 20969 104 - 1030 104 - 103 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 - 104 104 +julia> integer_ids, integer_mask = tensorize_queries( + "[unused0]", false, tokenizer, batch_text); +(Int32[102 102 102; 2 2 2; … ; 104 104 8792; 104 104 2095], Bool[1 1 1; 1 1 1; … ; 0 0 1; 0 0 1]) + +julia> integer_ids +32×3 Matrix{Int32}: + 102 102 102 + 2 2 2 + 2055 2055 2024 + 2025 2080 2004 + 2318 20404 1038 + 7517 4522 2429 + 2007 1030 2147 + 20711 103 23033 + 2362 104 1013 + 20969 104 1046 + 1030 104 1006 + 103 104 1050 + 104 104 9970 + 104 104 2438 + 104 104 2024 + 104 104 2147 + 104 104 6500 + 104 104 2009 + 104 104 2018 + 104 104 2065 + 104 104 2942 + 104 104 2157 + 104 104 2009 + 104 104 2024 + 104 104 2004 + 104 104 2429 + 104 104 25450 + 104 104 2013 + 104 104 3591 + 104 104 19205 + 104 104 8792 + 104 104 2095 julia> integer_mask -32×2 Matrix{Bool}: - 1 1 - 1 1 - 1 1 - 1 1 - 1 1 - 1 1 - 1 1 - 1 1 - 1 0 - 1 0 - 1 0 - 1 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - 0 0 - +32×3 Matrix{Bool}: + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 1 1 + 1 0 1 + 1 0 1 + 1 0 1 + 1 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + 0 0 1 + +julia> TextEncoders.decode(tokenizer, integer_ids) +32×3 Matrix{String}: + "[CLS]" "[CLS]" "[CLS]" + "[unused0]" "[unused0]" "[unused0]" + "what" "what" "this" + "are" "do" "is" + "white" "rabbits" "a" + "spots" "eat" "really" + "on" "?" "long" + "ras" "[SEP]" "query" + "##p" "[MASK]" "." + "##berries" "[MASK]" "i" + "?" "[MASK]" "'" + "[SEP]" "[MASK]" "m" + "[MASK]" "[MASK]" "deliberately" + "[MASK]" "[MASK]" "making" + "[MASK]" "[MASK]" "this" + "[MASK]" "[MASK]" "long" + "[MASK]" "[MASK]" "##so" + "[MASK]" "[MASK]" "that" + "[MASK]" "[MASK]" "you" + "[MASK]" "[MASK]" "can" + "[MASK]" "[MASK]" "actually" + "[MASK]" "[MASK]" "see" + "[MASK]" "[MASK]" "that" + "[MASK]" "[MASK]" "this" + "[MASK]" "[MASK]" "is" + "[MASK]" "[MASK]" "really" + "[MASK]" "[MASK]" "truncated" + "[MASK]" "[MASK]" "at" + "[MASK]" "[MASK]" "32" + "[MASK]" "[MASK]" "token" + "[MASK]" "[MASK]" "##san" + "[MASK]" "[MASK]" "##d" ``` """ -function tensorize_queries(config::ColBERTConfig, +function tensorize_queries(query_token_id::String, attend_to_mask_tokens::Bool, tokenizer::TextEncoders.AbstractTransformerTextEncoder, batch_text::Vector{String}) - # placeholder for [Q] marker token - batch_text = [". " * query for query in batch_text] - + # we assume that tokenizer is configured to have maxlen: query_maxlen - 1 # getting the integer ids and masks encoded_text = Transformers.TextEncoders.encode(tokenizer, batch_text) ids, mask = encoded_text.token, encoded_text.attention_mask integer_ids = reinterpret(Int32, ids) integer_mask = NeuralAttentionlib.getmask(mask, ids)[1, :, :] - @assert isequal(size(integer_ids), size(integer_mask)) "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(size(integer_mask))" - @assert isequal( - size(integer_ids)[1], config.query_maxlen) "size(integer_ids): $(size(integer_ids)), query_maxlen: $(query_tokenizer.config.query_maxlen)" - @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" - @assert integer_mask isa AbstractMatrix{Bool} "$(typeof(integer_mask))" # adding the [Q] marker token ID and [MASK] augmentation Q_marker_token_id = TextEncodeBase.lookup( - tokenizer.vocab, config.query_token_id) - mask_token_id = TextEncodeBase.lookup(tokenizer.vocab, "[MASK]") - integer_ids[2, :] .= Q_marker_token_id - integer_ids[integer_ids .== 1] .= mask_token_id - - if config.attend_to_mask_tokens + tokenizer.vocab, query_token_id) |> Int32 + mask_token_id = TextEncodeBase.lookup(tokenizer.vocab, "[MASK]") |> Int32 + pad_token_id = TextEncodeBase.lookup( + tokenizer.vocab, tokenizer.config.padsym) |> Int32 + integer_ids = [integer_ids[begin:1, :]; + fill(Q_marker_token_id, (1, length(batch_text))); + integer_ids[2:end, :]] + integer_ids[integer_ids .== pad_token_id] .= mask_token_id + integer_mask = [integer_mask[begin:1, :]; + fill(true, (1, length(batch_text))); integer_mask[2:end, :]] + + if attend_to_mask_tokens integer_mask[integer_ids .== mask_token_id] .= 1 - @assert isequal(sum(integer_mask), prod(size(integer_mask))) "sum(integer_mask): $(sum(integer_mask)), prod(size(integer_mask)): $(prod(size(integer_mask)))" + @assert isequal(sum(integer_mask), prod(size(integer_mask))) + "sum(integer_mask): $(sum(integer_mask)), prod(size(integer_mask)): $(prod(size(integer_mask)))" end + @assert isequal(size(integer_ids), size(integer_mask)) + "size(integer_ids): $(size(integer_ids)), size(integer_mask): $(size(integer_mask))" + @assert integer_ids isa Matrix{Int32} "$(typeof(integer_ids))" + @assert integer_mask isa Matrix{Bool} "$(typeof(integer_mask))" integer_ids, integer_mask end