Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] CAGRA filtering with BFKNN when sparsity matching threshold #378

Open
wants to merge 4 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,13 @@ struct index : cuvs::neighbors::index {
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
"Dataset and knn_graph must have equal number of rows");
update_graph(res, knn_graph);
if constexpr (raft::is_device_mdspan_v<decltype(dataset)>) {
contiguous_dataset_ =
raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1));
} else {
contiguous_dataset_ =
raft::make_host_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1));
}

raft::resource::sync_stream(res);
}
Expand All @@ -417,13 +424,19 @@ struct index : cuvs::neighbors::index {
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_ = make_aligned_dataset(res, dataset, 16);
contiguous_dataset_ = dataset;
dataset_ = make_aligned_dataset(res, dataset, 16);
}

/** Set the dataset reference explicitly to a device matrix view with padding. */
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::layout_stride> dataset)
{
contiguous_dataset_ = std::monostate{};
if (dataset.stride(0) == dataset.extent(1) && dataset.stride(1) == 1) {
contiguous_dataset_ =
raft::make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1));
}
dataset_ = make_aligned_dataset(res, dataset, 16);
}

Expand All @@ -436,7 +449,8 @@ struct index : cuvs::neighbors::index {
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset)
{
dataset_ = make_aligned_dataset(res, dataset, 16);
contiguous_dataset_ = dataset;
dataset_ = make_aligned_dataset(res, dataset, 16);
}

