Skip to content

Commit

Permalink
Adding tests for utils + minor changes to src.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed Sep 8, 2024
1 parent f48ccf2 commit b0ac2ac
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 14 deletions.
34 changes: 29 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
function compute_distances_kernel!(batch_distances::AbstractMatrix{Float32},
batch_data::AbstractMatrix{Float32},
centroids::AbstractMatrix{Float32})
isequal(size(batch_distances), (size(centroids, 2), size(batch_data, 2))) ||
throw(DimensionMismatch("batch_distances should have size " *
"(num_centroids, point_bsize)!"))
isequal(size(batch_data, 1), size(centroids, 1)) ||
throw(DimensionMismatch("batch_data and centroids should have " *
"the same embedding dimension!"))

batch_distances .= 0.0f0
# Compute squared distances: (a-b)^2 = a^2 + b^2 - 2ab
# a^2 term
Expand All @@ -54,20 +61,30 @@ end
function update_centroids_kernel!(new_centroids::AbstractMatrix{Float32},
batch_data::AbstractMatrix{Float32},
batch_one_hot::AbstractMatrix{Float32})
isequal(
size(new_centroids), (size(batch_data, 1), (size(batch_one_hot, 1)))) ||
throw(DimensionMismatch("new_centroids should have the right shape " *
"for multiplying batch_data and batch_one_hot! "))
mul!(new_centroids, batch_data, batch_one_hot', 1.0f0, 1.0f0)
end

function assign_clusters_kernel!(batch_assignments::AbstractVector{Int32},
batch_distances::AbstractMatrix{Float32})
length(batch_assignments) == size(batch_distances, 2) ||
throw(DimensionMismatch("length(batch_assignments) " *
"should be equal to the point " *
"batch size of batch_distances!"))
_, min_indices = findmin(batch_distances, dims = 1)
batch_assignments .= getindex.(min_indices, 1) |> vec
end

function onehot_encode!(batch_one_hot::AbstractArray{Float32},
batch_assignments::AbstractVector{Int32}, k::Int)
# Create a range array for columns
col_indices = Vector(1:length(batch_assignments)) |> Flux.gpu
# Use broadcasting to set the appropriate elements to 1
isequal(size(batch_one_hot), (k, length(batch_assignments))) ||
throw(DimensionMismatch("batch_one_hot should have shape " *
"(k, length(batch_assignments))!"))
col_indices = similar(batch_assignments, length(batch_assignments)) # respects device
copyto!(col_indices, collect(1:length(batch_assignments)))
batch_one_hot[batch_assignments .+ (col_indices .- 1) .* k] .= 1
end

Expand Down Expand Up @@ -236,7 +253,14 @@ julia> centroids
function kmeans_gpu_onehot!(
data::AbstractMatrix{Float32}, centroids::AbstractMatrix{Float32}, k::Int; max_iters::Int = 10,
tol::Float32 = 1.0f-4, point_bsize::Int = 1000)
@assert size(centroids)[2] == k
# TODO: move point_bsize to config?
size(centroids, 2) == k ||
throw(DimensionMismatch("size(centroids, 2) must be k!"))

# randomly initialize centroids
centroids .= data[:, randperm(size(data, 2))[1:k]]

# allocations
d, n = size(data) # dimension, number of inputs
assignments = Vector{Int32}(undef, n) |> Flux.gpu
distances = Matrix{Float32}(undef, k, point_bsize) |> Flux.gpu
Expand Down Expand Up @@ -303,7 +327,7 @@ end
function _topk(data::Matrix{T}, k::Int; dims::Int = 1) where {T <: Number}
# TODO: only works on CPU; make it work on GPUs?
# partialsortperm is not available in CUDA.jl
@assert dims in [1, 2]
dims in [1, 2] || throw(DomainError("dims must be 1 or 2!"))
mapslices(v -> partialsortperm(v, 1:k, rev = true), data, dims = dims)
end

Expand Down
7 changes: 3 additions & 4 deletions test/indexing/collection_indexer.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using LinearAlgebra: __normalize!
using ColBERT: _sample_pids, _heldout_split, setup, _bucket_cutoffs_and_weights,
_normalize_array!, _compute_avg_residuals!

Expand Down Expand Up @@ -61,7 +60,7 @@ end
@test plan_dict["num_chunks"] == div(length(collection), chunksize)

## with remainders
chunksize = rand(1:20)
chunksize = rand(1:20) + 1
collection = string.(rand(
'a':'z', chunksize * rand(1:100) + rand(1:(chunksize - 1))))
plan_dict = setup(
Expand Down Expand Up @@ -111,7 +110,7 @@ end

@testset "_compute_avg_residuals!" begin
# Test 1: centroids and heldout_avg_residual have the same columns with different perms
nbits = rand(1:20)
nbits = rand(2:20)
centroids = rand(Float32, rand(1:20), rand(1:20))
_normalize_array!(centroids; dims = 1)
perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))]
Expand All @@ -125,7 +124,7 @@ end

