From 0faf8894085155e1eabd13e20af5ccfcf22e363c Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 2 Oct 2024 11:56:13 -0700 Subject: [PATCH 1/2] [Feat] CAGRA filtering with BFKNN when sparsity matching threshold --- cpp/include/cuvs/neighbors/cagra.hpp | 27 ++++++-- .../neighbors/detail/cagra/cagra_search.cuh | 56 +++++++++++++++ cpp/test/neighbors/ann_cagra.cuh | 68 +++++++++++++------ .../ann_cagra/test_float_uint32_t.cu | 3 + 4 files changed, 131 insertions(+), 23 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index e48050756..5b7a5ab0f 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -403,6 +403,13 @@ struct index : cuvs::neighbors::index { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); update_graph(res, knn_graph); + if constexpr (raft::is_device_mdspan_v) { + contiguous_dataset_ = + raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + } else { + contiguous_dataset_ = + raft::make_host_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + } raft::resource::sync_stream(res); } @@ -417,13 +424,16 @@ struct index : cuvs::neighbors::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - dataset_ = make_aligned_dataset(res, dataset, 16); + contiguous_dataset_ = dataset; + dataset_ = make_aligned_dataset(res, dataset, 16); } /** Set the dataset reference explicitly to a device matrix view with padding. */ void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { + contiguous_dataset_ = + raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); dataset_ = make_aligned_dataset(res, dataset, 16); } @@ -436,7 +446,8 @@ struct index : cuvs::neighbors::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - dataset_ = make_aligned_dataset(res, dataset, 16); + contiguous_dataset_ = dataset; + dataset_ = make_aligned_dataset(res, dataset, 16); } /** @@ -447,14 +458,16 @@ struct index : cuvs::neighbors::index { auto update_dataset(raft::resources const& res, DatasetT&& dataset) -> std::enable_if_t, DatasetT>> { - dataset_ = std::make_unique(std::move(dataset)); + contiguous_dataset_ = std::monostate{}; + dataset_ = std::make_unique(std::move(dataset)); } template auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) -> std::enable_if_t, DatasetT>> { - dataset_ = std::move(dataset); + contiguous_dataset_ = std::monostate{}; + dataset_ = std::move(dataset); } /** @@ -492,11 +505,17 @@ struct index : cuvs::neighbors::index { graph_view_ = graph_.view(); } + auto contiguous_dataset() const { return contiguous_dataset_; } + private: cuvs::distance::DistanceType metric_; raft::device_matrix graph_; raft::device_matrix_view graph_view_; std::unique_ptr> dataset_; + std::variant, + raft::host_matrix_view> + contiguous_dataset_ = std::monostate{}; }; /** * @} diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 4c15b8e14..5a1b764d0 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -26,9 +26,11 @@ #include #include #include +#include #include +#include #include // TODO: Fix these when ivf methods are moved over @@ -140,6 +142,60 @@ void search_main(raft::resources const& res, raft::device_matrix_view distances, CagraSampleFilterT sample_filter = CagraSampleFilterT()) { + if constexpr (!std::is_same_v && + (std::is_same_v || std::is_same_v)) { + auto n_queries = queries.extent(0); + auto n_dataset = index.size(); + + auto bitset_filter_view = sample_filter.bitset_view_; + auto dataset_view = index.contiguous_dataset(); + + auto sparsity = bitset_filter_view.sparsity(res); + constexpr double threshold_to_bf = 0.9; + + // TODO: Support host dataset in `brute_force::build` + if (sparsity >= threshold_to_bf && + std::holds_alternative>( + dataset_view)) { + using bitmap_view_t = cuvs::core::bitmap_view; + + auto stream = raft::resource::get_cuda_stream(res); + auto bitmap_n_elements = + bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries); + + rmm::device_uvector raw_bitmap(bitmap_n_elements, stream); + rmm::device_uvector raw_neighbors(neighbors.size(), stream); + + bitset_filter_view.repeat(res, n_queries, raw_bitmap.data()); + + auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset); + + auto brute_force_neighbors = raft::make_device_matrix_view( + raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1)); + auto brute_force_dataset = + std::get_if>(&dataset_view); + + if (brute_force_dataset) { + auto brute_force_idx = + cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric()); + cuvs::neighbors::brute_force::search( + res, + brute_force_idx, + queries, + brute_force_neighbors, + distances, + cuvs::neighbors::filtering::bitmap_filter(brute_force_filter)); + raft::linalg::unaryOp(neighbors.data_handle(), + brute_force_neighbors.data_handle(), + neighbors.size(), + raft::cast_op(), + raft::resource::get_cuda_stream(res)); + return; + } + } + } + auto stream = raft::resource::get_cuda_stream(res); const auto& graph = index.graph(); auto graph_internal = raft::make_device_matrix_view( diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 37d42dd1d..512e7a60d 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -51,12 +51,12 @@ namespace cuvs::neighbors::cagra { namespace { struct test_cagra_sample_filter { - static constexpr unsigned offset = 300; inline _RAFT_HOST_DEVICE auto operator()( // query index const uint32_t query_ix, // the index of the current sample inside the current inverted list - const uint32_t sample_ix) const + const uint32_t sample_ix, + const uint32_t offset) const { return sample_ix >= offset; } @@ -276,6 +276,7 @@ struct AnnCagraInputs { bool include_serialized_dataset; // std::optional double min_recall; // = std::nullopt; + uint32_t filter_offset = 300; std::optional ivf_pq_search_refine_ratio = std::nullopt; std::optional compression = std::nullopt; @@ -702,21 +703,20 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { { rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); - auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim; - cuvs::neighbors::naive_knn( - handle_, - distances_naive_dev.data(), - indices_naive_dev.data(), - search_queries.data(), - database_filtered_ptr, - ps.n_queries, - ps.n_rows - test_cagra_sample_filter::offset, - ps.dim, - ps.k, - ps.metric); + auto* database_filtered_ptr = database.data() + ps.filter_offset * ps.dim; + cuvs::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.n_queries, + ps.n_rows - ps.filter_offset, + ps.dim, + ps.k, + ps.metric); raft::linalg::addScalar(indices_naive_dev.data(), indices_naive_dev.data(), - IdxT(test_cagra_sample_filter::offset), + IdxT(ps.filter_offset), queries_size, stream_); raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); @@ -787,7 +787,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { auto dists_out_view = raft::make_device_matrix_view( distances_dev.data(), ps.n_queries, ps.k); auto removed_indices = - raft::make_device_vector(handle_, test_cagra_sample_filter::offset); + raft::make_device_vector(handle_, ps.filter_offset); thrust::sequence( raft::resource::get_thrust_policy(handle_), thrust::device_pointer_cast(removed_indices.data_handle()), @@ -813,8 +813,9 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { bool unacceptable_node = false; for (int q = 0; q < ps.n_queries; q++) { for (int i = 0; i < ps.k; i++) { - const auto n = indices_Cagra[q * ps.k + i]; - unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n); + const auto n = indices_Cagra[q * ps.k + i]; + unacceptable_node = + unacceptable_node | !test_cagra_sample_filter()(q, n, ps.filter_offset); } } EXPECT_FALSE(unacceptable_node); @@ -1002,6 +1003,7 @@ inline std::vector generate_inputs() {false, true}, {false}, {0.99}, + {uint32_t(300)}, {1.0f, 2.0f, 3.0f}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -1028,6 +1030,34 @@ inline std::vector generate_inputs() return inputs; } -const std::vector inputs = generate_inputs(); +inline std::vector generate_bf_inputs() +{ + // Add test cases for brute force as sparsity >= 0.9. + std::vector inputs_for_brute_force; + auto inputs_original = raft::util::itertools::product( + {100}, + {10000, 100000}, + {1, 8, 17}, + {1, 16, 256}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, + {0, 1, 10, 100}, + {0}, + {256}, + {1}, + {cuvs::distance::DistanceType::L2Expanded}, + {false}, + {true}, + {1.0}); + for (auto input : inputs_original) { + input.filter_offset = 0.90 * input.n_rows; + inputs_for_brute_force.push_back(input); + } + + return inputs_for_brute_force; +} + +const std::vector inputs = generate_inputs(); +const std::vector inputs_brute_force = generate_bf_inputs(); } // namespace cuvs::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index ca188d132..a98c31510 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -34,5 +34,8 @@ INSTANTIATE_TEST_CASE_P(AnnCagraAddNodesTest, AnnCagraAddNodesTestF_U32, ::testing::ValuesIn(inputs)); INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnCagraFilterToBruteForceTest, + AnnCagraFilterTestF_U32, + ::testing::ValuesIn(inputs_brute_force)); } // namespace cuvs::neighbors::cagra From f14be712214b975a6e99896611677e184fc2454d Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 3 Oct 2024 16:50:16 -0700 Subject: [PATCH 2/2] revert: update_dataset on strided matrix --- cpp/include/cuvs/neighbors/cagra.hpp | 7 +++++-- cpp/src/neighbors/detail/cagra/cagra_search.cuh | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 5b7a5ab0f..83d9eec12 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -432,8 +432,11 @@ struct index : cuvs::neighbors::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - contiguous_dataset_ = - raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + contiguous_dataset_ = std::monostate{}; + if (dataset.stride(0) == dataset.extent(1) && dataset.stride(1) == 1) { + contiguous_dataset_ = + raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + } dataset_ = make_aligned_dataset(res, dataset, 16); } diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 5a1b764d0..ba0d82831 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -177,6 +177,7 @@ void search_main(raft::resources const& res, std::get_if>(&dataset_view); if (brute_force_dataset) { + RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity); auto brute_force_idx = cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric()); cuvs::neighbors::brute_force::search(