diff --git a/.gitignore b/.gitignore index 9fcde3fb3..68996dbdf 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,9 @@ docs/source/_static/rust # clang tooling compile_commands.json .clangd/ + +# serialized ann indexes +cagra_index +ivf_flat_index +ivf_pq_index + diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index f6d6a2223..c2d311bf2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -187,6 +187,8 @@ include(cmake/thirdparty/get_cutlass.cmake) add_library( cuvs SHARED + src/neighbors/brute_force_index.cu + src/neighbors/brute_force.cu src/neighbors/cagra_build_float.cpp src/neighbors/cagra_build_int8.cpp src/neighbors/cagra_build_uint8.cpp @@ -197,6 +199,30 @@ add_library( src/neighbors/cagra_serialize_float.cpp src/neighbors/cagra_serialize_int8.cpp src/neighbors/cagra_serialize_uint8.cpp + src/neighbors/ivf_flat_index.cpp + src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_search_float_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cpp + src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cpp + src/neighbors/ivf_pq_index.cpp + src/neighbors/ivf_pq/ivf_pq_build_float_int64_t.cpp + src/neighbors/ivf_pq/ivf_pq_build_int8_t_int64_t.cpp + src/neighbors/ivf_pq/ivf_pq_build_uint8_t_int64_t.cpp + src/neighbors/ivf_pq/ivf_pq_extend_float_int64_t.cpp + src/neighbors/ivf_pq/ivf_pq_extend_int8_t_int64_t.cpp + src/neighbors/ivf_pq/ivf_pq_extend_uint8_t_int64_t.cpp + src/neighbors/ivf_pq/ivf_pq_search_float_int64_t.cpp + src/neighbors/ivf_pq/ivf_pq_search_int8_t_int64_t.cpp + src/neighbors/ivf_pq/ivf_pq_search_uint8_t_int64_t.cpp + src/neighbors/ivf_pq_serialize.cpp ) target_compile_options( @@ -297,7 +323,14 @@ target_link_options(cuvs PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") # ################################################################################################## # * cuvs_c ------------------------------------------------------------------------------- if(BUILD_C_LIBRARY) - add_library(cuvs_c SHARED src/core/c_api.cpp src/neighbors/cagra_c.cpp) + add_library( + cuvs_c SHARED + src/core/c_api.cpp + src/neighbors/brute_force_c.cpp + src/neighbors/ivf_flat_c.cpp + src/neighbors/ivf_pq_c.cpp + src/neighbors/cagra_c.cpp + ) add_library(cuvs::c_api ALIAS cuvs_c) diff --git a/cpp/include/cuvs/core/detail/interop.hpp b/cpp/include/cuvs/core/detail/interop.hpp index f218dc554..208daaae7 100644 --- a/cpp/include/cuvs/core/detail/interop.hpp +++ b/cpp/include/cuvs/core/detail/interop.hpp @@ -53,20 +53,20 @@ DLDataType data_type_to_DLDataType() } } -bool is_dlpack_device_compatible(DLTensor tensor) +inline bool is_dlpack_device_compatible(DLTensor tensor) { return tensor.device.device_type == kDLCUDAManaged || tensor.device.device_type == kDLCUDAHost || tensor.device.device_type == kDLCUDA; } -bool is_dlpack_host_compatible(DLTensor tensor) +inline bool is_dlpack_host_compatible(DLTensor tensor) { return tensor.device.device_type == kDLCUDAManaged || tensor.device.device_type == kDLCUDAHost || tensor.device.device_type == kDLCPU; } template > -MdspanType from_dlpack(DLManagedTensor* managed_tensor) +inline MdspanType from_dlpack(DLManagedTensor* managed_tensor) { auto tensor = managed_tensor->dl_tensor; diff --git a/cpp/include/cuvs/core/interop.hpp b/cpp/include/cuvs/core/interop.hpp index 9510022db..2462f02ec 100644 --- a/cpp/include/cuvs/core/interop.hpp +++ b/cpp/include/cuvs/core/interop.hpp @@ -33,7 +33,7 @@ namespace cuvs::core { * @param[in] tensor DLTensor object to check underlying memory type * @return bool */ -bool is_dlpack_device_compatible(DLTensor tensor) +inline bool is_dlpack_device_compatible(DLTensor tensor) { return detail::is_dlpack_device_compatible(tensor); } @@ -46,7 +46,7 @@ bool is_dlpack_device_compatible(DLTensor tensor) * @param tensor DLTensor object to check underlying memory type * @return bool */ -bool is_dlpack_host_compatible(DLTensor tensor) +inline bool is_dlpack_host_compatible(DLTensor tensor) { return detail::is_dlpack_host_compatible(tensor); } @@ -72,7 +72,7 @@ bool is_dlpack_host_compatible(DLTensor tensor) * @return MdspanType */ template > -MdspanType from_dlpack(DLManagedTensor* managed_tensor) +inline MdspanType from_dlpack(DLManagedTensor* managed_tensor) { return detail::from_dlpack(managed_tensor); } diff --git a/cpp/include/cuvs/distance/distance_types.h b/cpp/include/cuvs/distance/distance_types.h new file mode 100644 index 000000000..8e9a4149c --- /dev/null +++ b/cpp/include/cuvs/distance/distance_types.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __cplusplus +extern "C" { +#endif + +/** enum to tell how to compute distance */ +enum DistanceType { + + /** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */ + L2Expanded = 0, + /** same as above, but inside the epilogue, perform square root operation */ + L2SqrtExpanded = 1, + /** cosine distance */ + CosineExpanded = 2, + /** L1 distance */ + L1 = 3, + /** evaluate as dist_ij += (x_ik - y-jk)^2 */ + L2Unexpanded = 4, + /** same as above, but inside the epilogue, perform square root operation */ + L2SqrtUnexpanded = 5, + /** basic inner product **/ + InnerProduct = 6, + /** Chebyshev (Linf) distance **/ + Linf = 7, + /** Canberra distance **/ + Canberra = 8, + /** Generalized Minkowski distance **/ + LpUnexpanded = 9, + /** Correlation distance **/ + CorrelationExpanded = 10, + /** Jaccard distance **/ + JaccardExpanded = 11, + /** Hellinger distance **/ + HellingerExpanded = 12, + /** Haversine distance **/ + Haversine = 13, + /** Bray-Curtis distance **/ + BrayCurtis = 14, + /** Jensen-Shannon distance**/ + JensenShannon = 15, + /** Hamming distance **/ + HammingUnexpanded = 16, + /** KLDivergence **/ + KLDivergence = 17, + /** RusselRao **/ + RusselRaoExpanded = 18, + /** Dice-Sorensen distance **/ + DiceExpanded = 19, + /** Precomputed (special value) **/ + Precomputed = 100 +}; + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/cpp/include/cuvs/neighbors/brute_force.h b/cpp/include/cuvs/neighbors/brute_force.h new file mode 100644 index 000000000..0bb4d6bdb --- /dev/null +++ b/cpp/include/cuvs/neighbors/brute_force.h @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @defgroup bruteforce_c_index Bruteforce index + * @{ + */ +/** + * @brief Struct to hold address of cuvs::neighbors::brute_force::index and its active trained dtype + * + */ +typedef struct { + uintptr_t addr; + DLDataType dtype; +} bruteForceIndex; + +typedef bruteForceIndex* cuvsBruteForceIndex_t; + +/** + * @brief Allocate BRUTEFORCE index + * + * @param[in] index cuvsBruteForceIndex_t to allocate + * @return cuvsError_t + */ +cuvsError_t bruteForceIndexCreate(cuvsBruteForceIndex_t* index); + +/** + * @brief De-allocate BRUTEFORCE index + * + * @param[in] index cuvsBruteForceIndex_t to de-allocate + */ +cuvsError_t bruteForceIndexDestroy(cuvsBruteForceIndex_t index); +/** + * @} + */ + +/** + * @defgroup bruteforce_c_index_build Bruteforce index build + * @{ + */ +/** + * @brief Build a BRUTEFORCE index with a `DLManagedTensor` which has underlying + * `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`, + * or `kDLCPU`. Also, acceptable underlying types are: + * 1. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * 2. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8` + * 3. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8` + * + * @code {.c} + * #include + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // Assume a populated `DLManagedTensor` type here + * DLManagedTensor dataset; + * + * // Create BRUTEFORCE index + * cuvsBruteForceIndex_t index; + * cuvsError_t index_create_status = bruteForceIndexCreate(&index); + * + * // Build the BRUTEFORCE Index + * cuvsError_t build_status = bruteForceBuild(res, &dataset_tensor, L2Expanded, 0.f, index); + * + * // de-allocate `index` and `res` + * cuvsError_t index_destroy_status = bruteForceIndexDestroy(index); + * cuvsError_t res_destroy_status = cuvsResourcesDestroy(res); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] dataset DLManagedTensor* training dataset + * @param[in] metric metric + * @param[in] metric_arg metric_arg + * @param[out] index cuvsBruteForceIndex_t Newly built BRUTEFORCE index + * @return cuvsError_t + */ +cuvsError_t bruteForceBuild(cuvsResources_t res, + DLManagedTensor* dataset, + enum DistanceType metric, + float metric_arg, + cuvsBruteForceIndex_t index); +/** + * @} + */ + +/** + * @defgroup bruteforce_c_index_search Bruteforce index search + * @{ + */ +/** + * @brief Search a BRUTEFORCE index with a `DLManagedTensor` which has underlying + * `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`. + * It is also important to note that the BRUTEFORCE index must have been built + * with the same type of `queries`, such that `index.dtype.code == + * queries.dl_tensor.dtype.code` Types for input are: + * 1. `queries`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32` + * 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * + * @code {.c} + * #include + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // Assume a populated `DLManagedTensor` type here + * DLManagedTensor dataset; + * DLManagedTensor queries; + * DLManagedTensor neighbors; + * + * // Search the `index` built using `bruteForceBuild` + * cuvsError_t search_status = bruteForceSearch(res, index, &queries, &neighbors, &distances); + * + * // de-allocate `res` + * cuvsError_t res_destroy_status = cuvsResourcesDestroy(res); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] index bruteForceIndex which has been returned by `bruteForceBuild` + * @param[in] queries DLManagedTensor* queries dataset to search + * @param[out] neighbors DLManagedTensor* output `k` neighbors for queries + * @param[out] distances DLManagedTensor* output `k` distances for queries + */ +cuvsError_t bruteForceSearch(cuvsResources_t res, + cuvsBruteForceIndex_t index, + DLManagedTensor* queries, + DLManagedTensor* neighbors, + DLManagedTensor* distances); +/** + * @} + */ + +#ifdef __cplusplus +} +#endif diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp new file mode 100644 index 000000000..26951e1ec --- /dev/null +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include +#include +#include + +namespace cuvs::neighbors::brute_force { + +/** + * @defgroup bruteforce_cpp_index Bruteforce index + * @{ + */ +/** + * @brief Brute Force index. + * + * The index stores the dataset and norms for the dataset in device memory. + * + * @tparam T data element type + */ +template +struct index : cuvs::neighbors::ann::index { + public: + index(const index&) = delete; + index(index&&) = default; + index& operator=(const index&) = delete; + index& operator=(index&&) = default; + ~index() = default; + index(void* raft_index); + + /** Distance metric used for retrieval */ + cuvs::distance::DistanceType metric() const noexcept; + + /** Metric argument */ + T metric_arg() const noexcept; + + /** Total length of the index (number of vectors). */ + size_t size() const noexcept; + + /** Dimensionality of the data. */ + size_t dim() const noexcept; + + /** Dataset [size, dim] */ + raft::device_matrix_view dataset() const noexcept; + + /** Dataset norms */ + raft::device_vector_view norms() const; + + /** Whether ot not this index has dataset norms */ + bool has_norms() const noexcept; + + // Get pointer to underlying RAFT index, not meant to be used outside of cuVS + inline const void* get_raft_index() const noexcept { return raft_index_.get(); } + + private: + std::unique_ptr raft_index_; +}; +/** + * @} + */ + +/** + * @defgroup bruteforce_cpp_index_build Bruteforce index build + * @{ + */ +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // create and fill the index from a [N, D] dataset + * auto index = brute_force::build(handle, dataset, metric); + * @endcode + * + * @param[in] handle + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * @param[in] metric cuvs::distance::DistanceType + * @param[in] metric_arg metric argument + * + * @return the constructed ivf-flat index + */ +auto build(raft::resources const& handle, + raft::device_matrix_view dataset, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded, + float metric_arg = 0) -> cuvs::neighbors::brute_force::index; +/** + * @} + */ + +/** + * @defgroup bruteforce_cpp_index_search Bruteforce index search + * @{ + */ +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * brute_force::search(handle, index, queries1, out_inds1, out_dists1); + * brute_force::search(handle, index, queries2, out_inds2, out_dists2); + * brute_force::search(handle, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @param[in] handle + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + */ +void search(raft::resources const& handle, + const cuvs::neighbors::brute_force::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); +/** + * @} + */ + +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/include/cuvs/neighbors/cagra.h b/cpp/include/cuvs/neighbors/cagra.h index ae8f8ea01..912430a47 100644 --- a/cpp/include/cuvs/neighbors/cagra.h +++ b/cpp/include/cuvs/neighbors/cagra.h @@ -273,7 +273,7 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res, * It is also important to note that the CAGRA Index must have been built * with the same type of `queries`, such that `index.dtype.code == * queries.dl_tensor.dtype.code` Types for input are: - * 1. `queries`: kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * 1. `queries`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` * 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32` * 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` * diff --git a/cpp/include/cuvs/neighbors/ivf_flat.h b/cpp/include/cuvs/neighbors/ivf_flat.h new file mode 100644 index 000000000..08200ae7d --- /dev/null +++ b/cpp/include/cuvs/neighbors/ivf_flat.h @@ -0,0 +1,283 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @defgroup ivf_flat_c_index_params IVF-Flat index build parameters + * @{ + */ +/** + * @brief Supplemental parameters to build IVF-Flat Index + * + */ +struct ivfFlatIndexParams { + /** Distance type. */ + enum DistanceType metric; + /** The argument used by some distance metrics. */ + float metric_arg; + /** + * Whether to add the dataset content to the index, i.e.: + * + * - `true` means the index is filled with the dataset vectors and ready to search after calling + * `build`. + * - `false` means `build` only trains the underlying model (e.g. quantizer or clustering), but + * the index is left empty; you'd need to call `extend` on the index afterwards to populate it. + */ + bool add_data_on_build; + /** The number of inverted lists (clusters) */ + uint32_t n_lists; + /** The number of iterations searching for kmeans centers (index building). */ + uint32_t kmeans_n_iters; + /** The fraction of data to use during iterative kmeans building. */ + double kmeans_trainset_fraction; + /** + * By default (adaptive_centers = false), the cluster centers are trained in `ivf_flat::build`, + * and never modified in `ivf_flat::extend`. As a result, you may need to retrain the index + * from scratch after invoking (`ivf_flat::extend`) a few times with new data, the distribution of + * which is no longer representative of the original training set. + * + * The alternative behavior (adaptive_centers = true) is to update the cluster centers for new + * data when it is added. In this case, `index.centers()` are always exactly the centroids of the + * data in the corresponding clusters. The drawback of this behavior is that the centroids depend + * on the order of adding new data (through the classification of the added data); that is, + * `index.centers()` "drift" together with the changing distribution of the newly added data. + */ + bool adaptive_centers; + /** + * By default, the algorithm allocates more space than necessary for individual clusters + * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + * data copies during repeated calls to `extend` (extending the database). + * + * The alternative is the conservative allocation behavior; when enabled, the algorithm always + * allocates the minimum amount of memory required to store the given number of records. Set this + * flag to `true` if you prefer to use as little GPU memory for the database as possible. + */ + bool conservative_memory_allocation; +}; + +typedef struct ivfFlatIndexParams* cuvsIvfFlatIndexParams_t; + +/** + * @brief Allocate IVF-Flat Index params, and populate with default values + * + * @param[in] index_params cuvsIvfFlatIndexParams_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsIvfFlatIndexParamsCreate(cuvsIvfFlatIndexParams_t* index_params); + +/** + * @brief De-allocate IVF-Flat Index params + * + * @param[in] index_params + * @return cuvsError_t + */ +cuvsError_t cuvsIvfFlatIndexParamsDestroy(cuvsIvfFlatIndexParams_t index_params); +/** + * @} + */ + +/** + * @defgroup ivf_flat_c_search_params IVF-Flat index search parameters + * @{ + */ +/** + * @brief Supplemental parameters to search IVF-Flat index + * + */ +struct ivfFlatSearchParams { + /** The number of clusters to search. */ + uint32_t n_probes; +}; + +typedef struct ivfFlatSearchParams* cuvsIvfFlatSearchParams_t; + +/** + * @brief Allocate IVF-Flat search params, and populate with default values + * + * @param[in] params cuvsIvfFlatSearchParams_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsIvfFlatSearchParamsCreate(cuvsIvfFlatSearchParams_t* params); + +/** + * @brief De-allocate IVF-Flat search params + * + * @param[in] params + * @return cuvsError_t + */ +cuvsError_t cuvsIvfFlatSearchParamsDestroy(cuvsIvfFlatSearchParams_t params); +/** + * @} + */ + +/** + * @defgroup ivf_flat_c_index IVF-Flat index + * @{ + */ +/** + * @brief Struct to hold address of cuvs::neighbors::ivf_flat::index and its active trained dtype + * + */ +typedef struct { + uintptr_t addr; + DLDataType dtype; +} ivfFlatIndex; + +typedef ivfFlatIndex* cuvsIvfFlatIndex_t; + +/** + * @brief Allocate IVF-Flat index + * + * @param[in] index cuvsIvfFlatIndex_t to allocate + * @return ivfFlatError_t + */ +cuvsError_t ivfFlatIndexCreate(cuvsIvfFlatIndex_t* index); + +/** + * @brief De-allocate IVF-Flat index + * + * @param[in] index cuvsIvfFlatIndex_t to de-allocate + */ +cuvsError_t ivfFlatIndexDestroy(cuvsIvfFlatIndex_t index); +/** + * @} + */ + +/** + * @defgroup ivf_flat_c_index_build IVF-Flat index build + * @{ + */ +/** + * @brief Build a IVF-Flat index with a `DLManagedTensor` which has underlying + * `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`, + * or `kDLCPU`. Also, acceptable underlying types are: + * 1. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * 2. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8` + * 3. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8` + * + * @code {.c} + * #include + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // Assume a populated `DLManagedTensor` type here + * DLManagedTensor dataset; + * + * // Create default index params + * cuvsIvfFlatIndexParams_t index_params; + * cuvsError_t params_create_status = cuvsIvfFlatIndexParamsCreate(&index_params); + * + * // Create IVF-Flat index + * cuvsIvfFlatIndex_t index; + * cuvsError_t index_create_status = ivfFlatIndexCreate(&index); + * + * // Build the IVF-Flat Index + * cuvsError_t build_status = ivfFlatBuild(res, index_params, &dataset, index); + * + * // de-allocate `index_params`, `index` and `res` + * cuvsError_t params_destroy_status = cuvsIvfFlatIndexParamsDestroy(index_params); + * cuvsError_t index_destroy_status = ivfFlatIndexDestroy(index); + * cuvsError_t res_destroy_status = cuvsResourcesDestroy(res); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] index_params cuvsIvfFlatIndexParams_t used to build IVF-Flat index + * @param[in] dataset DLManagedTensor* training dataset + * @param[out] index cuvsIvfFlatIndex_t Newly built IVF-Flat index + * @return cuvsError_t + */ +cuvsError_t ivfFlatBuild(cuvsResources_t res, + cuvsIvfFlatIndexParams_t index_params, + DLManagedTensor* dataset, + cuvsIvfFlatIndex_t index); +/** + * @} + */ + +/** + * @defgroup ivf_flat_c_index_search IVF-Flat index search + * @{ + */ +/** + * @brief Search a IVF-Flat index with a `DLManagedTensor` which has underlying + * `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`. + * It is also important to note that the IVF-Flat Index must have been built + * with the same type of `queries`, such that `index.dtype.code == + * queries.dl_tensor.dtype.code` Types for input are: + * 1. `queries`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32` + * 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * + * @code {.c} + * #include + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // Assume a populated `DLManagedTensor` type here + * DLManagedTensor dataset; + * DLManagedTensor queries; + * DLManagedTensor neighbors; + * + * // Create default search params + * cuvsIvfFlatSearchParams_t search_params; + * cuvsError_t params_create_status = cuvsIvfFlatSearchParamsCreate(&search_params); + * + * // Search the `index` built using `ivfFlatBuild` + * cuvsError_t search_status = ivfFlatSearch(res, search_params, index, &queries, &neighbors, + * &distances); + * + * // de-allocate `search_params` and `res` + * cuvsError_t params_destroy_status = cuvsIvfFlatSearchParamsDestroy(search_params); + * cuvsError_t res_destroy_status = cuvsResourcesDestroy(res); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] search_params cuvsIvfFlatSearchParams_t used to search IVF-Flat index + * @param[in] index ivfFlatIndex which has been returned by `ivfFlatBuild` + * @param[in] queries DLManagedTensor* queries dataset to search + * @param[out] neighbors DLManagedTensor* output `k` neighbors for queries + * @param[out] distances DLManagedTensor* output `k` distances for queries + */ +cuvsError_t ivfFlatSearch(cuvsResources_t res, + cuvsIvfFlatSearchParams_t search_params, + cuvsIvfFlatIndex_t index, + DLManagedTensor* queries, + DLManagedTensor* neighbors, + DLManagedTensor* distances); +/** + * @} + */ + +#ifdef __cplusplus +} +#endif diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp new file mode 100644 index 000000000..efb32e024 --- /dev/null +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -0,0 +1,1026 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include + +namespace cuvs::neighbors::ivf_flat { +/** + * @defgroup ivf_flat_cpp_index_params IVF-Flat index build parameters + * @{ + */ +struct index_params : ann::index_params { + /** The number of inverted lists (clusters) */ + uint32_t n_lists = 1024; + /** The number of iterations searching for kmeans centers (index building). */ + uint32_t kmeans_n_iters = 20; + /** The fraction of data to use during iterative kmeans building. */ + double kmeans_trainset_fraction = 0.5; + /** + * By default (adaptive_centers = false), the cluster centers are trained in `ivf_flat::build`, + * and never modified in `ivf_flat::extend`. As a result, you may need to retrain the index + * from scratch after invoking (`ivf_flat::extend`) a few times with new data, the distribution of + * which is no longer representative of the original training set. + * + * The alternative behavior (adaptive_centers = true) is to update the cluster centers for new + * data when it is added. In this case, `index.centers()` are always exactly the centroids of the + * data in the corresponding clusters. The drawback of this behavior is that the centroids depend + * on the order of adding new data (through the classification of the added data); that is, + * `index.centers()` "drift" together with the changing distribution of the newly added data. + */ + bool adaptive_centers = false; + /** + * By default, the algorithm allocates more space than necessary for individual clusters + * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + * data copies during repeated calls to `extend` (extending the database). + * + * The alternative is the conservative allocation behavior; when enabled, the algorithm always + * allocates the minimum amount of memory required to store the given number of records. Set this + * flag to `true` if you prefer to use as little GPU memory for the database as possible. + */ + bool conservative_memory_allocation = false; + + /** Build a raft IVF_FLAT index params from an existing cuvs IVF_FLAT index params. */ + operator raft::neighbors::ivf_flat::index_params() const + { + return raft::neighbors::ivf_flat::index_params{ + { + .metric = static_cast((int)this->metric), + .metric_arg = this->metric_arg, + .add_data_on_build = this->add_data_on_build, + }, + .n_lists = n_lists, + .kmeans_n_iters = kmeans_n_iters, + .kmeans_trainset_fraction = kmeans_trainset_fraction, + .adaptive_centers = adaptive_centers, + .conservative_memory_allocation = conservative_memory_allocation}; + } +}; +/** + * @} + */ + +/** + * @defgroup ivf_flat_cpp_search_params IVF-Flat index search parameters + * @{ + */ +struct search_params : ann::search_params { + /** The number of clusters to search. */ + uint32_t n_probes = 20; + + /** Build a raft IVF_FLAT search params from an existing cuvs IVF_FLAT search params. */ + operator raft::neighbors::ivf_flat::search_params() const + { + raft::neighbors::ivf_flat::search_params result = {{}, n_probes}; + return result; + } +}; +/** + * @} + */ + +/** + * @defgroup ivf_flat_cpp_index IVF-Flat index + * @{ + */ +/** + * @brief IVF-flat index. + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + */ +template +struct index : ann::index { + static_assert(!raft::is_narrowing_v, + "IdxT must be able to represent all values of uint32_t"); + + public: + index(const index&) = delete; + index(index&&) = default; + index& operator=(const index&) = delete; + index& operator=(index&&) = default; + ~index() = default; + /** Construct an empty index. It needs to be trained and then populated. */ + index(raft::resources const& res, const index_params& params, uint32_t dim); + /** Construct an empty index. It needs to be trained and then populated. */ + index(raft::resources const& res, + cuvs::distance::DistanceType metric, + uint32_t n_lists, + bool adaptive_centers, + bool conservative_memory_allocation, + uint32_t dim); + index(raft::neighbors::ivf_flat::index&& raft_idx); + + /** + * Vectorized load/store size in elements, determines the size of interleaved data chunks. + */ + uint32_t veclen() const noexcept; + + /** Distance metric used for clustering. */ + cuvs::distance::DistanceType metric() const noexcept; + + /** Whether `centers()` change upon extending the index (ivf_flat::extend). */ + bool adaptive_centers() const noexcept; + + /** + * Inverted list data [size, dim]. + * + * The data consists of the dataset rows, grouped by their labels (into clusters/lists). + * Within each list (cluster), the data is grouped into blocks of `kIndexGroupSize` interleaved + * vectors. Note, the total index length is slightly larger than the source dataset length, + * because each cluster is padded by `kIndexGroupSize` elements. + * + * Interleaving pattern: + * within groups of `kIndexGroupSize` rows, the data is interleaved with the block size equal to + * `veclen * sizeof(T)`. That is, a chunk of `veclen` consecutive components of one row is + * followed by a chunk of the same size of the next row, and so on. + * + * __Example__: veclen = 2, dim = 6, kIndexGroupSize = 32, list_size = 31 + * + * x[ 0, 0], x[ 0, 1], x[ 1, 0], x[ 1, 1], ... x[14, 0], x[14, 1], x[15, 0], x[15, 1], + * x[16, 0], x[16, 1], x[17, 0], x[17, 1], ... x[30, 0], x[30, 1], - , - , + * x[ 0, 2], x[ 0, 3], x[ 1, 2], x[ 1, 3], ... x[14, 2], x[14, 3], x[15, 2], x[15, 3], + * x[16, 2], x[16, 3], x[17, 2], x[17, 3], ... x[30, 2], x[30, 3], - , - , + * x[ 0, 4], x[ 0, 5], x[ 1, 4], x[ 1, 5], ... x[14, 4], x[14, 5], x[15, 4], x[15, 5], + * x[16, 4], x[16, 5], x[17, 4], x[17, 5], ... x[30, 4], x[30, 5], - , - , + * + */ + /** Sizes of the lists (clusters) [n_lists] + * NB: This may differ from the actual list size if the shared lists have been extended by another + * index + */ + raft::device_vector_view list_sizes() noexcept; + raft::device_vector_view list_sizes() const noexcept; + + /** k-means cluster centers corresponding to the lists [n_lists, dim] */ + raft::device_matrix_view centers() noexcept; + raft::device_matrix_view centers() const noexcept; + + /** + * (Optional) Precomputed norms of the `centers` w.r.t. the chosen distance metric [n_lists]. + * + * NB: this may be empty if the index is empty or if the metric does not require the center norms + * calculation. + */ + std::optional> center_norms() noexcept; + std::optional> center_norms() const noexcept; + + /** Total length of the index. */ + IdxT size() const noexcept; + + /** Dimensionality of the data. */ + uint32_t dim() const noexcept; + + /** Number of clusters/inverted lists. */ + uint32_t n_lists() const noexcept; + raft::device_vector_view data_ptrs() noexcept; + raft::device_vector_view data_ptrs() const noexcept; + + /** Pointers to the inverted lists (clusters) indices [n_lists]. */ + raft::device_vector_view inds_ptrs() noexcept; + raft::device_vector_view inds_ptrs() const noexcept; + + /** + * Whether to use convervative memory allocation when extending the list (cluster) data + * (see index_params.conservative_memory_allocation). + */ + bool conservative_memory_allocation() const noexcept; + + /** Lists' data and indices. */ + std::vector>>& lists() noexcept; + const std::vector>>& lists() + const noexcept; + + // Get pointer to underlying RAFT index, not meant to be used outside of cuVS + inline raft::neighbors::ivf_flat::index* get_raft_index() noexcept + { + return raft_index_.get(); + } + inline const raft::neighbors::ivf_flat::index* get_raft_index() const noexcept + { + return raft_index_.get(); + } + + private: + std::unique_ptr> raft_index_; +}; +/** + * @} + */ + +/** + * @defgroup ivf_flat_cpp_index_build IVF-Flat index build + * @{ + */ +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, dataset, index_params); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-flat index + */ +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_flat::index_params& index_params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::ivf_flat::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_flat::index index; + * ivf_flat::build(handle, dataset, index_params, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_flat::index + * + */ +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_flat::index_params& index_params, + raft::device_matrix_view dataset, + cuvs::neighbors::ivf_flat::index& idx); + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, dataset, index_params); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-flat index + */ +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_flat::index_params& index_params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::ivf_flat::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_flat::index index; + * ivf_flat::build(handle, dataset, index_params, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_flat::index + * + */ +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_flat::index_params& index_params, + raft::device_matrix_view dataset, + cuvs::neighbors::ivf_flat::index& idx); + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, dataset, index_params); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-flat index + */ +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_flat::index_params& index_params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::ivf_flat::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_flat::index index; + * ivf_flat::build(handle, dataset, index_params, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_flat::index + * + */ +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_flat::index_params& index_params, + raft::device_matrix_view dataset, + cuvs::neighbors::ivf_flat::index& idx); +/** + * @} + */ + +/** + * @defgroup ivf_flat_cpp_index_extend IVF-Flat index extend + * @{ + */ + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_flat::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] idx original index + * + * @return the constructed extended ivf-flat index + */ +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_flat::index& idx) + -> cuvs::neighbors::ivf_flat::index; + +/** + * @brief Extend the index in-place with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_flat::extend(handle, dataset, no_opt, &index_empty); + * @endcode + * + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx pointer to index, to be overwritten in-place + */ +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_flat::index* idx); + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_flat::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] idx original index + * + * @return the constructed extended ivf-flat index + */ +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_flat::index& idx) + -> cuvs::neighbors::ivf_flat::index; + +/** + * @brief Extend the index in-place with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_flat::extend(handle, dataset, no_opt, &index_empty); + * @endcode + * + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx pointer to index, to be overwritten in-place + */ +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_flat::index* idx); + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_flat::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] idx original index + * + * @return the constructed extended ivf-flat index + */ +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_flat::index& idx) + -> cuvs::neighbors::ivf_flat::index; + +/** + * @brief Extend the index in-place with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_flat::extend(handle, dataset, no_opt, &index_empty); + * @endcode + * + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx pointer to index, to be overwritten in-place + */ +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_flat::index* idx); +/** + * @} + */ + +/** + * @defgroup ivf_flat_cpp_index_search IVF-Flat index search + * @{ + */ + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries raft::device_matrix_view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors raft::device_matrix_view to the indices of the neighbors in the source + * dataset [n_queries, k] + * @param[out] distances raft::device_matrix_view to the distances to the selected neighbors + * [n_queries, k] + */ +void search(raft::resources const& handle, + const cuvs::neighbors::ivf_flat::search_params& params, + cuvs::neighbors::ivf_flat::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries raft::device_matrix_view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors raft::device_matrix_view to the indices of the neighbors in the source + * dataset [n_queries, k] + * @param[out] distances raft::device_matrix_view to the distances to the selected neighbors + * [n_queries, k] + */ +void search(raft::resources const& handle, + const cuvs::neighbors::ivf_flat::search_params& params, + cuvs::neighbors::ivf_flat::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries raft::device_matrix_view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors raft::device_matrix_view to the indices of the neighbors in the source + * dataset [n_queries, k] + * @param[out] distances raft::device_matrix_view to the distances to the selected neighbors + * [n_queries, k] + */ +void search(raft::resources const& handle, + const cuvs::neighbors::ivf_flat::search_params& params, + cuvs::neighbors::ivf_flat::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); +/** + * @} + */ + +/** + * @defgroup ivf_flat_cpp_serialize IVF-Flat index serialize + * @{ + */ + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_flat::build(...);` + * cuvs::serialize_file(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-Flat index + * + */ +void serialize_file(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::ivf_flat::index& index); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * using IdxT = int64_t; // type of the index + * // create an empty index with `ivf_pq::index index(handle, index_params, dim);` + * cuvs::deserialize_file(handle, filename, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index IVF-Flat index + * + */ +void deserialize_file(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_flat::index* index); + +/** + * Write the index to an output string + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output string + * std::string str; + * // create an index with `auto index = ivf_flat::build(...);` + * cuvs::serialize(handle, str, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[out] str output string + * @param[in] index IVF-Flat index + * + */ +void serialize(raft::resources const& handle, + std::string& str, + const cuvs::neighbors::ivf_flat::index& index); + +/** + * Load index from input string + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input string + * std::string str; + * using T = float; // data element type + * using IdxT = int64_t; // type of the index + * // create an empty index with `ivf_pq::index index(handle, index_params, dim);` + * auto index = cuvs::deserialize(handle, str, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] str output string + * @param[in] index IVF-Flat index + * + */ +void deserialize(raft::resources const& handle, + const std::string& str, + cuvs::neighbors::ivf_flat::index* index); + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_flat::build(...);` + * cuvs::serialize_file(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-Flat index + * + */ +void serialize_file(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::ivf_flat::index& index); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * using IdxT = int64_t; // type of the index + * // create an empty index with `ivf_pq::index index(handle, index_params, dim);` + * cuvs::deserialize_file(handle, filename, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index IVF-Flat index + * + */ +void deserialize_file(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_flat::index* index); + +/** + * Write the index to an output string + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output string + * std::string str; + * // create an index with `auto index = ivf_flat::build(...);` + * cuvs::serialize(handle, str, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[out] str output string + * @param[in] index IVF-Flat index + * + */ +void serialize(raft::resources const& handle, + std::string& str, + const cuvs::neighbors::ivf_flat::index& index); + +/** + * Load index from input string + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input string + * std::string str; + * using T = float; // data element type + * using IdxT = int64_t; // type of the index + * // create an empty index with `ivf_pq::index index(handle, index_params, dim);` + * auto index = cuvs::deserialize(handle, str, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] str output string + * @param[in] index IVF-Flat index + * + */ +void deserialize(raft::resources const& handle, + const std::string& str, + cuvs::neighbors::ivf_flat::index* index); + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_flat::build(...);` + * cuvs::serialize_file(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-Flat index + * + */ +void serialize_file(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::ivf_flat::index& index); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using T = float; // data element type + * using IdxT = int64_t; // type of the index + * // create an empty index with `ivf_pq::index index(handle, index_params, dim);` + * cuvs::deserialize_file(handle, filename, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index IVF-Flat index + * + */ +void deserialize_file(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_flat::index* index); + +/** + * Write the index to an output string + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output string + * std::string str; + * // create an index with `auto index = ivf_flat::build(...);` + * cuvs::serialize(handle, str, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[out] str output string + * @param[in] index IVF-Flat index + * + */ +void serialize(raft::resources const& handle, + std::string& str, + const cuvs::neighbors::ivf_flat::index& index); + +/** + * Load index from input string + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input string + * std::string str; + * using T = float; // data element type + * using IdxT = int64_t; // type of the index + * // create an empty index with `ivf_pq::index index(handle, index_params, dim);` + * auto index = cuvs::deserialize(handle, str, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] str output string + * @param[in] index IVF-Flat index + * + */ +void deserialize(raft::resources const& handle, + const std::string& str, + cuvs::neighbors::ivf_flat::index* index); + +/** + * @} + */ +} // namespace cuvs::neighbors::ivf_flat \ No newline at end of file diff --git a/cpp/include/cuvs/neighbors/ivf_pq.h b/cpp/include/cuvs/neighbors/ivf_pq.h new file mode 100644 index 000000000..c1fcaed86 --- /dev/null +++ b/cpp/include/cuvs/neighbors/ivf_pq.h @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @defgroup ivf_pq_c_index_params IVF-PQ index build parameters + * @{ + */ +/** + * @brief A type for specifying how PQ codebooks are created + * + */ +enum codebook_gen { // NOLINT + PER_SUBSPACE = 0, // NOLINT + PER_CLUSTER = 1, // NOLINT +}; + +/** + * @brief Supplemental parameters to build IVF-PQ Index + * + */ +struct ivfPqIndexParams { + /** Distance type. */ + enum DistanceType metric; + /** The argument used by some distance metrics. */ + float metric_arg; + /** + * Whether to add the dataset content to the index, i.e.: + * + * - `true` means the index is filled with the dataset vectors and ready to search after calling + * `build`. + * - `false` means `build` only trains the underlying model (e.g. quantizer or clustering), but + * the index is left empty; you'd need to call `extend` on the index afterwards to populate it. + */ + bool add_data_on_build; + /** + * The number of inverted lists (clusters) + * + * Hint: the number of vectors per cluster (`n_rows/n_lists`) should be approximately 1,000 to + * 10,000. + */ + uint32_t n_lists; + /** The number of iterations searching for kmeans centers (index building). */ + uint32_t kmeans_n_iters; + /** The fraction of data to use during iterative kmeans building. */ + double kmeans_trainset_fraction; + /** + * The bit length of the vector element after compression by PQ. + * + * Possible values: [4, 5, 6, 7, 8]. + * + * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + * performance, but the lower the recall. + */ + uint32_t pq_bits; + /** + * The dimensionality of the vector after compression by PQ. When zero, an optimal value is + * selected using a heuristic. + * + * NB: `pq_dim * pq_bits` must be a multiple of 8. + * + * Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but + * lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are + * desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. + * For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' + * should be also a divisor of the dataset dim. + */ + uint32_t pq_dim; + /** How PQ codebooks are created. */ + enum codebook_gen codebook_kind; + /** + * Apply a random rotation matrix on the input data and queries even if `dim % pq_dim == 0`. + * + * Note: if `dim` is not multiple of `pq_dim`, a random rotation is always applied to the input + * data and queries to transform the working space from `dim` to `rot_dim`, which may be slightly + * larger than the original space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). + * However, this transform is not necessary when `dim` is multiple of `pq_dim` + * (`dim == rot_dim`, hence no need in adding "extra" data columns / features). + * + * By default, if `dim == rot_dim`, the rotation transform is initialized with the identity + * matrix. When `force_random_rotation == true`, a random orthogonal transform matrix is generated + * regardless of the values of `dim` and `pq_dim`. + */ + bool force_random_rotation; + /** + * By default, the algorithm allocates more space than necessary for individual clusters + * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + * data copies during repeated calls to `extend` (extending the database). + * + * The alternative is the conservative allocation behavior; when enabled, the algorithm always + * allocates the minimum amount of memory required to store the given number of records. Set this + * flag to `true` if you prefer to use as little GPU memory for the database as possible. + */ + bool conservative_memory_allocation; +}; + +typedef struct ivfPqIndexParams* cuvsIvfPqIndexParams_t; + +/** + * @brief Allocate IVF-PQ Index params, and populate with default values + * + * @param[in] index_params cuvsIvfPqIndexParams_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsIvfPqIndexParamsCreate(cuvsIvfPqIndexParams_t* index_params); + +/** + * @brief De-allocate IVF-PQ Index params + * + * @param[in] index_params + * @return cuvsError_t + */ +cuvsError_t cuvsIvfPqIndexParamsDestroy(cuvsIvfPqIndexParams_t index_params); +/** + * @} + */ + +/** + * @defgroup ivf_pq_c_search_params IVF-PQ index search parameters + * @{ + */ +/** + * @brief Supplemental parameters to search IVF-PQ index + * + */ +struct ivfPqSearchParams { + /** The number of clusters to search. */ + uint32_t n_probes; + /** + * Data type of look up table to be created dynamically at search time. + * + * Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U] + * + * The use of low-precision types reduces the amount of shared memory required at search time, so + * fast shared memory kernels can be used even for datasets with large dimansionality. Note that + * the recall is slightly degraded when low-precision type is selected. + */ + cudaDataType_t lut_dtype; + /** + * Storage data type for distance/similarity computed at search time. + * + * Possible values: [CUDA_R_16F, CUDA_R_32F] + * + * If the performance limiter at search time is device memory access, selecting FP16 will improve + * performance slightly. + */ + cudaDataType_t internal_distance_dtype; + /** + * Preferred fraction of SM's unified memory / L1 cache to be used as shared memory. + * + * Possible values: [0.0 - 1.0] as a fraction of the `sharedMemPerMultiprocessor`. + * + * One wants to increase the carveout to make sure a good GPU occupancy for the main search + * kernel, but not to keep it too high to leave some memory to be used as L1 cache. Note, this + * value is interpreted only as a hint. Moreover, a GPU usually allows only a fixed set of cache + * configurations, so the provided value is rounded up to the nearest configuration. Refer to the + * NVIDIA tuning guide for the target GPU architecture. + * + * Note, this is a low-level tuning parameter that can have drastic negative effects on the search + * performance if tweaked incorrectly. + */ + double preferred_shmem_carveout; +}; + +typedef struct ivfPqSearchParams* cuvsIvfPqSearchParams_t; + +/** + * @brief Allocate IVF-PQ search params, and populate with default values + * + * @param[in] params cuvsIvfPqSearchParams_t to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsIvfPqSearchParamsCreate(cuvsIvfPqSearchParams_t* params); + +/** + * @brief De-allocate IVF-PQ search params + * + * @param[in] params + * @return cuvsError_t + */ +cuvsError_t cuvsIvfPqSearchParamsDestroy(cuvsIvfPqSearchParams_t params); +/** + * @} + */ + +/** + * @defgroup ivf_pq_c_index IVF-PQ index + * @{ + */ +/** + * @brief Struct to hold address of cuvs::neighbors::ivf_pq::index and its active trained dtype + * + */ +typedef struct { + uintptr_t addr; + DLDataType dtype; +} ivfPqIndex; + +typedef ivfPqIndex* cuvsIvfPqIndex_t; + +/** + * @brief Allocate IVF-PQ index + * + * @param[in] index cuvsIvfPqIndex_t to allocate + * @return ivfPqError_t + */ +cuvsError_t ivfPqIndexCreate(cuvsIvfPqIndex_t* index); + +/** + * @brief De-allocate IVF-PQ index + * + * @param[in] index cuvsIvfPqIndex_t to de-allocate + */ +cuvsError_t ivfPqIndexDestroy(cuvsIvfPqIndex_t index); +/** + * @} + */ + +/** + * @defgroup ivf_pq_c_index_build IVF-PQ index build + * @{ + */ +/** + * @brief Build a IVF-PQ index with a `DLManagedTensor` which has underlying + * `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`, + * or `kDLCPU`. Also, acceptable underlying types are: + * 1. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * 2. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8` + * 3. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8` + * + * @code {.c} + * #include + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // Assume a populated `DLManagedTensor` type here + * DLManagedTensor dataset; + * + * // Create default index params + * cuvsIvfPqIndexParams_t index_params; + * cuvsError_t params_create_status = cuvsIvfPqIndexParamsCreate(&index_params); + * + * // Create IVF-PQ index + * cuvsIvfPqIndex_t index; + * cuvsError_t index_create_status = ivfPqIndexCreate(&index); + * + * // Build the IVF-PQ Index + * cuvsError_t build_status = ivfPqBuild(res, index_params, &dataset, index); + * + * // de-allocate `index_params`, `index` and `res` + * cuvsError_t params_destroy_status = cuvsIvfPqIndexParamsDestroy(index_params); + * cuvsError_t index_destroy_status = ivfPqIndexDestroy(index); + * cuvsError_t res_destroy_status = cuvsResourcesDestroy(res); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] params cuvsIvfPqIndexParams_t used to build IVF-PQ index + * @param[in] dataset DLManagedTensor* training dataset + * @param[out] index cuvsIvfPqIndex_t Newly built IVF-PQ index + * @return cuvsError_t + */ +cuvsError_t ivfPqBuild(cuvsResources_t res, + cuvsIvfPqIndexParams_t params, + DLManagedTensor* dataset, + cuvsIvfPqIndex_t index); +/** + * @} + */ + +/** + * @defgroup ivf_pq_c_index_search IVF-PQ index search + * @{ + */ +/** + * @brief Search a IVF-PQ index with a `DLManagedTensor` which has underlying + * `DLDeviceType` equal to `kDLCUDA`, `kDLCUDAHost`, `kDLCUDAManaged`. + * It is also important to note that the IVF-PQ Index must have been built + * with the same type of `queries`, such that `index.dtype.code == + * queries.dl_tensor.dtype.code` Types for input are: + * 1. `queries`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32` + * 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32` + * + * @code {.c} + * #include + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // Assume a populated `DLManagedTensor` type here + * DLManagedTensor dataset; + * DLManagedTensor queries; + * DLManagedTensor neighbors; + * + * // Create default search params + * cuvsIvfPqSearchParams_t search_params; + * cuvsError_t params_create_status = cuvsIvfPqSearchParamsCreate(&search_params); + * + * // Search the `index` built using `ivfPqBuild` + * cuvsError_t search_status = ivfPqSearch(res, search_params, index, &queries, &neighbors, + * &distances); + * + * // de-allocate `search_params` and `res` + * cuvsError_t params_destroy_status = cuvsIvfPqSearchParamsDestroy(search_params); + * cuvsError_t res_destroy_status = cuvsResourcesDestroy(res); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] search_params cuvsIvfPqSearchParams_t used to search IVF-PQ index + * @param[in] index ivfPqIndex which has been returned by `ivfPqBuild` + * @param[in] queries DLManagedTensor* queries dataset to search + * @param[out] neighbors DLManagedTensor* output `k` neighbors for queries + * @param[out] distances DLManagedTensor* output `k` distances for queries + */ +cuvsError_t ivfPqSearch(cuvsResources_t res, + cuvsIvfPqSearchParams_t search_params, + cuvsIvfPqIndex_t index, + DLManagedTensor* queries, + DLManagedTensor* neighbors, + DLManagedTensor* distances); +/** + * @} + */ + +#ifdef __cplusplus +} +#endif diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp new file mode 100644 index 000000000..b2fc9a366 --- /dev/null +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -0,0 +1,868 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include + +namespace cuvs::neighbors::ivf_pq { + +/** + * @defgroup ivf_pq_cpp_index_params IVF-PQ index build parameters + * @{ + */ +/** A type for specifying how PQ codebooks are created. */ +enum class codebook_gen { // NOLINT + PER_SUBSPACE = 0, // NOLINT + PER_CLUSTER = 1, // NOLINT +}; + +struct index_params : ann::index_params { + /** + * The number of inverted lists (clusters) + * + * Hint: the number of vectors per cluster (`n_rows/n_lists`) should be approximately 1,000 to + * 10,000. + */ + uint32_t n_lists = 1024; + /** The number of iterations searching for kmeans centers (index building). */ + uint32_t kmeans_n_iters = 20; + /** The fraction of data to use during iterative kmeans building. */ + double kmeans_trainset_fraction = 0.5; + /** + * The bit length of the vector element after compression by PQ. + * + * Possible values: [4, 5, 6, 7, 8]. + * + * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + * performance, but the lower the recall. + */ + uint32_t pq_bits = 8; + /** + * The dimensionality of the vector after compression by PQ. When zero, an optimal value is + * selected using a heuristic. + * + * NB: `pq_dim * pq_bits` must be a multiple of 8. + * + * Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but + * lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are + * desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. + * For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim' + * should be also a divisor of the dataset dim. + */ + uint32_t pq_dim = 0; + /** How PQ codebooks are created. */ + codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE; + /** + * Apply a random rotation matrix on the input data and queries even if `dim % pq_dim == 0`. + * + * Note: if `dim` is not multiple of `pq_dim`, a random rotation is always applied to the input + * data and queries to transform the working space from `dim` to `rot_dim`, which may be slightly + * larger than the original space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`). + * However, this transform is not necessary when `dim` is multiple of `pq_dim` + * (`dim == rot_dim`, hence no need in adding "extra" data columns / features). + * + * By default, if `dim == rot_dim`, the rotation transform is initialized with the identity + * matrix. When `force_random_rotation == true`, a random orthogonal transform matrix is generated + * regardless of the values of `dim` and `pq_dim`. + */ + bool force_random_rotation = false; + /** + * By default, the algorithm allocates more space than necessary for individual clusters + * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of + * data copies during repeated calls to `extend` (extending the database). + * + * The alternative is the conservative allocation behavior; when enabled, the algorithm always + * allocates the minimum amount of memory required to store the given number of records. Set this + * flag to `true` if you prefer to use as little GPU memory for the database as possible. + */ + bool conservative_memory_allocation = false; + + /** Build a raft IVF_PQ index params from an existing cuvs IVF_PQ index params. */ + operator raft::neighbors::ivf_pq::index_params() const + { + return raft::neighbors::ivf_pq::index_params{ + { + .metric = static_cast((int)this->metric), + .metric_arg = this->metric_arg, + .add_data_on_build = this->add_data_on_build, + }, + .n_lists = n_lists, + .kmeans_n_iters = kmeans_n_iters, + .kmeans_trainset_fraction = kmeans_trainset_fraction, + .pq_bits = pq_bits, + .pq_dim = pq_dim, + .codebook_kind = static_cast((int)this->codebook_kind), + .force_random_rotation = force_random_rotation, + .conservative_memory_allocation = conservative_memory_allocation}; + } +}; +/** + * @} + */ + +/** + * @defgroup ivf_pq_cpp_search_params IVF-PQ index search parameters + * @{ + */ +struct search_params : ann::search_params { + /** The number of clusters to search. */ + uint32_t n_probes = 20; + /** + * Data type of look up table to be created dynamically at search time. + * + * Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U] + * + * The use of low-precision types reduces the amount of shared memory required at search time, so + * fast shared memory kernels can be used even for datasets with large dimansionality. Note that + * the recall is slightly degraded when low-precision type is selected. + */ + cudaDataType_t lut_dtype = CUDA_R_32F; + /** + * Storage data type for distance/similarity computed at search time. + * + * Possible values: [CUDA_R_16F, CUDA_R_32F] + * + * If the performance limiter at search time is device memory access, selecting FP16 will improve + * performance slightly. + */ + cudaDataType_t internal_distance_dtype = CUDA_R_32F; + /** + * Preferred fraction of SM's unified memory / L1 cache to be used as shared memory. + * + * Possible values: [0.0 - 1.0] as a fraction of the `sharedMemPerMultiprocessor`. + * + * One wants to increase the carveout to make sure a good GPU occupancy for the main search + * kernel, but not to keep it too high to leave some memory to be used as L1 cache. Note, this + * value is interpreted only as a hint. Moreover, a GPU usually allows only a fixed set of cache + * configurations, so the provided value is rounded up to the nearest configuration. Refer to the + * NVIDIA tuning guide for the target GPU architecture. + * + * Note, this is a low-level tuning parameter that can have drastic negative effects on the search + * performance if tweaked incorrectly. + */ + double preferred_shmem_carveout = 1.0; + + /** Build a raft IVF_PQ search params from an existing cuvs IVF_PQ search params. */ + operator raft::neighbors::ivf_pq::search_params() const + { + raft::neighbors::ivf_pq::search_params result = { + {}, n_probes, lut_dtype, internal_distance_dtype, preferred_shmem_carveout}; + return result; + } +}; +/** + * @} + */ + +template +using list_data = raft::neighbors::ivf_pq::list_data; + +/** + * @defgroup ivf_pq_cpp_index IVF-PQ index + * @{ + */ +/** + * @brief IVF-PQ index. + * + * In the IVF-PQ index, a database vector y is approximated with two level quantization: + * + * y = Q_1(y) + Q_2(y - Q_1(y)) + * + * The first level quantizer (Q_1), maps the vector y to the nearest cluster center. The number of + * clusters is n_lists. + * + * The second quantizer encodes the residual, and it is defined as a product quantizer [1]. + * + * A product quantizer encodes a `dim` dimensional vector with a `pq_dim` dimensional vector. + * First we split the input vector into `pq_dim` subvectors (denoted by u), where each u vector + * contains `pq_len` distinct components of y + * + * y_1, y_2, ... y_{pq_len}, y_{pq_len+1}, ... y_{2*pq_len}, ... y_{dim-pq_len+1} ... y_{dim} + * \___________________/ \____________________________/ \______________________/ + * u_1 u_2 u_{pq_dim} + * + * Then each subvector encoded with a separate quantizer q_i, end the results are concatenated + * + * Q_2(y) = q_1(u_1),q_2(u_2),...,q_{pq_dim}(u_pq_dim}) + * + * Each quantizer q_i outputs a code with pq_bit bits. The second level quantizers are also defined + * by k-means clustering in the corresponding sub-space: the reproduction values are the centroids, + * and the set of reproduction values is the codebook. + * + * When the data dimensionality `dim` is not multiple of `pq_dim`, the feature space is transformed + * using a random orthogonal matrix to have `rot_dim = pq_dim * pq_len` dimensions + * (`rot_dim >= dim`). + * + * The second-level quantizers are trained either for each subspace or for each cluster: + * (a) codebook_gen::PER_SUBSPACE: + * creates `pq_dim` second-level quantizers - one for each slice of the data along features; + * (b) codebook_gen::PER_CLUSTER: + * creates `n_lists` second-level quantizers - one for each first-level cluster. + * In either case, the centroids are again found using k-means clustering interpreting the data as + * having pq_len dimensions. + * + * [1] Product quantization for nearest neighbor search Herve Jegou, Matthijs Douze, Cordelia Schmid + * + * @tparam IdxT type of the indices in the source dataset + * + */ +template +struct index : ann::index { + static_assert(!raft::is_narrowing_v, + "IdxT must be able to represent all values of uint32_t"); + + using pq_centers_extents = typename raft::neighbors::ivf_pq::index::pq_centers_extents; + + public: + index(const index&) = delete; + index(index&&) = default; + auto operator=(const index&) -> index& = delete; + auto operator=(index&&) -> index& = default; + ~index() = default; + + /** Construct an empty index. It needs to be trained and then populated. */ + index(raft::resources const& handle, const index_params& params, uint32_t dim); + index(raft::neighbors::ivf_pq::index&& raft_idx); + + /** Total length of the index. */ + IdxT size() const noexcept; + + /** Dimensionality of the input data. */ + uint32_t dim() const noexcept; + + /** + * Dimensionality of the cluster centers: + * input data dim extended with vector norms and padded to 8 elems. + */ + uint32_t dim_ext() const noexcept; + + /** + * Dimensionality of the data after transforming it for PQ processing + * (rotated and augmented to be muplitple of `pq_dim`). + */ + uint32_t rot_dim() const noexcept; + + /** The bit length of an encoded vector element after compression by PQ. */ + uint32_t pq_bits() const noexcept; + + /** The dimensionality of an encoded vector after compression by PQ. */ + uint32_t pq_dim() const noexcept; + + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + uint32_t pq_len() const noexcept; + + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + uint32_t pq_book_size() const noexcept; + + /** Distance metric used for clustering. */ + cuvs::distance::DistanceType metric() const noexcept; + + /** How PQ codebooks are created. */ + codebook_gen codebook_kind() const noexcept; + + /** Number of clusters/inverted lists (first level quantization). */ + uint32_t n_lists() const noexcept; + + /** + * Whether to use convervative memory allocation when extending the list (cluster) data + * (see index_params.conservative_memory_allocation). + */ + bool conservative_memory_allocation() const noexcept; + + /** + * PQ cluster centers + * + * - codebook_gen::PER_SUBSPACE: [pq_dim , pq_len, pq_book_size] + * - codebook_gen::PER_CLUSTER: [n_lists, pq_len, pq_book_size] + */ + raft::mdspan pq_centers() noexcept; + raft::mdspan pq_centers() const noexcept; + + /** Lists' data and indices. */ + std::vector>>& lists() noexcept; + const std::vector>>& lists() const noexcept; + + /** Pointers to the inverted lists (clusters) data [n_lists]. */ + raft::device_vector_view data_ptrs() noexcept; + raft::device_vector_view data_ptrs() + const noexcept; + + /** Pointers to the inverted lists (clusters) indices [n_lists]. */ + raft::device_vector_view inds_ptrs() noexcept; + raft::device_vector_view inds_ptrs() const noexcept; + + /** The transform matrix (original space -> rotated padded space) [rot_dim, dim] */ + raft::device_matrix_view rotation_matrix() noexcept; + raft::device_matrix_view rotation_matrix() const noexcept; + + /** + * Accumulated list sizes, sorted in descending order [n_lists + 1]. + * The last value contains the total length of the index. + * The value at index zero is always zero. + * + * That is, the content of this span is as if the `list_sizes` was sorted and then accumulated. + * + * This span is used during search to estimate the maximum size of the workspace. + */ + raft::host_vector_view accum_sorted_sizes() noexcept; + raft::host_vector_view accum_sorted_sizes() const noexcept; + + /** Sizes of the lists [n_lists]. */ + raft::device_vector_view list_sizes() noexcept; + raft::device_vector_view list_sizes() const noexcept; + + /** Cluster centers corresponding to the lists in the original space [n_lists, dim_ext] */ + raft::device_matrix_view centers() noexcept; + raft::device_matrix_view centers() const noexcept; + + /** Cluster centers corresponding to the lists in the rotated space [n_lists, rot_dim] */ + raft::device_matrix_view centers_rot() noexcept; + raft::device_matrix_view centers_rot() const noexcept; + + // Get pointer to underlying RAFT index, not meant to be used outside of cuVS + inline raft::neighbors::ivf_pq::index* get_raft_index() noexcept + { + return raft_index_.get(); + } + inline const raft::neighbors::ivf_pq::index* get_raft_index() const noexcept + { + return raft_index_.get(); + } + + private: + std::unique_ptr> raft_index_; +}; +/** + * @} + */ + +/** + * @defgroup ivf_pq_cpp_index_build IVF-PQ index build + * @{ + */ +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-pq index + */ +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_pq::index index; + * ivf_pq::build(handle, index_params, dataset, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_pq::index + * + */ +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::device_matrix_view dataset, + cuvs::neighbors::ivf_pq::index* idx); + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-pq index + */ +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_pq::index index; + * ivf_pq::build(handle, index_params, dataset, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_pq::index + * + */ +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::device_matrix_view dataset, + cuvs::neighbors::ivf_pq::index* idx); + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-pq index + */ +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_pq::index index; + * ivf_pq::build(handle, index_params, dataset, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_pq::index + * + */ +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::device_matrix_view dataset, + cuvs::neighbors::ivf_pq::index* idx); +/** + * @} + */ + +/** + * @defgroup ivf_pq_cpp_index_extend IVF-PQ index extend + * @{ + */ +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_pq::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_pq::index& idx) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_pq::extend(handle, new_vectors, no_op, &index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_pq::index* idx); + +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_pq::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_pq::index& idx) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_pq::extend(handle, new_vectors, no_op, &index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_pq::index* idx); + +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_pq::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_pq::index& idx) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_pq::extend(handle, new_vectors, no_op, &index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_pq::index* idx); +/** + * @} + */ + +/** + * @defgroup ivf_pq_cpp_index_search IVF-PQ index search + * @{ + */ +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @code{.cpp} + * ... + * // use default search parameters + * ivf_pq::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_pq::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_pq::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_pq::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @param[in] handle + * @param[in] search_params configure the search + * @param[in] index ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +void search(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::search_params& search_params, + cuvs::neighbors::ivf_pq::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @code{.cpp} + * ... + * // use default search parameters + * ivf_pq::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_pq::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_pq::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_pq::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @param[in] handle + * @param[in] search_params configure the search + * @param[in] index ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +void search(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::search_params& search_params, + cuvs::neighbors::ivf_pq::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @code{.cpp} + * ... + * // use default search parameters + * ivf_pq::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_pq::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_pq::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_pq::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @param[in] handle + * @param[in] search_params configure the search + * @param[in] index ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +void search(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::search_params& search_params, + cuvs::neighbors::ivf_pq::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); +/** + * @} + */ + +/** + * @defgroup ivf_pq_cpp_serialize IVF-PQ index serialize + * @{ + */ +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_pq::build(...);` + * cuvs::serialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-PQ index + * + */ +void serialize(raft::resources const& handle, + std::string& filename, + const cuvs::neighbors::ivf_pq::index& index); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using IdxT = int64_t; // type of the index + * // create an empty index with `ivf_pq::index index(handle, index_params, dim);` + * cuvs::deserialize(handle, filename, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index IVF-PQ index + * + */ +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_pq::index* index); +/** + * @} + */ + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/brute_force.cu b/cpp/src/neighbors/brute_force.cu new file mode 100644 index 000000000..33dc2088c --- /dev/null +++ b/cpp/src/neighbors/brute_force.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace cuvs::neighbors::brute_force { + +#define CUVS_INST_BFKNN(T, IdxT) \ + auto build(raft::resources const& res, \ + raft::device_matrix_view dataset, \ + cuvs::distance::DistanceType metric, \ + T metric_arg) \ + ->cuvs::neighbors::brute_force::index \ + { \ + auto index_on_stack = raft::neighbors::brute_force::build( \ + res, dataset, static_cast(metric), metric_arg); \ + auto index_on_heap = \ + new raft::neighbors::brute_force::index(std::move(index_on_stack)); \ + return cuvs::neighbors::brute_force::index(index_on_heap); \ + } \ + \ + void search(raft::resources const& res, \ + const cuvs::neighbors::brute_force::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + auto raft_idx = \ + reinterpret_cast*>(idx.get_raft_index()); \ + raft::neighbors::brute_force::search(res, *raft_idx, queries, neighbors, distances); \ + } \ + \ + template struct cuvs::neighbors::brute_force::index; + +CUVS_INST_BFKNN(float, int64_t); +// CUVS_INST_BFKNN(int8_t, int64_t); +// CUVS_INST_BFKNN(uint8_t, int64_t); + +#undef CUVS_INST_BFKNN + +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/src/neighbors/brute_force_c.cpp b/cpp/src/neighbors/brute_force_c.cpp new file mode 100644 index 000000000..531fc9a57 --- /dev/null +++ b/cpp/src/neighbors/brute_force_c.cpp @@ -0,0 +1,167 @@ + +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +template +void* _build(cuvsResources_t res, + DLManagedTensor* dataset_tensor, + enum DistanceType metric, + T metric_arg) +{ + auto res_ptr = reinterpret_cast(res); + + using mdspan_type = raft::device_matrix_view; + auto mds = cuvs::core::from_dlpack(dataset_tensor); + + auto index_on_stack = cuvs::neighbors::brute_force::build( + *res_ptr, mds, static_cast((int)metric), metric_arg); + auto index_on_heap = new cuvs::neighbors::brute_force::index(std::move(index_on_stack)); + + return index_on_heap; +} + +template +void _search(cuvsResources_t res, + bruteForceIndex index, + DLManagedTensor* queries_tensor, + DLManagedTensor* neighbors_tensor, + DLManagedTensor* distances_tensor) +{ + auto res_ptr = reinterpret_cast(res); + auto index_ptr = reinterpret_cast*>(index.addr); + + using queries_mdspan_type = raft::device_matrix_view; + using neighbors_mdspan_type = raft::device_matrix_view; + using distances_mdspan_type = raft::device_matrix_view; + auto queries_mds = cuvs::core::from_dlpack(queries_tensor); + auto neighbors_mds = cuvs::core::from_dlpack(neighbors_tensor); + auto distances_mds = cuvs::core::from_dlpack(distances_tensor); + + cuvs::neighbors::brute_force::search( + *res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds); +} + +} // namespace + +extern "C" cuvsError_t bruteForceIndexCreate(cuvsBruteForceIndex_t* index) +{ + try { + *index = new bruteForceIndex{}; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t bruteForceIndexDestroy(cuvsBruteForceIndex_t index_c_ptr) +{ + try { + auto index = *index_c_ptr; + + if (index.dtype.code == kDLFloat) { + auto index_ptr = reinterpret_cast*>(index.addr); + delete index_ptr; + } else if (index.dtype.code == kDLInt) { + auto index_ptr = reinterpret_cast*>(index.addr); + delete index_ptr; + } else if (index.dtype.code == kDLUInt) { + auto index_ptr = reinterpret_cast*>(index.addr); + delete index_ptr; + } + delete index_c_ptr; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t bruteForceBuild(cuvsResources_t res, + DLManagedTensor* dataset_tensor, + enum DistanceType metric, + float metric_arg, + cuvsBruteForceIndex_t index) +{ + try { + auto dataset = dataset_tensor->dl_tensor; + + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + index->addr = + reinterpret_cast(_build(res, dataset_tensor, metric, metric_arg)); + index->dtype.code = kDLFloat; + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t bruteForceSearch(cuvsResources_t res, + cuvsBruteForceIndex_t index_c_ptr, + DLManagedTensor* queries_tensor, + DLManagedTensor* neighbors_tensor, + DLManagedTensor* distances_tensor) +{ + try { + auto queries = queries_tensor->dl_tensor; + auto neighbors = neighbors_tensor->dl_tensor; + auto distances = distances_tensor->dl_tensor; + + RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(queries), + "queries should have device compatible memory"); + RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(neighbors), + "neighbors should have device compatible memory"); + RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(distances), + "distances should have device compatible memory"); + + RAFT_EXPECTS(neighbors.dtype.code == kDLInt && neighbors.dtype.bits == 64, + "neighbors should be of type int64_t"); + RAFT_EXPECTS(distances.dtype.code == kDLFloat && distances.dtype.bits == 32, + "distances should be of type float32"); + + auto index = *index_c_ptr; + RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries"); + + if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) { + _search(res, index, queries_tensor, neighbors_tensor, distances_tensor); + } else { + RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d", + queries.dtype.code, + queries.dtype.bits); + } + + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} diff --git a/cpp/src/neighbors/brute_force_index.cu b/cpp/src/neighbors/brute_force_index.cu new file mode 100644 index 000000000..b05fa7ced --- /dev/null +++ b/cpp/src/neighbors/brute_force_index.cu @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace cuvs::neighbors::brute_force { + +template +inline const raft::neighbors::brute_force::index* get_underlying_index( + const cuvs::neighbors::brute_force::index* idx) +{ + return reinterpret_cast*>(idx->get_raft_index()); +} + +template +index::index(void* raft_index) + : cuvs::neighbors::ann::index(), raft_index_(reinterpret_cast(raft_index)) +{ +} + +template +cuvs::distance::DistanceType index::metric() const noexcept +{ + auto raft_index = cuvs::neighbors::brute_force::get_underlying_index(this); + return static_cast((int)raft_index->metric()); +} + +template +size_t index::size() const noexcept +{ + auto raft_index = get_underlying_index(this); + return raft_index->size(); +} + +template +size_t index::dim() const noexcept +{ + auto raft_index = get_underlying_index(this); + return raft_index->dim(); +} + +template +raft::device_matrix_view index::dataset() const noexcept +{ + auto raft_index = get_underlying_index(this); + return raft_index->dataset(); +} + +template +raft::device_vector_view index::norms() const +{ + auto raft_index = get_underlying_index(this); + return raft_index->norms(); +} + +template +bool index::has_norms() const noexcept +{ + auto raft_index = get_underlying_index(this); + return raft_index->has_norms(); +} + +template +T index::metric_arg() const noexcept +{ + auto raft_index = get_underlying_index(this); + return raft_index->metric_arg(); +} + +template struct index; + +} // namespace cuvs::neighbors::brute_force diff --git a/cpp/src/neighbors/cagra_c.cpp b/cpp/src/neighbors/cagra_c.cpp index 9fdfe2c1e..9e5087016 100644 --- a/cpp/src/neighbors/cagra_c.cpp +++ b/cpp/src/neighbors/cagra_c.cpp @@ -162,14 +162,14 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(queries), "queries should have device compatible memory"); RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(neighbors), - "queries should have device compatible memory"); + "neighbors should have device compatible memory"); RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(distances), - "queries should have device compatible memory"); + "distances should have device compatible memory"); RAFT_EXPECTS(neighbors.dtype.code == kDLUInt && neighbors.dtype.bits == 32, "neighbors should be of type uint32_t"); RAFT_EXPECTS(distances.dtype.code == kDLFloat && neighbors.dtype.bits == 32, - "neighbors should be of type float32"); + "distances should be of type float32"); auto index = *index_c_ptr; RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries"); diff --git a/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py b/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py new file mode 100644 index 000000000..bf0cad6d4 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py @@ -0,0 +1,174 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { +""" + +footer = """ +} // namespace cuvs::neighbors::ivf_flat +""" + +types = dict( + float_int64_t=("float", "int64_t"), + int8_t_int64_t=("int8_t", "int64_t"), + uint8_t_int64_t=("uint8_t", "int64_t"), +) + +build_macro = """ +#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \\ + auto build(raft::resources const& handle, \\ + const cuvs::neighbors::ivf_flat::index_params& params, \\ + raft::device_matrix_view dataset) \\ + ->cuvs::neighbors::ivf_flat::index \\ + { \\ + return cuvs::neighbors::ivf_flat::index( \\ + std::move(raft::runtime::neighbors::ivf_flat::build(handle, params, dataset))); \\ + } \\ + \\ + void build(raft::resources const& handle, \\ + const cuvs::neighbors::ivf_flat::index_params& params, \\ + raft::device_matrix_view dataset, \\ + cuvs::neighbors::ivf_flat::index& idx) \\ + { \\ + raft::runtime::neighbors::ivf_flat::build(handle, params, dataset, *idx.get_raft_index()); \\ + } +""" + +extend_macro = """ +#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \\ + auto extend(raft::resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + const cuvs::neighbors::ivf_flat::index& orig_index) \\ + ->cuvs::neighbors::ivf_flat::index \\ + { \\ + return cuvs::neighbors::ivf_flat::index( \\ + std::move(raft::runtime::neighbors::ivf_flat::extend( \\ + handle, new_vectors, new_indices, *orig_index.get_raft_index()))); \\ + } \\ + \\ + void extend(raft::resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + cuvs::neighbors::ivf_flat::index* idx) \\ + { \\ + raft::runtime::neighbors::ivf_flat::extend( \\ + handle, new_vectors, new_indices, idx->get_raft_index()); \\ + } +""" + +search_macro = """ +#define CUVS_INST_IVF_FLAT_SEARCH(T, IdxT) \\ + void search(raft::resources const& handle, \\ + const cuvs::neighbors::ivf_flat::search_params& params, \\ + cuvs::neighbors::ivf_flat::index& index, \\ + raft::device_matrix_view queries, \\ + raft::device_matrix_view neighbors, \\ + raft::device_matrix_view distances) \\ + { \\ + raft::runtime::neighbors::ivf_flat::search( \\ + handle, params, *index.get_raft_index(), queries, neighbors, distances); \\ + } +""" + +serialize_macro = """ +#define CUVS_INST_IVF_FLAT_SERIALIZE(T, IdxT) \\ + void serialize_file(raft::resources const& handle, \\ + const std::string& filename, \\ + const cuvs::neighbors::ivf_flat::index& index) \\ + { \\ + raft::runtime::neighbors::ivf_flat::serialize_file(handle, filename, *index.get_raft_index()); \\ + } \\ + \\ + void deserialize_file(raft::resources const& handle, \\ + const std::string& filename, \\ + cuvs::neighbors::ivf_flat::index* index) \\ + { \\ + raft::runtime::neighbors::ivf_flat::deserialize_file( \\ + handle, filename, index->get_raft_index()); \\ + } \\ + \\ + void serialize(raft::resources const& handle, \\ + std::string& str, \\ + const cuvs::neighbors::ivf_flat::index& index) \\ + { \\ + raft::runtime::neighbors::ivf_flat::serialize(handle, str, *index.get_raft_index()); \\ + } \\ + \\ + void deserialize(raft::resources const& handle, \\ + const std::string& str, \\ + cuvs::neighbors::ivf_flat::index* index) \\ + { \\ + raft::runtime::neighbors::ivf_flat::deserialize(handle, str, index->get_raft_index()); \\ + } +""" + +macros = dict( + build=dict( + definition=build_macro, + name="CUVS_INST_IVF_FLAT_BUILD", + ), + extend=dict( + definition=extend_macro, + name="CUVS_INST_IVF_FLAT_EXTEND", + ), + search=dict( + definition=search_macro, + name="CUVS_INST_IVF_FLAT_SEARCH", + ), + serialize=dict( + definition=serialize_macro, + name="CUVS_INST_IVF_FLAT_SERIALIZE", + ), +) + +for type_path, (T, IdxT) in types.items(): + for macro_path, macro in macros.items(): + path = f"ivf_flat_{macro_path}_{type_path}.cpp" + with open(path, "w") as f: + f.write(header) + f.write(macro["definition"]) + f.write(f"{macro['name']}({T}, {IdxT});\n\n") + f.write(f"#undef {macro['name']}\n") + f.write(footer) + + print(f"src/neighbors/ivf_flat/{path}") diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cpp new file mode 100644 index 000000000..177aaac11 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->cuvs::neighbors::ivf_flat::index \ + { \ + return cuvs::neighbors::ivf_flat::index( \ + std::move(raft::runtime::neighbors::ivf_flat::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_flat::index& idx) \ + { \ + raft::runtime::neighbors::ivf_flat::build(handle, params, dataset, *idx.get_raft_index()); \ + } +CUVS_INST_IVF_FLAT_BUILD(float, int64_t); + +#undef CUVS_INST_IVF_FLAT_BUILD + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cpp new file mode 100644 index 000000000..6fe6e2b8d --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->cuvs::neighbors::ivf_flat::index \ + { \ + return cuvs::neighbors::ivf_flat::index( \ + std::move(raft::runtime::neighbors::ivf_flat::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_flat::index& idx) \ + { \ + raft::runtime::neighbors::ivf_flat::build(handle, params, dataset, *idx.get_raft_index()); \ + } +CUVS_INST_IVF_FLAT_BUILD(int8_t, int64_t); + +#undef CUVS_INST_IVF_FLAT_BUILD + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cpp new file mode 100644 index 000000000..01098ed45 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->cuvs::neighbors::ivf_flat::index \ + { \ + return cuvs::neighbors::ivf_flat::index( \ + std::move(raft::runtime::neighbors::ivf_flat::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_flat::index& idx) \ + { \ + raft::runtime::neighbors::ivf_flat::build(handle, params, dataset, *idx.get_raft_index()); \ + } +CUVS_INST_IVF_FLAT_BUILD(uint8_t, int64_t); + +#undef CUVS_INST_IVF_FLAT_BUILD + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cpp new file mode 100644 index 000000000..04ca3a50f --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_flat::index& orig_index) \ + ->cuvs::neighbors::ivf_flat::index \ + { \ + return cuvs::neighbors::ivf_flat::index( \ + std::move(raft::runtime::neighbors::ivf_flat::extend( \ + handle, new_vectors, new_indices, *orig_index.get_raft_index()))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_flat::index* idx) \ + { \ + raft::runtime::neighbors::ivf_flat::extend( \ + handle, new_vectors, new_indices, idx->get_raft_index()); \ + } +CUVS_INST_IVF_FLAT_EXTEND(float, int64_t); + +#undef CUVS_INST_IVF_FLAT_EXTEND + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cpp new file mode 100644 index 000000000..accc53e04 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_flat::index& orig_index) \ + ->cuvs::neighbors::ivf_flat::index \ + { \ + return cuvs::neighbors::ivf_flat::index( \ + std::move(raft::runtime::neighbors::ivf_flat::extend( \ + handle, new_vectors, new_indices, *orig_index.get_raft_index()))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_flat::index* idx) \ + { \ + raft::runtime::neighbors::ivf_flat::extend( \ + handle, new_vectors, new_indices, idx->get_raft_index()); \ + } +CUVS_INST_IVF_FLAT_EXTEND(int8_t, int64_t); + +#undef CUVS_INST_IVF_FLAT_EXTEND + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cpp new file mode 100644 index 000000000..e44ae51b1 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_flat::index& orig_index) \ + ->cuvs::neighbors::ivf_flat::index \ + { \ + return cuvs::neighbors::ivf_flat::index( \ + std::move(raft::runtime::neighbors::ivf_flat::extend( \ + handle, new_vectors, new_indices, *orig_index.get_raft_index()))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_flat::index* idx) \ + { \ + raft::runtime::neighbors::ivf_flat::extend( \ + handle, new_vectors, new_indices, idx->get_raft_index()); \ + } +CUVS_INST_IVF_FLAT_EXTEND(uint8_t, int64_t); + +#undef CUVS_INST_IVF_FLAT_EXTEND + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search_float_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_search_float_int64_t.cpp new file mode 100644 index 000000000..48a584e9e --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search_float_int64_t.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::search_params& params, \ + cuvs::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::runtime::neighbors::ivf_flat::search( \ + handle, params, *index.get_raft_index(), queries, neighbors, distances); \ + } +CUVS_INST_IVF_FLAT_SEARCH(float, int64_t); + +#undef CUVS_INST_IVF_FLAT_SEARCH + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cpp new file mode 100644 index 000000000..5645c18e0 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::search_params& params, \ + cuvs::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::runtime::neighbors::ivf_flat::search( \ + handle, params, *index.get_raft_index(), queries, neighbors, distances); \ + } +CUVS_INST_IVF_FLAT_SEARCH(int8_t, int64_t); + +#undef CUVS_INST_IVF_FLAT_SEARCH + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cpp new file mode 100644 index 000000000..ab1bee8b6 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_flat::search_params& params, \ + cuvs::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::runtime::neighbors::ivf_flat::search( \ + handle, params, *index.get_raft_index(), queries, neighbors, distances); \ + } +CUVS_INST_IVF_FLAT_SEARCH(uint8_t, int64_t); + +#undef CUVS_INST_IVF_FLAT_SEARCH + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cpp new file mode 100644 index 000000000..19a3d72d9 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_SERIALIZE(T, IdxT) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::ivf_flat::index& index) \ + { \ + raft::runtime::neighbors::ivf_flat::serialize_file(handle, filename, *index.get_raft_index()); \ + } \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::ivf_flat::index* index) \ + { \ + raft::runtime::neighbors::ivf_flat::deserialize_file( \ + handle, filename, index->get_raft_index()); \ + } \ + \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const cuvs::neighbors::ivf_flat::index& index) \ + { \ + raft::runtime::neighbors::ivf_flat::serialize(handle, str, *index.get_raft_index()); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + cuvs::neighbors::ivf_flat::index* index) \ + { \ + raft::runtime::neighbors::ivf_flat::deserialize(handle, str, index->get_raft_index()); \ + } +CUVS_INST_IVF_FLAT_SERIALIZE(float, int64_t); + +#undef CUVS_INST_IVF_FLAT_SERIALIZE + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cpp new file mode 100644 index 000000000..f65fe221d --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_SERIALIZE(T, IdxT) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::ivf_flat::index& index) \ + { \ + raft::runtime::neighbors::ivf_flat::serialize_file(handle, filename, *index.get_raft_index()); \ + } \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::ivf_flat::index* index) \ + { \ + raft::runtime::neighbors::ivf_flat::deserialize_file( \ + handle, filename, index->get_raft_index()); \ + } \ + \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const cuvs::neighbors::ivf_flat::index& index) \ + { \ + raft::runtime::neighbors::ivf_flat::serialize(handle, str, *index.get_raft_index()); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + cuvs::neighbors::ivf_flat::index* index) \ + { \ + raft::runtime::neighbors::ivf_flat::deserialize(handle, str, index->get_raft_index()); \ + } +CUVS_INST_IVF_FLAT_SERIALIZE(int8_t, int64_t); + +#undef CUVS_INST_IVF_FLAT_SERIALIZE + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cpp b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cpp new file mode 100644 index 000000000..5f312dc0a --- /dev/null +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_flat.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_flat.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_flat { + +#define CUVS_INST_IVF_FLAT_SERIALIZE(T, IdxT) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::ivf_flat::index& index) \ + { \ + raft::runtime::neighbors::ivf_flat::serialize_file(handle, filename, *index.get_raft_index()); \ + } \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::ivf_flat::index* index) \ + { \ + raft::runtime::neighbors::ivf_flat::deserialize_file( \ + handle, filename, index->get_raft_index()); \ + } \ + \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const cuvs::neighbors::ivf_flat::index& index) \ + { \ + raft::runtime::neighbors::ivf_flat::serialize(handle, str, *index.get_raft_index()); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + cuvs::neighbors::ivf_flat::index* index) \ + { \ + raft::runtime::neighbors::ivf_flat::deserialize(handle, str, index->get_raft_index()); \ + } +CUVS_INST_IVF_FLAT_SERIALIZE(uint8_t, int64_t); + +#undef CUVS_INST_IVF_FLAT_SERIALIZE + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat_c.cpp b/cpp/src/neighbors/ivf_flat_c.cpp new file mode 100644 index 000000000..b9488ec36 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_c.cpp @@ -0,0 +1,246 @@ + +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +template +void* _build(cuvsResources_t res, ivfFlatIndexParams params, DLManagedTensor* dataset_tensor) +{ + auto res_ptr = reinterpret_cast(res); + + auto build_params = cuvs::neighbors::ivf_flat::index_params(); + build_params.metric = static_cast((int)params.metric), + build_params.metric_arg = params.metric_arg; + build_params.add_data_on_build = params.add_data_on_build; + build_params.n_lists = params.n_lists; + build_params.kmeans_n_iters = params.kmeans_n_iters; + build_params.kmeans_trainset_fraction = params.kmeans_trainset_fraction; + build_params.adaptive_centers = params.adaptive_centers; + build_params.conservative_memory_allocation = params.conservative_memory_allocation; + + auto dataset = dataset_tensor->dl_tensor; + auto dim = dataset.shape[0]; + + auto index = new cuvs::neighbors::ivf_flat::index(*res_ptr, build_params, dim); + + using mdspan_type = raft::device_matrix_view; + auto mds = cuvs::core::from_dlpack(dataset_tensor); + + cuvs::neighbors::ivf_flat::build(*res_ptr, build_params, mds, *index); + + return index; +} + +template +void _search(cuvsResources_t res, + ivfFlatSearchParams params, + ivfFlatIndex index, + DLManagedTensor* queries_tensor, + DLManagedTensor* neighbors_tensor, + DLManagedTensor* distances_tensor) +{ + auto res_ptr = reinterpret_cast(res); + auto index_ptr = reinterpret_cast*>(index.addr); + + auto search_params = cuvs::neighbors::ivf_flat::search_params(); + search_params.n_probes = params.n_probes; + + using queries_mdspan_type = raft::device_matrix_view; + using neighbors_mdspan_type = raft::device_matrix_view; + using distances_mdspan_type = raft::device_matrix_view; + auto queries_mds = cuvs::core::from_dlpack(queries_tensor); + auto neighbors_mds = cuvs::core::from_dlpack(neighbors_tensor); + auto distances_mds = cuvs::core::from_dlpack(distances_tensor); + + cuvs::neighbors::ivf_flat::search( + *res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds); +} + +} // namespace + +extern "C" cuvsError_t ivfFlatIndexCreate(cuvsIvfFlatIndex_t* index) +{ + try { + *index = new ivfFlatIndex{}; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t ivfFlatIndexDestroy(cuvsIvfFlatIndex_t index_c_ptr) +{ + try { + auto index = *index_c_ptr; + + if (index.dtype.code == kDLFloat) { + auto index_ptr = + reinterpret_cast*>(index.addr); + delete index_ptr; + } else if (index.dtype.code == kDLInt) { + auto index_ptr = + reinterpret_cast*>(index.addr); + delete index_ptr; + } else if (index.dtype.code == kDLUInt) { + auto index_ptr = + reinterpret_cast*>(index.addr); + delete index_ptr; + } + delete index_c_ptr; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t ivfFlatBuild(cuvsResources_t res, + cuvsIvfFlatIndexParams_t params, + DLManagedTensor* dataset_tensor, + cuvsIvfFlatIndex_t index) +{ + try { + auto dataset = dataset_tensor->dl_tensor; + + if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { + index->addr = + reinterpret_cast(_build(res, *params, dataset_tensor)); + index->dtype.code = kDLFloat; + } else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) { + index->addr = + reinterpret_cast(_build(res, *params, dataset_tensor)); + index->dtype.code = kDLInt; + } else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) { + index->addr = + reinterpret_cast(_build(res, *params, dataset_tensor)); + index->dtype.code = kDLUInt; + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t ivfFlatSearch(cuvsResources_t res, + cuvsIvfFlatSearchParams_t params, + cuvsIvfFlatIndex_t index_c_ptr, + DLManagedTensor* queries_tensor, + DLManagedTensor* neighbors_tensor, + DLManagedTensor* distances_tensor) +{ + try { + auto queries = queries_tensor->dl_tensor; + auto neighbors = neighbors_tensor->dl_tensor; + auto distances = distances_tensor->dl_tensor; + + RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(queries), + "queries should have device compatible memory"); + RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(neighbors), + "neighbors should have device compatible memory"); + RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(distances), + "distances should have device compatible memory"); + + RAFT_EXPECTS(neighbors.dtype.code == kDLInt && neighbors.dtype.bits == 64, + "neighbors should be of type int64_t"); + RAFT_EXPECTS(distances.dtype.code == kDLFloat && distances.dtype.bits == 32, + "distances should be of type float32"); + + auto index = *index_c_ptr; + RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries"); + + if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) { + _search( + res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); + } else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) { + _search( + res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); + } else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) { + _search( + res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); + } else { + RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d", + queries.dtype.code, + queries.dtype.bits); + } + + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t cuvsIvfFlatIndexParamsCreate(cuvsIvfFlatIndexParams_t* params) +{ + try { + *params = new ivfFlatIndexParams{.metric = L2Expanded, + .metric_arg = 2.0f, + .add_data_on_build = true, + .n_lists = 1024, + .kmeans_n_iters = 20, + .kmeans_trainset_fraction = 0.5, + .adaptive_centers = false, + .conservative_memory_allocation = false}; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t cuvsIvfFlatIndexParamsDestroy(cuvsIvfFlatIndexParams_t params) +{ + try { + delete params; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t cuvsIvfFlatSearchParamsCreate(cuvsIvfFlatSearchParams_t* params) +{ + try { + *params = new ivfFlatSearchParams{.n_probes = 20}; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t cuvsIvfFlatSearchParamsDestroy(cuvsIvfFlatSearchParams_t params) +{ + try { + delete params; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} diff --git a/cpp/src/neighbors/ivf_flat_index.cpp b/cpp/src/neighbors/ivf_flat_index.cpp new file mode 100644 index 000000000..678bec32a --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_index.cpp @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace cuvs::neighbors::ivf_flat { + +template +index::index(raft::resources const& res, const index_params& params, uint32_t dim) + : ann::index(), + raft_index_(std::make_unique>( + res, + static_cast((int)params.metric), + params.n_lists, + params.adaptive_centers, + params.conservative_memory_allocation, + dim)) +{ +} + +template +index::index(raft::neighbors::ivf_flat::index&& raft_idx) + : ann::index(), + raft_index_(std::make_unique>(std::move(raft_idx))) +{ +} + +template +uint32_t index::veclen() const noexcept +{ + return raft_index_->veclen(); +} + +template +cuvs::distance::DistanceType index::metric() const noexcept +{ + return static_cast((int)raft_index_->metric()); +} + +template +bool index::adaptive_centers() const noexcept +{ + return raft_index_->adaptive_centers(); +} + +template +raft::device_vector_view index::list_sizes() noexcept +{ + return raft_index_->list_sizes(); +} + +template +raft::device_vector_view index::list_sizes() const noexcept +{ + return raft_index_->list_sizes(); +} + +template +raft::device_matrix_view index::centers() noexcept +{ + return raft_index_->centers(); +} + +template +raft::device_matrix_view index::centers() + const noexcept +{ + return raft_index_->centers(); +} + +template +std::optional> index::center_norms() noexcept +{ + return raft_index_->center_norms(); +} + +template +std::optional> index::center_norms() + const noexcept +{ + return raft_index_->center_norms(); +} + +template +IdxT index::size() const noexcept +{ + return raft_index_->size(); +} + +template +uint32_t index::dim() const noexcept +{ + return raft_index_->dim(); +} + +template +uint32_t index::n_lists() const noexcept +{ + return raft_index_->n_lists(); +} + +template +raft::device_vector_view index::data_ptrs() noexcept +{ + return raft_index_->data_ptrs(); +} + +template +raft::device_vector_view index::data_ptrs() const noexcept +{ + return raft_index_->data_ptrs(); +} + +template +raft::device_vector_view index::inds_ptrs() noexcept +{ + return raft_index_->inds_ptrs(); +} + +template +raft::device_vector_view index::inds_ptrs() const noexcept +{ + return raft_index_->inds_ptrs(); +} + +template +bool index::conservative_memory_allocation() const noexcept +{ + return raft_index_->conservative_memory_allocation(); +} + +template +std::vector>>& +index::lists() noexcept +{ + return raft_index_->lists(); +} + +template +const std::vector>>& +index::lists() const noexcept +{ + return raft_index_->lists(); +} + +template struct index; +template struct index; +template struct index; + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_pq/generate_ivf_pq.py b/cpp/src/neighbors/ivf_pq/generate_ivf_pq.py new file mode 100644 index 000000000..d0a5d2b19 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/generate_ivf_pq.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { +""" + +footer = """ +} // namespace cuvs::neighbors::ivf_pq +""" + +types = dict( + float_int64_t=("float", "int64_t"), + int8_t_int64_t=("int8_t", "int64_t"), + uint8_t_int64_t=("uint8_t", "int64_t"), +) + +build_macro = """ +#define CUVS_INST_IVF_PQ_BUILD(T, IdxT) \\ + auto build(raft::resources const& handle, \\ + const cuvs::neighbors::ivf_pq::index_params& params, \\ + raft::device_matrix_view dataset) \\ + ->cuvs::neighbors::ivf_pq::index \\ + { \\ + return cuvs::neighbors::ivf_pq::index( \\ + std::move(raft::runtime::neighbors::ivf_pq::build(handle, params, dataset))); \\ + } \\ + \\ + void build(raft::resources const& handle, \\ + const cuvs::neighbors::ivf_pq::index_params& params, \\ + raft::device_matrix_view dataset, \\ + cuvs::neighbors::ivf_pq::index* idx) \\ + { \\ + raft::runtime::neighbors::ivf_pq::build(handle, params, dataset, idx->get_raft_index()); \\ + } +""" + +extend_macro = """ +#define CUVS_INST_IVF_PQ_EXTEND(T, IdxT) \\ + auto extend(raft::resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + const cuvs::neighbors::ivf_pq::index& orig_index) \\ + ->cuvs::neighbors::ivf_pq::index \\ + { \\ + return cuvs::neighbors::ivf_pq::index( \\ + std::move(raft::runtime::neighbors::ivf_pq::extend( \\ + handle, new_vectors, new_indices, *orig_index.get_raft_index()))); \\ + } \\ + \\ + void extend(raft::resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + cuvs::neighbors::ivf_pq::index* idx) \\ + { \\ + raft::runtime::neighbors::ivf_pq::extend( \\ + handle, new_vectors, new_indices, idx->get_raft_index()); \\ + } +""" + +search_macro = """ +#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \\ + void search(raft::resources const& handle, \\ + const cuvs::neighbors::ivf_pq::search_params& params, \\ + cuvs::neighbors::ivf_pq::index& index, \\ + raft::device_matrix_view queries, \\ + raft::device_matrix_view neighbors, \\ + raft::device_matrix_view distances) \\ + { \\ + raft::runtime::neighbors::ivf_pq::search( \\ + handle, params, *index.get_raft_index(), queries, neighbors, distances); \\ + } +""" + +macros = dict( + build=dict( + definition=build_macro, + name="CUVS_INST_IVF_PQ_BUILD", + ), + extend=dict( + definition=extend_macro, + name="CUVS_INST_IVF_PQ_EXTEND", + ), + search=dict( + definition=search_macro, + name="CUVS_INST_IVF_PQ_SEARCH", + ), +) + +for type_path, (T, IdxT) in types.items(): + for macro_path, macro in macros.items(): + path = f"ivf_pq_{macro_path}_{type_path}.cpp" + with open(path, "w") as f: + f.write(header) + f.write(macro["definition"]) + f.write(f"{macro['name']}({T}, {IdxT});\n\n") + f.write(f"#undef {macro['name']}\n") + f.write(footer) + + print(f"src/neighbors/ivf_pq/{path}") diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build_float_int64_t.cpp b/cpp/src/neighbors/ivf_pq/ivf_pq_build_float_int64_t.cpp new file mode 100644 index 000000000..78c4a0f67 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build_float_int64_t.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset) \ + ->cuvs::neighbors::ivf_pq::index \ + { \ + return cuvs::neighbors::ivf_pq::index( \ + std::move(raft::runtime::neighbors::ivf_pq::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_pq::index* idx) \ + { \ + raft::runtime::neighbors::ivf_pq::build(handle, params, dataset, idx->get_raft_index()); \ + } +CUVS_INST_IVF_PQ_BUILD(float, int64_t); + +#undef CUVS_INST_IVF_PQ_BUILD + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build_int8_t_int64_t.cpp b/cpp/src/neighbors/ivf_pq/ivf_pq_build_int8_t_int64_t.cpp new file mode 100644 index 000000000..c9d7fc4c9 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build_int8_t_int64_t.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset) \ + ->cuvs::neighbors::ivf_pq::index \ + { \ + return cuvs::neighbors::ivf_pq::index( \ + std::move(raft::runtime::neighbors::ivf_pq::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_pq::index* idx) \ + { \ + raft::runtime::neighbors::ivf_pq::build(handle, params, dataset, idx->get_raft_index()); \ + } +CUVS_INST_IVF_PQ_BUILD(int8_t, int64_t); + +#undef CUVS_INST_IVF_PQ_BUILD + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build_uint8_t_int64_t.cpp b/cpp/src/neighbors/ivf_pq/ivf_pq_build_uint8_t_int64_t.cpp new file mode 100644 index 000000000..24e56592d --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build_uint8_t_int64_t.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset) \ + ->cuvs::neighbors::ivf_pq::index \ + { \ + return cuvs::neighbors::ivf_pq::index( \ + std::move(raft::runtime::neighbors::ivf_pq::build(handle, params, dataset))); \ + } \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::ivf_pq::index* idx) \ + { \ + raft::runtime::neighbors::ivf_pq::build(handle, params, dataset, idx->get_raft_index()); \ + } +CUVS_INST_IVF_PQ_BUILD(uint8_t, int64_t); + +#undef CUVS_INST_IVF_PQ_BUILD + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_extend_float_int64_t.cpp b/cpp/src/neighbors/ivf_pq/ivf_pq_extend_float_int64_t.cpp new file mode 100644 index 000000000..ec189ca9a --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_extend_float_int64_t.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_EXTEND(T, IdxT) \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_pq::index& orig_index) \ + ->cuvs::neighbors::ivf_pq::index \ + { \ + return cuvs::neighbors::ivf_pq::index( \ + std::move(raft::runtime::neighbors::ivf_pq::extend( \ + handle, new_vectors, new_indices, *orig_index.get_raft_index()))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_pq::index* idx) \ + { \ + raft::runtime::neighbors::ivf_pq::extend( \ + handle, new_vectors, new_indices, idx->get_raft_index()); \ + } +CUVS_INST_IVF_PQ_EXTEND(float, int64_t); + +#undef CUVS_INST_IVF_PQ_EXTEND + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_extend_int8_t_int64_t.cpp b/cpp/src/neighbors/ivf_pq/ivf_pq_extend_int8_t_int64_t.cpp new file mode 100644 index 000000000..27eadec72 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_extend_int8_t_int64_t.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_EXTEND(T, IdxT) \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_pq::index& orig_index) \ + ->cuvs::neighbors::ivf_pq::index \ + { \ + return cuvs::neighbors::ivf_pq::index( \ + std::move(raft::runtime::neighbors::ivf_pq::extend( \ + handle, new_vectors, new_indices, *orig_index.get_raft_index()))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_pq::index* idx) \ + { \ + raft::runtime::neighbors::ivf_pq::extend( \ + handle, new_vectors, new_indices, idx->get_raft_index()); \ + } +CUVS_INST_IVF_PQ_EXTEND(int8_t, int64_t); + +#undef CUVS_INST_IVF_PQ_EXTEND + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_extend_uint8_t_int64_t.cpp b/cpp/src/neighbors/ivf_pq/ivf_pq_extend_uint8_t_int64_t.cpp new file mode 100644 index 000000000..072b30bb0 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_extend_uint8_t_int64_t.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_EXTEND(T, IdxT) \ + auto extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const cuvs::neighbors::ivf_pq::index& orig_index) \ + ->cuvs::neighbors::ivf_pq::index \ + { \ + return cuvs::neighbors::ivf_pq::index( \ + std::move(raft::runtime::neighbors::ivf_pq::extend( \ + handle, new_vectors, new_indices, *orig_index.get_raft_index()))); \ + } \ + \ + void extend(raft::resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + cuvs::neighbors::ivf_pq::index* idx) \ + { \ + raft::runtime::neighbors::ivf_pq::extend( \ + handle, new_vectors, new_indices, idx->get_raft_index()); \ + } +CUVS_INST_IVF_PQ_EXTEND(uint8_t, int64_t); + +#undef CUVS_INST_IVF_PQ_EXTEND + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search_float_int64_t.cpp b/cpp/src/neighbors/ivf_pq/ivf_pq_search_float_int64_t.cpp new file mode 100644 index 000000000..69db44d9b --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search_float_int64_t.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::runtime::neighbors::ivf_pq::search( \ + handle, params, *index.get_raft_index(), queries, neighbors, distances); \ + } +CUVS_INST_IVF_PQ_SEARCH(float, int64_t); + +#undef CUVS_INST_IVF_PQ_SEARCH + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search_int8_t_int64_t.cpp b/cpp/src/neighbors/ivf_pq/ivf_pq_search_int8_t_int64_t.cpp new file mode 100644 index 000000000..e2e96b9a7 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search_int8_t_int64_t.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::runtime::neighbors::ivf_pq::search( \ + handle, params, *index.get_raft_index(), queries, neighbors, distances); \ + } +CUVS_INST_IVF_PQ_SEARCH(int8_t, int64_t); + +#undef CUVS_INST_IVF_PQ_SEARCH + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search_uint8_t_int64_t.cpp b/cpp/src/neighbors/ivf_pq/ivf_pq_search_uint8_t_int64_t.cpp new file mode 100644 index 000000000..18e7f4618 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search_uint8_t_int64_t.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::runtime::neighbors::ivf_pq::search( \ + handle, params, *index.get_raft_index(), queries, neighbors, distances); \ + } +CUVS_INST_IVF_PQ_SEARCH(uint8_t, int64_t); + +#undef CUVS_INST_IVF_PQ_SEARCH + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq_c.cpp b/cpp/src/neighbors/ivf_pq_c.cpp new file mode 100644 index 000000000..27076cbc1 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq_c.cpp @@ -0,0 +1,235 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +template +void* _build(cuvsResources_t res, ivfPqIndexParams params, DLManagedTensor* dataset_tensor) +{ + auto res_ptr = reinterpret_cast(res); + + auto build_params = cuvs::neighbors::ivf_pq::index_params(); + build_params.metric = static_cast((int)params.metric), + build_params.metric_arg = params.metric_arg; + build_params.add_data_on_build = params.add_data_on_build; + build_params.n_lists = params.n_lists; + build_params.kmeans_n_iters = params.kmeans_n_iters; + build_params.kmeans_trainset_fraction = params.kmeans_trainset_fraction; + build_params.pq_bits = params.pq_bits; + build_params.pq_dim = params.pq_dim; + build_params.codebook_kind = + static_cast((int)params.codebook_kind); + build_params.force_random_rotation = params.force_random_rotation; + build_params.conservative_memory_allocation = params.conservative_memory_allocation; + + auto dataset = dataset_tensor->dl_tensor; + auto dim = dataset.shape[0]; + + auto index = new cuvs::neighbors::ivf_pq::index(*res_ptr, build_params, dim); + + using mdspan_type = raft::device_matrix_view; + auto mds = cuvs::core::from_dlpack(dataset_tensor); + + cuvs::neighbors::ivf_pq::build(*res_ptr, build_params, mds, index); + + return index; +} + +template +void _search(cuvsResources_t res, + ivfPqSearchParams params, + ivfPqIndex index, + DLManagedTensor* queries_tensor, + DLManagedTensor* neighbors_tensor, + DLManagedTensor* distances_tensor) +{ + auto res_ptr = reinterpret_cast(res); + auto index_ptr = reinterpret_cast*>(index.addr); + + auto search_params = cuvs::neighbors::ivf_pq::search_params(); + search_params.n_probes = params.n_probes; + search_params.lut_dtype = params.lut_dtype; + search_params.internal_distance_dtype = params.internal_distance_dtype; + search_params.preferred_shmem_carveout = params.preferred_shmem_carveout; + + using queries_mdspan_type = raft::device_matrix_view; + using neighbors_mdspan_type = raft::device_matrix_view; + using distances_mdspan_type = raft::device_matrix_view; + auto queries_mds = cuvs::core::from_dlpack(queries_tensor); + auto neighbors_mds = cuvs::core::from_dlpack(neighbors_tensor); + auto distances_mds = cuvs::core::from_dlpack(distances_tensor); + + cuvs::neighbors::ivf_pq::search( + *res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds); +} + +} // namespace + +extern "C" cuvsError_t ivfPqIndexCreate(cuvsIvfPqIndex_t* index) +{ + try { + *index = new ivfPqIndex{}; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t ivfPqIndexDestroy(cuvsIvfPqIndex_t index_c_ptr) +{ + try { + auto index = *index_c_ptr; + + auto index_ptr = reinterpret_cast*>(index.addr); + delete index_ptr; + delete index_c_ptr; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t ivfPqBuild(cuvsResources_t res, + cuvsIvfPqIndexParams_t params, + DLManagedTensor* dataset_tensor, + cuvsIvfPqIndex_t index) +{ + try { + auto dataset = dataset_tensor->dl_tensor; + + if ((dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) || + (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) || + (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8)) { + index->addr = reinterpret_cast(_build(res, *params, dataset_tensor)); + index->dtype.code = dataset.dtype.code; + } else { + RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", + dataset.dtype.code, + dataset.dtype.bits); + } + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t ivfPqSearch(cuvsResources_t res, + cuvsIvfPqSearchParams_t params, + cuvsIvfPqIndex_t index_c_ptr, + DLManagedTensor* queries_tensor, + DLManagedTensor* neighbors_tensor, + DLManagedTensor* distances_tensor) +{ + try { + auto queries = queries_tensor->dl_tensor; + auto neighbors = neighbors_tensor->dl_tensor; + auto distances = distances_tensor->dl_tensor; + + RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(queries), + "queries should have device compatible memory"); + RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(neighbors), + "neighbors should have device compatible memory"); + RAFT_EXPECTS(cuvs::core::is_dlpack_device_compatible(distances), + "distances should have device compatible memory"); + + RAFT_EXPECTS(neighbors.dtype.code == kDLInt && neighbors.dtype.bits == 64, + "neighbors should be of type int64_t"); + RAFT_EXPECTS(distances.dtype.code == kDLFloat && distances.dtype.bits == 32, + "distances should be of type float32"); + + auto index = *index_c_ptr; + RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries"); + + if ((queries.dtype.code == kDLFloat && queries.dtype.bits == 32) || + (queries.dtype.code == kDLInt && queries.dtype.bits == 8) || + (queries.dtype.code == kDLUInt && queries.dtype.bits == 8)) { + _search(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor); + } else { + RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d", + queries.dtype.code, + queries.dtype.bits); + } + + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t cuvsIvfPqIndexParamsCreate(cuvsIvfPqIndexParams_t* params) +{ + try { + *params = new ivfPqIndexParams{.metric = L2Expanded, + .metric_arg = 2.0f, + .add_data_on_build = true, + .n_lists = 1024, + .kmeans_n_iters = 20, + .kmeans_trainset_fraction = 0.5, + .pq_bits = 8, + .pq_dim = 0, + .codebook_kind = codebook_gen::PER_SUBSPACE, + .force_random_rotation = false, + .conservative_memory_allocation = false}; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t cuvsIvfPqIndexParamsDestroy(cuvsIvfPqIndexParams_t params) +{ + try { + delete params; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t cuvsIvfPqSearchParamsCreate(cuvsIvfPqSearchParams_t* params) +{ + try { + *params = new ivfPqSearchParams{.n_probes = 20, + .lut_dtype = CUDA_R_32F, + .internal_distance_dtype = CUDA_R_32F, + .preferred_shmem_carveout = 1.0}; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} + +extern "C" cuvsError_t cuvsIvfPqSearchParamsDestroy(cuvsIvfPqSearchParams_t params) +{ + try { + delete params; + return CUVS_SUCCESS; + } catch (...) { + return CUVS_ERROR; + } +} diff --git a/cpp/src/neighbors/ivf_pq_index.cpp b/cpp/src/neighbors/ivf_pq_index.cpp new file mode 100644 index 000000000..b464da670 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq_index.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace cuvs::neighbors::ivf_pq { + +template +index::index(raft::resources const& handle, const index_params& params, uint32_t dim) + : ann::index(), + raft_index_(std::make_unique>( + handle, + static_cast((int)params.metric), + static_cast((int)params.codebook_kind), + params.n_lists, + dim, + params.pq_bits, + params.pq_dim, + params.conservative_memory_allocation)) +{ +} + +template +index::index(raft::neighbors::ivf_pq::index&& raft_idx) + : ann::index(), + raft_index_(std::make_unique>(std::move(raft_idx))) +{ +} + +template +IdxT index::size() const noexcept +{ + return raft_index_->size(); +} + +template +uint32_t index::dim() const noexcept +{ + return raft_index_->dim(); +} + +template +uint32_t index::dim_ext() const noexcept +{ + return raft_index_->dim_ext(); +} + +template +uint32_t index::rot_dim() const noexcept +{ + return raft_index_->rot_dim(); +} + +template +uint32_t index::pq_bits() const noexcept +{ + return raft_index_->pq_bits(); +} + +template +uint32_t index::pq_dim() const noexcept +{ + return raft_index_->pq_dim(); +} + +template +uint32_t index::pq_len() const noexcept +{ + return raft_index_->pq_len(); +} + +template +uint32_t index::pq_book_size() const noexcept +{ + return raft_index_->pq_book_size(); +} + +template +cuvs::distance::DistanceType index::metric() const noexcept +{ + return static_cast((int)raft_index_->metric()); +} + +template +codebook_gen index::codebook_kind() const noexcept +{ + return static_cast((int)raft_index_->codebook_kind()); +} + +template +uint32_t index::n_lists() const noexcept +{ + return raft_index_->n_lists(); +} + +template +bool index::conservative_memory_allocation() const noexcept +{ + return raft_index_->conservative_memory_allocation(); +} + +template +raft:: + mdspan::pq_centers_extents, raft::row_major> + index::pq_centers() noexcept +{ + return raft_index_->pq_centers(); +} + +template +raft::mdspan::pq_centers_extents, + raft::row_major> +index::pq_centers() const noexcept +{ + return raft_index_->pq_centers(); +} + +template +std::vector>>& index::lists() noexcept +{ + return raft_index_->lists(); +} + +template +const std::vector>>& index::lists() const noexcept +{ + return raft_index_->lists(); +} + +template +raft::device_vector_view index::data_ptrs() noexcept +{ + return raft_index_->data_ptrs(); +} + +template +raft::device_vector_view index::data_ptrs() + const noexcept +{ + return raft_index_->data_ptrs(); +} + +template +raft::device_vector_view index::inds_ptrs() noexcept +{ + return raft_index_->inds_ptrs(); +} + +template +raft::device_vector_view index::inds_ptrs() + const noexcept +{ + return raft_index_->inds_ptrs(); +} + +template +raft::device_matrix_view index::rotation_matrix() noexcept +{ + return raft_index_->rotation_matrix(); +} + +template +raft::device_matrix_view index::rotation_matrix() + const noexcept +{ + return raft_index_->rotation_matrix(); +} + +template +raft::host_vector_view index::accum_sorted_sizes() noexcept +{ + return raft_index_->accum_sorted_sizes(); +} + +template +raft::host_vector_view index::accum_sorted_sizes() + const noexcept +{ + return raft_index_->accum_sorted_sizes(); +} + +template +raft::device_vector_view index::list_sizes() noexcept +{ + return raft_index_->list_sizes(); +} + +template +raft::device_vector_view index::list_sizes() + const noexcept +{ + return raft_index_->list_sizes(); +} + +template +raft::device_matrix_view index::centers() noexcept +{ + return raft_index_->centers(); +} + +template +raft::device_matrix_view index::centers() + const noexcept +{ + return raft_index_->centers(); +} + +template +raft::device_matrix_view index::centers_rot() noexcept +{ + return raft_index_->centers_rot(); +} + +template +raft::device_matrix_view index::centers_rot() + const noexcept +{ + return raft_index_->centers_rot(); +} + +template struct index; + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq_serialize.cpp b/cpp/src/neighbors/ivf_pq_serialize.cpp new file mode 100644 index 000000000..4bdd3b04b --- /dev/null +++ b/cpp/src/neighbors/ivf_pq_serialize.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_SERIALIZE(IdxT) \ + void serialize(raft::resources const& handle, \ + std::string& filename, \ + const cuvs::neighbors::ivf_pq::index& index) \ + { \ + raft::runtime::neighbors::ivf_pq::serialize(handle, filename, *index.get_raft_index()); \ + } \ + void deserialize(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::ivf_pq::index* index) \ + { \ + raft::runtime::neighbors::ivf_pq::deserialize(handle, filename, index->get_raft_index()); \ + } + +CUVS_INST_IVF_PQ_SERIALIZE(int64_t); + +#undef CUVS_INST_IVF_PQ_SERIALIZE + +} // namespace cuvs::neighbors::ivf_pq \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index f33c14179..841058f60 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -88,6 +88,34 @@ endfunction() # ################################################################################################## if(BUILD_TESTS) + ConfigureTest(NAME NEIGHBORS_TEST PATH test/neighbors/brute_force.cu GPUS 1 PERCENT 100) + + ConfigureTest( + NAME + NEIGHBORS_ANN_IVF_FLAT_TEST + PATH + test/neighbors/ann_ivf_flat/test_float_int64_t.cu + test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu + test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu + GPUS + 1 + PERCENT + 100 + ) + + ConfigureTest( + NAME + NEIGHBORS_ANN_IVF_PQ_TEST + PATH + test/neighbors/ann_ivf_pq/test_float_int64_t.cu + test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu + test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu + GPUS + 1 + PERCENT + 100 + ) + ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_TEST @@ -105,6 +133,32 @@ endif() if(BUILD_C_TESTS) ConfigureTest(NAME INTEROP_TEST PATH test/core/interop.cu C_LIB) + ConfigureTest( + NAME + BRUTEFORCE_C_TEST + PATH + test/neighbors/run_brute_force_c.c + test/neighbors/brute_force_c.cu + C_LIB + ) + + ConfigureTest( + NAME + IVF_FLAT_C_TEST + PATH + test/neighbors/run_ivf_flat_c.c + test/neighbors/ann_ivf_flat_c.cu + C_LIB + ) + + ConfigureTest( + NAME + IVF_PQ_C_TEST + PATH + test/neighbors/run_ivf_pq_c.c + test/neighbors/ann_ivf_pq_c.cu C_LIB + ) + ConfigureTest(NAME CAGRA_C_TEST PATH test/neighbors/ann_cagra_c.cu C_LIB) endif() diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh new file mode 100644 index 000000000..90b70ba11 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -0,0 +1,643 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" +#include "naive_knn.cuh" + +#include +#include +#include + +#include + +namespace cuvs::neighbors::ivf_flat { + +struct test_ivf_sample_filter { + static constexpr unsigned offset = 300; +}; + +template +struct AnnIvfFlatInputs { + IdxT num_queries; + IdxT num_db_vecs; + IdxT dim; + IdxT k; + IdxT nprobe; + IdxT nlist; + cuvs::distance::DistanceType metric; + bool adaptive_centers; + // bool host_dataset; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const AnnIvfFlatInputs& p) +{ + os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " + << p.nprobe << ", " << p.nlist << ", " << static_cast(p.metric) << ", " + << p.adaptive_centers << '}' << std::endl; + return os; +} + +template +class AnnIVFFlatTest : public ::testing::TestWithParam> { + public: + AnnIVFFlatTest() + : stream_(raft::resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam>::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + void testIVFFlat() + { + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivfflat(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_ivfflat(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + cuvs::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + ps.metric); + raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + + { + // unless something is really wrong with clustering, this could serve as a lower bound on + // recall + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + + rmm::device_uvector distances_ivfflat_dev(queries_size, stream_); + rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); + + { + // legacy interface + raft::spatial::knn::IVFFlatParam ivfParams; + ivfParams.nprobe = ps.nprobe; + ivfParams.nlist = ps.nlist; + raft::spatial::knn::knnIndex index; + + raft::spatial::knn::approx_knn_build_index( + handle_, + &index, + dynamic_cast(&ivfParams), + static_cast((int)ps.metric), + (IdxT)0, + database.data(), + ps.num_db_vecs, + ps.dim); + + raft::resource::sync_stream(handle_); + raft::spatial::knn::approx_knn_search(handle_, + distances_ivfflat_dev.data(), + indices_ivfflat_dev.data(), + &index, + ps.k, + search_queries.data(), + ps.num_queries); + + raft::update_host( + distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); + raft::update_host( + indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ivfflat, + distances_naive, + distances_ivfflat, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + { + cuvs::neighbors::ivf_flat::index_params index_params; + cuvs::neighbors::ivf_flat::search_params search_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + search_params.n_probes = ps.nprobe; + + index_params.add_data_on_build = false; + index_params.kmeans_trainset_fraction = 0.5; + index_params.metric_arg = 0; + + cuvs::neighbors::ivf_flat::index idx(handle_, index_params, ps.dim); + cuvs::neighbors::ivf_flat::index index_2(handle_, index_params, ps.dim); + + // if (!ps.host_dataset) { + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + idx = cuvs::neighbors::ivf_flat::build(handle_, index_params, database_view); + rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); + thrust::sequence(raft::resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(vector_indices.data()), + thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs)); + raft::resource::sync_stream(handle_); + + IdxT half_of_data = ps.num_db_vecs / 2; + + auto half_of_data_view = raft::make_device_matrix_view( + (const DataT*)database.data(), half_of_data, ps.dim); + + const std::optional> no_opt = std::nullopt; + index_2 = cuvs::neighbors::ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); + + auto new_half_of_data_view = raft::make_device_matrix_view( + database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); + + auto new_half_of_data_indices_view = raft::make_device_vector_view( + vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); + + cuvs::neighbors::ivf_flat::extend( + handle_, + new_half_of_data_view, + std::make_optional>( + new_half_of_data_indices_view), + &index_2); + + /* + } else { + auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); + raft::copy( + host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); + idx = ivf_flat::build(handle_, index_params, + raft::make_const_mdspan(host_database.view())); + + auto vector_indices = raft::make_host_vector(handle_, ps.num_db_vecs); + std::iota(vector_indices.data_handle(), vector_indices.data_handle() + ps.num_db_vecs, 0); + + IdxT half_of_data = ps.num_db_vecs / 2; + + auto half_of_data_view = raft::make_host_matrix_view( + (const DataT*)host_database.data_handle(), half_of_data, ps.dim); + + const std::optional> no_opt = std::nullopt; + index_2 = ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); + + auto new_half_of_data_view = raft::make_host_matrix_view( + host_database.data_handle() + half_of_data * ps.dim, + IdxT(ps.num_db_vecs) - half_of_data, + ps.dim); + auto new_half_of_data_indices_view = raft::make_host_vector_view( + vector_indices.data_handle() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); + ivf_flat::extend(handle_, + new_half_of_data_view, + std::make_optional>( + new_half_of_data_indices_view), + &index_2); + } + */ + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = raft::make_device_matrix_view( + indices_ivfflat_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_ivfflat_dev.data(), ps.num_queries, ps.k); + const std::string filename = "ivf_flat_index"; + cuvs::neighbors::ivf_flat::serialize_file(handle_, filename, index_2); + cuvs::neighbors::ivf_flat::index index_loaded(handle_, index_params, ps.dim); + cuvs::neighbors::ivf_flat::deserialize_file(handle_, filename, &index_loaded); + ASSERT_EQ(index_2.size(), index_loaded.size()); + + cuvs::neighbors::ivf_flat::search(handle_, + search_params, + index_loaded, + search_queries_view, + indices_out_view, + dists_out_view); + + raft::update_host( + distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); + raft::update_host( + indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + + // Test the centroid invariants + if (index_2.adaptive_centers()) { + // The centers must be up-to-date with the corresponding data + std::vector list_sizes(index_2.n_lists()); + std::vector list_indices(index_2.n_lists()); + rmm::device_uvector centroid(ps.dim, stream_); + raft::copy( + list_sizes.data(), index_2.list_sizes().data_handle(), index_2.n_lists(), stream_); + raft::copy( + list_indices.data(), index_2.inds_ptrs().data_handle(), index_2.n_lists(), stream_); + raft::resource::sync_stream(handle_); + for (uint32_t l = 0; l < index_2.n_lists(); l++) { + if (list_sizes[l] == 0) continue; + rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); + raft::spatial::knn::detail::utils::copy_selected((IdxT)list_sizes[l], + (IdxT)ps.dim, + database.data(), + list_indices[l], + (IdxT)ps.dim, + cluster_data.data(), + (IdxT)ps.dim, + stream_); + raft::stats::mean( + centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, true, stream_); + ASSERT_TRUE(cuvs::devArrMatch(index_2.centers().data_handle() + ps.dim * l, + centroid.data(), + ps.dim, + cuvs::CompareApprox(0.001), + stream_)); + } + } else { + // The centers must be immutable + ASSERT_TRUE(cuvs::devArrMatch(index_2.centers().data_handle(), + idx.centers().data_handle(), + index_2.centers().size(), + cuvs::Compare(), + stream_)); + } + } + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ivfflat, + distances_naive, + distances_ivfflat, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + } + } + + /* + void testPacker() + { + ivf_flat::index_params index_params; + ivf_flat::search_params search_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = false; + search_params.n_probes = ps.nprobe; + + index_params.add_data_on_build = false; + index_params.kmeans_trainset_fraction = 1.0; + index_params.metric_arg = 0; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + + auto idx = ivf_flat::build(handle_, index_params, database_view); + + const std::optional> no_opt = std::nullopt; + index extend_index = ivf_flat::extend(handle_, database_view, no_opt, idx); + + auto list_sizes = raft::make_host_vector(idx.n_lists()); + raft::update_host(list_sizes.data_handle(), + extend_index.list_sizes().data_handle(), + extend_index.n_lists(), + stream_); + raft::resource::sync_stream(handle_); + + auto& lists = idx.lists(); + + // conservative memory allocation for codepacking + auto list_device_spec = list_spec{idx.dim(), false}; + + for (uint32_t label = 0; label < idx.n_lists(); label++) { + uint32_t list_size = list_sizes.data_handle()[label]; + + ivf::resize_list(handle_, lists[label], list_device_spec, list_size, 0); + } + + idx.recompute_internal_state(handle_); + + using interleaved_group = Pow2; + + for (uint32_t label = 0; label < idx.n_lists(); label++) { + uint32_t list_size = list_sizes.data_handle()[label]; + + if (list_size > 0) { + uint32_t padded_list_size = interleaved_group::roundUp(list_size); + uint32_t n_elems = padded_list_size * idx.dim(); + auto list_data = lists[label]->data; + auto list_inds = extend_index.lists()[label]->indices; + + // fetch the flat codes + auto flat_codes = make_device_matrix(handle_, list_size, idx.dim()); + + matrix::gather( + handle_, + make_device_matrix_view( + (const DataT*)database.data(), static_cast(ps.num_db_vecs), idx.dim()), + make_device_vector_view((const IdxT*)list_inds.data_handle(), + list_size), + flat_codes.view()); + + helpers::codepacker::pack( + handle_, make_const_mdspan(flat_codes.view()), idx.veclen(), 0, list_data.view()); + + { + auto mask = make_device_vector(handle_, n_elems); + + linalg::map_offset(handle_, + mask.view(), + [dim = idx.dim(), + list_size, + padded_list_size, + chunk_size = util::FastIntDiv(idx.veclen())] __device__(auto i) { + uint32_t max_group_offset = + interleaved_group::roundDown(list_size); if (i < max_group_offset * dim) { return true; } + uint32_t surplus = (i - max_group_offset * dim); + uint32_t ingroup_id = interleaved_group::mod(surplus / chunk_size); + return ingroup_id < (list_size - max_group_offset); + }); + + // ensure that the correct number of indices are masked out + ASSERT_TRUE(thrust::reduce(raft::resource::get_thrust_policy(handle_), + mask.data_handle(), + mask.data_handle() + n_elems, + 0) == list_size * ps.dim); + + auto packed_list_data = make_device_vector(handle_, n_elems); + + linalg::map_offset(handle_, + packed_list_data.view(), + [mask = mask.data_handle(), + list_data = list_data.data_handle()] __device__(uint32_t i) { + if (mask[i]) return list_data[i]; + return DataT{0}; + }); + + auto extend_data = extend_index.lists()[label]->data; + auto extend_data_filtered = make_device_vector(handle_, n_elems); + linalg::map_offset(handle_, + extend_data_filtered.view(), + [mask = mask.data_handle(), + extend_data = extend_data.data_handle()] __device__(uint32_t i) { + if (mask[i]) return extend_data[i]; + return DataT{0}; + }); + + ASSERT_TRUE(cuvs::devArrMatch(packed_list_data.data_handle(), + extend_data_filtered.data_handle(), + n_elems, + cuvs::Compare(), + stream_)); + } + + auto unpacked_flat_codes = + make_device_matrix(handle_, list_size, idx.dim()); + + helpers::codepacker::unpack( + handle_, list_data.view(), idx.veclen(), 0, unpacked_flat_codes.view()); + + ASSERT_TRUE(cuvs::devArrMatch(flat_codes.data_handle(), + unpacked_flat_codes.data_handle(), + list_size * ps.dim, + cuvs::Compare(), + stream_)); + } + } + } + */ + + void testFilter() + { + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivfflat(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_ivfflat(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto* database_filtered_ptr = database.data() + test_ivf_sample_filter::offset * ps.dim; + cuvs::neighbors::naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database_filtered_ptr, + ps.num_queries, + ps.num_db_vecs - test_ivf_sample_filter::offset, + ps.dim, + ps.k, + ps.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_ivf_sample_filter::offset), + queries_size, + stream_); + raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + + /* + { + // unless something is really wrong with clustering, this could serve as a lower bound on + // recall + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + + auto distances_ivfflat_dev = raft::make_device_matrix(handle_, ps.num_queries, ps.k); + auto indices_ivfflat_dev = + raft::make_device_matrix(handle_, ps.num_queries, ps.k); + + { + ivf_flat::index_params index_params; + ivf_flat::search_params search_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + search_params.n_probes = ps.nprobe; + + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 0.5; + index_params.metric_arg = 0; + + // Create IVF Flat index + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + auto index = ivf_flat::build(handle_, index_params, database_view); + + // Create Bitset filter + auto removed_indices = + raft::make_device_vector(handle_, test_ivf_sample_filter::offset); + thrust::sequence(raft::resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + + test_ivf_sample_filter::offset)); + raft::resource::sync_stream(handle_); + + raft::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.num_db_vecs); + + // Search with the filter + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + ivf_flat::search_with_filtering( + handle_, + search_params, + index, + search_queries_view, + indices_ivfflat_dev.view(), + distances_ivfflat_dev.view(), + raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); + + raft::update_host( + distances_ivfflat.data(), distances_ivfflat_dev.data_handle(), queries_size, stream_); + raft::update_host( + indices_ivfflat.data(), indices_ivfflat_dev.data_handle(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ivfflat, + distances_naive, + distances_ivfflat, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + } + */ + } + + void SetUp() override + { + database.resize(ps.num_db_vecs * ps.dim, stream_); + search_queries.resize(ps.num_queries * ps.dim, stream_); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + } + raft::resource::sync_stream(handle_); + } + + void TearDown() override + { + raft::resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnIvfFlatInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + +const std::vector> inputs = { + // test various dims (aligned and not aligned to vector sizes) + {1000, 10000, 1, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, true}, + + // test dims that do not fit into kernel shared memory limits + {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 2051, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2052, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 2053, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2056, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + + // various random combinations + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + + /* + // host input data + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + */ + + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::InnerProduct, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + + {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + + // test splitting the big query batches (> max gridDim.y) into smaller batches + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, false}, + {1000000, 1024, 32, 10, 256, 256, cuvs::distance::DistanceType::InnerProduct, false}, + {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, true}, + + // test radix_sort for getting the cluster selection + {1000, + 10000, + 16, + 10, + raft::matrix::detail::select::warpsort::kMaxCapacity * 2, + raft::matrix::detail::select::warpsort::kMaxCapacity * 4, + cuvs::distance::DistanceType::L2Expanded, + false}, + {1000, + 10000, + 16, + 10, + raft::matrix::detail::select::warpsort::kMaxCapacity * 4, + raft::matrix::detail::select::warpsort::kMaxCapacity * 4, + cuvs::distance::DistanceType::InnerProduct, + false}, + + // The following two test cases should show very similar recall. + // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric, adaptive_centers + {20000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}, + {100000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}}; + +} // namespace cuvs::neighbors::ivf_flat \ No newline at end of file diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu new file mode 100644 index 000000000..0ce168f5e --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_ivf_flat.cuh" + +namespace cuvs::neighbors::ivf_flat { + +typedef AnnIVFFlatTest AnnIVFFlatTestF_float; +TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat) { this->testIVFFlat(); } + +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_float, ::testing::ValuesIn(inputs)); + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu new file mode 100644 index 000000000..15935fd88 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_ivf_flat.cuh" + +namespace cuvs::neighbors::ivf_flat { + +typedef AnnIVFFlatTest AnnIVFFlatTestF_int8; +TEST_P(AnnIVFFlatTestF_int8, AnnIVFFlat) { this->testIVFFlat(); } + +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_int8, ::testing::ValuesIn(inputs)); + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu new file mode 100644 index 000000000..42a8dab2e --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_ivf_flat.cuh" + +namespace cuvs::neighbors::ivf_flat { + +typedef AnnIVFFlatTest AnnIVFFlatTestF_uint8; +TEST_P(AnnIVFFlatTestF_uint8, AnnIVFFlat) { this->testIVFFlat(); } + +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_uint8, ::testing::ValuesIn(inputs)); + +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat_c.cu b/cpp/test/neighbors/ann_ivf_flat_c.cu new file mode 100644 index 000000000..e85450494 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_flat_c.cu @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "ann_utils.cuh" +#include + +extern "C" void run_ivf_flat(int64_t n_rows, + int64_t n_queries, + int64_t n_dim, + uint32_t n_neighbors, + float* index_data, + float* query_data, + float* distances_data, + int64_t* neighbors_data, + enum DistanceType metric, + size_t n_probes, + size_t n_lists); + +template +void generate_random_data(T* devPtr, size_t size) +{ + raft::handle_t handle; + raft::random::RngState r(1234ULL); + raft::random::uniform(handle, r, devPtr, size, T(0.1), T(2.0)); +}; + +template +void recall_eval(T* query_data, + T* index_data, + IdxT* neighbors, + T* distances, + size_t n_queries, + size_t n_rows, + size_t n_dim, + size_t n_neighbors, + DistanceType metric, + size_t n_probes, + size_t n_lists) +{ + raft::handle_t handle; + auto distances_ref = raft::make_device_matrix(handle, n_queries, n_neighbors); + auto neighbors_ref = raft::make_device_matrix(handle, n_queries, n_neighbors); + cuvs::neighbors::naive_knn( + handle, + distances_ref.data_handle(), + neighbors_ref.data_handle(), + query_data, + index_data, + n_queries, + n_rows, + n_dim, + n_neighbors, + static_cast((uint16_t)metric)); + + size_t size = n_queries * n_neighbors; + std::vector neighbors_h(size); + std::vector distances_h(size); + std::vector neighbors_ref_h(size); + std::vector distances_ref_h(size); + + auto stream = raft::resource::get_cuda_stream(handle); + raft::copy(neighbors_h.data(), neighbors, size, stream); + raft::copy(distances_h.data(), distances, size, stream); + raft::copy(neighbors_ref_h.data(), neighbors_ref.data_handle(), size, stream); + raft::copy(distances_ref_h.data(), distances_ref.data_handle(), size, stream); + + // verify output + double min_recall = static_cast(n_probes) / static_cast(n_lists); + ASSERT_TRUE(cuvs::neighbors::eval_neighbours(neighbors_ref_h, + neighbors_h, + distances_ref_h, + distances_h, + n_queries, + n_neighbors, + 0.001, + min_recall)); +}; + +TEST(IvfFlatC, BuildSearch) +{ + int64_t n_rows = 8096; + int64_t n_queries = 128; + int64_t n_dim = 32; + uint32_t n_neighbors = 8; + + enum DistanceType metric = L2Expanded; + size_t n_probes = 20; + size_t n_lists = 1024; + + float *index_data, *query_data, *distances_data; + int64_t* neighbors_data; + cudaMalloc(&index_data, sizeof(float) * n_rows * n_dim); + cudaMalloc(&query_data, sizeof(float) * n_queries * n_dim); + cudaMalloc(&neighbors_data, sizeof(int64_t) * n_queries * n_neighbors); + cudaMalloc(&distances_data, sizeof(float) * n_queries * n_neighbors); + + generate_random_data(index_data, n_rows * n_dim); + generate_random_data(query_data, n_queries * n_dim); + + run_ivf_flat(n_rows, + n_queries, + n_dim, + n_neighbors, + index_data, + query_data, + distances_data, + neighbors_data, + metric, + n_probes, + n_lists); + + recall_eval(query_data, + index_data, + neighbors_data, + distances_data, + n_queries, + n_rows, + n_dim, + n_neighbors, + metric, + n_probes, + n_lists); + + // delete device memory + cudaFree(index_data); + cudaFree(query_data); + cudaFree(neighbors_data); + cudaFree(distances_data); +} diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh new file mode 100644 index 000000000..5e276cbf1 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -0,0 +1,981 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" +#include "ivf_pq_helpers.cuh" +#include "naive_knn.cuh" +#include + +#include +#include +#include + +namespace cuvs::neighbors::ivf_pq { + +using namespace raft; + +struct test_ivf_sample_filter { + static constexpr unsigned offset = 300; +}; + +struct ivf_pq_inputs { + uint32_t num_db_vecs = 4096; + uint32_t num_queries = 1024; + uint32_t dim = 64; + uint32_t k = 32; + std::optional min_recall = std::nullopt; + + cuvs::neighbors::ivf_pq::index_params index_params; + cuvs::neighbors::ivf_pq::search_params search_params; + + // Set some default parameters for tests + ivf_pq_inputs() + { + index_params.n_lists = max(32u, min(1024u, num_db_vecs / 128u)); + index_params.kmeans_trainset_fraction = 1.0; + } +}; + +inline auto operator<<(std::ostream& os, const ivf_pq::codebook_gen& p) -> std::ostream& +{ + switch (p) { + case ivf_pq::codebook_gen::PER_CLUSTER: os << "codebook_gen::PER_CLUSTER"; break; + case ivf_pq::codebook_gen::PER_SUBSPACE: os << "codebook_gen::PER_SUBSPACE"; break; + default: RAFT_FAIL("unreachable code"); + } + return os; +} + +inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream& +{ + ivf_pq_inputs dflt; + bool need_comma = false; +#define PRINT_DIFF_V(spec, val) \ + do { \ + if (dflt spec != p spec) { \ + if (need_comma) { os << ", "; } \ + os << #spec << " = " << val; \ + need_comma = true; \ + } \ + } while (0) +#define PRINT_DIFF(spec) PRINT_DIFF_V(spec, p spec) + + os << "ivf_pq_inputs {"; + PRINT_DIFF(.num_db_vecs); + PRINT_DIFF(.num_queries); + PRINT_DIFF(.dim); + PRINT_DIFF(.k); + PRINT_DIFF_V(.min_recall, p.min_recall.value_or(0)); + PRINT_DIFF_V(.index_params.metric, + cuvs::neighbors::print_metric{ + static_cast((int)p.index_params.metric)}); + PRINT_DIFF(.index_params.metric_arg); + PRINT_DIFF(.index_params.add_data_on_build); + PRINT_DIFF(.index_params.n_lists); + PRINT_DIFF(.index_params.kmeans_n_iters); + PRINT_DIFF(.index_params.kmeans_trainset_fraction); + PRINT_DIFF(.index_params.pq_bits); + PRINT_DIFF(.index_params.pq_dim); + PRINT_DIFF(.index_params.codebook_kind); + PRINT_DIFF(.index_params.force_random_rotation); + PRINT_DIFF(.search_params.n_probes); + PRINT_DIFF_V(.search_params.lut_dtype, cuvs::neighbors::print_dtype{p.search_params.lut_dtype}); + PRINT_DIFF_V(.search_params.internal_distance_dtype, + cuvs::neighbors::print_dtype{p.search_params.internal_distance_dtype}); + os << "}"; + return os; +} + +template +void compare_vectors_l2( + const raft::resources& res, T a, T b, uint32_t label, double compression_ratio, double eps) +{ + auto n_rows = a.extent(0); + auto dim = a.extent(1); + rmm::mr::managed_memory_resource managed_memory; + auto dist = make_device_mdarray(res, &managed_memory, make_extents(n_rows)); + linalg::map_offset(res, dist.view(), [a, b, dim] __device__(uint32_t i) { + spatial::knn::detail::utils::mapping f{}; + double d = 0.0f; + for (uint32_t j = 0; j < dim; j++) { + double t = f(a(i, j)) - f(b(i, j)); + d += t * t; + } + return sqrt(d / double(dim)); + }); + resource::sync_stream(res); + for (uint32_t i = 0; i < n_rows; i++) { + double d = dist(i); + // The theoretical estimate of the error is hard to come up with, + // the estimate below is based on experimentation + curse of dimensionality + ASSERT_LE(d, 1.2 * eps * std::pow(2.0, compression_ratio)) + << " (label = " << label << ", ix = " << i << ", eps = " << eps << ")"; + } +} + +template +auto min_output_size(const raft::resources& handle, + const ivf_pq::index& index, + uint32_t n_probes) -> IdxT +{ + auto acc_sizes = index.accum_sorted_sizes(); + uint32_t last_nonzero = index.n_lists(); + while (last_nonzero > 0 && acc_sizes(last_nonzero - 1) == acc_sizes(last_nonzero)) { + last_nonzero--; + } + return acc_sizes(last_nonzero) - acc_sizes(last_nonzero - std::min(last_nonzero, n_probes)); +} + +template +class ivf_pq_test : public ::testing::TestWithParam { + public: + ivf_pq_test() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + void gen_data() + { + database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_); + search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + } + resource::sync_stream(handle_); + } + + void calc_ref() + { + size_t queries_size = size_t{ps.num_queries} * size_t{ps.k}; + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + cuvs::neighbors::naive_knn( + handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.num_queries, + ps.num_db_vecs, + ps.dim, + ps.k, + static_cast((int)ps.index_params.metric)); + distances_ref.resize(queries_size); + update_host(distances_ref.data(), distances_naive_dev.data(), queries_size, stream_); + indices_ref.resize(queries_size); + update_host(indices_ref.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + auto build_only() + { + auto ipams = ps.index_params; + ipams.add_data_on_build = true; + + auto index_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + return cuvs::neighbors::ivf_pq::build(handle_, ipams, index_view); + } + + auto build_2_extends() + { + auto db_indices = make_device_vector(handle_, ps.num_db_vecs); + linalg::map_offset(handle_, db_indices.view(), identity_op{}); + resource::sync_stream(handle_); + auto size_1 = IdxT(ps.num_db_vecs) / 2; + auto size_2 = IdxT(ps.num_db_vecs) - size_1; + auto vecs_1 = database.data(); + auto vecs_2 = database.data() + size_t(size_1) * size_t(ps.dim); + auto inds_1 = db_indices.data_handle(); + auto inds_2 = db_indices.data_handle() + size_t(size_1); + + auto ipams = ps.index_params; + ipams.add_data_on_build = false; + + auto database_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + auto idx = cuvs::neighbors::ivf_pq::build(handle_, ipams, database_view); + + auto vecs_2_view = raft::make_device_matrix_view(vecs_2, size_2, ps.dim); + auto inds_2_view = raft::make_device_vector_view(inds_2, size_2); + cuvs::neighbors::ivf_pq::extend(handle_, vecs_2_view, inds_2_view, &idx); + + auto vecs_1_view = + raft::make_device_matrix_view(vecs_1, size_1, ps.dim); + auto inds_1_view = raft::make_device_vector_view(inds_1, size_1); + cuvs::neighbors::ivf_pq::extend(handle_, vecs_1_view, inds_1_view, &idx); + return idx; + } + + auto build_serialize() + { + std::string filename = "ivf_pq_index"; + cuvs::neighbors::ivf_pq::serialize(handle_, filename, build_only()); + cuvs::neighbors::ivf_pq::index index(handle_, ps.index_params, ps.dim); + cuvs::neighbors::ivf_pq::deserialize(handle_, filename, &index); + return index; + } + + void check_reconstruction(const index& index, + double compression_ratio, + uint32_t label, + uint32_t n_take, + uint32_t n_skip) + { + auto& rec_list = index.lists()[label]; + auto dim = index.dim(); + n_take = std::min(n_take, rec_list->size.load()); + n_skip = std::min(n_skip, rec_list->size.load() - n_take); + + if (n_take == 0) { return; } + + auto rec_data = make_device_matrix(handle_, n_take, dim); + auto orig_data = make_device_matrix(handle_, n_take, dim); + + cuvs::neighbors::ivf_pq::helpers::reconstruct_list_data( + handle_, index, rec_data.view(), label, n_skip); + + matrix::gather(database.data(), + IdxT{dim}, + IdxT{n_take}, + rec_list->indices.data_handle() + n_skip, + IdxT{n_take}, + orig_data.data_handle(), + stream_); + + compare_vectors_l2(handle_, rec_data.view(), orig_data.view(), label, compression_ratio, 0.06); + } + + void check_reconstruct_extend(index* index, double compression_ratio, uint32_t label) + { + // NB: this is not reference, the list is retained; the index will have to create a new list on + // `erase_list` op. + auto old_list = index->lists()[label]; + auto n_rows = old_list->size.load(); + if (n_rows == 0) { return; } + + auto vectors_1 = make_device_matrix(handle_, n_rows, index->dim()); + auto indices = make_device_vector(handle_, n_rows); + copy(indices.data_handle(), old_list->indices.data_handle(), n_rows, stream_); + + cuvs::neighbors::ivf_pq::helpers::reconstruct_list_data( + handle_, *index, vectors_1.view(), label, 0); + cuvs::neighbors::ivf_pq::helpers::erase_list(handle_, index, label); + // NB: passing the type parameter because const->non-const implicit conversion of the mdspans + // breaks type inference + cuvs::neighbors::ivf_pq::helpers::extend_list( + handle_, index, vectors_1.view(), indices.view(), label); + + auto& new_list = index->lists()[label]; + ASSERT_NE(old_list.get(), new_list.get()) + << "The old list should have been shared and retained after ivf_pq index has erased the " + "corresponding cluster."; + + auto vectors_2 = make_device_matrix(handle_, n_rows, index->dim()); + cuvs::neighbors::ivf_pq::helpers::reconstruct_list_data( + handle_, *index, vectors_2.view(), label, 0); + // The code search is unstable, and there's high chance of repeating values of the lvl-2 codes. + // Hence, encoding-decoding chain often leads to altering both the PQ codes and the + // reconstructed data. + compare_vectors_l2( + handle_, vectors_1.view(), vectors_2.view(), label, compression_ratio, 0.04); // 0.025); + } + + void check_packing(index* index, uint32_t label) + { + auto old_list = index->lists()[label]; + auto n_rows = old_list->size.load(); + + if (n_rows == 0) { return; } + + auto codes = make_device_matrix(handle_, n_rows, index->pq_dim()); + auto indices = make_device_vector(handle_, n_rows); + copy(indices.data_handle(), old_list->indices.data_handle(), n_rows, stream_); + + cuvs::neighbors::ivf_pq::helpers::unpack_list_data(handle_, *index, codes.view(), label, 0); + cuvs::neighbors::ivf_pq::helpers::erase_list(handle_, index, label); + cuvs::neighbors::ivf_pq::helpers::extend_list_with_codes( + handle_, index, codes.view(), indices.view(), label); + + auto& new_list = index->lists()[label]; + ASSERT_NE(old_list.get(), new_list.get()) + << "The old list should have been shared and retained after ivf_pq index has erased the " + "corresponding cluster."; + auto list_data_size = (n_rows / raft::neighbors::ivf_pq::kIndexGroupSize) * + new_list->data.extent(1) * new_list->data.extent(2) * + new_list->data.extent(3); + + ASSERT_TRUE(old_list->data.size() >= list_data_size); + ASSERT_TRUE(new_list->data.size() >= list_data_size); + ASSERT_TRUE(cuvs::devArrMatch(old_list->data.data_handle(), + new_list->data.data_handle(), + list_data_size, + cuvs::Compare{})); + + // Pack a few vectors back to the list. + int row_offset = 9; + int n_vec = 3; + ASSERT_TRUE(row_offset + n_vec < n_rows); + size_t offset = row_offset * index->pq_dim(); + auto codes_to_pack = make_device_matrix_view( + codes.data_handle() + offset, n_vec, index->pq_dim()); + ivf_pq::helpers::pack_list_data(handle_, index, codes_to_pack, label, row_offset); + ASSERT_TRUE(cuvs::devArrMatch(old_list->data.data_handle(), + new_list->data.data_handle(), + list_data_size, + cuvs::Compare{})); + + // Another test with the API that take list_data directly + auto list_data = index->lists()[label]->data.view(); + uint32_t n_take = 4; + ASSERT_TRUE(row_offset + n_take < n_rows); + auto codes2 = raft::make_device_matrix(handle_, n_take, index->pq_dim()); + ivf_pq::helpers::codepacker::unpack( + handle_, list_data, index->pq_bits(), row_offset, codes2.view()); + + // Write it back + ivf_pq::helpers::codepacker::pack( + handle_, make_const_mdspan(codes2.view()), index->pq_bits(), row_offset, list_data); + ASSERT_TRUE(cuvs::devArrMatch(old_list->data.data_handle(), + new_list->data.data_handle(), + list_data_size, + cuvs::Compare{})); + } + + template + void run(BuildIndex build_index) + { + index index = build_index(); + + double compression_ratio = + static_cast(ps.dim * 8) / static_cast(index.pq_dim() * index.pq_bits()); + + for (uint32_t label = 0; label < index.n_lists(); label++) { + switch (label % 3) { + case 0: { + // Reconstruct and re-write vectors for one label + check_reconstruct_extend(&index, compression_ratio, label); + } break; + case 1: { + // Dump and re-write codes for one label + check_packing(&index, label); + } break; + default: { + // check a small subset of data in a randomly chosen cluster to see if the data + // reconstruction works well. + check_reconstruction(index, compression_ratio, label, 100, 7); + } + } + } + + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivf_pq(queries_size); + std::vector distances_ivf_pq(queries_size); + + rmm::device_uvector distances_ivf_pq_dev(queries_size, stream_); + rmm::device_uvector indices_ivf_pq_dev(queries_size, stream_); + + auto query_view = + raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); + auto inds_view = raft::make_device_matrix_view( + indices_ivf_pq_dev.data(), ps.num_queries, ps.k); + auto dists_view = raft::make_device_matrix_view( + distances_ivf_pq_dev.data(), ps.num_queries, ps.k); + + cuvs::neighbors::ivf_pq::search( + handle_, ps.search_params, index, query_view, inds_view, dists_view); + + update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); + update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + + // A very conservative lower bound on recall + double min_recall = + static_cast(ps.search_params.n_probes) / static_cast(ps.index_params.n_lists); + // Using a heuristic to lower the required recall due to code-packing errors + min_recall = + std::min(std::erfc(0.05 * compression_ratio / std::max(min_recall, 0.5)), min_recall); + // Use explicit per-test min recall value if provided. + min_recall = ps.min_recall.value_or(min_recall); + + ASSERT_TRUE(cuvs::neighbors::eval_neighbours(indices_ref, + indices_ivf_pq, + distances_ref, + distances_ivf_pq, + ps.num_queries, + ps.k, + 0.0001 * compression_ratio, + min_recall)) + << ps; + + // Test a few extra invariants + IdxT min_results = min_output_size(handle_, index, ps.search_params.n_probes); + IdxT max_oob = ps.k <= min_results ? 0 : ps.k - min_results; + IdxT found_oob = 0; + for (uint32_t query_ix = 0; query_ix < ps.num_queries; query_ix++) { + for (uint32_t k = 0; k < ps.k; k++) { + auto flat_i = query_ix * ps.k + k; + auto found_ix = indices_ivf_pq[flat_i]; + if (found_ix == raft::neighbors::ivf_pq::kOutOfBoundsRecord) { + found_oob++; + continue; + } + ASSERT_NE(found_ix, raft::neighbors::ivf::kInvalidRecord) + << "got an invalid record at query_ix = " << query_ix << ", k = " << k + << " (distance = " << distances_ivf_pq[flat_i] << ")"; + ASSERT_LT(found_ix, ps.num_db_vecs) + << "got an impossible index = " << found_ix << " at query_ix = " << query_ix + << ", k = " << k << " (distance = " << distances_ivf_pq[flat_i] << ")"; + } + } + ASSERT_LE(found_oob, max_oob) + << "got too many records out-of-bounds (see ivf_pq::kOutOfBoundsRecord)."; + if (found_oob > 0) { + RAFT_LOG_WARN( + "Got %zu results out-of-bounds because of large top-k (%zu) and small n_probes (%u) and " + "small DB size/n_lists ratio (%zu / %u)", + size_t(found_oob), + size_t(ps.k), + ps.search_params.n_probes, + size_t(ps.num_db_vecs), + ps.index_params.n_lists); + } + } + + void SetUp() override // NOLINT + { + gen_data(); + calc_ref(); + } + + void TearDown() override // NOLINT + { + cudaGetLastError(); + resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + ivf_pq_inputs ps; // NOLINT + rmm::device_uvector database; // NOLINT + rmm::device_uvector search_queries; // NOLINT + std::vector indices_ref; // NOLINT + std::vector distances_ref; // NOLINT +}; + +template +class ivf_pq_filter_test : public ::testing::TestWithParam { + public: + ivf_pq_filter_test() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + void gen_data() + { + database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_); + search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + } + resource::sync_stream(handle_); + } + + void calc_ref() + { + size_t queries_size = size_t{ps.num_queries} * size_t{ps.k}; + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + cuvs::neighbors::naive_knn( + handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data() + test_ivf_sample_filter::offset * ps.dim, + ps.num_queries, + ps.num_db_vecs - test_ivf_sample_filter::offset, + ps.dim, + ps.k, + static_cast((int)ps.index_params.metric)); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_ivf_sample_filter::offset), + queries_size, + stream_); + distances_ref.resize(queries_size); + update_host(distances_ref.data(), distances_naive_dev.data(), queries_size, stream_); + indices_ref.resize(queries_size); + update_host(indices_ref.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + auto build_only() + { + auto ipams = ps.index_params; + ipams.add_data_on_build = true; + + auto index_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + return cuvs::neighbors::ivf_pq::build(handle_, ipams, index_view); + } + + template + void run(BuildIndex build_index) + { + index index = build_index(); + + double compression_ratio = + static_cast(ps.dim * 8) / static_cast(index.pq_dim() * index.pq_bits()); + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivf_pq(queries_size); + std::vector distances_ivf_pq(queries_size); + + rmm::device_uvector distances_ivf_pq_dev(queries_size, stream_); + rmm::device_uvector indices_ivf_pq_dev(queries_size, stream_); + + auto query_view = + raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); + auto inds_view = raft::make_device_matrix_view( + indices_ivf_pq_dev.data(), ps.num_queries, ps.k); + auto dists_view = raft::make_device_matrix_view( + distances_ivf_pq_dev.data(), ps.num_queries, ps.k); + + // Create Bitset filter + auto removed_indices = + raft::make_device_vector(handle_, test_ivf_sample_filter::offset); + thrust::sequence( + resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + test_ivf_sample_filter::offset)); + resource::sync_stream(handle_); + + raft::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.num_db_vecs); + raft::neighbors::ivf_pq::search_with_filtering( + handle_, + ps.search_params, + index, + query_view, + inds_view, + dists_view, + raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())); + + update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); + update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + + // A very conservative lower bound on recall + double min_recall = + static_cast(ps.search_params.n_probes) / static_cast(ps.index_params.n_lists); + // Using a heuristic to lower the required recall due to code-packing errors + min_recall = + std::min(std::erfc(0.05 * compression_ratio / std::max(min_recall, 0.5)), min_recall); + // Use explicit per-test min recall value if provided. + min_recall = ps.min_recall.value_or(min_recall); + + ASSERT_TRUE(cuvs::neighbors::eval_neighbours(indices_ref, + indices_ivf_pq, + distances_ref, + distances_ivf_pq, + ps.num_queries, + ps.k, + 0.0001 * compression_ratio, + min_recall)) + << ps; + } + + void SetUp() override // NOLINT + { + gen_data(); + calc_ref(); + } + + void TearDown() override // NOLINT + { + cudaGetLastError(); + resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + ivf_pq_inputs ps; // NOLINT + rmm::device_uvector database; // NOLINT + rmm::device_uvector search_queries; // NOLINT + std::vector indices_ref; // NOLINT + std::vector distances_ref; // NOLINT +}; + +/* Test cases */ +using test_cases_t = std::vector; + +// concatenate parameter sets for different type +template +auto operator+(const std::vector& a, const std::vector& b) -> std::vector +{ + std::vector res = a; + res.insert(res.end(), b.begin(), b.end()); + return res; +} + +inline auto defaults() -> test_cases_t { return {ivf_pq_inputs{}}; } + +template +auto map(const std::vector& xs, F f) -> std::vector +{ + std::vector ys(xs.size()); + std::transform(xs.begin(), xs.end(), ys.begin(), f); + return ys; +} + +inline auto with_dims(const std::vector& dims) -> test_cases_t +{ + return map(dims, [](uint32_t d) { + ivf_pq_inputs x; + x.dim = d; + return x; + }); +} + +/** These will surely trigger the fastest kernel available. */ +inline auto small_dims() -> test_cases_t { return with_dims({1, 2, 3, 4, 5, 8, 15, 16, 17}); } + +inline auto small_dims_per_cluster() -> test_cases_t +{ + return map(small_dims(), [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + y.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; + return y; + }); +} + +inline auto big_dims() -> test_cases_t +{ + // with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144, 8192, 12288, 16384}); + auto xs = with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144}); + return map(xs, [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + uint32_t pq_len = 2; + y.index_params.pq_dim = div_rounding_up_safe(x.dim, pq_len); + // This comes from pure experimentation, also the recall depens a lot on pq_len. + y.min_recall = 0.48 + 0.028 * std::log2(x.dim); + return y; + }); +} + +/** These will surely trigger no-smem-lut kernel. */ +inline auto big_dims_moderate_lut() -> test_cases_t +{ + return map(big_dims(), [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + uint32_t pq_len = 2; + y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u); + y.index_params.pq_bits = 6; + y.search_params.lut_dtype = CUDA_R_16F; + y.min_recall = 0.69; + return y; + }); +} + +/** Some of these should trigger no-basediff kernel. */ +inline auto big_dims_small_lut() -> test_cases_t +{ + return map(big_dims(), [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + uint32_t pq_len = 8; + y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u); + y.index_params.pq_bits = 6; + y.search_params.lut_dtype = CUDA_R_8U; + y.min_recall = 0.21; + return y; + }); +} + +/** + * A minimal set of tests to check various enum-like parameters. + */ +inline auto enum_variety() -> test_cases_t +{ + test_cases_t xs; +#define ADD_CASE(f) \ + do { \ + xs.push_back({}); \ + ([](ivf_pq_inputs & x) f)(xs[xs.size() - 1]); \ + } while (0); + + ADD_CASE({ + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; + x.index_params.pq_bits = 4; + x.min_recall = 0.79; + }); + ADD_CASE({ + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; + x.index_params.pq_bits = 5; + x.min_recall = 0.83; + }); + + ADD_CASE({ + x.index_params.pq_bits = 6; + x.min_recall = 0.84; + }); + ADD_CASE({ + x.index_params.pq_bits = 7; + x.min_recall = 0.85; + }); + ADD_CASE({ + x.index_params.pq_bits = 8; + x.min_recall = 0.86; + }); + + ADD_CASE({ + x.index_params.force_random_rotation = true; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.index_params.force_random_rotation = false; + x.min_recall = 0.86; + }); + + ADD_CASE({ + x.search_params.lut_dtype = CUDA_R_32F; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.search_params.lut_dtype = CUDA_R_16F; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.search_params.lut_dtype = CUDA_R_8U; + x.min_recall = 0.84; + }); + + ADD_CASE({ + x.search_params.internal_distance_dtype = CUDA_R_32F; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.search_params.internal_distance_dtype = CUDA_R_16F; + x.search_params.lut_dtype = CUDA_R_16F; + x.min_recall = 0.86; + }); + + return xs; +} + +inline auto enum_variety_l2() -> test_cases_t +{ + return map(enum_variety(), [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + y.index_params.metric = distance::DistanceType::L2Expanded; + return y; + }); +} + +inline auto enum_variety_ip() -> test_cases_t +{ + return map(enum_variety(), [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + if (y.min_recall.has_value()) { + if (y.search_params.lut_dtype == CUDA_R_8U) { + // InnerProduct score is signed, + // thus we're forced to used signed 8-bit representation, + // thus we have one bit less precision + y.min_recall = y.min_recall.value() * 0.90; + } else { + // In other cases it seems to perform a little bit better, still worse than L2 + y.min_recall = y.min_recall.value() * 0.94; + } + } + y.index_params.metric = distance::DistanceType::InnerProduct; + return y; + }); +} + +inline auto enum_variety_l2sqrt() -> test_cases_t +{ + return map(enum_variety(), [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + y.index_params.metric = distance::DistanceType::L2SqrtExpanded; + return y; + }); +} + +/** + * Try different number of n_probes, some of which may trigger the non-fused version of the search + * kernel. + */ +inline auto var_n_probes() -> test_cases_t +{ + ivf_pq_inputs dflt; + std::vector xs; + for (auto x = dflt.index_params.n_lists; x >= 1; x /= 2) { + xs.push_back(x); + } + return map(xs, [](uint32_t n_probes) { + ivf_pq_inputs x; + x.search_params.n_probes = n_probes; + return x; + }); +} + +/** + * Try different number of nearest neighbours. + * Values smaller than 32 test if the code behaves well when Capacity (== 32) does not change, + * but `k <= Capacity` changes. + * + * Values between `32 and ivf_pq::detail::kMaxCapacity` test various instantiations of the + * main kernel (Capacity-templated) + * + * Values above ivf_pq::detail::kMaxCapacity should trigger the non-fused version of the kernel + * (manage_local_topk = false). + * + * Also we test here various values that are close-but-not-power-of-two to catch any problems + * related to rounding/alignment. + * + * Note, we cannot control explicitly which instance of the search kernel to choose, hence it's + * important to try a variety of different values of `k` to make sure all paths are triggered. + * + * Set the log level to DEBUG (5) or above to inspect the selected kernel instances. + */ +inline auto var_k() -> test_cases_t +{ + return map( + {1, 2, 3, 5, 8, 15, 16, 32, 63, 65, 127, 128, 256, 257, 1023, 2048, 2049}, [](uint32_t k) { + ivf_pq_inputs x; + x.k = k; + // when there's not enough data, try more cluster probes + x.search_params.n_probes = max(x.search_params.n_probes, min(x.index_params.n_lists, k)); + return x; + }); +} + +/** + * Cases brought up from downstream projects. + */ +inline auto special_cases() -> test_cases_t +{ + test_cases_t xs; + +#define ADD_CASE(f) \ + do { \ + xs.push_back({}); \ + ([](ivf_pq_inputs & x) f)(xs[xs.size() - 1]); \ + } while (0); + + ADD_CASE({ + x.num_db_vecs = 1183514; + x.dim = 100; + x.num_queries = 10000; + x.k = 10; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.index_params.pq_dim = 10; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 1024; + x.search_params.n_probes = 50; + }); + + ADD_CASE({ + x.num_db_vecs = 10000; + x.dim = 16; + x.num_queries = 500; + x.k = 128; + x.index_params.metric = distance::DistanceType::L2Expanded; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 100; + x.search_params.n_probes = 100; + }); + + ADD_CASE({ + x.num_db_vecs = 10000; + x.dim = 16; + x.num_queries = 500; + x.k = 129; + x.index_params.metric = distance::DistanceType::L2Expanded; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 100; + x.search_params.n_probes = 100; + }); + + return xs; +} + +/* Test instantiations */ + +#define TEST_BUILD_SEARCH(type) \ + TEST_P(type, build_search) /* NOLINT */ \ + { \ + this->run([this]() { return this->build_only(); }); \ + } + +#define TEST_BUILD_EXTEND_SEARCH(type) \ + TEST_P(type, build_extend_search) /* NOLINT */ \ + { \ + this->run([this]() { return this->build_2_extends(); }); \ + } + +#define TEST_BUILD_SERIALIZE_SEARCH(type) \ + TEST_P(type, build_serialize_search) /* NOLINT */ \ + { \ + this->run([this]() { return this->build_serialize(); }); \ + } + +#define INSTANTIATE(type, vals) \ + INSTANTIATE_TEST_SUITE_P(IvfPq, type, ::testing::ValuesIn(vals)); /* NOLINT */ + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu new file mode 100644 index 000000000..5ca4cde68 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../ann_ivf_pq.cuh" + +namespace cuvs::neighbors::ivf_pq { + +using f32_f32_i64 = ivf_pq_test; + +TEST_BUILD_EXTEND_SEARCH(f32_f32_i64) +TEST_BUILD_SERIALIZE_SEARCH(f32_f32_i64) +INSTANTIATE(f32_f32_i64, defaults() + small_dims() + big_dims_moderate_lut()); + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu new file mode 100644 index 000000000..0763b661e --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../ann_ivf_pq.cuh" + +namespace cuvs::neighbors::ivf_pq { + +using f32_i08_i64 = ivf_pq_test; + +TEST_BUILD_SEARCH(f32_i08_i64) +TEST_BUILD_SERIALIZE_SEARCH(f32_i08_i64) +INSTANTIATE(f32_i08_i64, defaults() + big_dims() + var_k()); + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu new file mode 100644 index 000000000..93e7d5336 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../ann_ivf_pq.cuh" + +namespace cuvs::neighbors::ivf_pq { + +using f32_u08_i64 = ivf_pq_test; + +TEST_BUILD_SEARCH(f32_u08_i64) +TEST_BUILD_EXTEND_SEARCH(f32_u08_i64) +INSTANTIATE(f32_u08_i64, small_dims_per_cluster() + enum_variety()); + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq_c.cu b/cpp/test/neighbors/ann_ivf_pq_c.cu new file mode 100644 index 000000000..94d121ce2 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq_c.cu @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "ann_utils.cuh" +#include + +extern "C" void run_ivf_pq(int64_t n_rows, + int64_t n_queries, + int64_t n_dim, + uint32_t n_neighbors, + float* index_data, + float* query_data, + float* distances_data, + int64_t* neighbors_data, + enum DistanceType metric, + size_t n_probes, + size_t n_lists); + +template +void generate_random_data(T* devPtr, size_t size) +{ + raft::handle_t handle; + raft::random::RngState r(1234ULL); + raft::random::uniform(handle, r, devPtr, size, T(0.1), T(2.0)); +}; + +template +void recall_eval(T* query_data, + T* index_data, + IdxT* neighbors, + T* distances, + size_t n_queries, + size_t n_rows, + size_t n_dim, + size_t n_neighbors, + DistanceType metric, + size_t n_probes, + size_t n_lists) +{ + raft::handle_t handle; + auto distances_ref = raft::make_device_matrix(handle, n_queries, n_neighbors); + auto neighbors_ref = raft::make_device_matrix(handle, n_queries, n_neighbors); + cuvs::neighbors::naive_knn( + handle, + distances_ref.data_handle(), + neighbors_ref.data_handle(), + query_data, + index_data, + n_queries, + n_rows, + n_dim, + n_neighbors, + static_cast((uint16_t)metric)); + + size_t size = n_queries * n_neighbors; + std::vector neighbors_h(size); + std::vector distances_h(size); + std::vector neighbors_ref_h(size); + std::vector distances_ref_h(size); + + auto stream = raft::resource::get_cuda_stream(handle); + raft::copy(neighbors_h.data(), neighbors, size, stream); + raft::copy(distances_h.data(), distances, size, stream); + raft::copy(neighbors_ref_h.data(), neighbors_ref.data_handle(), size, stream); + raft::copy(distances_ref_h.data(), distances_ref.data_handle(), size, stream); + + // verify output + double min_recall = static_cast(n_probes) / static_cast(n_lists); + ASSERT_TRUE(cuvs::neighbors::eval_neighbours(neighbors_ref_h, + neighbors_h, + distances_ref_h, + distances_h, + n_queries, + n_neighbors, + 0.001, + min_recall)); +}; + +TEST(IvfPqC, BuildSearch) +{ + int64_t n_rows = 8096; + int64_t n_queries = 128; + int64_t n_dim = 32; + uint32_t n_neighbors = 8; + + enum DistanceType metric = L2Expanded; + size_t n_probes = 20; + size_t n_lists = 1024; + + float *index_data, *query_data, *distances_data; + int64_t* neighbors_data; + cudaMalloc(&index_data, sizeof(float) * n_rows * n_dim); + cudaMalloc(&query_data, sizeof(float) * n_queries * n_dim); + cudaMalloc(&neighbors_data, sizeof(int64_t) * n_queries * n_neighbors); + cudaMalloc(&distances_data, sizeof(float) * n_queries * n_neighbors); + + generate_random_data(index_data, n_rows * n_dim); + generate_random_data(query_data, n_queries * n_dim); + + run_ivf_pq(n_rows, + n_queries, + n_dim, + n_neighbors, + index_data, + query_data, + distances_data, + neighbors_data, + metric, + n_probes, + n_lists); + + recall_eval(query_data, + index_data, + neighbors_data, + distances_data, + n_queries, + n_rows, + n_dim, + n_neighbors, + metric, + n_probes, + n_lists); + + // delete device memory + cudaFree(index_data); + cudaFree(query_data); + cudaFree(neighbors_data); + cudaFree(distances_data); +} diff --git a/cpp/test/neighbors/brute_force.cu b/cpp/test/neighbors/brute_force.cu new file mode 100644 index 000000000..d058bface --- /dev/null +++ b/cpp/test/neighbors/brute_force.cu @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include +#include + +namespace cuvs::neighbors::brute_force { +struct KNNInputs { + std::vector> input; + int k; + std::vector labels; +}; + +template +RAFT_KERNEL build_actual_output( + int* output, int n_rows, int k, const int* idx_labels, const IdxT* indices) +{ + int element = threadIdx.x + blockDim.x * blockIdx.x; + if (element >= n_rows * k) return; + + output[element] = idx_labels[indices[element]]; +} + +RAFT_KERNEL build_expected_output(int* output, int n_rows, int k, const int* labels) +{ + int row = threadIdx.x + blockDim.x * blockIdx.x; + if (row >= n_rows) return; + + int cur_label = labels[row]; + for (int i = 0; i < k; i++) { + output[row * k + i] = cur_label; + } +} + +template +class KNNTest : public ::testing::TestWithParam { + public: + KNNTest() + : params_(::testing::TestWithParam::GetParam()), + stream(raft::resource::get_cuda_stream(handle)), + actual_labels_(0, stream), + expected_labels_(0, stream), + input_(0, stream), + search_data_(0, stream), + indices_(0, stream), + distances_(0, stream), + search_labels_(0, stream) + { + } + + protected: + void testBruteForce() + { + // #if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) + raft::print_device_vector("Input array: ", input_.data(), rows_ * cols_, std::cout); + std::cout << "K: " << k_ << std::endl; + raft::print_device_vector("Labels array: ", search_labels_.data(), rows_, std::cout); + // #endif + + auto index = raft::make_device_matrix_view( + (const T*)(input_.data()), rows_, cols_); + auto search = raft::make_device_matrix_view( + (const T*)(search_data_.data()), rows_, cols_); + auto indices = + raft::make_device_matrix_view(indices_.data(), rows_, k_); + auto distances = + raft::make_device_matrix_view(distances_.data(), rows_, k_); + + auto metric = cuvs::distance::DistanceType::L2Unexpanded; + auto idx = cuvs::neighbors::brute_force::build(handle, index, metric); + cuvs::neighbors::brute_force::search(handle, idx, search, indices, distances); + + build_actual_output<<>>( + actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data()); + + build_expected_output<<>>( + expected_labels_.data(), rows_, k_, search_labels_.data()); + + ASSERT_TRUE(devArrMatch( + expected_labels_.data(), actual_labels_.data(), rows_ * k_, cuvs::Compare(), stream)); + } + + void SetUp() override + { + rows_ = params_.input.size(); + cols_ = params_.input[0].size(); + k_ = params_.k; + + actual_labels_.resize(rows_ * k_, stream); + expected_labels_.resize(rows_ * k_, stream); + input_.resize(rows_ * cols_, stream); + search_data_.resize(rows_ * cols_, stream); + indices_.resize(rows_ * k_, stream); + distances_.resize(rows_ * k_, stream); + search_labels_.resize(rows_, stream); + + RAFT_CUDA_TRY( + cudaMemsetAsync(actual_labels_.data(), 0, actual_labels_.size() * sizeof(int), stream)); + RAFT_CUDA_TRY( + cudaMemsetAsync(expected_labels_.data(), 0, expected_labels_.size() * sizeof(int), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(input_.data(), 0, input_.size() * sizeof(float), stream)); + RAFT_CUDA_TRY( + cudaMemsetAsync(search_data_.data(), 0, search_data_.size() * sizeof(float), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(indices_.data(), 0, indices_.size() * sizeof(IdxT), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(distances_.data(), 0, distances_.size() * sizeof(float), stream)); + RAFT_CUDA_TRY( + cudaMemsetAsync(search_labels_.data(), 0, search_labels_.size() * sizeof(int), stream)); + + std::vector row_major_input; + for (std::size_t i = 0; i < params_.input.size(); ++i) { + for (std::size_t j = 0; j < params_.input[i].size(); ++j) { + row_major_input.push_back(params_.input[i][j]); + } + } + rmm::device_buffer input_d = + rmm::device_buffer(row_major_input.data(), row_major_input.size() * sizeof(float), stream); + float* input_ptr = static_cast(input_d.data()); + + rmm::device_buffer labels_d = + rmm::device_buffer(params_.labels.data(), params_.labels.size() * sizeof(int), stream); + int* labels_ptr = static_cast(labels_d.data()); + + raft::copy(input_.data(), input_ptr, rows_ * cols_, stream); + raft::copy(search_data_.data(), input_ptr, rows_ * cols_, stream); + raft::copy(search_labels_.data(), labels_ptr, rows_, stream); + raft::resource::sync_stream(handle, stream); + } + + private: + raft::resources handle; + cudaStream_t stream; + + KNNInputs params_; + int rows_; + int cols_; + rmm::device_uvector input_; + rmm::device_uvector search_data_; + rmm::device_uvector indices_; + rmm::device_uvector distances_; + int k_; + + rmm::device_uvector search_labels_; + rmm::device_uvector actual_labels_; + rmm::device_uvector expected_labels_; +}; + +const std::vector inputs = { + // 2D + {{ + {2.7810836, 2.550537003}, + {1.465489372, 2.362125076}, + {3.396561688, 4.400293529}, + {1.38807019, 1.850220317}, + {3.06407232, 3.005305973}, + {7.627531214, 2.759262235}, + {5.332441248, 2.088626775}, + {6.922596716, 1.77106367}, + {8.675418651, -0.242068655}, + {7.673756466, 3.508563011}, + }, + 2, + {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}}}; + +typedef KNNTest KNNTestFint64_t; +TEST_P(KNNTestFint64_t, BruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFint64_t, ::testing::ValuesIn(inputs)); +} // namespace cuvs::neighbors::brute_force \ No newline at end of file diff --git a/cpp/test/neighbors/brute_force_c.cu b/cpp/test/neighbors/brute_force_c.cu new file mode 100644 index 000000000..7730a98c6 --- /dev/null +++ b/cpp/test/neighbors/brute_force_c.cu @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "ann_utils.cuh" +#include + +extern "C" void run_brute_force(int64_t n_rows, + int64_t n_queries, + int64_t n_dim, + uint32_t n_neighbors, + float* index_data, + float* query_data, + float* distances_data, + int64_t* neighbors_data, + enum DistanceType metric); + +template +void generate_random_data(T* devPtr, size_t size) +{ + raft::handle_t handle; + raft::random::RngState r(1234ULL); + raft::random::uniform(handle, r, devPtr, size, T(0.1), T(2.0)); +}; + +template +void recall_eval(T* query_data, + T* index_data, + IdxT* neighbors, + T* distances, + size_t n_queries, + size_t n_rows, + size_t n_dim, + size_t n_neighbors, + DistanceType metric) +{ + raft::handle_t handle; + auto distances_ref = raft::make_device_matrix(handle, n_queries, n_neighbors); + auto neighbors_ref = raft::make_device_matrix(handle, n_queries, n_neighbors); + cuvs::neighbors::naive_knn( + handle, + distances_ref.data_handle(), + neighbors_ref.data_handle(), + query_data, + index_data, + n_queries, + n_rows, + n_dim, + n_neighbors, + static_cast((uint16_t)metric)); + + size_t size = n_queries * n_neighbors; + std::vector neighbors_h(size); + std::vector distances_h(size); + std::vector neighbors_ref_h(size); + std::vector distances_ref_h(size); + + auto stream = raft::resource::get_cuda_stream(handle); + raft::copy(neighbors_h.data(), neighbors, size, stream); + raft::copy(distances_h.data(), distances, size, stream); + raft::copy(neighbors_ref_h.data(), neighbors_ref.data_handle(), size, stream); + raft::copy(distances_ref_h.data(), distances_ref.data_handle(), size, stream); + + // verify output + double min_recall = 0.95; + ASSERT_TRUE(cuvs::neighbors::eval_neighbours(neighbors_ref_h, + neighbors_h, + distances_ref_h, + distances_h, + n_queries, + n_neighbors, + 0.001, + min_recall)); +}; + +TEST(BruteForceC, BuildSearch) +{ + int64_t n_rows = 8096; + int64_t n_queries = 128; + int64_t n_dim = 32; + uint32_t n_neighbors = 8; + + enum DistanceType metric = L2Expanded; + + float *index_data, *query_data, *distances_data; + int64_t* neighbors_data; + cudaMalloc(&index_data, sizeof(float) * n_rows * n_dim); + cudaMalloc(&query_data, sizeof(float) * n_queries * n_dim); + cudaMalloc(&neighbors_data, sizeof(int64_t) * n_queries * n_neighbors); + cudaMalloc(&distances_data, sizeof(float) * n_queries * n_neighbors); + + generate_random_data(index_data, n_rows * n_dim); + generate_random_data(query_data, n_queries * n_dim); + + run_brute_force(n_rows, + n_queries, + n_dim, + n_neighbors, + index_data, + query_data, + distances_data, + neighbors_data, + metric); + + recall_eval(query_data, + index_data, + neighbors_data, + distances_data, + n_queries, + n_rows, + n_dim, + n_neighbors, + metric); + + // delete device memory + cudaFree(index_data); + cudaFree(query_data); + cudaFree(neighbors_data); + cudaFree(distances_data); +} diff --git a/cpp/test/neighbors/ivf_pq_helpers.cuh b/cpp/test/neighbors/ivf_pq_helpers.cuh new file mode 100644 index 000000000..79a310f07 --- /dev/null +++ b/cpp/test/neighbors/ivf_pq_helpers.cuh @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +#include + +namespace cuvs::neighbors::ivf_pq::helpers { + +using namespace raft; +using namespace raft::neighbors; +using namespace raft::neighbors::ivf_pq; + +/** + * @defgroup ivf_pq_helpers Helper functions for manipulationg IVF PQ Index + * @{ + */ + +namespace codepacker { +/** + * @brief Unpack `n_take` consecutive records of a single list (cluster) in the compressed index + * starting at given `offset`. + * + * Bit compression is removed, which means output will have pq_dim dimensional vectors (one code per + * byte, instead of ceildiv(pq_dim * pq_bits, 8) bytes of pq codes). + * + * Usage example: + * @code{.cpp} + * auto list_data = index.lists()[label]->data.view(); + * // allocate the buffer for the output + * uint32_t n_take = 4; + * auto codes = raft::make_device_matrix(res, n_take, index.pq_dim()); + * uint32_t offset = 0; + * // unpack n_take elements from the list + * ivf_pq::helpers::codepacker::unpack(res, list_data, index.pq_bits(), offset, codes.view()); + * @endcode + * + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res raft resource + * @param[in] list_data block to read from + * @param[in] pq_bits bit length of encoded vector elements + * @param[in] offset + * How many records in the list to skip. + * @param[out] codes + * the destination buffer [n_take, index.pq_dim()]. + * The length `n_take` defines how many records to unpack, + * it must be smaller than the list size. + */ +inline void unpack( + raft::resources const& res, + device_mdspan::list_extents, row_major> list_data, + uint32_t pq_bits, + uint32_t offset, + device_matrix_view codes) +{ + raft::neighbors::ivf_pq::detail::unpack_list_data( + codes, list_data, offset, pq_bits, resource::get_cuda_stream(res)); +} + +/** + * Write flat PQ codes into an existing list by the given offset. + * + * NB: no memory allocation happens here; the list must fit the data (offset + n_vec). + * + * Usage example: + * @code{.cpp} + * auto list_data = index.lists()[label]->data.view(); + * // allocate the buffer for the input codes + * auto codes = raft::make_device_matrix(res, n_vec, index.pq_dim()); + * ... prepare n_vecs to pack into the list in codes ... + * // write codes into the list starting from the 42nd position + * ivf_pq::helpers::codepacker::pack( + * res, make_const_mdspan(codes.view()), index.pq_bits(), 42, list_data); + * @endcode + * + * @param[in] res + * @param[in] codes flat PQ codes, one code per byte [n_vec, pq_dim] + * @param[in] pq_bits bit length of encoded vector elements + * @param[in] offset how many records to skip before writing the data into the list + * @param[in] list_data block to write into + */ +inline void pack( + raft::resources const& res, + device_matrix_view codes, + uint32_t pq_bits, + uint32_t offset, + device_mdspan::list_extents, row_major> list_data) +{ + raft::neighbors::ivf_pq::detail::pack_list_data( + list_data, codes, offset, pq_bits, resource::get_cuda_stream(res)); +} +} // namespace codepacker + +/** + * Write flat PQ codes into an existing list by the given offset. + * + * The list is identified by its label. + * + * NB: no memory allocation happens here; the list must fit the data (offset + n_vec). + * + * Usage example: + * @code{.cpp} + * // We will write into the 137th cluster + * uint32_t label = 137; + * // allocate the buffer for the input codes + * auto codes = raft::make_device_matrix(res, n_vec, index.pq_dim()); + * ... prepare n_vecs to pack into the list in codes ... + * // write codes into the list starting from the 42nd position + * ivf_pq::helpers::pack_list_data(res, &index, codes_to_pack, label, 42); + * @endcode + * + * @param[in] res + * @param[inout] index IVF-PQ index. + * @param[in] codes flat PQ codes, one code per byte [n_rows, pq_dim] + * @param[in] label The id of the list (cluster) into which we write. + * @param[in] offset how many records to skip before writing the data into the list + */ +template +void pack_list_data(raft::resources const& res, + cuvs::neighbors::ivf_pq::index* index, + device_matrix_view codes, + uint32_t label, + uint32_t offset) +{ + raft::neighbors::ivf_pq::detail::pack_list_data( + res, index->get_raft_index(), codes, label, offset); +} + +/** + * @brief Unpack `n_take` consecutive records of a single list (cluster) in the compressed index + * starting at given `offset`, one code per byte (independently of pq_bits). + * + * Usage example: + * @code{.cpp} + * // We will unpack the fourth cluster + * uint32_t label = 3; + * // Get the list size + * uint32_t list_size = 0; + * raft::copy(&list_size, index.list_sizes().data_handle() + label, 1, + * resource::get_cuda_stream(res)); resource::sync_stream(res); + * // allocate the buffer for the output + * auto codes = raft::make_device_matrix(res, list_size, index.pq_dim()); + * // unpack the whole list + * ivf_pq::helpers::unpack_list_data(res, index, codes.view(), label, 0); + * @endcode + * + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] index + * @param[out] out_codes + * the destination buffer [n_take, index.pq_dim()]. + * The length `n_take` defines how many records to unpack, + * it must be smaller than the list size. + * @param[in] label + * The id of the list (cluster) to decode. + * @param[in] offset + * How many records in the list to skip. + */ +template +void unpack_list_data(raft::resources const& res, + const cuvs::neighbors::ivf_pq::index& index, + device_matrix_view out_codes, + uint32_t label, + uint32_t offset) +{ + return raft::neighbors::ivf_pq::detail::unpack_list_data( + res, *index.get_raft_index(), out_codes, label, offset); +} + +/** + * @brief Unpack a series of records of a single list (cluster) in the compressed index + * by their in-list offsets, one code per byte (independently of pq_bits). + * + * Usage example: + * @code{.cpp} + * // We will unpack the fourth cluster + * uint32_t label = 3; + * // Create the selection vector + * auto selected_indices = raft::make_device_vector(res, 4); + * ... fill the indices ... + * resource::sync_stream(res); + * // allocate the buffer for the output + * auto codes = raft::make_device_matrix(res, selected_indices.size(), index.pq_dim()); + * // decode the whole list + * ivf_pq::helpers::unpack_list_data( + * res, index, selected_indices.view(), codes.view(), label); + * @endcode + * + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] index + * @param[in] in_cluster_indices + * The offsets of the selected indices within the cluster. + * @param[out] out_codes + * the destination buffer [n_take, index.pq_dim()]. + * The length `n_take` defines how many records to unpack, + * it must be smaller than the list size. + * @param[in] label + * The id of the list (cluster) to decode. + */ +template +void unpack_list_data(raft::resources const& res, + const cuvs::neighbors::ivf_pq::index& index, + device_vector_view in_cluster_indices, + device_matrix_view out_codes, + uint32_t label) +{ + return raft::neighbors::ivf_pq::detail::unpack_list_data( + res, index, out_codes, label, in_cluster_indices); +} + +/** + * @brief Decode `n_take` consecutive records of a single list (cluster) in the compressed index + * starting at given `offset`. + * + * Usage example: + * @code{.cpp} + * // We will reconstruct the fourth cluster + * uint32_t label = 3; + * // Get the list size + * uint32_t list_size = 0; + * raft::copy(&list_size, index.list_sizes().data_handle() + label, 1, + * resource::get_cuda_stream(res)); resource::sync_stream(res); + * // allocate the buffer for the output + * auto decoded_vectors = raft::make_device_matrix(res, list_size, index.dim()); + * // decode the whole list + * ivf_pq::helpers::reconstruct_list_data(res, index, decoded_vectors.view(), label, 0); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] index + * @param[out] out_vectors + * the destination buffer [n_take, index.dim()]. + * The length `n_take` defines how many records to reconstruct, + * it must be smaller than the list size. + * @param[in] label + * The id of the list (cluster) to decode. + * @param[in] offset + * How many records in the list to skip. + */ +template +void reconstruct_list_data(raft::resources const& res, + const cuvs::neighbors::ivf_pq::index& index, + device_matrix_view out_vectors, + uint32_t label, + uint32_t offset) +{ + return raft::neighbors::ivf_pq::detail::reconstruct_list_data( + res, *index.get_raft_index(), out_vectors, label, offset); +} + +/** + * @brief Decode a series of records of a single list (cluster) in the compressed index + * by their in-list offsets. + * + * Usage example: + * @code{.cpp} + * // We will reconstruct the fourth cluster + * uint32_t label = 3; + * // Create the selection vector + * auto selected_indices = raft::make_device_vector(res, 4); + * ... fill the indices ... + * resource::sync_stream(res); + * // allocate the buffer for the output + * auto decoded_vectors = raft::make_device_matrix( + * res, selected_indices.size(), index.dim()); + * // decode the whole list + * ivf_pq::helpers::reconstruct_list_data( + * res, index, selected_indices.view(), decoded_vectors.view(), label); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] index + * @param[in] in_cluster_indices + * The offsets of the selected indices within the cluster. + * @param[out] out_vectors + * the destination buffer [n_take, index.dim()]. + * The length `n_take` defines how many records to reconstruct, + * it must be smaller than the list size. + * @param[in] label + * The id of the list (cluster) to decode. + */ +template +void reconstruct_list_data(raft::resources const& res, + const cuvs::neighbors::ivf_pq::index& index, + device_vector_view in_cluster_indices, + device_matrix_view out_vectors, + uint32_t label) +{ + return raft::neighbors::ivf_pq::detail::reconstruct_list_data( + res, index, out_vectors, label, in_cluster_indices); +} + +/** + * @brief Extend one list of the index in-place, by the list label, skipping the classification and + * encoding steps. + * + * Usage example: + * @code{.cpp} + * // We will extend the fourth cluster + * uint32_t label = 3; + * // We will fill 4 new vectors + * uint32_t n_vec = 4; + * // Indices of the new vectors + * auto indices = raft::make_device_vector(res, n_vec); + * ... fill the indices ... + * auto new_codes = raft::make_device_matrix new_codes( + * res, n_vec, index.pq_dim()); + * ... fill codes ... + * // extend list with new codes + * ivf_pq::helpers::extend_list_with_codes( + * res, &index, codes.view(), indices.view(), label); + * @endcode + * + * @tparam IdxT + * + * @param[in] res + * @param[inout] index + * @param[in] new_codes flat PQ codes, one code per byte [n_rows, index.pq_dim()] + * @param[in] new_indices source indices [n_rows] + * @param[in] label the id of the target list (cluster). + */ +template +void extend_list_with_codes(raft::resources const& res, + cuvs::neighbors::ivf_pq::index* index, + device_matrix_view new_codes, + device_vector_view new_indices, + uint32_t label) +{ + raft::neighbors::ivf_pq::detail::extend_list_with_codes( + res, index->get_raft_index(), new_codes, new_indices, label); +} + +/** + * @brief Extend one list of the index in-place, by the list label, skipping the classification + * step. + * + * Usage example: + * @code{.cpp} + * // We will extend the fourth cluster + * uint32_t label = 3; + * // We will extend with 4 new vectors + * uint32_t n_vec = 4; + * // Indices of the new vectors + * auto indices = raft::make_device_vector(res, n_vec); + * ... fill the indices ... + * auto new_vectors = raft::make_device_matrix new_codes( + * res, n_vec, index.dim()); + * ... fill vectors ... + * // extend list with new vectors + * ivf_pq::helpers::extend_list( + * res, &index, new_vectors.view(), indices.view(), label); + * @endcode + * + * @tparam T + * @tparam IdxT + * + * @param[in] res + * @param[inout] index + * @param[in] new_vectors data to encode [n_rows, index.dim()] + * @param[in] new_indices source indices [n_rows] + * @param[in] label the id of the target list (cluster). + * + */ +template +void extend_list(raft::resources const& res, + cuvs::neighbors::ivf_pq::index* index, + device_matrix_view new_vectors, + device_vector_view new_indices, + uint32_t label) +{ + raft::neighbors::ivf_pq::detail::extend_list( + res, index->get_raft_index(), new_vectors, new_indices, label); +} + +/** + * @brief Remove all data from a single list (cluster) in the index. + * + * Usage example: + * @code{.cpp} + * // We will erase the fourth cluster (label = 3) + * ivf_pq::helpers::erase_list(res, &index, 3); + * @endcode + * + * @tparam IdxT + * @param[in] res + * @param[inout] index + * @param[in] label the id of the target list (cluster). + */ +template +void erase_list(raft::resources const& res, + cuvs::neighbors::ivf_pq::index* index, + uint32_t label) +{ + raft::neighbors::ivf_pq::detail::erase_list(res, index->get_raft_index(), label); +} + +/** @} */ +} // namespace cuvs::neighbors::ivf_pq::helpers diff --git a/cpp/test/neighbors/run_brute_force_c.c b/cpp/test/neighbors/run_brute_force_c.c new file mode 100644 index 000000000..9c7af13a6 --- /dev/null +++ b/cpp/test/neighbors/run_brute_force_c.c @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +void run_brute_force(int64_t n_rows, + int64_t n_queries, + int64_t n_dim, + uint32_t n_neighbors, + float* index_data, + float* query_data, + float* distances_data, + int64_t* neighbors_data, + enum DistanceType metric) +{ + // create cuvsResources_t + cuvsResources_t res; + cuvsResourcesCreate(&res); + + // create dataset DLTensor + DLManagedTensor dataset_tensor; + dataset_tensor.dl_tensor.data = index_data; + dataset_tensor.dl_tensor.device.device_type = kDLCUDA; + dataset_tensor.dl_tensor.ndim = 2; + dataset_tensor.dl_tensor.dtype.code = kDLFloat; + dataset_tensor.dl_tensor.dtype.bits = 32; + dataset_tensor.dl_tensor.dtype.lanes = 1; + int64_t dataset_shape[2] = {n_rows, n_dim}; + dataset_tensor.dl_tensor.shape = dataset_shape; + dataset_tensor.dl_tensor.strides = NULL; + + // create index + cuvsBruteForceIndex_t index; + bruteForceIndexCreate(&index); + + // build index + bruteForceBuild(res, &dataset_tensor, metric, 0.0f, index); + + // create queries DLTensor + DLManagedTensor queries_tensor; + queries_tensor.dl_tensor.data = (void*)query_data; + queries_tensor.dl_tensor.device.device_type = kDLCUDA; + queries_tensor.dl_tensor.ndim = 2; + queries_tensor.dl_tensor.dtype.code = kDLFloat; + queries_tensor.dl_tensor.dtype.bits = 32; + queries_tensor.dl_tensor.dtype.lanes = 1; + int64_t queries_shape[2] = {n_queries, n_dim}; + queries_tensor.dl_tensor.shape = queries_shape; + queries_tensor.dl_tensor.strides = NULL; + + // create neighbors DLTensor + DLManagedTensor neighbors_tensor; + neighbors_tensor.dl_tensor.data = (void*)neighbors_data; + neighbors_tensor.dl_tensor.device.device_type = kDLCUDA; + neighbors_tensor.dl_tensor.ndim = 2; + neighbors_tensor.dl_tensor.dtype.code = kDLInt; + neighbors_tensor.dl_tensor.dtype.bits = 64; + neighbors_tensor.dl_tensor.dtype.lanes = 1; + int64_t neighbors_shape[2] = {n_queries, n_neighbors}; + neighbors_tensor.dl_tensor.shape = neighbors_shape; + neighbors_tensor.dl_tensor.strides = NULL; + + // create distances DLTensor + DLManagedTensor distances_tensor; + distances_tensor.dl_tensor.data = (void*)distances_data; + distances_tensor.dl_tensor.device.device_type = kDLCUDA; + distances_tensor.dl_tensor.ndim = 2; + distances_tensor.dl_tensor.dtype.code = kDLFloat; + distances_tensor.dl_tensor.dtype.bits = 32; + distances_tensor.dl_tensor.dtype.lanes = 1; + int64_t distances_shape[2] = {n_queries, n_neighbors}; + distances_tensor.dl_tensor.shape = distances_shape; + distances_tensor.dl_tensor.strides = NULL; + + // search index + bruteForceSearch(res, index, &queries_tensor, &neighbors_tensor, &distances_tensor); + + // de-allocate index and res + bruteForceIndexDestroy(index); + cuvsResourcesDestroy(res); +} diff --git a/cpp/test/neighbors/run_ivf_flat_c.c b/cpp/test/neighbors/run_ivf_flat_c.c new file mode 100644 index 000000000..badb507a5 --- /dev/null +++ b/cpp/test/neighbors/run_ivf_flat_c.c @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +void run_ivf_flat(int64_t n_rows, + int64_t n_queries, + int64_t n_dim, + uint32_t n_neighbors, + float* index_data, + float* query_data, + float* distances_data, + int64_t* neighbors_data, + enum DistanceType metric, + size_t n_probes, + size_t n_lists) +{ + // create cuvsResources_t + cuvsResources_t res; + cuvsResourcesCreate(&res); + + // create dataset DLTensor + DLManagedTensor dataset_tensor; + dataset_tensor.dl_tensor.data = index_data; + dataset_tensor.dl_tensor.device.device_type = kDLCUDA; + dataset_tensor.dl_tensor.ndim = 2; + dataset_tensor.dl_tensor.dtype.code = kDLFloat; + dataset_tensor.dl_tensor.dtype.bits = 32; + dataset_tensor.dl_tensor.dtype.lanes = 1; + int64_t dataset_shape[2] = {n_rows, n_dim}; + dataset_tensor.dl_tensor.shape = dataset_shape; + dataset_tensor.dl_tensor.strides = NULL; + + // create index + cuvsIvfFlatIndex_t index; + ivfFlatIndexCreate(&index); + + // build index + cuvsIvfFlatIndexParams_t build_params; + cuvsIvfFlatIndexParamsCreate(&build_params); + build_params->metric = metric; + build_params->n_lists = n_lists; + ivfFlatBuild(res, build_params, &dataset_tensor, index); + + // create queries DLTensor + DLManagedTensor queries_tensor; + queries_tensor.dl_tensor.data = (void*)query_data; + queries_tensor.dl_tensor.device.device_type = kDLCUDA; + queries_tensor.dl_tensor.ndim = 2; + queries_tensor.dl_tensor.dtype.code = kDLFloat; + queries_tensor.dl_tensor.dtype.bits = 32; + queries_tensor.dl_tensor.dtype.lanes = 1; + int64_t queries_shape[2] = {n_queries, n_dim}; + queries_tensor.dl_tensor.shape = queries_shape; + queries_tensor.dl_tensor.strides = NULL; + + // create neighbors DLTensor + DLManagedTensor neighbors_tensor; + neighbors_tensor.dl_tensor.data = (void*)neighbors_data; + neighbors_tensor.dl_tensor.device.device_type = kDLCUDA; + neighbors_tensor.dl_tensor.ndim = 2; + neighbors_tensor.dl_tensor.dtype.code = kDLInt; + neighbors_tensor.dl_tensor.dtype.bits = 64; + neighbors_tensor.dl_tensor.dtype.lanes = 1; + int64_t neighbors_shape[2] = {n_queries, n_neighbors}; + neighbors_tensor.dl_tensor.shape = neighbors_shape; + neighbors_tensor.dl_tensor.strides = NULL; + + // create distances DLTensor + DLManagedTensor distances_tensor; + distances_tensor.dl_tensor.data = (void*)distances_data; + distances_tensor.dl_tensor.device.device_type = kDLCUDA; + distances_tensor.dl_tensor.ndim = 2; + distances_tensor.dl_tensor.dtype.code = kDLFloat; + distances_tensor.dl_tensor.dtype.bits = 32; + distances_tensor.dl_tensor.dtype.lanes = 1; + int64_t distances_shape[2] = {n_queries, n_neighbors}; + distances_tensor.dl_tensor.shape = distances_shape; + distances_tensor.dl_tensor.strides = NULL; + + // search index + cuvsIvfFlatSearchParams_t search_params; + cuvsIvfFlatSearchParamsCreate(&search_params); + search_params->n_probes = n_probes; + ivfFlatSearch(res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor); + + // de-allocate index and res + cuvsIvfFlatSearchParamsDestroy(search_params); + cuvsIvfFlatIndexParamsDestroy(build_params); + ivfFlatIndexDestroy(index); + cuvsResourcesDestroy(res); +} diff --git a/cpp/test/neighbors/run_ivf_pq_c.c b/cpp/test/neighbors/run_ivf_pq_c.c new file mode 100644 index 000000000..fece4a644 --- /dev/null +++ b/cpp/test/neighbors/run_ivf_pq_c.c @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +void run_ivf_pq(int64_t n_rows, + int64_t n_queries, + int64_t n_dim, + uint32_t n_neighbors, + float* index_data, + float* query_data, + float* distances_data, + int64_t* neighbors_data, + enum DistanceType metric, + size_t n_probes, + size_t n_lists) +{ + // create cuvsResources_t + cuvsResources_t res; + cuvsResourcesCreate(&res); + + // create dataset DLTensor + DLManagedTensor dataset_tensor; + dataset_tensor.dl_tensor.data = index_data; + dataset_tensor.dl_tensor.device.device_type = kDLCUDA; + dataset_tensor.dl_tensor.ndim = 2; + dataset_tensor.dl_tensor.dtype.code = kDLFloat; + dataset_tensor.dl_tensor.dtype.bits = 32; + dataset_tensor.dl_tensor.dtype.lanes = 1; + int64_t dataset_shape[2] = {n_rows, n_dim}; + dataset_tensor.dl_tensor.shape = dataset_shape; + dataset_tensor.dl_tensor.strides = NULL; + + // create index + cuvsIvfPqIndex_t index; + ivfPqIndexCreate(&index); + + // build index + cuvsIvfPqIndexParams_t build_params; + cuvsIvfPqIndexParamsCreate(&build_params); + build_params->metric = metric; + build_params->n_lists = n_lists; + ivfPqBuild(res, build_params, &dataset_tensor, index); + + // create queries DLTensor + DLManagedTensor queries_tensor; + queries_tensor.dl_tensor.data = (void*)query_data; + queries_tensor.dl_tensor.device.device_type = kDLCUDA; + queries_tensor.dl_tensor.ndim = 2; + queries_tensor.dl_tensor.dtype.code = kDLFloat; + queries_tensor.dl_tensor.dtype.bits = 32; + queries_tensor.dl_tensor.dtype.lanes = 1; + int64_t queries_shape[2] = {n_queries, n_dim}; + queries_tensor.dl_tensor.shape = queries_shape; + queries_tensor.dl_tensor.strides = NULL; + + // create neighbors DLTensor + DLManagedTensor neighbors_tensor; + neighbors_tensor.dl_tensor.data = (void*)neighbors_data; + neighbors_tensor.dl_tensor.device.device_type = kDLCUDA; + neighbors_tensor.dl_tensor.ndim = 2; + neighbors_tensor.dl_tensor.dtype.code = kDLInt; + neighbors_tensor.dl_tensor.dtype.bits = 64; + neighbors_tensor.dl_tensor.dtype.lanes = 1; + int64_t neighbors_shape[2] = {n_queries, n_neighbors}; + neighbors_tensor.dl_tensor.shape = neighbors_shape; + neighbors_tensor.dl_tensor.strides = NULL; + + // create distances DLTensor + DLManagedTensor distances_tensor; + distances_tensor.dl_tensor.data = (void*)distances_data; + distances_tensor.dl_tensor.device.device_type = kDLCUDA; + distances_tensor.dl_tensor.ndim = 2; + distances_tensor.dl_tensor.dtype.code = kDLFloat; + distances_tensor.dl_tensor.dtype.bits = 32; + distances_tensor.dl_tensor.dtype.lanes = 1; + int64_t distances_shape[2] = {n_queries, n_neighbors}; + distances_tensor.dl_tensor.shape = distances_shape; + distances_tensor.dl_tensor.strides = NULL; + + // search index + cuvsIvfPqSearchParams_t search_params; + cuvsIvfPqSearchParamsCreate(&search_params); + search_params->n_probes = n_probes; + ivfPqSearch(res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor); + + // de-allocate index and res + cuvsIvfPqSearchParamsDestroy(search_params); + cuvsIvfPqIndexParamsDestroy(build_params); + ivfPqIndexDestroy(index); + cuvsResourcesDestroy(res); +} diff --git a/docs/source/c_api/neighbors.rst b/docs/source/c_api/neighbors.rst index 6bbd3cc7f..dc55a74dc 100644 --- a/docs/source/c_api/neighbors.rst +++ b/docs/source/c_api/neighbors.rst @@ -9,6 +9,7 @@ Nearest Neighbors :maxdepth: 2 :caption: Contents: - - - neighbors_cagra_c.rst \ No newline at end of file + neighbors_bruteforce_c.rst + neighbors_ivf_flat_c.rst + neighbors_ivf_pq_c.rst + neighbors_cagra_c.rst diff --git a/docs/source/c_api/neighbors_bruteforce_c.rst b/docs/source/c_api/neighbors_bruteforce_c.rst new file mode 100644 index 000000000..af0356eee --- /dev/null +++ b/docs/source/c_api/neighbors_bruteforce_c.rst @@ -0,0 +1,34 @@ +Bruteforce +========== + +The bruteforce method is running the KNN algorithm. It performs an extensive search, and in contrast to ANN methods produces an exact result. + +.. role:: py(code) + :language: c + :class: highlight + +``#include `` + +Index +----- + +.. doxygengroup:: bruteforce_c_index + :project: cuvs + :members: + :content-only: + +Index build +----------- + +.. doxygengroup:: bruteforce_c_index_build + :project: cuvs + :members: + :content-only: + +Index search +------------ + +.. doxygengroup:: bruteforce_c_index_search + :project: cuvs + :members: + :content-only: diff --git a/docs/source/c_api/neighbors_ivf_flat_c.rst b/docs/source/c_api/neighbors_ivf_flat_c.rst new file mode 100644 index 000000000..9e1ccc0d1 --- /dev/null +++ b/docs/source/c_api/neighbors_ivf_flat_c.rst @@ -0,0 +1,50 @@ +IVF-Flat +======== + +The IVF-Flat method is an ANN algorithm. It uses an inverted file index (IVF) with unmodified (that is, flat) vectors. This algorithm provides simple knobs to reduce the overall search space and to trade-off accuracy for speed. + +.. role:: py(code) + :language: c + :class: highlight + +``#include `` + +Index build parameters +---------------------- + +.. doxygengroup:: ivf_flat_c_index_params + :project: cuvs + :members: + :content-only: + +Index search parameters +----------------------- + +.. doxygengroup:: ivf_flat_c_search_params + :project: cuvs + :members: + :content-only: + +Index +----- + +.. doxygengroup:: ivf_flat_c_index + :project: cuvs + :members: + :content-only: + +Index build +----------- + +.. doxygengroup:: ivf_flat_c_index_build + :project: cuvs + :members: + :content-only: + +Index search +------------ + +.. doxygengroup:: ivf_flat_c_index_search + :project: cuvs + :members: + :content-only: diff --git a/docs/source/c_api/neighbors_ivf_pq_c.rst b/docs/source/c_api/neighbors_ivf_pq_c.rst new file mode 100644 index 000000000..070719609 --- /dev/null +++ b/docs/source/c_api/neighbors_ivf_pq_c.rst @@ -0,0 +1,50 @@ +IVF-PQ +====== + +The IVF-PQ method is an ANN algorithm. Like IVF-Flat, IVF-PQ splits the points into a number of clusters (also specified by a parameter called n_lists) and searches the closest clusters to compute the nearest neighbors (also specified by a parameter called n_probes), but it shrinks the sizes of the vectors using a technique called product quantization. + +.. role:: py(code) + :language: c + :class: highlight + +``#include `` + +Index build parameters +---------------------- + +.. doxygengroup:: ivf_pq_c_index_params + :project: cuvs + :members: + :content-only: + +Index search parameters +----------------------- + +.. doxygengroup:: ivf_pq_c_search_params + :project: cuvs + :members: + :content-only: + +Index +----- + +.. doxygengroup:: ivf_pq_c_index + :project: cuvs + :members: + :content-only: + +Index build +----------- + +.. doxygengroup:: ivf_pq_c_index_build + :project: cuvs + :members: + :content-only: + +Index search +------------ + +.. doxygengroup:: ivf_pq_c_index_search + :project: cuvs + :members: + :content-only: diff --git a/docs/source/cpp_api.rst b/docs/source/cpp_api.rst index 67bc7f65b..e04fff0b8 100644 --- a/docs/source/cpp_api.rst +++ b/docs/source/cpp_api.rst @@ -7,4 +7,5 @@ C++ API Documentation .. toctree:: :maxdepth: 4 + cpp_api/distance.rst cpp_api/neighbors.rst diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst new file mode 100644 index 000000000..21f4558ec --- /dev/null +++ b/docs/source/cpp_api/distance.rst @@ -0,0 +1,19 @@ +Distance +======== + +This page provides C++ class references for the publicly-exposed elements of the `cuvs/distance` package. cuVS's +distances have been highly optimized and support a wide assortment of different distance measures. + +.. role:: py(code) + :language: c++ + :class: highlight + +Distance Types +-------------- + +``#include `` + +namespace *cuvs::distance* + +.. doxygenenum:: cuvs::distance::DistanceType + :project: cuvs diff --git a/docs/source/cpp_api/neighbors.rst b/docs/source/cpp_api/neighbors.rst index 61898cec8..f9006412c 100644 --- a/docs/source/cpp_api/neighbors.rst +++ b/docs/source/cpp_api/neighbors.rst @@ -9,4 +9,7 @@ Nearest Neighbors :maxdepth: 2 :caption: Contents: - neighbors_cagra.rst \ No newline at end of file + neighbors_bruteforce.rst + neighbors_ivf_flat.rst + neighbors_ivf_pq.rst + neighbors_cagra.rst diff --git a/docs/source/cpp_api/neighbors_bruteforce.rst b/docs/source/cpp_api/neighbors_bruteforce.rst new file mode 100644 index 000000000..3adcb01c5 --- /dev/null +++ b/docs/source/cpp_api/neighbors_bruteforce.rst @@ -0,0 +1,36 @@ +Bruteforce +========== + +The bruteforce method is running the KNN algorithm. It performs an extensive search, and in contrast to ANN methods produces an exact result. + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *cuvs::neighbors::bruteforce* + +Index +----- + +.. doxygengroup:: bruteforce_cpp_index + :project: cuvs + :members: + :content-only: + +Index build +----------- + +.. doxygengroup:: bruteforce_cpp_index_build + :project: cuvs + :members: + :content-only: + +Index search +------------ + +.. doxygengroup:: bruteforce_cpp_index_search + :project: cuvs + :members: + :content-only: diff --git a/docs/source/cpp_api/neighbors_ivf_flat.rst b/docs/source/cpp_api/neighbors_ivf_flat.rst new file mode 100644 index 000000000..3836223e1 --- /dev/null +++ b/docs/source/cpp_api/neighbors_ivf_flat.rst @@ -0,0 +1,68 @@ +IVF-Flat +======== + +The IVF-Flat method is an ANN algorithm. It uses an inverted file index (IVF) with unmodified (that is, flat) vectors. This algorithm provides simple knobs to reduce the overall search space and to trade-off accuracy for speed. + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *cuvs::neighbors::ivf_flat* + +Index build parameters +---------------------- + +.. doxygengroup:: ivf_flat_cpp_index_params + :project: cuvs + :members: + :content-only: + +Index search parameters +----------------------- + +.. doxygengroup:: ivf_flat_cpp_search_params + :project: cuvs + :members: + :content-only: + +Index +----- + +.. doxygengroup:: ivf_flat_cpp_index + :project: cuvs + :members: + :content-only: + +Index build +----------- + +.. doxygengroup:: ivf_flat_cpp_index_build + :project: cuvs + :members: + :content-only: + +Index extend +------------ + +.. doxygengroup:: ivf_flat_cpp_index_extend + :project: cuvs + :members: + :content-only: + +Index search +------------ + +.. doxygengroup:: ivf_flat_cpp_index_search + :project: cuvs + :members: + :content-only: + +Index serialize +--------------- + +.. doxygengroup:: ivf_flat_cpp_serialize + :project: cuvs + :members: + :content-only: diff --git a/docs/source/cpp_api/neighbors_ivf_pq.rst b/docs/source/cpp_api/neighbors_ivf_pq.rst new file mode 100644 index 000000000..0d4d7061a --- /dev/null +++ b/docs/source/cpp_api/neighbors_ivf_pq.rst @@ -0,0 +1,68 @@ +IVF-PQ +====== + +The IVF-PQ method is an ANN algorithm. Like IVF-Flat, IVF-PQ splits the points into a number of clusters (also specified by a parameter called n_lists) and searches the closest clusters to compute the nearest neighbors (also specified by a parameter called n_probes), but it shrinks the sizes of the vectors using a technique called product quantization. + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *cuvs::neighbors::ivf_pq* + +Index build parameters +---------------------- + +.. doxygengroup:: ivf_pq_cpp_index_params + :project: cuvs + :members: + :content-only: + +Index search parameters +----------------------- + +.. doxygengroup:: ivf_pq_cpp_search_params + :project: cuvs + :members: + :content-only: + +Index +----- + +.. doxygengroup:: ivf_pq_cpp_index + :project: cuvs + :members: + :content-only: + +Index build +----------- + +.. doxygengroup:: ivf_pq_cpp_index_build + :project: cuvs + :members: + :content-only: + +Index extend +------------ + +.. doxygengroup:: ivf_pq_cpp_index_extend + :project: cuvs + :members: + :content-only: + +Index search +------------ + +.. doxygengroup:: ivf_pq_cpp_index_search + :project: cuvs + :members: + :content-only: + +Index serialize +--------------- + +.. doxygengroup:: ivf_pq_cpp_serialize + :project: cuvs + :members: + :content-only: