Skip to content

Commit

Permalink
Completing tests for indexing.
Browse files Browse the repository at this point in the history
  • Loading branch information
codetalker7 committed Sep 8, 2024
1 parent b0ac2ac commit a631f2a
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 47 deletions.
6 changes: 3 additions & 3 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ function index(indexer::Indexer)
indexer.skiplist, plan_dict["num_chunks"], plan_dict["chunksize"],
centroids, bucket_cutoffs, indexer.config.nbits)

# check if all relevant files are saved
_check_all_files_are_saved(indexer.config.index_path)

# collect embedding offsets and more metadata for chunks
chunk_emb_counts = load_chunk_metadata_property(
indexer.config.index_path, "num_embeddings")
Expand All @@ -139,4 +136,7 @@ function index(indexer::Indexer)
ivf_lengths_path = joinpath(indexer.config.index_path, "ivf_lengths.jld2")
JLD2.save_object(ivf_path, ivf)
JLD2.save_object(ivf_lengths_path, ivf_lengths)

# check if all relevant files are saved
_check_all_files_are_saved(indexer.config.index_path)
end
76 changes: 37 additions & 39 deletions src/indexing/collection_indexer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,6 @@ function train(
# TODO: put point_bsize in the config!
kmeans_gpu_onehot!(
sample, centroids, num_partitions; max_iters = kmeans_niters)
@assert(size(centroids, 2)==num_partitions,
"size(centroids): $(size(centroids)), num_partitions: $(num_partitions)")
@assert(centroids isa AbstractMatrix{Float32}, "$(typeof(centroids))")

# computing average residuals
heldout = heldout |> Flux.gpu
Expand Down Expand Up @@ -278,51 +275,52 @@ function index(index_path::String, bert::HF.HGFBertModel, linear::Layers.Dense,
end
end

"""
check_chunk_exists(saver::IndexSaver, chunk_idx::Int)
Check if the index chunk exists for the given `chunk_idx`.
# Arguments
- `saver`: The `IndexSaver` object that contains the indexing settings.
- `chunk_idx`: The index of the chunk to check.
# Returns
function _check_all_files_are_saved(index_path::String)
@info "Checking if all index files are saved."

A boolean indicating whether all relevant files for the chunk exist.
"""
function check_chunk_exists(index_path::String, chunk_idx::Int)
path_prefix = joinpath(index_path, string(chunk_idx))
codes_path = "$(path_prefix).codes.jld2"
residuals_path = "$(path_prefix).residuals.jld2"
doclens_path = joinpath(index_path, "doclens.$(chunk_idx).jld2")
metadata_path = joinpath(index_path, "$(chunk_idx).metadata.json")

for file in [codes_path, residuals_path, doclens_path, metadata_path]
if !isfile(file)
return false
end
# first get the plan
isfile(joinpath(index_path, "plan.json")) || begin
@info "plan.json is missing from the index!"
return false
end

true
end

function _check_all_files_are_saved(index_path::String)
plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json"))

@info "Checking if all files are saved."
for chunk_idx in 1:(plan_metadata["num_chunks"])
if !(check_chunk_exists(index_path, chunk_idx))
@error "Some files for chunk $(chunk_idx) are missing!"
end
# get the non-chunk files
files = [
joinpath(index_path, "config.json"),
joinpath(index_path, "centroids.jld2"),
joinpath(index_path, "bucket_cutoffs.jld2"),
joinpath(index_path, "bucket_weights.jld2"),
joinpath(index_path, "avg_residual.jld2"),
joinpath(index_path, "ivf.jld2"),
joinpath(index_path, "ivf_lengths.jld2")
]

# get the chunk files
for chunk_idx in 1:plan_metadata["num_chunks"]
append!(files,
[
joinpath(index_path, "$(chunk_idx).codes.jld2"),
joinpath(index_path, "$(chunk_idx).residuals.jld2"),
joinpath(index_path, "doclens.$(chunk_idx).jld2"),
joinpath(index_path, "$(chunk_idx).metadata.json")
])
end

# check for any missing files
missing_files = findall(!isfile, files)
isempty(missing_files) || begin
@info "$(files[missing_files]) are missing!"
return false
end

@info "Found all files!"
true
end

function _collect_embedding_id_offset(chunk_emb_counts::Vector{Int})
length(chunk_emb_counts) > 0 || return zeros(Int, 1)
chunk_embedding_offsets = cat([1], chunk_emb_counts[1:(end - 1)], dims = 1)
length(chunk_emb_counts) > 0 || return 0, zeros(Int, 1)
chunk_embedding_offsets = [1; _head(chunk_emb_counts)]
chunk_embedding_offsets = cumsum(chunk_embedding_offsets)
sum(chunk_emb_counts), chunk_embedding_offsets
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
150 changes: 145 additions & 5 deletions test/indexing/collection_indexer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using ColBERT: _sample_pids, _heldout_split, setup, _bucket_cutoffs_and_weights,
_normalize_array!, _compute_avg_residuals!
_normalize_array!, _compute_avg_residuals!, train,
_check_all_files_are_saved, _collect_embedding_id_offset,
_build_ivf

@testset "_sample_pids tests" begin
# Test 1: More pids than given can't be sampled
Expand Down Expand Up @@ -110,8 +112,8 @@ end

@testset "_compute_avg_residuals!" begin
# Test 1: centroids and heldout_avg_residual have the same columns with different perms
nbits = rand(2:20)
centroids = rand(Float32, rand(1:20), rand(1:20))
nbits = rand(1:20)
centroids = rand(Float32, rand(2:20), rand(1:20))
_normalize_array!(centroids; dims = 1)
perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))]
heldout = centroids[:, perm]
Expand All @@ -124,8 +126,8 @@ end

