From a631f2a8f2d907599b4a4f4e8ac05e2cb5083998 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Mon, 9 Sep 2024 02:58:40 +0530 Subject: [PATCH] Completing tests for indexing. --- src/indexing.jl | 6 +- src/indexing/collection_indexer.jl | 76 +++++++------- test/Project.toml | 1 + test/indexing/collection_indexer.jl | 150 +++++++++++++++++++++++++++- test/runtests.jl | 1 + 5 files changed, 187 insertions(+), 47 deletions(-) diff --git a/src/indexing.jl b/src/indexing.jl index a691e12..b80091e 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -110,9 +110,6 @@ function index(indexer::Indexer) indexer.skiplist, plan_dict["num_chunks"], plan_dict["chunksize"], centroids, bucket_cutoffs, indexer.config.nbits) - # check if all relevant files are saved - _check_all_files_are_saved(indexer.config.index_path) - # collect embedding offsets and more metadata for chunks chunk_emb_counts = load_chunk_metadata_property( indexer.config.index_path, "num_embeddings") @@ -139,4 +136,7 @@ function index(indexer::Indexer) ivf_lengths_path = joinpath(indexer.config.index_path, "ivf_lengths.jld2") JLD2.save_object(ivf_path, ivf) JLD2.save_object(ivf_lengths_path, ivf_lengths) + + # check if all relevant files are saved + _check_all_files_are_saved(indexer.config.index_path) end diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index 6a45ae7..f885019 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -220,9 +220,6 @@ function train( # TODO: put point_bsize in the config! kmeans_gpu_onehot!( sample, centroids, num_partitions; max_iters = kmeans_niters) - @assert(size(centroids, 2)==num_partitions, - "size(centroids): $(size(centroids)), num_partitions: $(num_partitions)") - @assert(centroids isa AbstractMatrix{Float32}, "$(typeof(centroids))") # computing average residuals heldout = heldout |> Flux.gpu @@ -278,51 +275,52 @@ function index(index_path::String, bert::HF.HGFBertModel, linear::Layers.Dense, end end -""" - check_chunk_exists(saver::IndexSaver, chunk_idx::Int) - -Check if the index chunk exists for the given `chunk_idx`. - -# Arguments - - - `saver`: The `IndexSaver` object that contains the indexing settings. - - `chunk_idx`: The index of the chunk to check. - -# Returns +function _check_all_files_are_saved(index_path::String) + @info "Checking if all index files are saved." -A boolean indicating whether all relevant files for the chunk exist. -""" -function check_chunk_exists(index_path::String, chunk_idx::Int) - path_prefix = joinpath(index_path, string(chunk_idx)) - codes_path = "$(path_prefix).codes.jld2" - residuals_path = "$(path_prefix).residuals.jld2" - doclens_path = joinpath(index_path, "doclens.$(chunk_idx).jld2") - metadata_path = joinpath(index_path, "$(chunk_idx).metadata.json") - - for file in [codes_path, residuals_path, doclens_path, metadata_path] - if !isfile(file) - return false - end + # first get the plan + isfile(joinpath(index_path, "plan.json")) || begin + @info "plan.json is missing from the index!" + return false end - - true -end - -function _check_all_files_are_saved(index_path::String) plan_metadata = JSON.parsefile(joinpath(index_path, "plan.json")) - @info "Checking if all files are saved." - for chunk_idx in 1:(plan_metadata["num_chunks"]) - if !(check_chunk_exists(index_path, chunk_idx)) - @error "Some files for chunk $(chunk_idx) are missing!" - end + # get the non-chunk files + files = [ + joinpath(index_path, "config.json"), + joinpath(index_path, "centroids.jld2"), + joinpath(index_path, "bucket_cutoffs.jld2"), + joinpath(index_path, "bucket_weights.jld2"), + joinpath(index_path, "avg_residual.jld2"), + joinpath(index_path, "ivf.jld2"), + joinpath(index_path, "ivf_lengths.jld2") + ] + + # get the chunk files + for chunk_idx in 1:plan_metadata["num_chunks"] + append!(files, + [ + joinpath(index_path, "$(chunk_idx).codes.jld2"), + joinpath(index_path, "$(chunk_idx).residuals.jld2"), + joinpath(index_path, "doclens.$(chunk_idx).jld2"), + joinpath(index_path, "$(chunk_idx).metadata.json") + ]) end + + # check for any missing files + missing_files = findall(!isfile, files) + isempty(missing_files) || begin + @info "$(files[missing_files]) are missing!" + return false + end + @info "Found all files!" + true end function _collect_embedding_id_offset(chunk_emb_counts::Vector{Int}) - length(chunk_emb_counts) > 0 || return zeros(Int, 1) - chunk_embedding_offsets = cat([1], chunk_emb_counts[1:(end - 1)], dims = 1) + length(chunk_emb_counts) > 0 || return 0, zeros(Int, 1) + chunk_embedding_offsets = [1; _head(chunk_emb_counts)] chunk_embedding_offsets = cumsum(chunk_embedding_offsets) sum(chunk_emb_counts), chunk_embedding_offsets end diff --git a/test/Project.toml b/test/Project.toml index 97bba88..29425ae 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/indexing/collection_indexer.jl b/test/indexing/collection_indexer.jl index 7109d91..e987e1c 100644 --- a/test/indexing/collection_indexer.jl +++ b/test/indexing/collection_indexer.jl @@ -1,5 +1,7 @@ using ColBERT: _sample_pids, _heldout_split, setup, _bucket_cutoffs_and_weights, - _normalize_array!, _compute_avg_residuals! + _normalize_array!, _compute_avg_residuals!, train, + _check_all_files_are_saved, _collect_embedding_id_offset, + _build_ivf @testset "_sample_pids tests" begin # Test 1: More pids than given can't be sampled @@ -110,8 +112,8 @@ end @testset "_compute_avg_residuals!" begin # Test 1: centroids and heldout_avg_residual have the same columns with different perms - nbits = rand(2:20) - centroids = rand(Float32, rand(1:20), rand(1:20)) + nbits = rand(1:20) + centroids = rand(Float32, rand(2:20), rand(1:20)) _normalize_array!(centroids; dims = 1) perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))] heldout = centroids[:, perm] @@ -124,8 +126,8 @@ end # Test 2: some tolerance level tol = 1e-5 - nbits = rand(2:20) - centroids = rand(Float32, rand(1:20), rand(1:20)) + nbits = rand(1:20) + centroids = rand(Float32, rand(2:20), rand(1:20)) _normalize_array!(centroids; dims = 1) perm = randperm(size(centroids, 2))[1:rand(1:size(centroids, 2))] heldout = centroids[:, perm] @@ -162,3 +164,141 @@ end @test_throws DimensionMismatch _compute_avg_residuals!( nbits, centroids, heldout, codes) end + +@testset "train" begin + # Test 1: When all inputs are the same + testing shapes, types + dim = rand(2:20) + nbits = rand(1:5) + kmeans_niters = rand(1:5) + sample = ones(Float32, dim, rand(1:20)) + heldout = ones(Float32, dim, rand(1:size(sample, 2))) + num_partitions = rand(1:size(sample, 2)) + centroids, bucket_cutoffs, bucket_weights, avg_residual = train( + sample, heldout, num_partitions, nbits, kmeans_niters) + @test all(iszero(bucket_cutoffs)) + @test all(iszero(bucket_weights)) + @test iszero(avg_residual) + @test centroids isa Matrix{Float32} + @test bucket_cutoffs isa Vector{Float32} + @test bucket_weights isa Vector{Float32} + @test avg_residual isa Float32 + @test isequal(size(centroids), (dim, num_partitions)) + @test length(bucket_cutoffs) == (1 << nbits) - 1 + @test length(bucket_weights) == (1 << nbits) +end + +@testset "_check_all_files_are_saved" begin + temp_dir = mktempdir() + + # Create plan.json with required structure + plan_data = Dict( + "num_chunks" => 2, + "avg_doclen_est" => 10, + "num_documents" => 100, + "num_embeddings_est" => 200, + "num_embeddings" => 200, + "embeddings_offsets" => [0, 100], + "num_partitions" => 4, + "chunksize" => 50 + ) + open(joinpath(temp_dir, "plan.json"), "w") do f + JSON.print(f, plan_data) + end + + # Create non-chunk files + non_chunk_files = [ + "config.json", + "centroids.jld2", + "bucket_cutoffs.jld2", + "bucket_weights.jld2", + "avg_residual.jld2", + "ivf.jld2", + "ivf_lengths.jld2" + ] + for file in non_chunk_files + touch(joinpath(temp_dir, file)) + end + + # Create chunk files + for chunk_idx in 1:plan_data["num_chunks"] + chunk_metadata = Dict( + "num_passages" => 50, + "num_embeddings" => 100, + "passage_offset" => 0 + ) + open(joinpath(temp_dir, "$(chunk_idx).metadata.json"), "w") do f + JSON.print(f, chunk_metadata) + end + touch(joinpath(temp_dir, "$(chunk_idx).codes.jld2")) + touch(joinpath(temp_dir, "$(chunk_idx).residuals.jld2")) + touch(joinpath(temp_dir, "doclens.$(chunk_idx).jld2")) + end + + # Test 1: Check that all files exist + @test _check_all_files_are_saved(temp_dir) + + # Test 2: Remove one file at a time and check the function returns false + all_files = [ + "config.json", + non_chunk_files..., + "$(1).codes.jld2", "$(1).residuals.jld2", "doclens.1.jld2", "1.metadata.json", + "$(2).codes.jld2", "$(2).residuals.jld2", "doclens.2.jld2", "2.metadata.json" + ] + + for file in all_files + rm(joinpath(temp_dir, file)) + @test !_check_all_files_are_saved(temp_dir) + touch(joinpath(temp_dir, file)) # Recreate the file for the next iteration + end + rm(joinpath(temp_dir, "plan.json")) + @test !_check_all_files_are_saved(temp_dir) + + # Clean up + rm(temp_dir, recursive = true) +end + +@testset "_collect_embedding_id_offset" begin + # Test 1: Small test with fixed values + chunk_emb_counts = [3, 5, 2] + total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts) + @test total_sum == 10 + @test offsets == [1, 4, 9] + + # Test 2: Edge case with empty inputs + chunk_emb_counts = Int[] + total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts) + @test total_sum == 0 # No elements, so sum is 0 + @test offsets == [0] # When empty, it should return [0] + + # Test 3: All elements are ones + chunk_emb_counts = ones(Int, rand(1:20)) + total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts) + @test total_sum == length(chunk_emb_counts) + @test offsets == collect(1:length(chunk_emb_counts)) + + # Test 4: Type of outputs + chunk_emb_counts = rand(Int, rand(1:20)) + total_sum, offsets = _collect_embedding_id_offset(chunk_emb_counts) + @test total_sum isa Int + @test offsets isa Vector{Int} +end + +@testset "_build_ivf" begin + # Test 1: Typical input case + codes = UInt32[5, 3, 8, 2, 5, 5, 4, 2, 2, 1, 3] + num_partitions = 10 + ivf, ivf_lengths = _build_ivf(codes, num_partitions) + @test ivf == [10, 4, 8, 9, 2, 11, 7, 1, 5, 6, 3] + @test ivf_lengths == [1, 3, 2, 1, 3, 0, 0, 1, 0, 0] + + # Test 2: Testing types, shapes and range of vals + num_partitions = rand(1:1000) + codes = UInt32.(rand(1:num_partitions, 10000)) # Large array with random values + ivf, ivf_lengths = _build_ivf(codes, num_partitions) + @test length(ivf) == length(codes) + @test sum(ivf_lengths) == length(codes) + @test length(ivf_lengths) == num_partitions + @test all(in(ivf), codes) + @test ivf isa Vector{Int} + @test ivf_lengths isa Vector{Int} +end diff --git a/test/runtests.jl b/test/runtests.jl index 96f3cac..1a1af4e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using ColBERT using .Iterators +using JSON using LinearAlgebra using Logging using Random