# Test 2: some tolerance level
tol = 1e-5
nbits = rand(1:20)
nbits = rand(2:20)
centroids = rand(Float32, rand(1:20), rand(1:20))
_normalize_array!(centroids; dims = 1)
perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))]
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Base: SimpleLogger, NullLogger, global_logger
using ColBERT
using .Iterators
using LinearAlgebra
Expand Down
205 changes: 201 additions & 4 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,213 @@
using ColBERT: compute_distances_kernel!, update_centroids_kernel!,
assign_clusters_kernel!, onehot_encode!, kmeans_gpu_onehot!,
_normalize_array!, _topk, _head

@testset "compute_distances_kernel!" begin
# Test 1: when all entries are the same
dim = rand(1:20)
batch_data = ones(Float32, dim, rand(1:20))
centroids = ones(Float32, dim, rand(1:20))
batch_distances = Matrix{Float32}(
undef, size(centroids, 2), size(batch_data, 2))
compute_distances_kernel!(batch_distances, batch_data, centroids)
@test all(iszero, batch_distances)

# Test 2: Edge case, single point and centroid
batch_data = reshape(Float32[1.0; 2.0], 2, 1)
centroids = reshape(Float32[2.0; 3.0], 2, 1)
batch_distances = Matrix{Float32}(undef, 1, 1)
compute_distances_kernel!(batch_distances, batch_data, centroids)
@test batch_distances Float32[2]

# Test 3: Special case
dim = rand(1:20)
bsize = rand(1:20)
batch_data = ones(Float32, dim, bsize)
centroids = ones(Float32, dim, bsize)
for idx in 1:bsize
batch_data[:, idx] .*= idx
centroids[:, idx] .*= idx
end
expected_distances = ones(Float32, bsize, bsize)
for (i, j) in product(1:bsize, 1:bsize)
expected_distances[i, j] = dim * (i - j)^2
end
batch_distances = Matrix{Float32}(undef, bsize, bsize)
compute_distances_kernel!(batch_distances, batch_data, centroids)
@test isequal(expected_distances, batch_distances)

# Test 4: Correct errors are thrown
batch_data = Float32[1.0 2.0; 3.0 4.0] # 2x2 matrix
centroids = Float32[1.0 0.0; 0.0 1.0] # 2x2 matrix
batch_distances = zeros(Float32, 3, 2) # Incorrect size: should be 2x2
@test_throws DimensionMismatch compute_distances_kernel!(
batch_distances, batch_data, centroids)

batch_data = Float32[1.0 2.0; 3.0 4.0] # 2x2 matrix
centroids = Float32[1.0 0.0 1.0; 0.0 1.0 0.0; 1.0 1.0 1.0] # 3x3 matrix, different row count
batch_distances = zeros(Float32, 3, 2) # Should match 3x2, but embedding dim is wrong
@test_throws DimensionMismatch compute_distances_kernel!(
batch_distances, batch_data, centroids)
end

@testset "update_centroids_kernel!" begin
# Test 1: Generic test to see if results are accumulated correctly
dim = rand(1:20)
num_centroids = rand(1:20)
num_points = rand(1:20)
point_to_centroid = rand(1:num_centroids, num_points)
new_centroids = ones(Float32, dim, num_centroids)
batch_data = ones(Float32, dim, num_points)
batch_one_hot = zeros(Float32, num_centroids, num_points)
for idx in 1:num_points
batch_one_hot[point_to_centroid[idx], idx] = 1.0f0
end
expected = zeros(Float32, dim, num_centroids)
for centroid in point_to_centroid
expected[:, centroid] .+= 1.0f0
end
update_centroids_kernel!(new_centroids, batch_data, batch_one_hot)
@test isequal(new_centroids, expected .+ 1.0f0)

# Test 2: error, incorrect `new_centroids` size
batch_data = Float32[1.0 2.0; 3.0 4.0] # 2x2 matrix
batch_one_hot = Float32[1.0 0.0; 0.0 1.0] # 2x2 matrix (one-hot encoded)
new_centroids = zeros(Float32, 3, 2) # Incorrect size: should be 2x2
@test_throws DimensionMismatch update_centroids_kernel!(
new_centroids, batch_data, batch_one_hot)

