From 9baadc8755c5c42d8d70bf70441a230b6ad3ec2a Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 7 Feb 2024 16:08:52 +0100 Subject: [PATCH] Expose IVF methods --- cpp/include/cuvs/neighbors/ivf_flat.hpp | 51 ++++++++++++++++++++++++- cpp/include/cuvs/neighbors/ivf_pq.hpp | 41 ++++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index cd74b9ff3..327490db0 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -27,4 +27,53 @@ using search_params = raft::neighbors::ivf_flat::search_params; template using index = raft::neighbors::ivf_flat::index; -} // namespace cuvs::neighbors::ivf_flat +#define CUVS_IVF_FLAT(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_flat::index& idx); \ + \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_flat::index& orig_index); \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_flat::index* idx); \ + \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::search_params& params, \ + cuvs::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::ivf_flat::index& index); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::ivf_flat::index* index); \ + \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const cuvs::neighbors::ivf_flat::index& index); \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + cuvs::neighbors::ivf_flat::index* index); + +CUVS_IVF_FLAT(float, uint64_t); +CUVS_IVF_FLAT(int8_t, uint64_t); +CUVS_IVF_FLAT(uint8_t, uint64_t); + +#undef CUVS_IVF_FLAT + +} // namespace cuvs::neighbors::ivf_flat \ No newline at end of file diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 3a1cfbff0..5dec958eb 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -27,4 +27,45 @@ using search_params = raft::neighbors::ivf_pq::search_params; template using index = raft::neighbors::ivf_pq::index; +#define CUVS_IVF_PQ(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_pq::index* idx); \ + \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_pq::index& orig_index); \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_pq::index* idx); \ + \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + void serialize(raft::resources const& handle, \ + std::string& filename, \ + const cuvs::neighbors::ivf_pq::index& index); \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::ivf_pq::index* index); + +CUVS_IVF_PQ(float, uint64_t); +CUVS_IVF_PQ(int8_t, uint64_t); +CUVS_IVF_PQ(uint8_t, uint64_t); + +#undef CUVS_IVF_PQ + } // namespace cuvs::neighbors::ivf_pq