# Test 2: some tolerance level
tol = 1e-5
nbits = rand(2:20)
centroids = rand(Float32, rand(1:20), rand(1:20))
nbits = rand(1:20)
centroids = rand(Float32, rand(2:20), rand(1:20))
_normalize_array!(centroids; dims = 1)
perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))]
heldout = centroids[:, perm]
Expand Down Expand Up @@ -162,3 +164,141 @@ end
@test_throws DimensionMismatch _compute_avg_residuals!(
nbits, centroids, heldout, codes)
end

@testset "train" begin
# Test 1: When all inputs are the same + testing shapes, types
dim = rand(2:20)
nbits = rand(1:5)
kmeans_niters = rand(1:5)
sample = ones(Float32, dim, rand(1:20))
heldout = ones(Float32, dim, rand(1:size(sample, 2)))
num_partitions = rand(1:size(sample, 2))
centroids, bucket_cutoffs, bucket_weights, avg_residual = train(
sample, heldout, num_partitions, nbits, kmeans_niters)
@test all(iszero(bucket_cutoffs))
@test all(iszero(bucket_weights))
@test iszero(avg_residual)
@test centroids isa Matrix{Float32}
@test bucket_cutoffs isa Vector{Float32}
@test bucket_weights isa Vector{Float32}
@test avg_residual isa Float32
@test isequal(size(centroids), (dim, num_partitions))
@test length(bucket_cutoffs) == (1 << nbits) - 1
@test length(bucket_weights) == (1 << nbits)
end

@testset "_check_all_files_are_saved" begin
temp_dir = mktempdir()

# Create plan.json with required structure
plan_data = Dict(
"num_chunks" => 2,
"avg_doclen_est" => 10,
"num_documents" => 100,
"num_embeddings_est" => 200,
"num_embeddings" => 200,
"embeddings_offsets" => [0, 100],
"num_partitions" => 4,
"chunksize" => 50
)
open(joinpath(temp_dir, "plan.json"), "w") do f
JSON.print(f, plan_data)
end

# Create non-chunk files
non_chunk_files = [
"config.json",
"centroids.jld2",
"bucket_cutoffs.jld2",
"bucket_weights.jld2",
"avg_residual.jld2",
"ivf.jld2",
"ivf_lengths.jld2"
]
for file in non_chunk_files
touch(joinpath(temp_dir, file))
end

# Create chunk files
for chunk_idx in 1:plan_data["num_chunks"]
chunk_metadata = Dict(
"num_passages" => 50,
"num_embeddings" => 100,
"passage_offset" => 0
)
open(joinpath(temp_dir, "$(chunk_idx).metadata.json"), "w") do f
JSON.print(f, chunk_metadata)
end
touch(joinpath(temp_dir, "$(chunk_idx).codes.jld2"))
touch(joinpath(temp_dir, "$(chunk_idx).residuals.jld2"))
touch(joinpath(temp_dir, "doclens.$(chunk_idx).jld2"))
end

# Test 1: Check that all files exist
@test _check_all_files_are_saved(temp_dir)

# Test 2: Remove one file at a time and check the function returns false
all_files = [
"config.json",
non_chunk_files...,
"$(1).codes.jld2", "$(1).residuals.jld2", "doclens.1.jld2", "1.metadata.json",
"$(2).codes.jld2", "$(2).residuals.jld2", "doclens.2.jld2", "2.metadata.json"
]

for file in all_files
rm(joinpath(temp_dir, file))
@test !_check_all_files_are_saved(temp_dir)
touch(joinpath(temp_dir, file)) # Recreate the file for the next iteration
end
rm(joinpath(temp_dir, "plan.json"))
@test !_check_all_files_are_saved(temp_dir)

# Clean up
rm(temp_dir, recursive = true)
end

@testset "_collect_embedding_id_offset" begin
# Test 1: Small test with fixed values
chunk_emb_counts = [3, 5, 2]
total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts)
@test total_sum == 10
@test offsets == [1, 4, 9]

# Test 2: Edge case with empty inputs
chunk_emb_counts = Int[]
total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts)
@test total_sum == 0 # No elements, so sum is 0
@test offsets == [0] # When empty, it should return [0]

# Test 3: All elements are ones
chunk_emb_counts = ones(Int, rand(1:20))
total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts)
@test total_sum == length(chunk_emb_counts)
@test offsets == collect(1:length(chunk_emb_counts))

# Test 4: Type of outputs
chunk_emb_counts = rand(Int, rand(1:20))
total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts)
@test total_sum isa Int
@test offsets isa Vector{Int}
end

@testset "_build_ivf" begin
# Test 1: Typical input case
codes = UInt32[5, 3, 8, 2, 5, 5, 4, 2, 2, 1, 3]
num_partitions = 10
ivf, ivf_lengths = _build_ivf(codes, num_partitions)
@test ivf == [10, 4, 8, 9, 2, 11, 7, 1, 5, 6, 3]
@test ivf_lengths == [1, 3, 2, 1, 3, 0, 0, 1, 0, 0]

# Test 2: Testing types, shapes and range of vals
num_partitions = rand(1:1000)
codes = UInt32.(rand(1:num_partitions, 10000)) # Large array with random values
ivf, ivf_lengths = _build_ivf(codes, num_partitions)
@test length(ivf) == length(codes)
@test sum(ivf_lengths) == length(codes)
@test length(ivf_lengths) == num_partitions
@test all(in(ivf), codes)
@test ivf isa Vector{Int}
@test ivf_lengths isa Vector{Int}
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using ColBERT
using .Iterators
using JSON
using LinearAlgebra
using Logging
using Random
Expand Down

0 comments on commit a631f2a

Please sign in to comment.