Skip to content

Commit

Permalink
Adding tests for modelling functions + related minor modifications in
Browse files Browse the repository at this point in the history
src.
  • Loading branch information
codetalker7 committed Sep 6, 2024
1 parent f752c00 commit f0ba596
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 19 deletions.
17 changes: 8 additions & 9 deletions src/modelling/embedding_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,16 @@ end

function _clear_masked_embeddings!(D::AbstractArray{Float32, 3},
integer_ids::AbstractMatrix{Int32}, skiplist::Vector{Int})
@assert isequal(size(D)[2:end], size(integer_ids))
"size(D): $(size(D)), size(integer_ids): $(size(integer_ids))"

isequal(size(D)[2:end], size(integer_ids)) ||
throw(DomainError("The number of embeddings in D and tokens " *
"in integer_ids must be equal!"))
# set everything to true
mask = similar(integer_ids, Bool) # respects the device as well
mask .= true
mask_skiplist!(mask, integer_ids, skiplist) # (doc_maxlen, current_batch_size)
mask = reshape(mask, (1, size(mask)...)) # (1, doc_maxlen, current_batch_size)

@assert isequal(size(mask)[2:end], size(D)[2:end])
"size(mask): $(size(mask)), size(D): $(size(D))"
@assert mask isa AbstractArray{Bool} "$(typeof(mask))"

# clear embeddings
D .= D .* mask # clear embeddings of masked tokens
mask
end
Expand All @@ -201,6 +198,8 @@ end