# Test 3: error, incorrect `batch_one_hot` size
batch_data = Float32[1.0 2.0; 3.0 4.0] # 2x2 matrix
batch_one_hot = Float32[1.0 0.0 0.0; 0.0 1.0 0.0] # Incorrect size: should be 2x2, not 2x3
new_centroids = zeros(Float32, 2, 2) # Correct size, but the error should be triggered by batch_one_hot
@test_throws DimensionMismatch update_centroids_kernel!(
new_centroids, batch_data, batch_one_hot)
end

@testset "assign_clusters_kernel!" begin
# Test 1: testing the correct minimum assignment with random permutations
num_points = rand(1:100)
batch_assignments = Vector{Int32}(undef, num_points)
batch_distances = Matrix{Float32}(undef, rand(1:100), num_points)
expected_assignments = Vector{Int32}(undef, num_points)
for (idx, col) in enumerate(eachcol(batch_distances))
perm = randperm(size(batch_distances, 1))
col .= Float32.(perm)
expected_assignments[idx] = sortperm(perm)[1]
end
assign_clusters_kernel!(batch_assignments, batch_distances)
@test isequal(expected_assignments, batch_assignments)

# Test 2: check DimensionMismatch error
batch_distances = Float32[1.0 2.0;
4.0 5.0]
batch_assignments = Int32[0]
@test_throws DimensionMismatch assign_clusters_kernel!(
batch_assignments, batch_distances)
end

@testset "onehot_encode!" begin
# Test 1: Basic functionality
k = rand(1:100)
batch_assignments = Int32.(collect(1:k))
batch_one_hot = zeros(Float32, k, k)
onehot_encode!(batch_one_hot, batch_assignments, k)
@test isequal(batch_one_hot, I(k))

# Test 2: Slightly convoluted example
batch_assignments = Int32[4, 2, 3, 1]
batch_one_hot = zeros(Float32, 4, 4)
onehot_encode!(batch_one_hot, batch_assignments, 4)
@test batch_one_hot == Float32[0 0 0 1;
0 1 0 0;
0 0 1 0;
1 0 0 0]
# Test 3: Edge case with k = 1
batch_assignments = Int32[1, 1, 1]
batch_one_hot = zeros(Float32, 1, 3)
onehot_encode!(batch_one_hot, batch_assignments, 1)
@test batch_one_hot == Float32[1 1 1]

# Test 4: Dimension mismatch error
batch_assignments = Int32[1, 2]
batch_one_hot = zeros(Float32, 3, 3)
@test_throws DimensionMismatch onehot_encode!(
batch_one_hot, batch_assignments, 3)
end

@testset "kmeans_gpu_onehot!" begin
# Test 1: When all points are centroids
data = rand(Float32, rand(1:100), rand(1:100))
centroids = similar(data)
point_bsize = rand(1:size(data, 2))
cluster_ids = kmeans_gpu_onehot!(data, centroids, size(data, 2))
@test isequal(centroids[:, cluster_ids], data)
end

@testset "_normalize_array!" begin
# column normalization
# column normalization
X = rand(Float32, rand(1:100), rand(1:100))
_normalize_array!(X, dims = 1)
for col in eachcol(X)
for col in eachcol(X)
@test isapprox(norm(col), 1)
end

# row normalization
# row normalization
X = rand(Float32, rand(1:100), rand(1:100))
_normalize_array!(X, dims = 2)
for row in eachrow(X)
for row in eachrow(X)
@test isapprox(norm(row), 1)
end
end

@testset "_topk" begin
# Test 1: Basic functionality with k = 2, along dimension 1 (columns)
data = [3.0 1.0 4.0;
1.0 5.0 9.0;
2.0 6.0 5.0]
k = 2
result = _topk(data, k, dims = 1)
@test result == [1 3 2;
3 2 3]

# Test 2: Basic functionality with k = 2, along dimension 2 (rows)
result = _topk(data, k, dims = 2)
@test result == [3 1;
3 2;
2 3]

# Test 3: Check DomainError for invalid dims value
@test_throws DomainError _topk(data, k, dims = 3)
end

@testset "_head" begin
# Test 1: Basic functionality with a non-empty vector
v = [1, 2, 3, 4]
result = _head(v)
@test result == [1, 2, 3]

# Test 2: Edge case with a single-element vector
v = [10]
result = _head(v)
@test result == Int[]

# Test 3: Edge case with an empty vector
v = Int[]
result = _head(v)
@test result == Int[]

# Test 4: Test with a vector of strings
v = ["a", "b", "c"]
result = _head(v)
@test result == ["a", "b"]

# Test 5: Test with a vector of floating-point numbers
v = [1.5, 2.5, 3.5]
result = _head(v)
@test result == [1.5, 2.5]

# Test 6: Test with a vector of characters
v = ['a', 'b', 'c']
result = _head(v)
@test result == ['a', 'b']
end

0 comments on commit b0ac2ac

Please sign in to comment.