Skip to content

Commit

Permalink
Expose IVF methods
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Feb 7, 2024
1 parent 6976acc commit 9baadc8
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 1 deletion.
51 changes: 50 additions & 1 deletion cpp/include/cuvs/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,53 @@ using search_params = raft::neighbors::ivf_flat::search_params;
template <typename T, typename IdxT>
using index = raft::neighbors::ivf_flat::index<T, IdxT>;

} // namespace cuvs::neighbors::ivf_flat
#define CUVS_IVF_FLAT(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::ivf_flat::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset); \
\
void build(raft::resources const& handle, \
const cuvs::neighbors::ivf_flat::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
cuvs::neighbors::ivf_flat::index<T, IdxT>& idx); \
\
auto extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const cuvs::neighbors::ivf_flat::index<T, IdxT>& orig_index); \
\
void extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
cuvs::neighbors::ivf_flat::index<T, IdxT>* idx); \
\
void search(raft::resources const& handle, \
const cuvs::neighbors::ivf_flat::search_params& params, \
cuvs::neighbors::ivf_flat::index<T, IdxT>& index, \
raft::device_matrix_view<const T, IdxT, raft::row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors, \
raft::device_matrix_view<float, IdxT, raft::row_major> distances); \
\
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const cuvs::neighbors::ivf_flat::index<T, IdxT>& index); \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
cuvs::neighbors::ivf_flat::index<T, IdxT>* index); \
\
void serialize(raft::resources const& handle, \
std::string& str, \
const cuvs::neighbors::ivf_flat::index<T, IdxT>& index); \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
cuvs::neighbors::ivf_flat::index<T, IdxT>* index);

CUVS_IVF_FLAT(float, uint64_t);
CUVS_IVF_FLAT(int8_t, uint64_t);
CUVS_IVF_FLAT(uint8_t, uint64_t);

#undef CUVS_IVF_FLAT

} // namespace cuvs::neighbors::ivf_flat
41 changes: 41 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,45 @@ using search_params = raft::neighbors::ivf_pq::search_params;
template <typename IdxT>
using index = raft::neighbors::ivf_pq::index<IdxT>;

#define CUVS_IVF_PQ(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset); \
\
void build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx); \
\
auto extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const cuvs::neighbors::ivf_pq::index<IdxT>& orig_index); \
\
void extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx); \
\
void search(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::search_params& params, \
cuvs::neighbors::ivf_pq::index<IdxT>& index, \
raft::device_matrix_view<const T, IdxT, raft::row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors, \
raft::device_matrix_view<float, IdxT, raft::row_major> distances); \
\
void serialize(raft::resources const& handle, \
std::string& filename, \
const cuvs::neighbors::ivf_pq::index<IdxT>& index); \
\
void deserialize(raft::resources const& handle, \
const std::string& filename, \
cuvs::neighbors::ivf_pq::index<IdxT>* index);

CUVS_IVF_PQ(float, uint64_t);
CUVS_IVF_PQ(int8_t, uint64_t);
CUVS_IVF_PQ(uint8_t, uint64_t);

#undef CUVS_IVF_PQ

} // namespace cuvs::neighbors::ivf_pq

0 comments on commit 9baadc8

Please sign in to comment.