/**
Expand All @@ -447,14 +461,16 @@ struct index : cuvs::neighbors::index {
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<int64_t>, DatasetT>>
{
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
contiguous_dataset_ = std::monostate{};
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
}

template <typename DatasetT>
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
dataset_ = std::move(dataset);
contiguous_dataset_ = std::monostate{};
dataset_ = std::move(dataset);
}

/**
Expand Down Expand Up @@ -492,11 +508,17 @@ struct index : cuvs::neighbors::index {
graph_view_ = graph_.view();
}

auto contiguous_dataset() const { return contiguous_dataset_; }

private:
cuvs::distance::DistanceType metric_;
raft::device_matrix<IdxT, int64_t, raft::row_major> graph_;
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
std::unique_ptr<neighbors::dataset<int64_t>> dataset_;
std::variant<std::monostate,
raft::device_matrix_view<const T, int64_t, raft::row_major>,
raft::host_matrix_view<const T, int64_t, raft::row_major>>
contiguous_dataset_ = std::monostate{};
};
/**
* @}
Expand Down
57 changes: 57 additions & 0 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/unary_op.cuh>

#include <cuvs/distance/distance.hpp>

#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/neighbors/cagra.hpp>

// TODO: Fix these when ivf methods are moved over
Expand Down Expand Up @@ -140,6 +142,61 @@ void search_main(raft::resources const& res,
raft::device_matrix_view<DistanceT, int64_t, raft::row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
if constexpr (!std::is_same_v<CagraSampleFilterT,
cuvs::neighbors::filtering::none_sample_filter> &&
(std::is_same_v<T, float> || std::is_same_v<T, half>)) {
auto n_queries = queries.extent(0);
auto n_dataset = index.size();

auto bitset_filter_view = sample_filter.bitset_view_;
auto dataset_view = index.contiguous_dataset();

auto sparsity = bitset_filter_view.sparsity(res);
constexpr double threshold_to_bf = 0.9;

// TODO: Support host dataset in `brute_force::build`
if (sparsity >= threshold_to_bf &&
std::holds_alternative<raft::device_matrix_view<const T, int64_t, raft::row_major>>(
dataset_view)) {
using bitmap_view_t = cuvs::core::bitmap_view<const uint32_t, int64_t>;

auto stream = raft::resource::get_cuda_stream(res);
auto bitmap_n_elements =
bitmap_view_t::eval_n_elements(bitset_filter_view.size() * n_queries);

rmm::device_uvector<uint32_t> raw_bitmap(bitmap_n_elements, stream);
rmm::device_uvector<int64_t> raw_neighbors(neighbors.size(), stream);

bitset_filter_view.repeat(res, n_queries, raw_bitmap.data());

auto brute_force_filter = bitmap_view_t(raw_bitmap.data(), n_queries, n_dataset);

auto brute_force_neighbors = raft::make_device_matrix_view<int64_t, int64_t, raft::row_major>(
raw_neighbors.data(), neighbors.extent(0), neighbors.extent(1));
auto brute_force_dataset =
std::get_if<raft::device_matrix_view<const T, int64_t, raft::row_major>>(&dataset_view);

if (brute_force_dataset) {
RAFT_LOG_DEBUG("CAGRA is switching to brute force with sparsity:%d", sparsity);
auto brute_force_idx =
cuvs::neighbors::brute_force::build(res, *brute_force_dataset, index.metric());
cuvs::neighbors::brute_force::search(
res,
brute_force_idx,
queries,
brute_force_neighbors,
distances,
cuvs::neighbors::filtering::bitmap_filter(brute_force_filter));
raft::linalg::unaryOp(neighbors.data_handle(),
brute_force_neighbors.data_handle(),
neighbors.size(),
raft::cast_op<InternalIdxT>(),
raft::resource::get_cuda_stream(res));
return;
}
}
}

auto stream = raft::resource::get_cuda_stream(res);
const auto& graph = index.graph();
auto graph_internal = raft::make_device_matrix_view<const InternalIdxT, int64_t, raft::row_major>(
Expand Down
68 changes: 49 additions & 19 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ namespace cuvs::neighbors::cagra {
namespace {

struct test_cagra_sample_filter {
static constexpr unsigned offset = 300;
inline _RAFT_HOST_DEVICE auto operator()(
// query index
const uint32_t query_ix,
// the index of the current sample inside the current inverted list
const uint32_t sample_ix) const
const uint32_t sample_ix,
const uint32_t offset) const
{
return sample_ix >= offset;
}
Expand Down Expand Up @@ -276,6 +276,7 @@ struct AnnCagraInputs {
bool include_serialized_dataset;
// std::optional<double>
double min_recall; // = std::nullopt;
uint32_t filter_offset = 300;
std::optional<float> ivf_pq_search_refine_ratio = std::nullopt;
std::optional<vpq_params> compression = std::nullopt;

Expand Down Expand Up @@ -702,21 +703,20 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
{
rmm::device_uvector<DistanceT> distances_naive_dev(queries_size, stream_);
rmm::device_uvector<IdxT> indices_naive_dev(queries_size, stream_);
auto* database_filtered_ptr = database.data() + test_cagra_sample_filter::offset * ps.dim;
cuvs::neighbors::naive_knn<DistanceT, DataT, IdxT>(
handle_,
distances_naive_dev.data(),
indices_naive_dev.data(),
search_queries.data(),
database_filtered_ptr,
ps.n_queries,
ps.n_rows - test_cagra_sample_filter::offset,
ps.dim,
ps.k,
ps.metric);
auto* database_filtered_ptr = database.data() + ps.filter_offset * ps.dim;
cuvs::neighbors::naive_knn<DistanceT, DataT, IdxT>(handle_,
distances_naive_dev.data(),
indices_naive_dev.data(),
search_queries.data(),
database_filtered_ptr,
ps.n_queries,
ps.n_rows - ps.filter_offset,
ps.dim,
ps.k,
ps.metric);
raft::linalg::addScalar(indices_naive_dev.data(),
indices_naive_dev.data(),
IdxT(test_cagra_sample_filter::offset),
IdxT(ps.filter_offset),
queries_size,
stream_);
raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_);
Expand Down Expand Up @@ -787,7 +787,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto dists_out_view = raft::make_device_matrix_view<DistanceT, int64_t>(
distances_dev.data(), ps.n_queries, ps.k);
auto removed_indices =
raft::make_device_vector<int64_t, int64_t>(handle_, test_cagra_sample_filter::offset);
raft::make_device_vector<int64_t, int64_t>(handle_, ps.filter_offset);
thrust::sequence(
raft::resource::get_thrust_policy(handle_),
thrust::device_pointer_cast(removed_indices.data_handle()),
Expand All @@ -813,8 +813,9 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
bool unacceptable_node = false;
for (int q = 0; q < ps.n_queries; q++) {
for (int i = 0; i < ps.k; i++) {
const auto n = indices_Cagra[q * ps.k + i];
unacceptable_node = unacceptable_node | !test_cagra_sample_filter()(q, n);
const auto n = indices_Cagra[q * ps.k + i];
unacceptable_node =
unacceptable_node | !test_cagra_sample_filter()(q, n, ps.filter_offset);
}
}
EXPECT_FALSE(unacceptable_node);
Expand Down Expand Up @@ -1002,6 +1003,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{false, true},
{false},
{0.99},
{uint32_t(300)},
{1.0f, 2.0f, 3.0f});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand All @@ -1028,6 +1030,34 @@ inline std::vector<AnnCagraInputs> generate_inputs()
return inputs;
}

const std::vector<AnnCagraInputs> inputs = generate_inputs();
inline std::vector<AnnCagraInputs> generate_bf_inputs()
{
// Add test cases for brute force as sparsity >= 0.9.
std::vector<AnnCagraInputs> inputs_for_brute_force;
auto inputs_original = raft::util::itertools::product<AnnCagraInputs>(
{100},
{10000, 100000},
{1, 8, 17},
{1, 16, 256}, // k
{graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT},
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL},
{0, 1, 10, 100},
{0},
{256},
{1},
{cuvs::distance::DistanceType::L2Expanded},
{false},
{true},
{1.0});
for (auto input : inputs_original) {
input.filter_offset = 0.90 * input.n_rows;
inputs_for_brute_force.push_back(input);
}

return inputs_for_brute_force;
}

const std::vector<AnnCagraInputs> inputs = generate_inputs();
const std::vector<AnnCagraInputs> inputs_brute_force = generate_bf_inputs();

} // namespace cuvs::neighbors::cagra
3 changes: 3 additions & 0 deletions cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,8 @@ INSTANTIATE_TEST_CASE_P(AnnCagraAddNodesTest,
AnnCagraAddNodesTestF_U32,
::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterTest, AnnCagraFilterTestF_U32, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(AnnCagraFilterToBruteForceTest,
AnnCagraFilterTestF_U32,
::testing::ValuesIn(inputs_brute_force));

} // namespace cuvs::neighbors::cagra
Loading