From 48720bd04f60ac3f6630ead2b043b755a15495a2 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 11 Aug 2024 17:27:12 +0530 Subject: [PATCH] Adding the `config` as an argument to the `query` and `queryFromText` functions. --- src/modelling/checkpoint.jl | 204 ++++++++++++++++++++++++------------ 1 file changed, 139 insertions(+), 65 deletions(-) diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index 42384b6..928c8ca 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -200,11 +200,15 @@ Otherwise, all tokens are included in the mask. # Returns -An array of booleans indicating whether the corresponding token ID is included in the mask or not. The array has the same shape as `integer_ids`, i.e `(L, N)`, where `L` is the maximum length of any document in `integer_ids` and `N` is the number of documents. +An array of booleans indicating whether the corresponding token ID +is included in the mask or not. The array has the same shape as +`integer_ids`, i.e `(L, N)`, where `L` is the maximum length of +any document in `integer_ids` and `N` is the number of documents. # Examples -Continuing with the example for [`tensorize_docs`](@ref) and the `skiplist` from the example in [`Checkpoint`](@ref). +Continuing with the example for [`tensorize_docs`](@ref) and the +`skiplist` from the example in [`Checkpoint`](@ref). ```julia-repl julia> integer_ids = batches[1][1]; @@ -252,7 +256,7 @@ Compute the hidden state of the BERT and linear layers of ColBERT for documents. # Arguments - - `config`: The [`ColBERTConfig`](@ref) being used. + - `config`: The [`ColBERTConfig`](@ref) being used. - `checkpoint`: The [`Checkpoint`](@ref) containing the layers to compute the embeddings. - `integer_ids`: An array of token IDs to be fed into the BERT model. - `integer_mask`: An array of corresponding attention masks. Should have the same shape as `integer_ids`. @@ -261,8 +265,13 @@ Compute the hidden state of the BERT and linear layers of ColBERT for documents. A tuple `D, mask`, where: - - `D` is an array containing the normalized embeddings for each token in each document. It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of any document and `N` is the total number of documents. - - `mask` is an array containing attention masks for all documents, after masking out any tokens in the `skiplist` of `checkpoint`. It has shape `(1, L, N)`, where `(L, N)` is the same as described above. + - `D` is an array containing the normalized embeddings for each token in each document. + It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer + of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of + any document and `N` is the total number of documents. + - `mask` is an array containing attention masks for all documents, after masking out any + tokens in the `skiplist` of `checkpoint`. It has shape `(1, L, N)`, where `(L, N)` + is the same as described above. # Examples @@ -339,7 +348,10 @@ This function also applies ColBERT-style document pre-processing for each docume # Returns -A tuple `embs, doclens`, where `embs` is an array of embeddings and `doclens` is a `Vector` of document lengths. The array `embs` has shape `(D, N)`, where `D` is the embedding dimension (`128` for ColBERT's linear layer) and `N` is the total number of embeddings across all documents in `docs`. +A tuple `embs, doclens`, where `embs` is an array of embeddings and `doclens` is a `Vector` +of document lengths. The array `embs` has shape `(D, N)`, where `D` is the embedding +dimension (`128` for ColBERT's linear layer) and `N` is the total number of embeddings +across all documents in `docs`. # Examples @@ -444,56 +456,85 @@ function docFromText(config::ColBERTConfig, checkpoint::Checkpoint, end """ - query(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, + query( + config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}) Compute the hidden state of the BERT and linear layers of ColBERT for queries. # Arguments + - `config`: The [`ColBERTConfig`](@ref) to be used. - `checkpoint`: The [`Checkpoint`](@ref) containing the layers to compute the embeddings. - `integer_ids`: An array of token IDs to be fed into the BERT model. - `integer_mask`: An array of corresponding attention masks. Should have the same shape as `integer_ids`. # Returns -`Q`, where `Q` is an array containing the normalized embeddings for each token in the query matrix. It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of any query and `N` is the total number of queries. +`Q`, where `Q` is an array containing the normalized embeddings for each token in the query matrix. +It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer of ColBERT), +and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of any query and `N` is +the total number of queries. # Examples -Continuing from the queries example for [`tensorize`](@ref) and [`Checkpoint`](@ref): +Continuing from the queries example for [`tensorize_queries`](@ref) and [`Checkpoint`](@ref): ```julia-repl -julia> query(checkPoint, integer_ids, integer_mask) -128×32×1 Array{Float32, 3}: +julia> ColBERT.query(config, checkpoint, integer_ids, integer_mask) +128×32×1 CuArray{Float32, 3, CUDA.DeviceMemory}: [:, :, 1] = - 0.0158567 0.169676 0.092745 0.0798617 … 0.115938 0.112977 0.107919 - 0.220185 0.0304873 0.165348 0.150315 0.0168762 0.0178042 0.0200357 - -0.00790007 -0.0192251 -0.0852364 -0.0799609 -0.0777439 -0.0776733 -0.0830504 - -0.109909 -0.170906 -0.0138702 -0.0409767 -0.126037 -0.126829 -0.13149 - -0.0231786 0.0532214 0.0607473 0.0279048 0.117017 0.114073 0.108536 - 0.0620549 0.0465075 0.0821693 0.0606439 … 0.0150612 0.0133353 0.0126583 - -0.0290509 0.143255 0.0306142 0.042658 -0.164401 -0.161857 -0.160327 - 0.0921477 0.0588331 0.250449 0.234636 0.0664076 0.0659837 0.0711357 - 0.0279402 -0.0278357 0.144855 0.147958 0.154552 0.155525 0.163634 - -0.0768143 -0.00587305 0.00543038 0.00443374 -0.11757 -0.112495 -0.11112 - ⋮ ⋱ ⋮ - -0.0859686 0.0623054 0.0974813 0.126841 0.0182795 0.0230549 0.031103 - 0.0392043 0.0162653 0.0926306 0.104053 0.0491495 0.0484318 0.0438132 - -0.0340363 -0.0278066 -0.0181035 -0.0282369 … -0.0617945 -0.0631367 -0.0675882 - 0.013123 0.0565132 -0.0349061 -0.0464192 0.0724731 0.0780166 0.074623 - -0.117425 0.162483 0.11039 0.136364 -0.00538225 -0.00685449 -0.0019436 - -0.0401158 -0.0045094 0.0539569 0.0689953 -0.00518063 -0.00600252 -0.00771469 - 0.0893983 0.0695061 -0.0499409 -0.035411 0.0960932 0.0961893 0.103431 - -0.116265 -0.106331 -0.179832 -0.149728 … -0.0197172 -0.022061 -0.018135 - -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0699095 -0.0684749 -0.0662904 - 0.100019 -0.0618588 0.106134 0.0989047 -0.0556761 -0.0556784 -0.059571 + 0.0158568 0.169676 0.092745 0.0798617 0.153079 … 0.117006 0.115806 0.115938 0.112977 0.107919 + 0.220185 0.0304873 0.165348 0.150315 -0.0116249 0.0173332 0.0165187 0.0168762 0.0178042 0.0200356 + -0.00790017 -0.0192251 -0.0852365 -0.0799609 -0.0465292 -0.0693319 -0.0737462 -0.0777439 -0.0776733 -0.0830504 + -0.109909 -0.170906 -0.0138701 -0.0409766 -0.177391 -0.113141 -0.118738 -0.126037 -0.126829 -0.13149 + -0.0231787 0.0532214 0.0607473 0.0279048 0.0634681 0.112296 0.111831 0.117017 0.114073 0.108536 + 0.0620549 0.0465075 0.0821693 0.0606439 0.0592031 … 0.0167847 0.0148605 0.0150612 0.0133353 0.0126583 + -0.0290508 0.143255 0.0306142 0.0426579 0.129972 -0.17502 -0.169493 -0.164401 -0.161857 -0.160327 + 0.0921475 0.058833 0.250449 0.234636 0.0412965 0.0590262 0.0642577 0.0664076 0.0659837 0.0711358 + 0.0279402 -0.0278357 0.144855 0.147958 -0.0268559 0.161106 0.157629 0.154552 0.155525 0.163634 + -0.0768143 -0.00587302 0.00543038 0.00443376 -0.0134111 -0.126912 -0.123969 -0.11757 -0.112495 -0.11112 + -0.0184337 0.00668561 -0.191863 -0.161345 0.0222466 … -0.103246 -0.10374 -0.107664 -0.107267 -0.114564 + 0.0112104 0.0214651 -0.0923963 -0.0823052 0.0600248 0.103589 0.103387 0.106261 0.105065 0.10409 + 0.110971 0.272576 0.148319 0.143233 0.239578 0.11224 0.107913 0.109914 0.112652 0.108365 + -0.131066 0.0376254 -0.0164237 -0.000193318 0.00344707 -0.0893371 -0.0919217 -0.0969305 -0.0935498 -0.096145 + -0.0402605 0.0350559 0.0162864 0.0269105 0.00968855 -0.0623393 -0.0670097 -0.070679 -0.0655848 -0.0564059 + 0.0799973 0.0482302 0.0712078 0.0792903 0.0108783 … 0.00820444 0.00854873 0.00889943 0.00932721 0.00751066 + -0.137565 -0.0369116 -0.065728 -0.0664102 -0.0238012 0.029041 0.0292468 0.0297059 0.0278639 0.0257616 + 0.0479746 -0.102338 -0.0557072 -0.0833976 -0.0979401 -0.057629 -0.053911 -0.0566325 -0.0568765 -0.0581378 + 0.0656851 0.0195639 0.0288789 0.0559219 0.0315515 0.0472323 0.054771 0.0596156 0.0541802 0.0525933 + 0.0668634 -0.00400549 0.0297102 0.0505045 -0.00082792 0.0414113 0.0400276 0.0361149 0.0325914 0.0260693 + -0.0691096 0.0348577 -0.000312685 0.0232462 -0.00250495 … -0.141874 -0.142026 -0.132163 -0.129679 -0.131122 + -0.0273036 0.0653352 0.0332689 0.017918 0.0875479 0.0500921 0.0471914 0.0469949 0.0434268 0.0442646 + -0.0981665 -0.0296463 -0.0114686 -0.0348033 -0.0468719 -0.0772672 -0.0805913 -0.0809244 -0.0823798 -0.081472 + ⋮ ⋱ ⋮ + 0.0506199 0.00290888 0.047947 0.063503 -0.0072114 0.0360347 0.0326486 0.033966 0.0327732 0.0261081 + -0.0288586 -0.150171 -0.0699125 -0.108002 -0.142865 -0.0775934 -0.072192 -0.0697569 -0.0715358 -0.0683193 + -0.0646991 0.0724608 -0.00767811 -0.0184348 0.0524162 0.0457267 0.0532778 0.0649795 0.0697126 0.0808413 + 0.0445508 0.0296366 0.0325647 0.0521935 0.0436496 0.129031 0.126605 0.12324 0.120497 0.117703 + -0.127301 -0.0224252 -0.00579415 -0.00877803 -0.0140665 … -0.080026 -0.080839 -0.0823464 -0.0803394 -0.0856279 + 0.0304881 0.0396951 0.0798097 0.0736797 0.0800866 0.0426674 0.0411406 0.0460205 0.0460111 0.0532082 + 0.0488798 0.252244 0.0866849 0.098552 0.251561 -0.0236942 -0.035116 -0.0395483 -0.0463498 -0.0494207 + -0.0296798 -0.0494761 0.00688248 0.0264166 -0.0352487 -0.0476357 -0.0435388 -0.0404835 -0.0410673 -0.0367272 + 0.023548 -0.00147361 0.0629259 0.106951 0.0406627 0.00627022 0.00403014 -0.000107777 -0.000898423 0.00296315 + -0.0574151 -0.0875744 -0.103787 -0.114166 -0.103979 … -0.0708782 -0.0700138 -0.0687795 -0.070967 -0.0636385 + 0.0280373 0.149767 -0.0899733 -0.0732524 0.162316 0.022177 0.0183834 0.0201251 0.0197228 0.0219051 + -0.0617143 -0.0573989 -0.0973785 -0.0805046 -0.0525925 0.0997715 0.102691 0.107432 0.108591 0.109502 + -0.0859687 0.0623054 0.0974813 0.126841 0.0595557 0.0187937 0.0191363 0.0182794 0.0230548 0.031103 + 0.0392044 0.0162653 0.0926306 0.104054 0.0509464 0.0559883 0.0553617 0.0491496 0.0484319 0.0438133 + -0.0340362 -0.0278067 -0.0181035 -0.0282369 -0.0490531 … -0.0564175 -0.0562518 -0.0617946 -0.0631367 -0.0675882 + 0.0131229 0.0565131 -0.0349061 -0.0464192 0.0456515 0.0676478 0.0698765 0.0724731 0.0780165 0.0746229 + -0.117425 0.162483 0.11039 0.136364 0.135339 -0.00432259 -0.00508357 -0.00538224 -0.00685447 -0.00194357 + -0.0401157 -0.00450943 0.0539568 0.0689953 -0.00295334 -0.00671544 -0.00322498 -0.00518066 -0.00600254 -0.0077147 + 0.0893984 0.0695061 -0.049941 -0.035411 0.0767663 0.0913505 0.0964841 0.0960931 0.0961892 0.103431 + -0.116265 -0.106331 -0.179832 -0.149728 -0.0913282 … -0.0287848 -0.0275017 -0.0197172 -0.0220611 -0.018135 + -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.180245 -0.0780865 -0.073571 -0.0699094 -0.0684748 -0.0662903 + 0.100019 -0.0618588 0.106134 0.0989047 -0.0885639 -0.0547317 -0.0553563 -0.055676 -0.0556784 -0.0595709 ``` """ -function query(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, +function query( + config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}) - use_gpu = checkpoint.config.use_gpu - integer_ids = integer_ids |> Flux.gpu integer_mask = integer_mask |> Flux.gpu @@ -509,7 +550,7 @@ function query(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, Q = Q .* mask - if !use_gpu + if !config.use_gpu # doing this because normalize gives exact results Q = mapslices(v -> iszero(v) ? v : normalize(v), Q, dims = 1) # normalize each embedding else @@ -531,7 +572,7 @@ function query(checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, end """ - queryFromText( + queryFromText(config::ColBERTConfig, checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{Missing, Int}) Get ColBERT embeddings for `queries` using `checkpoint`. @@ -540,13 +581,16 @@ This function also applies ColBERT-style query pre-processing for each query in # Arguments + - `config`: The [`ColBERTConfig`](@ref) to be used. - `checkpoint`: A [`Checkpoint`](@ref) to be used to compute embeddings. - `queries`: A list of queries to get the embeddings for. - `bsize`: A batch size for processing queries in batches. # Returns -`embs`, where `embs` is an array of embeddings. The array `embs` has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for ColBERT's linear layer), `L` is the maximum length of any query in the batch, and `N` is the total number of queries in `queries`. +`embs`, where `embs` is an array of embeddings. The array `embs` has shape `(D, L, N)`, +where `D` is the embedding dimension (`128` for ColBERT's linear layer), `L` is the +maximum length of any query in the batch, and `N` is the total number of queries in `queries`. # Examples @@ -555,34 +599,64 @@ Continuing from the example in [`Checkpoint`](@ref): ```julia-repl julia> queries = ["what are white spots on raspberries?"]; -julia> queryFromText(checkPoint, queries, 128) +julia> ColBERT.queryFromText(config, checkpoint, queries, 128) 128×32×1 Array{Float32, 3}: [:, :, 1] = - 0.0158567 0.169676 0.092745 0.0798617 … 0.115806 0.115938 0.112977 0.107919 - 0.220185 0.0304873 0.165348 0.150315 0.0165188 0.0168762 0.0178042 0.0200357 - -0.00790007 -0.0192251 -0.0852364 -0.0799609 -0.0737461 -0.0777439 -0.0776733 -0.0830504 - -0.109909 -0.170906 -0.0138702 -0.0409767 -0.118738 -0.126037 -0.126829 -0.13149 - -0.0231786 0.0532214 0.0607473 0.0279048 0.111831 0.117017 0.114073 0.108536 - 0.0620549 0.0465075 0.0821693 0.0606439 … 0.0148605 0.0150612 0.0133353 0.0126583 - -0.0290509 0.143255 0.0306142 0.042658 -0.169493 -0.164401 -0.161857 -0.160327 - 0.0921477 0.0588331 0.250449 0.234636 0.0642578 0.0664076 0.0659837 0.0711357 - 0.0279402 -0.0278357 0.144855 0.147958 0.157629 0.154552 0.155525 0.163634 - -0.0768143 -0.00587305 0.00543038 0.00443374 -0.123969 -0.11757 -0.112495 -0.11112 - -0.0184338 0.00668557 -0.191863 -0.161345 … -0.10374 -0.107664 -0.107267 -0.114564 - ⋮ ⋱ ⋮ - -0.0859686 0.0623054 0.0974813 0.126841 0.0191363 0.0182795 0.0230549 0.031103 - 0.0392043 0.0162653 0.0926306 0.104053 0.0553615 0.0491495 0.0484318 0.0438132 - -0.0340363 -0.0278066 -0.0181035 -0.0282369 … -0.0562518 -0.0617945 -0.0631367 -0.0675882 - 0.013123 0.0565132 -0.0349061 -0.0464192 0.0698766 0.0724731 0.0780166 0.074623 - -0.117425 0.162483 0.11039 0.136364 -0.0050836 -0.00538225 -0.00685449 -0.0019436 - -0.0401158 -0.0045094 0.0539569 0.0689953 -0.00322497 -0.00518063 -0.00600252 -0.00771469 - 0.0893983 0.0695061 -0.0499409 -0.035411 0.0964842 0.0960932 0.0961893 0.103431 - -0.116265 -0.106331 -0.179832 -0.149728 … -0.0275017 -0.0197172 -0.022061 -0.018135 - -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.0735711 -0.0699095 -0.0684749 -0.0662904 - 0.100019 -0.0618588 0.106134 0.0989047 -0.0553564 -0.0556761 -0.0556784 -0.059571 + 0.0158568 0.169676 0.092745 0.0798617 0.153079 … 0.117734 0.117006 0.115806 0.115938 0.112977 0.107919 + 0.220185 0.0304873 0.165348 0.150315 -0.0116249 0.0181126 0.0173332 0.0165187 0.0168762 0.0178042 0.0200356 + -0.00790017 -0.0192251 -0.0852365 -0.0799609 -0.0465292 -0.0672796 -0.0693319 -0.0737462 -0.0777439 -0.0776733 -0.0830504 + -0.109909 -0.170906 -0.0138701 -0.0409766 -0.177391 -0.10489 -0.113141 -0.118738 -0.126037 -0.126829 -0.13149 + -0.0231787 0.0532214 0.0607473 0.0279048 0.0634681 0.113961 0.112296 0.111831 0.117017 0.114073 0.108536 + 0.0620549 0.0465075 0.0821693 0.0606439 0.0592031 … 0.0174852 0.0167847 0.0148605 0.0150612 0.0133353 0.0126583 + -0.0290508 0.143255 0.0306142 0.0426579 0.129972 -0.175238 -0.17502 -0.169493 -0.164401 -0.161857 -0.160327 + 0.0921475 0.058833 0.250449 0.234636 0.0412965 0.0555153 0.0590262 0.0642577 0.0664076 0.0659837 0.0711358 + 0.0279402 -0.0278357 0.144855 0.147958 -0.0268559 0.162062 0.161106 0.157629 0.154552 0.155525 0.163634 + -0.0768143 -0.00587302 0.00543038 0.00443376 -0.0134111 -0.128129 -0.126912 -0.123969 -0.11757 -0.112495 -0.11112 + -0.0184337 0.00668561 -0.191863 -0.161345 0.0222466 … -0.102283 -0.103246 -0.10374 -0.107664 -0.107267 -0.114564 + 0.0112104 0.0214651 -0.0923963 -0.0823052 0.0600248 0.103233 0.103589 0.103387 0.106261 0.105065 0.10409 + 0.110971 0.272576 0.148319 0.143233 0.239578 0.109759 0.11224 0.107913 0.109914 0.112652 0.108365 + -0.131066 0.0376254 -0.0164237 -0.000193318 0.00344707 -0.0862689 -0.0893371 -0.0919217 -0.0969305 -0.0935498 -0.096145 + -0.0402605 0.0350559 0.0162864 0.0269105 0.00968855 -0.0587467 -0.0623393 -0.0670097 -0.070679 -0.0655848 -0.0564059 + 0.0799973 0.0482302 0.0712078 0.0792903 0.0108783 … 0.00501423 0.00820444 0.00854873 0.00889943 0.00932721 0.00751066 + -0.137565 -0.0369116 -0.065728 -0.0664102 -0.0238012 0.0250844 0.029041 0.0292468 0.0297059 0.0278639 0.0257616 + 0.0479746 -0.102338 -0.0557072 -0.0833976 -0.0979401 -0.0583169 -0.057629 -0.053911 -0.0566325 -0.0568765 -0.0581378 + 0.0656851 0.0195639 0.0288789 0.0559219 0.0315515 0.03907 0.0472323 0.054771 0.0596156 0.0541802 0.0525933 + 0.0668634 -0.00400549 0.0297102 0.0505045 -0.00082792 0.0399623 0.0414113 0.0400276 0.0361149 0.0325914 0.0260693 + -0.0691096 0.0348577 -0.000312685 0.0232462 -0.00250495 … -0.146082 -0.141874 -0.142026 -0.132163 -0.129679 -0.131122 + -0.0273036 0.0653352 0.0332689 0.017918 0.0875479 0.0535029 0.0500921 0.0471914 0.0469949 0.0434268 0.0442646 + -0.0981665 -0.0296463 -0.0114686 -0.0348033 -0.0468719 -0.0741133 -0.0772672 -0.0805913 -0.0809244 -0.0823798 -0.081472 + -0.0262739 0.109895 0.0117273 0.0222689 0.100869 0.0119844 0.0132486 0.012956 0.0175875 0.013171 0.0195091 + 0.0861164 0.0799029 0.00381147 0.0170927 0.103322 0.0238912 0.0209658 0.0226638 0.0209905 0.0230679 0.0221191 + 0.125112 0.0880232 0.0351989 0.022897 0.0862715 … -0.0219898 -0.0238914 -0.0207844 -0.0229276 -0.0238033 -0.0236367 + ⋮ ⋱ ⋮ + -0.158838 0.0415251 -0.0584126 -0.0373528 0.0819274 -0.212757 -0.214835 -0.213414 -0.212899 -0.215478 -0.210674 + -0.039636 -0.0837763 -0.0837142 -0.0597521 -0.0868467 0.0309127 0.0339911 0.03399 0.0313526 0.0316408 0.0309661 + 0.0755214 0.0960326 0.0858578 0.0614626 0.111979 … 0.102411 0.101302 0.108277 0.109034 0.107593 0.111863 + 0.0506199 0.00290888 0.047947 0.063503 -0.0072114 0.0388324 0.0360347 0.0326486 0.033966 0.0327732 0.0261081 + -0.0288586 -0.150171 -0.0699125 -0.108002 -0.142865 -0.0811611 -0.0775934 -0.072192 -0.0697569 -0.0715358 -0.0683193 + -0.0646991 0.0724608 -0.00767811 -0.0184348 0.0524162 0.046386 0.0457267 0.0532778 0.0649795 0.0697126 0.0808413 + 0.0445508 0.0296366 0.0325647 0.0521935 0.0436496 0.125633 0.129031 0.126605 0.12324 0.120497 0.117703 + -0.127301 -0.0224252 -0.00579415 -0.00877803 -0.0140665 … -0.0826691 -0.080026 -0.080839 -0.0823464 -0.0803394 -0.0856279 + 0.0304881 0.0396951 0.0798097 0.0736797 0.0800866 0.0448139 0.0426674 0.0411406 0.0460205 0.0460111 0.0532082 + 0.0488798 0.252244 0.0866849 0.098552 0.251561 -0.0212669 -0.0236942 -0.035116 -0.0395483 -0.0463498 -0.0494207 + -0.0296798 -0.0494761 0.00688248 0.0264166 -0.0352487 -0.0486577 -0.0476357 -0.0435388 -0.0404835 -0.0410673 -0.0367272 + 0.023548 -0.00147361 0.0629259 0.106951 0.0406627 0.00599323 0.00627022 0.00403014 -0.000107777 -0.000898423 0.00296315 + -0.0574151 -0.0875744 -0.103787 -0.114166 -0.103979 … -0.0697383 -0.0708782 -0.0700138 -0.0687795 -0.070967 -0.0636385 + 0.0280373 0.149767 -0.0899733 -0.0732524 0.162316 0.0233808 0.022177 0.0183834 0.0201251 0.0197228 0.0219051 + -0.0617143 -0.0573989 -0.0973785 -0.0805046 -0.0525925 0.0936075 0.0997715 0.102691 0.107432 0.108591 0.109502 + -0.0859687 0.0623054 0.0974813 0.126841 0.0595557 0.0244905 0.0187937 0.0191363 0.0182794 0.0230548 0.031103 + 0.0392044 0.0162653 0.0926306 0.104054 0.0509464 0.0516558 0.0559883 0.0553617 0.0491496 0.0484319 0.0438133 + -0.0340362 -0.0278067 -0.0181035 -0.0282369 -0.0490531 … -0.0528032 -0.0564175 -0.0562518 -0.0617946 -0.0631367 -0.0675882 + 0.0131229 0.0565131 -0.0349061 -0.0464192 0.0456515 0.0670016 0.0676478 0.0698765 0.0724731 0.0780165 0.0746229 + -0.117425 0.162483 0.11039 0.136364 0.135339 -0.00589512 -0.00432259 -0.00508357 -0.00538224 -0.00685447 -0.00194357 + -0.0401157 -0.00450943 0.0539568 0.0689953 -0.00295334 -0.0122461 -0.00671544 -0.00322498 -0.00518066 -0.00600254 -0.0077147 + 0.0893984 0.0695061 -0.049941 -0.035411 0.0767663 0.0880484 0.0913505 0.0964841 0.0960931 0.0961892 0.103431 + -0.116265 -0.106331 -0.179832 -0.149728 -0.0913282 … -0.0318565 -0.0287848 -0.0275017 -0.0197172 -0.0220611 -0.018135 + -0.0443452 -0.192203 -0.0187912 -0.0247794 -0.180245 -0.0800835 -0.0780865 -0.073571 -0.0699094 -0.0684748 -0.0662903 + 0.100019 -0.0618588 0.106134 0.0989047 -0.0885639 -0.0577217 -0.0547317 -0.0553563 -0.055676 -0.0556784 -0.0595709 ``` """ -function queryFromText( +function queryFromText(config::ColBERTConfig, checkpoint::Checkpoint, queries::Vector{String}, bsize::Union{Missing, Int}) if ismissing(bsize) error("Currently bsize cannot be missing!") @@ -593,7 +667,7 @@ function queryFromText( process = tokenizer.process truncpad_pipe = Pipeline{:token}( TextEncodeBase.trunc_or_pad( - checkpoint.config.query_maxlen, "[PAD]", :tail, :tail), + config.query_maxlen, "[PAD]", :tail, :tail), :token) process = process[1:4] |> truncpad_pipe |> process[6:end] tokenizer = Transformers.TextEncoders.BertTextEncoder( @@ -601,8 +675,8 @@ function queryFromText( endsym = tokenizer.endsym, padsym = tokenizer.padsym, trunc = tokenizer.trunc) # get ids and masks, embeddings and returning the concatenated tensors - batches = tensorize(checkpoint.query_tokenizer, tokenizer, queries, bsize) - batches = [query(checkpoint, integer_ids, integer_mask) + batches = tensorize_queries(config, tokenizer, queries, bsize) + batches = [query(config, checkpoint, integer_ids, integer_mask) for (integer_ids, integer_mask) in batches] Q = cat(batches..., dims = 3)