diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index e666aef..6a45ae7 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -134,6 +134,19 @@ function setup(collection::Vector{String}, avg_doclen_est::Float32, ) end +function _bucket_cutoffs_and_weights( + nbits::Int, heldout_avg_residual::AbstractMatrix{Float32}) + num_options = 1 << nbits + quantiles = collect(0:(num_options - 1)) / num_options + bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[2:end], + quantiles .+ (0.5 / num_options) + bucket_cutoffs = Float32.(quantile( + heldout_avg_residual, bucket_cutoffs_quantiles)) + bucket_weights = Float32.(quantile( + heldout_avg_residual, bucket_weights_quantiles)) + bucket_cutoffs, bucket_weights +end + """ _compute_avg_residuals( nbits::Int, centroids::AbstractMatrix{Float32}, @@ -159,30 +172,20 @@ compression/decompression of residuals. function _compute_avg_residuals!( nbits::Int, centroids::AbstractMatrix{Float32}, heldout::AbstractMatrix{Float32}, codes::AbstractVector{UInt32}) - @assert length(codes) == size(heldout, 2) + length(codes) == size(heldout, 2) || + throw(DimensionMismatch("length(codes) must be equal to the number " * + "of embeddings in heldout!")) compress_into_codes!(codes, centroids, heldout) # get centroid codes heldout_reconstruct = centroids[:, codes] # get corresponding centroids heldout_avg_residual = heldout - heldout_reconstruct # compute the residual - avg_residual = mean(abs.(heldout_avg_residual), dims = 2) # for each dimension, take mean of absolute values of residuals # computing bucket weights and cutoffs - num_options = 2^nbits - quantiles = Vector(0:(num_options - 1)) / num_options - bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[2:end], - quantiles .+ (0.5 / num_options) - - bucket_cutoffs = Float32.(quantile( - heldout_avg_residual, bucket_cutoffs_quantiles)) - bucket_weights = Float32.(quantile( - heldout_avg_residual, bucket_weights_quantiles)) - @assert bucket_cutoffs isa AbstractVector{Float32} "$(typeof(bucket_cutoffs))" - @assert bucket_weights isa AbstractVector{Float32} "$(typeof(bucket_weights))" + bucket_cutoffs, bucket_weights = _bucket_cutoffs_and_weights( + nbits, heldout_avg_residual) - @info "Got bucket_cutoffs_quantiles = $(bucket_cutoffs_quantiles) and bucket_weights_quantiles = $(bucket_weights_quantiles)" @info "Got bucket_cutoffs = $(bucket_cutoffs) and bucket_weights = $(bucket_weights)" - bucket_cutoffs, bucket_weights, mean(avg_residual) end diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index d6df3da..da61a10 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -1,176 +1,3 @@ -""" - BaseColBERT(; - bert::HuggingFace.HGFBertModel, linear::Layers.Dense, - tokenizer::TextEncoders.AbstractTransformerTextEncoder) - -A struct representing the BERT model, linear layer, and the tokenizer used to compute -embeddings for documents and queries. - -# Arguments - - - `bert`: The pre-trained BERT model used to generate the embeddings. - - `linear`: The linear layer used to project the embeddings to a specific dimension. - - `tokenizer`: The tokenizer to used by the BERT model. - -# Returns - -A [`BaseColBERT`](@ref) object. - -# Examples - -```julia-repl -julia> using ColBERT, CUDA; - -julia> base_colbert = BaseColBERT("/home/codetalker7/models/colbertv2.0/"); - -julia> base_colbert.bert -HGFBertModel( - Chain( - CompositeEmbedding( - token = Embed(768, 30522), # 23_440_896 parameters - position = ApplyEmbed(.+, FixedLenPositionEmbed(768, 512)), # 393_216 parameters - segment = ApplyEmbed(.+, Embed(768, 2), Transformers.HuggingFace.bert_ones_like), # 1_536 parameters - ), - DropoutLayer( - LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters - ), - ), - Transformer<12>( - PostNormTransformerBlock( - DropoutLayer( - SelfAttention( - MultiheadQKVAttenOp(head = 12, p = nothing), - Fork<3>(Dense(W = (768, 768), b = true)), # 1_771_776 parameters - Dense(W = (768, 768), b = true), # 590_592 parameters - ), - ), - LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters - DropoutLayer( - Chain( - Dense(σ = NNlib.gelu, W = (768, 3072), b = true), # 2_362_368 parameters - Dense(W = (3072, 768), b = true), # 2_360_064 parameters - ), - ), - LayerNorm(768, ϵ = 1.0e-12), # 1_536 parameters - ), - ), # Total: 192 arrays, 85_054_464 parameters, 40.422 KiB. - Branch{(:pooled,) = (:hidden_state,)}( - BertPooler(Dense(σ = NNlib.tanh_fast, W = (768, 768), b = true)), # 590_592 parameters - ), -) # Total: 199 arrays, 109_482_240 parameters, 43.578 KiB. - -julia> base_colbert.linear -Dense(W = (768, 128), b = true) # 98_432 parameters - -julia> base_colbert.tokenizer -TrfTextEncoder( -├─ TextTokenizer(MatchTokenization(WordPieceTokenization(bert_uncased_tokenizer, WordPiece(vocab_size = 30522, unk = [UNK], max_char = 100)), 5 patterns)), -├─ vocab = Vocab{String, SizedArray}(size = 30522, unk = [UNK], unki = 101), -├─ config = @NamedTuple{startsym::String, endsym::String, padsym::String, trunc::Union{Nothing, Int64}}(("[CLS]", "[SEP]", "[PAD]", 512)), -├─ annotate = annotate_strings, -├─ onehot = lookup_first, -├─ decode = nestedcall(remove_conti_prefix), -├─ textprocess = Pipelines(target[token] := join_text(source); target[token] := nestedcall(cleanup ∘ remove_prefix_space, target.token); target := (target.token)), -└─ process = Pipelines: - ╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source) - ╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token) - ╰─ target[(token, segment)] := SequenceTemplate{String}([CLS]: Input[1]: [SEP]: (Input[2]: [SEP]:)...)(target.token) - ╰─ target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(512))(target.token) - ╰─ target[token] := TextEncodeBase.trunc_and_pad(512, [PAD], tail, tail)(target.token) - ╰─ target[token] := TextEncodeBase.nested2batch(target.token) - ╰─ target[segment] := TextEncodeBase.trunc_and_pad(512, 1, tail, tail)(target.segment) - ╰─ target[segment] := TextEncodeBase.nested2batch(target.segment) - ╰─ target[sequence_mask] := identity(target.attention_mask) - ╰─ target := (target.token, target.segment, target.attention_mask, target.sequence_mask) -``` -""" -struct BaseColBERT - bert::HF.HGFBertModel - linear::Layers.Dense - tokenizer::TextEncoders.AbstractTransformerTextEncoder -end - -function BaseColBERT(modelpath::AbstractString) - tokenizer, bert_model, linear = load_hgf_pretrained_local(modelpath) - bert_model = bert_model |> Flux.gpu - linear = linear |> Flux.gpu - BaseColBERT(bert_model, linear, tokenizer) -end - -""" - Checkpoint(model::BaseColBERT, config::ColBERTConfig) - -A wrapper for [`BaseColBERT`](@ref), containing information for generating embeddings -for docs and queries. - -If the `config` is set to mask punctuations, then the `skiplist` property of the created -[`Checkpoint`](@ref) will be set to a list of token IDs of punctuations. Otherwise, it will be empty. - -# Arguments - - - `model`: The [`BaseColBERT`](@ref) to be wrapped. - - `config`: The underlying [`ColBERTConfig`](@ref). - -# Returns - -The created [`Checkpoint`](@ref). - -# Examples - -Continuing from the example for [`BaseColBERT`](@ref): - -```julia-repl -julia> checkpoint = Checkpoint(base_colbert, config) - -julia> checkpoint.skiplist # by default, all punctuations -32-element Vector{Int64}: - 1000 - 1001 - 1002 - 1003 - 1004 - 1005 - 1006 - 1007 - 1008 - 1009 - 1010 - 1011 - 1012 - 1013 - ⋮ - 1028 - 1029 - 1030 - 1031 - 1032 - 1033 - 1034 - 1035 - 1036 - 1037 - 1064 - 1065 - 1066 - 1067 -``` -""" -struct Checkpoint - model::BaseColBERT - skiplist::Vector{Int64} -end - -function Checkpoint(model::BaseColBERT, config::ColBERTConfig) - if config.mask_punctuation - punctuation_list = string.(collect("!\"#\$%&\'()*+,-./:;<=>?@[\\]^_`{|}~")) - skiplist = [TextEncodeBase.lookup(model.tokenizer.vocab, punct) - for punct in punctuation_list] - else - skiplist = Vector{Int64}() - end - Checkpoint(model, skiplist) -end - """ doc( config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, diff --git a/src/modelling/tokenization/tokenizer_utils.jl b/src/modelling/tokenization/tokenizer_utils.jl index 2aefebd..fe60feb 100644 --- a/src/modelling/tokenization/tokenizer_utils.jl +++ b/src/modelling/tokenization/tokenizer_utils.jl @@ -116,9 +116,9 @@ A matrix equal to `data`, with the second row being filled with `marker`. # Examples ```julia-repl -julia> using ColBERT: _add_marker_row; +julia> using ColBERT: _add_marker_row; -julia> x = ones(Float32, 5, 5); +julia> x = ones(Float32, 5, 5); 5×5 Matrix{Float32}: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 @@ -138,5 +138,6 @@ julia> _add_marker_row(x, zero(Float32)) """ function _add_marker_row(data::AbstractMatrix{T}, marker::T) where {T} - [data[begin:1, :]; fill(marker, (1, size(data, 2))); data[2:end, :]] + [data[begin:min(1, size(data, 1)), :]; fill(marker, (1, size(data, 2))); + data[2:end, :]] end diff --git a/src/utils.jl b/src/utils.jl index 3f3c010..c160e67 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,39 +1,39 @@ -""" - _sort_by_length( - integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) - -Sort sentences by number of attended tokens, if the number of sentences is larger than `bsize`. - -# Arguments - - - `integer_ids`: The token IDs of documents to be sorted. - - `integer_mask`: The attention masks of the documents to be sorted (attention masks are just bits). - - `bsize`: The size of batches to be considered. - -# Returns - -Depending upon `bsize`, the following are returned: - - - If the number of documents (second dimension of `integer_ids`) is atmost `bsize`, then the - `integer_ids` and `integer_mask` are returned unchanged. - - If the number of documents is larger than `bsize`, then the passages are first sorted - by the number of attended tokens (figured out from the `integer_mask`), and then the - sorted arrays `integer_ids`, `integer_mask` are returned, along with a list of - `reverse_indices`, i.e a mapping from the documents to their indices in the original - order. -""" -function _sort_by_length( - integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}, batch_size::Int) - size(integer_ids, 2) <= batch_size && - return integer_ids, bitmask, Vector(1:size(integer_ids, 2)) - lengths = vec(sum(bitmask; dims = 1)) # number of attended tokens in each passage - indices = sortperm(lengths) # get the indices which will sort lengths - reverse_indices = sortperm(indices) # invert the indices list - @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" - @assert bitmask isa BitMatrix "$(typeof(bitmask))" - @assert reverse_indices isa Vector{Int} "$(typeof(reverse_indices))" - integer_ids[:, indices], bitmask[:, indices], reverse_indices -end +# """ +# _sort_by_length( +# integer_ids::AbstractMatrix{Int32}, integer_mask::AbstractMatrix{Bool}, bsize::Int) +# +# Sort sentences by number of attended tokens, if the number of sentences is larger than `bsize`. +# +# # Arguments +# +# - `integer_ids`: The token IDs of documents to be sorted. +# - `integer_mask`: The attention masks of the documents to be sorted (attention masks are just bits). +# - `bsize`: The size of batches to be considered. +# +# # Returns +# +# Depending upon `bsize`, the following are returned: +# +# - If the number of documents (second dimension of `integer_ids`) is atmost `bsize`, then the +# `integer_ids` and `integer_mask` are returned unchanged. +# - If the number of documents is larger than `bsize`, then the passages are first sorted +# by the number of attended tokens (figured out from the `integer_mask`), and then the +# sorted arrays `integer_ids`, `integer_mask` are returned, along with a list of +# `reverse_indices`, i.e a mapping from the documents to their indices in the original +# order. +# """ +# function _sort_by_length( +# integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}, batch_size::Int) +# size(integer_ids, 2) <= batch_size && +# return integer_ids, bitmask, Vector(1:size(integer_ids, 2)) +# lengths = vec(sum(bitmask; dims = 1)) # number of attended tokens in each passage +# indices = sortperm(lengths) # get the indices which will sort lengths +# reverse_indices = sortperm(indices) # invert the indices list +# @assert integer_ids isa AbstractMatrix{Int32} "$(typeof(integer_ids))" +# @assert bitmask isa BitMatrix "$(typeof(bitmask))" +# @assert reverse_indices isa Vector{Int} "$(typeof(reverse_indices))" +# integer_ids[:, indices], bitmask[:, indices], reverse_indices +# end function compute_distances_kernel!(batch_distances::AbstractMatrix{Float32}, batch_data::AbstractMatrix{Float32}, @@ -41,9 +41,9 @@ function compute_distances_kernel!(batch_distances::AbstractMatrix{Float32}, batch_distances .= 0.0f0 # Compute squared distances: (a-b)^2 = a^2 + b^2 - 2ab # a^2 term - sum_sq_data = sum(batch_data .^ 2, dims = 1) # (1, point_bsize) + sum_sq_data = sum(batch_data .^ 2, dims = 1) # (1, point_bsize) # b^2 term - sum_sq_centroids = sum(centroids .^ 2, dims = 1)' # (num_centroids, 1) + sum_sq_centroids = sum(centroids .^ 2, dims = 1)' # (num_centroids, 1) # -2ab term mul!(batch_distances, centroids', batch_data, -2.0f0, 1.0f0) # (num_centroids, point_bsize) # Compute (a-b)^2 = a^2 + b^2 - 2ab diff --git a/test/Project.toml b/test/Project.toml index fbc2394..97bba88 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/indexing/collection_indexer.jl b/test/indexing/collection_indexer.jl new file mode 100644 index 0000000..8217dcf --- /dev/null +++ b/test/indexing/collection_indexer.jl @@ -0,0 +1,165 @@ +using LinearAlgebra: __normalize! +using ColBERT: _sample_pids, _heldout_split, setup, _bucket_cutoffs_and_weights, + _normalize_array!, _compute_avg_residuals! + +@testset "_sample_pids tests" begin + # Test 1: More pids than given can't be sampled + num_documents = rand(0:100000) + pids = _sample_pids(num_documents) + @test length(pids) <= num_documents + + # Test 2: Edge case, when + num_documents = rand(0:1) + pids = _sample_pids(num_documents) + @test length(pids) <= num_documents +end + +@testset "_heldout_split" begin + # Test 1: A basic test with a large size + sample = rand(Float32, rand(1:20), 100000) + for heldout_fraction in Float32.(collect(0.1:0.1:1.0)) + sample_train, sample_heldout = _heldout_split( + sample; heldout_fraction = heldout_fraction) + heldout_size = min(50000, Int(floor(100000 * heldout_fraction))) + @test size(sample_train, 2) == 100000 - heldout_size + @test size(sample_heldout, 2) == heldout_size + end + + # Test 2: Edge case with 1 column, should return empty train and full heldout + sample = rand(Float32, 3, 1) + heldout_fraction = 0.5f0 + sample_train, sample_heldout = _heldout_split( + sample; heldout_fraction = heldout_fraction) + @test size(sample_train, 2) == 0 # No columns should be left in the train set + @test size(sample_heldout, 2) == 1 # All columns in the heldout set +end + +@testset "setup" begin + # Test 1: Number of documents and chunksize should not be altered + collection = string.(rand('a':'z', rand(1:1000))) + avg_doclen_est = Float32(100 * rand()) + nranks = rand(1:10) + num_clustering_embs = rand(1:1000) + chunksize = rand(1:20) + plan_dict = setup( + collection, avg_doclen_est, num_clustering_embs, chunksize, nranks) + @test plan_dict["avg_doclen_est"] == avg_doclen_est + @test plan_dict["chunksize"] == chunksize + @test plan_dict["num_documents"] == length(collection) + @test plan_dict["num_embeddings_est"] == avg_doclen_est * length(collection) + + # Test 2: Tests for number of chunks + avg_doclen_est = 1.0f0 + nranks = rand(1:10) + num_clustering_embs = rand(1:1000) + + ## without remainders + chunksize = rand(1:20) + collection = string.(rand('a':'z', chunksize * rand(1:100))) + plan_dict = setup( + collection, avg_doclen_est, num_clustering_embs, chunksize, nranks) + @test plan_dict["num_chunks"] == div(length(collection), chunksize) + + ## with remainders + chunksize = rand(1:20) + collection = string.(rand( + 'a':'z', chunksize * rand(1:100) + rand(1:(chunksize - 1)))) + plan_dict = setup( + collection, avg_doclen_est, num_clustering_embs, chunksize, nranks) + @test plan_dict["num_chunks"] == div(length(collection), chunksize) + 1 + + # Test 3: Tests for number of clusters + collection = string.(rand('a':'z', rand(1:1000))) + avg_doclen_est = Float32(100 * rand()) + nranks = rand(1:10) + num_clustering_embs = rand(1:10000) + chunksize = rand(1:20) + plan_dict = setup( + collection, avg_doclen_est, num_clustering_embs, chunksize, nranks) + @test plan_dict["num_partitions"] <= num_clustering_embs + @test plan_dict["num_partitions"] <= + 16 * sqrt(avg_doclen_est * length(collection)) +end + +@testset "_bucket_cutoffs_and_weights" begin + # Test 1: Basic test with 2x2 matrix and nbits=2 + heldout_avg_residual = [0.0f0 0.2f0; 0.4f0 0.6f0; 0.8f0 1.0f0] + nbits = 2 + cutoffs, weights = _bucket_cutoffs_and_weights(nbits, heldout_avg_residual) + expected_cutoffs = Float32[0.25, 0.5, 0.75] + expected_weights = Float32[0.125, 0.375, 0.625, 0.875] + @test cutoffs ≈ expected_cutoffs + @test weights ≈ expected_weights + + # Test 2: Uniform values + value = rand(Float32) + heldout_avg_residual = value * ones(Float32, rand(1:20), rand(1:20)) + nbits = rand(1:10) + cutoffs, weights = _bucket_cutoffs_and_weights(nbits, heldout_avg_residual) + @test all(isequal(value), cutoffs) + @test all(isequal(value), weights) + + # Test 3: Shapes and types + heldout_avg_residual = rand(Float32, rand(1:20), rand(1:20)) + nbits = rand(1:10) + cutoffs, weights = _bucket_cutoffs_and_weights(nbits, heldout_avg_residual) + @test length(cutoffs) == (1 << nbits) - 1 + @test length(weights) == 1 << nbits + @test cutoffs isa Vector{Float32} + @test weights isa Vector{Float32} +end + +@testset "_compute_avg_residuals!" begin + # Test 1: centroids and heldout_avg_residual have the same columns with different perms + nbits = rand(1:20) + centroids = rand(Float32, rand(1:20), rand(1:20)) + _normalize_array!(centroids; dims = 1) + perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))] + heldout = centroids[:, perm] + codes = Vector{UInt32}(undef, size(heldout, 2)) + bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals!( + nbits, centroids, heldout, codes) + @test all(iszero, bucket_cutoffs) + @test all(iszero, bucket_weights) + @test iszero(avg_residual) + + # Test 2: some tolerance level + tol = 1e-5 + nbits = rand(1:20) + centroids = rand(Float32, rand(1:20), rand(1:20)) + _normalize_array!(centroids; dims = 1) + perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))] + heldout = centroids[:, perm] + for col in eachcol(heldout) + col .+= -tol + 2 * tol * rand() + end + codes = Vector{UInt32}(undef, size(heldout, 2)) + bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals!( + nbits, centroids, heldout, codes) + @test all(<=(tol), bucket_cutoffs) + @test all(<=(tol), bucket_weights) + @test avg_residual <= tol + + # Test 3: Shapes and types + nbits = rand(1:20) + dim = rand(1:20) + centroids = rand(Float32, dim, rand(1:20)) + heldout = rand(Float32, dim, rand(1:20)) + codes = Vector{UInt32}(undef, size(heldout, 2)) + bucket_cutoffs, bucket_weights, avg_residual = _compute_avg_residuals!( + nbits, centroids, heldout, codes) + @test length(bucket_cutoffs) == (1 << nbits) - 1 + @test length(bucket_weights) == 1 << nbits + @test bucket_cutoffs isa Vector{Float32} + @test bucket_weights isa Vector{Float32} + @test avg_residual isa Float32 + + # Test 4: Correct errors are thrown + nbits = 2 + centroids = Float32[1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0] # (3, 3) matrix + heldout = Float32[1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0] # (3, 3) matrix + codes = UInt32[0, 0] # Length is 2, but `heldout` has 3 columns + # Check for DimensionMismatch error + @test_throws DimensionMismatch _compute_avg_residuals!( + nbits, centroids, heldout, codes) +end diff --git a/test/modelling/checkpoint.jl b/test/modelling/checkpoint.jl deleted file mode 100644 index e69de29..0000000 diff --git a/test/modelling/embedding_utils.jl b/test/modelling/embedding_utils.jl index a1e10a8..52ce3fa 100644 --- a/test/modelling/embedding_utils.jl +++ b/test/modelling/embedding_utils.jl @@ -84,7 +84,7 @@ end # Test Case 4: Skip all tokens dim, len, bsize = rand(1:20, 3) D = rand(Float32, dim, len, bsize) - integer_ids = rand(int32, len, bsize) + integer_ids = rand(Int32, len, bsize) skiplist = unique(Int.(vec(integer_ids))) expected_D = similar(D) expected_D .= 0.0f0 diff --git a/test/modelling/tokenization/tokenizer_utils.jl b/test/modelling/tokenization/tokenizer_utils.jl new file mode 100644 index 0000000..d4572f7 --- /dev/null +++ b/test/modelling/tokenization/tokenizer_utils.jl @@ -0,0 +1,19 @@ +using ColBERT: _add_marker_row + +@testset "_add_marker_row" begin + for type in [INT_TYPES; FLOAT_TYPES] + # Test 1: Generic + num_rows, num_cols = rand(1:20), rand(1:20) + x = rand(type, num_rows, num_cols) + x = _add_marker_row(x, zero(type)) + @test isequal(size(x), (num_rows + 1, num_cols)) + @test isequal(x[2, :], repeat([zero(type)], num_cols)) + + # Test 2: Edge case, empty array + num_cols = rand(1:20) + x = rand(type, 0, num_cols) + x = _add_marker_row(x, zero(type)) + @test isequal(size(x), (1, num_cols)) + @test isequal(x[1, :], repeat([zero(type)], num_cols)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ec4d48a..7721ed4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,15 @@ +using Base: SimpleLogger, NullLogger, global_logger using ColBERT using .Iterators using LinearAlgebra +using Logging using Random using Test +# turn off logging +logger = NullLogger() +global_logger(logger) + const INT_TYPES = [ Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128] const FLOAT_TYPES = [Float16, Float32, Float64] @@ -12,8 +18,10 @@ const FLOAT_TYPES = [Float16, Float32, Float64] # indexing operations include("indexing/codecs/residual.jl") +include("indexing/collection_indexer.jl") # modelling operations +include("modelling/tokenization/tokenizer_utils.jl") include("modelling/embedding_utils.jl") # utils