function _remove_masked_tokens(
D::AbstractMatrix{Float32}, mask::AbstractMatrix{Bool})
D[:, reshape(mask, prod(size(mask)))]
size(D, 2) == prod(size(mask)) ||
throw(DimensionMismatch("The total number of embeddings " * "
in D must be equal to the total number of tokens represented by mask!"))
D[:, vec(mask)]
end

10 changes: 0 additions & 10 deletions test/modelling/checkpoint.jl
Original file line number Diff line number Diff line change
@@ -1,10 +0,0 @@
@testset "doc_tokenization.jl" begin

end

@testset "query_tokenization.jl" begin
end

@testset "checkpoint.jl" begin
# use the defaults for the config; gpu tests will be separate
end
156 changes: 156 additions & 0 deletions test/modelling/embedding_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
using ColBERT: mask_skiplist!, _clear_masked_embeddings!, _flatten_embeddings,
_remove_masked_tokens

@testset "mask_skiplist!" begin
# Test Case 1: Simple case with no skips
mask = trues(3, 3)
integer_ids = Int32[1 2 3; 4 5 6; 7 8 9]
skiplist = Int[]
expected_mask = trues(3, 3)
mask_skiplist!(mask, integer_ids, skiplist)
@test mask == expected_mask

# Test Case 2: Skip one value
mask = trues(3, 3)
integer_ids = Int32[1 2 3; 4 5 6; 7 8 9]
skiplist = [5]
expected_mask = [true true true; true false true; true true true]
mask_skiplist!(mask, integer_ids, skiplist)
@test mask == expected_mask

# Test Case 3: Skip multiple values
mask = trues(3, 3)
integer_ids = Int32[1 2 3; 4 5 6; 7 8 9]
skiplist = [2, 6, 9]
expected_mask = [true false true; true true false; true true false]
mask_skiplist!(mask, integer_ids, skiplist)
@test mask == expected_mask

# Test Case 4: All values in skiplist
mask = trues(3, 3)
integer_ids = Int32[1 2 3; 4 5 6; 7 8 9]
skiplist = [1, 2, 3, 4, 5, 6, 7, 8, 9]
expected_mask = falses(3, 3)
mask_skiplist!(mask, integer_ids, skiplist)
@test mask == expected_mask

# Test Case 5: Empty integer_ids matrix
mask = trues(0, 0)
integer_ids = rand(Int32, 0, 0)
skiplist = [1]
expected_mask = trues(0, 0)
mask_skiplist!(mask, integer_ids, skiplist)
@test mask == expected_mask

# Test Case 6: Skiplist with no matching values
mask = trues(3, 3)
integer_ids = Int32[1 2 3; 4 5 6; 7 8 9]
skiplist = [10, 11]
expected_mask = trues(3, 3)
mask_skiplist!(mask, integer_ids, skiplist)
@test mask == expected_mask
end

@testset "_clear_masked_embeddings!" begin
# Test Case 1: No skiplist entries
dim, len, bsize = rand(1:20, 3)
D = rand(Float32, dim, len, bsize)
integer_ids = rand(Int32, len, bsize)
skiplist = Int[]
expected_D = copy(D)
_clear_masked_embeddings!(D, integer_ids, skiplist)
@test D == expected_D

# Test Case 2: Single skiplist entry
dim, len, bsize = rand(1:20, 3)
D = rand(Float32, dim, len, bsize)
integer_ids = rand(Int32, len, bsize)
skiplist = Int[integer_ids[rand(1:(len * bsize))]]
expected_D = copy(D)
expected_D[:, findall(in(skiplist), integer_ids)] .= 0.0f0
_clear_masked_embeddings!(D, integer_ids, skiplist)
@test D == expected_D

# Test Case 3: Multiple skiplist entries
dim, len, bsize = rand(1:20, 3)
D = rand(Float32, dim, len, bsize)
integer_ids = rand(Int32, len, bsize)
skiplist = unique(Int.(rand(vec(integer_ids), rand(1:(len * bsize)))))
expected_D = copy(D)
expected_D[:, findall(in(skiplist), integer_ids)] .= 0.0f0
_clear_masked_embeddings!(D, integer_ids, skiplist)
@test D == expected_D

# Test Case 4: Skip all tokens
dim, len, bsize = rand(1:20, 3)
D = rand(Float32, dim, len, bsize)
integer_ids = rand(int32, len, bsize)
skiplist = unique(Int.(vec(integer_ids)))
expected_D = similar(D)
expected_D .= 0.0f0
_clear_masked_embeddings!(D, integer_ids, skiplist)
@test D == expected_D

# Test Case 5: Skiplist with no matching tokens
dim, len, bsize = rand(1:20, 3)
D = rand(Float32, dim, len, bsize)
integer_ids = Int32.(rand(1:100, len, bsize))
skiplist = unique(rand(101:1000, rand(1:20)))
expected_D = copy(D)
_clear_masked_embeddings!(D, integer_ids, skiplist)
@test D == expected_D

# Test 6: Types and shapes
dim, len, bsize = rand(1:20, 3)
D = rand(Float32, dim, len, bsize)
integer_ids = rand(Int32, len, bsize)
skiplist = unique(rand(Int, rand(1:20)))
mask = _clear_masked_embeddings!(D, integer_ids, skiplist)
@test mask isa Array{Bool, 3}
@test isequal(size(mask), (1, size(D)[2:end]...))
end

@testset "_flatten_embeddings" begin
# Test Case 1: Generic case; len will correspond to a vector of constants
dim, len, bsize = rand(1:20, 3)
D = Array{Float32}(undef, dim, len, bsize)
for idx in 1:len
D[:, idx, :] .= idx
end
expected = Matrix{Float32}(undef, dim, len * bsize)
for idx in 1:len
expected[:, [idx + k * len for k in 0:(bsize - 1)]] .= idx
end
@test _flatten_embeddings(D) == expected

# Test Case 2: Edge case with 0x3x2 array (should return 0x6 array)
D = Float32[]
D = reshape(D, 0, 3, 2)
expected_output = reshape(Float32[], 0, 6)
@test _flatten_embeddings(D) == expected_output
end

@testset "_remove_masked_tokens" begin
# Test 1: Generic case; build a skiplist, and manually build the expected tensor
dim, len, bsize = rand(1:20, 3)
mask = trues(len, bsize)
skiplist = unique(rand(1:len, rand(1:len)))
for id in skiplist
mask[id, :] .= false
end
D = Matrix{Float32}(undef, dim, len * bsize)
for idx in 1:len
D[:, [idx + k * len for k in 0:(bsize - 1)]] .= idx
end
expected = rand(Float32, dim, 0)
for emb_id in 1:size(D, 2)
if !(D[1, emb_id] in skiplist)
expected = hcat(expected, D[:, emb_id])
end
end
@test _remove_masked_tokens(D, mask) == expected

# Test 2: Test for errors
@test_throws DimensionMismatch _remove_masked_tokens(
rand(Float32, 12, 20), rand(Bool, 4, 4))
end
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,12 @@ const INT_TYPES = [
const FLOAT_TYPES = [Float16, Float32, Float64]

# include("Aqua.jl")

# indexing operations
include("indexing/codecs/residual.jl")

# modelling operations
include("modelling/embedding_utils.jl")

# utils
include("utils.jl")

0 comments on commit f0ba596

Please sign in to comment.