diff --git a/.gitignore b/.gitignore index 17258e3de..97eab287d 100644 --- a/.gitignore +++ b/.gitignore @@ -81,4 +81,4 @@ ivf_pq_index # cuvs_bench datasets/ -/*.json \ No newline at end of file +/*.json diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index b72d7f165..78c67d9c8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -202,6 +202,7 @@ endif() add_library( cuvs-cagra-search STATIC src/neighbors/cagra_search_float.cu + src/neighbors/cagra_search_half.cu src/neighbors/cagra_search_int8.cu src/neighbors/cagra_search_uint8.cu src/neighbors/detail/cagra/compute_distance.cu @@ -257,14 +258,10 @@ add_library( src/neighbors/detail/cagra/search_multi_cta_half_uint32.cu src/neighbors/detail/cagra/search_multi_cta_int8_uint32.cu src/neighbors/detail/cagra/search_multi_cta_uint8_uint32.cu - src/neighbors/detail/cagra/search_multi_cta_float_uint64.cu - src/neighbors/detail/cagra/search_multi_cta_half_uint64.cu src/neighbors/detail/cagra/search_single_cta_float_uint32.cu src/neighbors/detail/cagra/search_single_cta_half_uint32.cu src/neighbors/detail/cagra/search_single_cta_int8_uint32.cu src/neighbors/detail/cagra/search_single_cta_uint8_uint32.cu - src/neighbors/detail/cagra/search_single_cta_float_uint64.cu - src/neighbors/detail/cagra/search_single_cta_half_uint64.cu ) file(GLOB_RECURSE compute_distance_sources "src/neighbors/detail/cagra/compute_distance_*.cu") @@ -293,9 +290,14 @@ target_compile_options( add_library( cuvs SHARED src/cluster/kmeans_balanced_fit_float.cu + src/cluster/kmeans_fit_mg_float.cu + src/cluster/kmeans_fit_mg_double.cu + src/cluster/kmeans_fit_double.cu src/cluster/kmeans_fit_float.cu src/cluster/kmeans_auto_find_k_float.cu + src/cluster/kmeans_fit_predict_double.cu src/cluster/kmeans_fit_predict_float.cu + src/cluster/kmeans_predict_double.cu src/cluster/kmeans_predict_float.cu src/cluster/kmeans_balanced_fit_float.cu src/cluster/kmeans_balanced_fit_predict_float.cu @@ -303,6 +305,7 @@ add_library( src/cluster/kmeans_balanced_fit_int8.cu src/cluster/kmeans_balanced_fit_predict_int8.cu src/cluster/kmeans_balanced_predict_int8.cu + src/cluster/kmeans_transform_double.cu src/cluster/kmeans_transform_float.cu src/cluster/single_linkage_float.cu src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu @@ -345,11 +348,14 @@ add_library( src/distance/detail/pairwise_matrix/dispatch_russel_rao_half_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_rbf.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu src/distance/detail/fused_distance_nn.cu src/distance/distance.cu src/distance/pairwise_distance.cu src/neighbors/brute_force.cu src/neighbors/cagra_build_float.cu + src/neighbors/cagra_build_half.cu src/neighbors/cagra_build_int8.cu src/neighbors/cagra_build_uint8.cu src/neighbors/cagra_extend_float.cu @@ -357,6 +363,7 @@ add_library( src/neighbors/cagra_extend_uint8.cu src/neighbors/cagra_optimize.cu src/neighbors/cagra_serialize_float.cu + src/neighbors/cagra_serialize_half.cu src/neighbors/cagra_serialize_int8.cu src/neighbors/cagra_serialize_uint8.cu src/neighbors/detail/cagra/cagra_build.cpp @@ -378,6 +385,7 @@ add_library( src/neighbors/ivf_pq/ivf_pq_serialize.cu src/neighbors/ivf_pq/ivf_pq_deserialize.cu src/neighbors/ivf_pq/detail/ivf_pq_build_extend_float_int64_t.cu + src/neighbors/ivf_pq/detail/ivf_pq_build_extend_half_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_build_extend_int8_t_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_build_extend_uint8_t_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -395,15 +403,19 @@ add_library( src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu src/neighbors/ivf_pq/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu src/neighbors/ivf_pq/detail/ivf_pq_search_float_int64_t.cu + src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_search_int8_t_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_search_uint8_t_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_float_int64_t.cu + src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_half_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_int8_t_int64_t.cu src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_uint8_t_int64_t.cu src/neighbors/nn_descent.cu src/neighbors/nn_descent_float.cu + src/neighbors/nn_descent_half.cu src/neighbors/nn_descent_int8.cu src/neighbors/nn_descent_uint8.cu + src/neighbors/reachability.cu src/neighbors/refine/detail/refine_device_float_float.cu src/neighbors/refine/detail/refine_device_half_float.cu src/neighbors/refine/detail/refine_device_int8_t_float.cu @@ -414,6 +426,7 @@ add_library( src/neighbors/refine/detail/refine_host_uint8_t_float.cpp src/neighbors/sample_filter.cu src/selection/select_k_float_int64_t.cu + src/selection/select_k_float_int32_t.cu src/selection/select_k_float_uint32_t.cu src/selection/select_k_half_uint32_t.cu src/stats/silhouette_score.cu diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index 5d7b8934f..db3e533e0 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -687,17 +687,17 @@ inline auto run_main(int argc, char** argv) -> int override_kv, metric_objective, threads); - // } else if (dtype == "half") { - // dispatch_benchmark(cmdline - // conf, - // force_overwrite, - // build_mode, - // search_mode, - // data_prefix, - // index_prefix, - // override_kv, - // metric_objective, - // threads); + } else if (dtype == "half") { + dispatch_benchmark(cmdline, + conf, + force_overwrite, + build_mode, + search_mode, + data_prefix, + index_prefix, + override_kv, + metric_objective, + threads); } else if (dtype == "uint8") { dispatch_benchmark(cmdline, conf, diff --git a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h index 92274e263..11e0e4ad3 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h +++ b/cpp/bench/ann/src/cuvs/cuvs_ann_bench_utils.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -75,13 +76,14 @@ inline auto rmm_oom_callback(std::size_t bytes, void*) -> bool */ class shared_raft_resources { public: - using pool_mr_type = rmm::mr::pool_memory_resource; - using mr_type = rmm::mr::failure_callback_resource_adaptor; + using pool_mr_type = rmm::mr::pool_memory_resource; + using mr_type = rmm::mr::failure_callback_resource_adaptor; + using large_mr_type = rmm::mr::managed_memory_resource; shared_raft_resources() try : orig_resource_{rmm::mr::get_current_device_resource()}, pool_resource_(orig_resource_, 1024 * 1024 * 1024ull), - resource_(&pool_resource_, rmm_oom_callback, nullptr) { + resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() { rmm::mr::set_current_device_resource(&resource_); } catch (const std::exception& e) { auto cuda_status = cudaGetLastError(); @@ -104,10 +106,16 @@ class shared_raft_resources { ~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); } + auto get_large_memory_resource() noexcept + { + return static_cast(&large_mr_); + } + private: rmm::mr::device_memory_resource* orig_resource_; pool_mr_type pool_resource_; mr_type resource_; + large_mr_type large_mr_; }; /** @@ -130,6 +138,12 @@ class configured_raft_resources { res_{std::make_unique( rmm::cuda_stream_view(get_stream_from_global_pool()))} { + // set the large workspace resource to the raft handle, but without the deleter + // (this resource is managed by the shared_res). + raft::resource::set_large_workspace_resource( + *res_, + std::shared_ptr(shared_res_->get_large_memory_resource(), + raft::void_op{})); } /** Default constructor creates all resources anew. */ diff --git a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu index a7495c23a..a956ab139 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_benchmark.cu @@ -121,7 +121,7 @@ auto create_search_param(const std::string& algo_name, const nlohmann::json& con }; // namespace cuvs::bench REGISTER_ALGO_INSTANCE(float); -// REGISTER_ALGO_INSTANCE(half); +REGISTER_ALGO_INSTANCE(half); REGISTER_ALGO_INSTANCE(std::int8_t); REGISTER_ALGO_INSTANCE(std::uint8_t); diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_half.cu b/cpp/bench/ann/src/cuvs/cuvs_cagra_half.cu index 6768034a2..b4a3235c4 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_half.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_half.cu @@ -16,5 +16,5 @@ #include "cuvs_cagra_wrapper.h" namespace cuvs::bench { -// template class cuvs_cagra; +template class cuvs_cagra; } // namespace cuvs::bench diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h index 9ca41cab0..ff854f890 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h @@ -289,7 +289,11 @@ void cuvs_cagra::save(const std::string& file) const template void cuvs_cagra::save_to_hnswlib(const std::string& file) const { - cuvs::neighbors::cagra::serialize_to_hnswlib(handle_, file, *index_); + if constexpr (!std::is_same_v) { + cuvs::neighbors::cagra::serialize_to_hnswlib(handle_, file, *index_); + } else { + RAFT_FAIL("Cannot save fp16 index to hnswlib format"); + } } template diff --git a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq.cu b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq.cu index 3ffdd4a25..2df460966 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq.cu @@ -17,7 +17,7 @@ namespace cuvs::bench { template class cuvs_ivf_pq; -// template class cuvs_ivf_pq; +template class cuvs_ivf_pq; template class cuvs_ivf_pq; template class cuvs_ivf_pq; } // namespace cuvs::bench diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 75205fa4f..89b3acc24 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -153,7 +153,7 @@ struct balanced_params : base_params { * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::params params; + * cuvs::cluster::kmeans::params params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * @@ -203,7 +203,159 @@ void fit(raft::resources const& handle, * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::params params; + * cuvs::cluster::kmeans::params params; + * int64_t n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, + * n_features); + * + * kmeans::fit(handle, + * params, + * X, + * std::nullopt, + * centroids, + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm. + * Initial centroids are chosen with k-means++ algorithm. Empty + * clusters are reinitialized by choosing new centroids with + * k-means++ algorithm. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * + * kmeans::fit(handle, + * params, + * X, + * std::nullopt, + * centroids, + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm. + * Initial centroids are chosen with k-means++ algorithm. Empty + * clusters are reinitialized by choosing new centroids with + * k-means++ algorithm. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int64_t n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, + * n_features); + * + * kmeans::fit(handle, + * params, + * X, + * std::nullopt, + * centroids, + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm. + * Initial centroids are chosen with k-means++ algorithm. Empty + * clusters are reinitialized by choosing new centroids with + * k-means++ algorithm. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * @@ -250,7 +402,7 @@ void fit(raft::resources const& handle, * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; + * cuvs::cluster::kmeans::balanced_params params; * int n_features = 15; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * @@ -284,7 +436,7 @@ void fit(const raft::resources& handle, * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; + * cuvs::cluster::kmeans::balanced_params params; * int n_features = 15; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * @@ -308,7 +460,6 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids); - /** * @brief Predict the closest cluster each sample in X belongs to. * @@ -318,7 +469,7 @@ void fit(const raft::resources& handle, * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::params params; + * cuvs::cluster::kmeans::params params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * @@ -363,7 +514,7 @@ void predict(raft::resources const& handle, raft::device_matrix_view X, std::optional> sample_weight, raft::device_matrix_view centroids, - raft::device_vector_view labels, + raft::device_vector_view labels, bool normalize_weight, raft::host_scalar_view inertia); @@ -376,7 +527,7 @@ void predict(raft::resources const& handle, * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::params params; + * cuvs::cluster::kmeans::params params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * @@ -388,7 +539,7 @@ void predict(raft::resources const& handle, * raft::make_scalar_view(&inertia), * raft::make_scalar_view(&n_iter)); * ... - * auto labels = raft::make_device_vector(handle, X.extent(0)); + * auto labels = raft::make_device_vector(handle, X.extent(0)); * * kmeans::predict(handle, * params, @@ -404,18 +555,26 @@ void predict(raft::resources const& handle, * @param[in] params Parameters for KMeans model. * @param[in] X New data to predict. * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] * @param[in] centroids Cluster centroids. The data must be in * row-major format. * [dim = n_clusters x n_features] + * @param[in] normalize_weight True if the weights should be normalized * @param[out] labels Index of the cluster each sample in X * belongs to. * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to + * their closest cluster center. */ -void predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, +void predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, raft::device_matrix_view centroids, - raft::device_vector_view labels); + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); /** * @brief Predict the closest cluster each sample in X belongs to. @@ -426,22 +585,144 @@ void predict(const raft::resources& handle, * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; - * int n_features = 15; + * cuvs::cluster::kmeans::params params; + * int n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * + * kmeans::fit(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * ... + * auto labels = raft::make_device_vector(handle, X.extent(0)); + * + * kmeans::predict(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * false, + * labels.view(), + * raft::make_scalar_view(&ineratia)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X New data to predict. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[in] normalize_weight True if the weights should be normalized + * @param[out] labels Index of the cluster each sample in X + * belongs to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to + * their closest cluster center. + */ +void predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + +/** + * @brief Predict the closest cluster each sample in X belongs to. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * + * kmeans::fit(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * ... + * auto labels = raft::make_device_vector(handle, X.extent(0)); + * + * kmeans::predict(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * false, + * labels.view(), + * raft::make_scalar_view(&ineratia)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X New data to predict. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[in] centroids Cluster centroids. The data must be in + * row-major format. + * [dim = n_clusters x n_features] + * @param[in] normalize_weight True if the weights should be normalized + * @param[out] labels Index of the cluster each sample in X + * belongs to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to + * their closest cluster center. + */ +void predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + +/** + * @brief Predict the closest cluster each sample in X belongs to. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * * kmeans::fit(handle, * params, * X, - * centroids.view()); + * std::nullopt, + * centroids.view(), + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); * ... * auto labels = raft::make_device_vector(handle, X.extent(0)); * * kmeans::predict(handle, * params, * X, + * std::nullopt, * centroids.view(), - * labels.view()); + * false, + * labels.view(), + * raft::make_scalar_view(&ineratia)); * @endcode * * @param[in] handle The raft handle. @@ -457,7 +738,7 @@ void predict(const raft::resources& handle, */ void predict(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, + raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view labels); @@ -471,7 +752,7 @@ void predict(const raft::resources& handle, * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::params params; + * cuvs::cluster::kmeans::params params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * auto labels = raft::make_device_vector(handle, X.extent(0)); @@ -516,6 +797,171 @@ void fit_predict(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter); +/** + * @brief Compute k-means clustering and predicts cluster index for each sample + * in the input. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int64_t n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, + * n_features); auto labels = raft::make_device_vector(handle, X.extent(0)); + * + * kmeans::fit_predict(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * labels.view(), + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must be + * in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids Optional + * [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] labels Index of the cluster each sample in X belongs + * to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Compute k-means clustering and predicts cluster index for each sample + * in the input. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * auto labels = raft::make_device_vector(handle, X.extent(0)); + * + * kmeans::fit_predict(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * labels.view(), + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must be + * in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids Optional + * [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] labels Index of the cluster each sample in X belongs + * to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Compute k-means clustering and predicts cluster index for each sample + * in the input. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int64_t n_features = 15, inertia, n_iter; + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, + * n_features); auto labels = raft::make_device_vector(handle, X.extent(0)); + * + * kmeans::fit_predict(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * labels.view(), + * raft::make_scalar_view(&inertia), + * raft::make_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must be + * in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids Optional + * [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] labels Index of the cluster each sample in X belongs + * to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + /** * @brief Compute balanced k-means clustering and predicts cluster index for each sample * in the input. @@ -526,7 +972,7 @@ void fit_predict(raft::resources const& handle, * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; + * cuvs::cluster::kmeans::balanced_params params; * int n_features = 15; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * auto labels = raft::make_device_vector(handle, X.extent(0)); @@ -570,7 +1016,7 @@ void fit_predict(const raft::resources& handle, * using namespace cuvs::cluster; * ... * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; + * cuvs::cluster::kmeans::balanced_params params; * int n_features = 15; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); * auto labels = raft::make_device_vector(handle, X.extent(0)); @@ -623,6 +1069,24 @@ void transform(raft::resources const& handle, raft::device_matrix_view centroids, raft::device_matrix_view X_new); +/** + * @brief Transform X to a cluster-distance space. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format + * [dim = n_samples x n_features] + * @param[in] centroids Cluster centroids. The data must be in row-major format. + * [dim = n_clusters x n_features] + * @param[out] X_new X transformed in the new space. + * [dim = n_samples x n_features] + */ +void transform(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_matrix_view X_new); /** * @} */ diff --git a/cpp/include/cuvs/core/c_api.h b/cpp/include/cuvs/core/c_api.h index 4db7fd12c..c8c8d3934 100644 --- a/cpp/include/cuvs/core/c_api.h +++ b/cpp/include/cuvs/core/c_api.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #ifdef __cplusplus @@ -138,10 +139,12 @@ cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes); * available memory * @param[in] max_pool_size_percent The maximum pool size as a percentage of the total * available memory + * @param[in] managed Whether to use a managed memory resource as upstream resource or not * @return cuvsError_t */ cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent, - int max_pool_size_percent); + int max_pool_size_percent, + bool managed); /** * @brief Resets the memory resource to use the default memory resource (cuda_memory_resource) * @return cuvsError_t diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index fec95b563..20db7e8b7 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -613,6 +613,78 @@ auto build(raft::resources const& res, * * @return the constructed cagra index */ +auto build(raft::resources const& res, + const cuvs::neighbors::cagra::index_params& params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::cagra::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * The build consist of two steps: build an intermediate knn-graph, and optimize it to + * create the final graph. The index_params struct controls the node degree of these + * graphs. + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * // use default search parameters + * cagra::search_params search_params; + * // search K nearest neighbours + * auto neighbors = raft::make_device_matrix(res, n_queries, k); + * auto distances = raft::make_device_matrix(res, n_queries, k); + * cagra::search(res, search_params, index, queries, neighbors, distances); + * @endcode + * + * @param[in] res + * @param[in] params parameters for building the index + * @param[in] dataset a matrix view (host) to a row-major matrix [n_rows, dim] + * + * @return the constructed cagra index + */ +auto build(raft::resources const& res, + const cuvs::neighbors::cagra::index_params& params, + raft::host_matrix_view dataset) + -> cuvs::neighbors::cagra::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * The build consist of two steps: build an intermediate knn-graph, and optimize it to + * create the final graph. The index_params struct controls the node degree of these + * graphs. + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * cagra::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = cagra::build(res, index_params, dataset); + * // use default search parameters + * cagra::search_params search_params; + * // search K nearest neighbours + * auto neighbors = raft::make_device_matrix(res, n_queries, k); + * auto distances = raft::make_device_matrix(res, n_queries, k); + * cagra::search(res, search_params, index, queries, neighbors, distances); + * @endcode + * + * @param[in] res + * @param[in] params parameters for building the index + * @param[in] dataset a matrix view (device) to a row-major matrix [n_rows, dim] + * + * @return the constructed cagra index + */ auto build(raft::resources const& res, const cuvs::neighbors::cagra::index_params& params, raft::device_matrix_view dataset) @@ -975,9 +1047,6 @@ void extend( * * See the [cagra::build](#cagra::build) documentation for a usage example. * - * @tparam T data element type - * @tparam IdxT type of the indices - * * @param[in] res raft resources * @param[in] params configure the search * @param[in] idx cagra index @@ -1000,8 +1069,26 @@ void search(raft::resources const& res, * * See the [cagra::build](#cagra::build) documentation for a usage example. * - * @tparam T data element type - * @tparam IdxT type of the indices + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] index cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +void search(raft::resources const& res, + cuvs::neighbors::cagra::search_params const& params, + const cuvs::neighbors::cagra::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + +/** + * @brief Search ANN using the constructed index. + * + * See the [cagra::build](#cagra::build) documentation for a usage example. * * @param[in] res raft resources * @param[in] params configure the search @@ -1024,9 +1111,6 @@ void search(raft::resources const& res, * * See the [cagra::build](#cagra::build) documentation for a usage example. * - * @tparam T data element type - * @tparam IdxT type of the indices - * * @param[in] res raft resources * @param[in] params configure the search * @param[in] index cagra index @@ -1156,6 +1240,111 @@ void serialize(raft::resources const& handle, void deserialize(raft::resources const& handle, std::istream& is, cuvs::neighbors::cagra::index* index); +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = cuvs::neighbors::cagra::build(...);` + * cuvs::neighbors::cagra::serialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + * + */ +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::cagra::index& index, + bool include_dataset = true); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + + * cuvs::neighbors::cagra::index index; + * cuvs::neighbors::cagra::deserialize(handle, filename, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index the cagra index + */ +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::cagra::index* index); + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cuvs::neighbors::cagra::build(...);` + * cuvs::neighbors::cagra::serialize(handle, os, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ +void serialize(raft::resources const& handle, + std::ostream& os, + const cuvs::neighbors::cagra::index& index, + bool include_dataset = true); + +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * cuvs::neighbors::cagra::index index; + * cuvs::neighbors::cagra::deserialize(handle, is, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] is input stream + * @param[out] index the cagra index + */ +void deserialize(raft::resources const& handle, + std::istream& is, + cuvs::neighbors::cagra::index* index); /** * Save the index to file. diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index b2db96686..8c378b1f0 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -16,6 +16,8 @@ #pragma once +#include + #include #include @@ -547,6 +549,52 @@ void build(raft::resources const& handle, raft::device_matrix_view dataset, cuvs::neighbors::ivf_pq::index* idx); +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-pq index + */ +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_pq::index index; + * ivf_pq::build(handle, index_params, dataset, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_pq::index + * + */ +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::device_matrix_view dataset, + cuvs::neighbors::ivf_pq::index* idx); /** * @brief Build the index from the dataset for efficient search. * @@ -726,6 +774,53 @@ void build(raft::resources const& handle, * * @return the constructed ivf-pq index */ +auto build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::host_matrix_view dataset) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_pq::index index; + * ivf_pq::build(handle, index_params, dataset, index); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset raft::host_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_pq::index + * + */ +void build(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index_params& index_params, + raft::host_matrix_view dataset, + cuvs::neighbors::ivf_pq::index* idx); + +/** + * @brief Build the index from the dataset for efficient search. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset); + * @endcode + * + * @param[in] handle + * @param[in] index_params configure the index building + * @param[in] dataset a host_matrix_view to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-pq index + */ auto build(raft::resources const& handle, const cuvs::neighbors::ivf_pq::index_params& index_params, raft::host_matrix_view dataset) @@ -887,6 +982,62 @@ void extend(raft::resources const& handle, std::optional> new_indices, cuvs::neighbors::ivf_pq::index* idx); +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_pq::extend(handle, new_vectors, no_op, index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +auto extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const cuvs::neighbors::ivf_pq::index& idx) + -> cuvs::neighbors::ivf_pq::index; + +/** + * @brief Extend the index with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_pq::extend(handle, new_vectors, no_op, &index_empty); + * @endcode + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +void extend(raft::resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + cuvs::neighbors::ivf_pq::index* idx); /** * @brief Extend the index with the new data. * @@ -1257,6 +1408,47 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances); +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @code{.cpp} + * ... + * // use default search parameters + * ivf_pq::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_pq::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_pq::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_pq::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @param[in] handle + * @param[in] search_params configure the search + * @param[in] index ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +void search(raft::resources const& handle, + const cuvs::neighbors::ivf_pq::search_params& search_params, + cuvs::neighbors::ivf_pq::index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances); + /** * @brief Search ANN using the constructed index. * @@ -1372,6 +1564,39 @@ void search_with_filtering( raft::device_matrix_view distances, cuvs::neighbors::filtering::bitset_filter sample_filter); +/** + * @brief Search ANN using the constructed index with the given filter. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + * @param[in] sample_filter a device bitset filter function that greenlights samples for a given + * query. + */ +void search_with_filtering( + raft::resources const& handle, + const search_params& params, + index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + cuvs::neighbors::filtering::bitset_filter sample_filter); + /** * @brief Search ANN using the constructed index with the given filter. * @@ -1924,6 +2149,12 @@ void reconstruct_list_data(raft::resources const& res, uint32_t label, uint32_t offset); +void reconstruct_list_data(raft::resources const& res, + const index& index, + raft::device_matrix_view out_vectors, + uint32_t label, + uint32_t offset); + void reconstruct_list_data(raft::resources const& res, const index& index, raft::device_matrix_view out_vectors, @@ -1972,6 +2203,11 @@ void reconstruct_list_data(raft::resources const& res, raft::device_vector_view in_cluster_indices, raft::device_matrix_view out_vectors, uint32_t label); +void reconstruct_list_data(raft::resources const& res, + const index& index, + raft::device_vector_view in_cluster_indices, + raft::device_matrix_view out_vectors, + uint32_t label); void reconstruct_list_data(raft::resources const& res, const index& index, raft::device_vector_view in_cluster_indices, diff --git a/cpp/include/cuvs/neighbors/nn_descent.hpp b/cpp/include/cuvs/neighbors/nn_descent.hpp index 9f4300177..347ccf889 100644 --- a/cpp/include/cuvs/neighbors/nn_descent.hpp +++ b/cpp/include/cuvs/neighbors/nn_descent.hpp @@ -27,6 +27,8 @@ #include +#include + namespace cuvs::neighbors::nn_descent { /** * @defgroup nn_descent_cpp_index_params The nn-descent algorithm parameters. @@ -237,6 +239,68 @@ auto build(raft::resources const& res, raft::host_matrix_view dataset) -> cuvs::neighbors::nn_descent::index; +/** + * @brief Build nn-descent Index with dataset in device memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = nn_descent::build(res, index_params, dataset); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @param[in] res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::device_matrix_view input dataset expected to be located + * in device memory + * @return index index containing all-neighbors knn graph in host memory + */ +auto build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::nn_descent::index; + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::host_matrix_view dataset + * auto index = cagra::build(res, index_params, dataset); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @return index index containing all-neighbors knn graph in host memory + */ +auto build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset) + -> cuvs::neighbors::nn_descent::index; + /** * @brief Build nn-descent Index with dataset in device memory * diff --git a/cpp/include/cuvs/neighbors/reachability.hpp b/cpp/include/cuvs/neighbors/reachability.hpp new file mode 100644 index 000000000..9746ac856 --- /dev/null +++ b/cpp/include/cuvs/neighbors/reachability.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace cuvs::neighbors::reachability { + +/** + * @defgroup reachability_cpp Mutual Reachability + * @{ + */ +/** + * Constructs a mutual reachability graph, which is a k-nearest neighbors + * graph projected into mutual reachability space using the following + * function for each data point, where core_distance is the distance + * to the kth neighbor: max(core_distance(a), core_distance(b), d(a, b)) + * + * Unfortunately, points in the tails of the pdf (e.g. in sparse regions + * of the space) can have very large neighborhoods, which will impact + * nearby neighborhoods. Because of this, it's possible that the + * radius for points in the main mass, which might have a very small + * radius initially, to expand very large. As a result, the initial + * knn which was used to compute the core distances may no longer + * capture the actual neighborhoods after projection into mutual + * reachability space. + * + * For the experimental version, we execute the knn twice- once + * to compute the radii (core distances) and again to capture + * the final neighborhoods. Future iterations of this algorithm + * will work improve upon this "exact" version, by using + * more specialized data structures, such as space-partitioning + * structures. It has also been shown that approximate nearest + * neighbors can yield reasonable neighborhoods as the + * data sizes increase. + * + * @param[in] handle raft handle for resource reuse + * @param[in] X input data points (size m * n) + * @param[in] min_samples this neighborhood will be selected for core distances + * @param[out] indptr CSR indptr of output knn graph (size m + 1) + * @param[out] core_dists output core distances array (size m) + * @param[out] out COO object, uninitialized on entry, on exit it stores the + * (symmetrized) maximum reachability distance for the k nearest + * neighbors. + * @param[in] metric distance metric to use, default Euclidean + * @param[in] alpha weight applied when internal distance is chosen for + * mutual reachability (value of 1.0 disables the weighting) + */ +void mutual_reachability_graph( + const raft::resources& handle, + raft::device_matrix_view X, + int min_samples, + raft::device_vector_view indptr, + raft::device_vector_view core_dists, + raft::sparse::COO& out, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded, + float alpha = 1.0); +/** + * @} + */ +} // namespace cuvs::neighbors::reachability diff --git a/cpp/include/cuvs/selection/select_k.hpp b/cpp/include/cuvs/selection/select_k.hpp index dc34caf41..e4dfdb12c 100644 --- a/cpp/include/cuvs/selection/select_k.hpp +++ b/cpp/include/cuvs/selection/select_k.hpp @@ -87,6 +87,16 @@ void select_k( SelectAlgo algo = SelectAlgo::kAuto, std::optional> len_i = std::nullopt); +void select_k(raft::resources const& handle, + raft::device_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, + bool select_min, + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + std::optional> len_i = std::nullopt); + /** * Select k smallest or largest key/values from each row in the input data. * diff --git a/cpp/src/cluster/detail/connectivities.cuh b/cpp/src/cluster/detail/connectivities.cuh index ada424192..e61c9166f 100644 --- a/cpp/src/cluster/detail/connectivities.cuh +++ b/cpp/src/cluster/detail/connectivities.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../../distance/distance.cuh" +#include "./kmeans_common.cuh" #include #include #include @@ -153,7 +153,11 @@ void pairwise_distances(const raft::resources& handle, // TODO: It would ultimately be nice if the MST could accept // dense inputs directly so we don't need to double the memory // usage to hand it a sparse array here. - distance::pairwise_distance(handle, X, X, data, m, m, n, metric); + auto X_view = raft::make_device_matrix_view(X, m, n); + + cuvs::cluster::kmeans::detail::pairwise_distance_kmeans( + handle, X_view, X_view, raft::make_device_matrix_view(data, m, m), metric); + // self-loops get max distance auto transform_in = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data)); diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index e7d4bdf76..9b673bca3 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -198,7 +198,7 @@ void kmeansPlusPlus(raft::resources const& handle, // Output - pwd [n_trials x n_samples] auto pwd = distBuffer.view(); cuvs::cluster::kmeans::detail::pairwise_distance_kmeans( - handle, centroidCandidates.view(), X, pwd, workspace, metric); + handle, centroidCandidates.view(), X, pwd, metric); // Update nearest cluster distance for each centroid candidate // Note pwd and minDistBuf points to same buffer which currently holds pairwise distance values. @@ -1247,7 +1247,7 @@ void kmeans_transform(raft::resources const& handle, // calculate pairwise distance between cluster centroids and current batch // of input dataset pairwise_distance_kmeans( - handle, datasetView, centroids, pairwiseDistanceView, workspace, metric); + handle, datasetView, centroids, pairwiseDistanceView, metric); } } diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 04c1a6802..eec71b5d2 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -293,7 +293,6 @@ void pairwise_distance_kmeans(raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_matrix_view pairwiseDistance, - rmm::device_uvector& workspace, cuvs::distance::DistanceType metric) { auto n_samples = X.extent(0); @@ -303,15 +302,23 @@ void pairwise_distance_kmeans(raft::resources const& handle, ASSERT(X.extent(1) == centroids.extent(1), "# features in dataset and centroids are different (must be same)"); - cuvs::distance::pairwise_distance(handle, - X.data_handle(), - centroids.data_handle(), - pairwiseDistance.data_handle(), - n_samples, - n_clusters, - n_features, - workspace, - metric); + if (metric == cuvs::distance::DistanceType::L2Expanded) { + cuvs::distance::distance(handle, X, centroids, pairwiseDistance); + } else if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + cuvs::distance::distance(handle, X, centroids, pairwiseDistance); + } else { + RAFT_FAIL("kmeans requires L2Expanded or L2SqrtExpanded distance, have %i", metric); + } } // shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores @@ -461,7 +468,7 @@ void minClusterAndDistanceCompute( // calculate pairwise distance between current tile of cluster centroids // and input dataset pairwise_distance_kmeans( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); + handle, datasetView, centroidsView, pairwiseDistanceView, metric); // argmin reduction returning pair // calculates the closest centroid and the distance to the closest @@ -591,7 +598,7 @@ void minClusterDistanceCompute(raft::resources const& handle, // calculate pairwise distance between current tile of cluster centroids // and input dataset pairwise_distance_kmeans( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); + handle, datasetView, centroidsView, pairwiseDistanceView, metric); raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(), pairwiseDistanceView.data_handle(), diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh new file mode 100644 index 000000000..b0f435502 --- /dev/null +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -0,0 +1,781 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../kmeans.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace cuvs::cluster::kmeans::mg::detail { + +#define CUVS_LOG_KMEANS(handle, fmt, ...) \ + do { \ + bool isRoot = true; \ + if (raft::resource::comms_initialized(handle)) { \ + const auto& comm = raft::resource::get_comms(handle); \ + const int my_rank = comm.get_rank(); \ + isRoot = my_rank == 0; \ + } \ + if (isRoot) { RAFT_LOG_DEBUG(fmt, ##__VA_ARGS__); } \ + } while (0) + +template +struct KeyValueIndexOp { + __host__ __device__ __forceinline__ IndexT + operator()(const raft::KeyValuePair& a) const + { + return a.key; + } +}; + +#define KMEANS_COMM_ROOT 0 + +static cuvs::cluster::kmeans::params default_params; + +// Selects 'n_clusters' samples randomly from X +template +void initRandom(const raft::resources& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids) +{ + const auto& comm = raft::resource::get_comms(handle); + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_local_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + + const int my_rank = comm.get_rank(); + const int n_ranks = comm.get_size(); + + std::vector nCentroidsSampledByRank(n_ranks, 0); + std::vector nCentroidsElementsToReceiveFromRank(n_ranks, 0); + + const int nranks_reqd = std::min(n_ranks, n_clusters); + ASSERT(KMEANS_COMM_ROOT < nranks_reqd, "KMEANS_COMM_ROOT must be in [0, %d)\n", nranks_reqd); + + for (int rank = 0; rank < nranks_reqd; ++rank) { + int nCentroidsSampledInRank = n_clusters / nranks_reqd; + if (rank == KMEANS_COMM_ROOT) { + nCentroidsSampledInRank += n_clusters - nCentroidsSampledInRank * nranks_reqd; + } + nCentroidsSampledByRank[rank] = nCentroidsSampledInRank; + nCentroidsElementsToReceiveFromRank[rank] = nCentroidsSampledInRank * n_features; + } + + auto nCentroidsSampledInRank = nCentroidsSampledByRank[my_rank]; + ASSERT((IndexT)nCentroidsSampledInRank <= (IndexT)n_local_samples, + "# random samples requested from rank-%d is larger than the available " + "samples at the rank (requested is %lu, available is %lu)", + my_rank, + (size_t)nCentroidsSampledInRank, + (size_t)n_local_samples); + + auto centroidsSampledInRank = + raft::make_device_matrix(handle, nCentroidsSampledInRank, n_features); + + cuvs::cluster::kmeans::shuffle_and_gather( + handle, X, centroidsSampledInRank.view(), nCentroidsSampledInRank, params.rng_state.seed); + + std::vector displs(n_ranks); + thrust::exclusive_scan(thrust::host, + nCentroidsElementsToReceiveFromRank.begin(), + nCentroidsElementsToReceiveFromRank.end(), + displs.begin()); + + // gather centroids from all ranks + comm.allgatherv(centroidsSampledInRank.data_handle(), // sendbuff + centroids.data_handle(), // recvbuff + nCentroidsElementsToReceiveFromRank.data(), // recvcount + displs.data(), + stream); +} + +/* + * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm + * Scalable kmeans++ pseudocode + * 1: C = sample a point uniformly at random from X + * 2: psi = phi_X (C) + * 3: for O( log(psi) ) times do + * 4: C' = sample each point x in X independently with probability + * p_x = l * ( d^2(x, C) / phi_X (C) ) + * 5: C = C U C' + * 6: end for + * 7: For x in C, set w_x to be the number of points in X closer to x than any + * other point in C + * 8: Recluster the weighted points in C into k clusters + */ +template +void initKMeansPlusPlus(const raft::resources& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + raft::device_matrix_view centroidsRawData, + rmm::device_uvector& workspace) +{ + const auto& comm = raft::resource::get_comms(handle); + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + const int my_rank = comm.get_rank(); + const int n_rank = comm.get_size(); + + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + raft::random::RngState rng(params.rng_state.seed, raft::random::GeneratorType::GenPhilox); + + // <<<< Step-1 >>> : C <- sample a point uniformly at random from X + // 1.1 - Select a rank r' at random from the available n_rank ranks with a + // probability of 1/n_rank [Note - with same seed all rank selects + // the same r' which avoids a call to comm] + // 1.2 - Rank r' samples a point uniformly at random from the local dataset + // X which will be used as the initial centroid for kmeans++ + // 1.3 - Communicate the initial centroid chosen by rank-r' to all other + // ranks + std::mt19937 gen(params.rng_state.seed); + std::uniform_int_distribution<> dis(0, n_rank - 1); + int rp = dis(gen); + + // buffer to flag the sample that is chosen as initial centroids + std::vector h_isSampleCentroid(n_samples); + std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); + + auto initialCentroid = raft::make_device_matrix(handle, 1, n_features); + CUVS_LOG_KMEANS( + handle, "@Rank-%d : KMeans|| : initial centroid is sampled at rank-%d\n", my_rank, rp); + + // 1.2 - Rank r' samples a point uniformly at random from the local dataset + // X which will be used as the initial centroid for kmeans++ + if (my_rank == rp) { + std::mt19937 gen(params.rng_state.seed); + std::uniform_int_distribution<> dis(0, n_samples - 1); + + int cIdx = dis(gen); + auto centroidsView = raft::make_device_matrix_view( + X.data_handle() + cIdx * n_features, 1, n_features); + + raft::copy( + initialCentroid.data_handle(), centroidsView.data_handle(), centroidsView.size(), stream); + + h_isSampleCentroid[cIdx] = 1; + } + + // 1.3 - Communicate the initial centroid chosen by rank-r' to all other ranks + comm.bcast(initialCentroid.data_handle(), initialCentroid.size(), rp, stream); + + // device buffer to flag the sample that is chosen as initial centroid + auto isSampleCentroid = raft::make_device_vector(handle, n_samples); + + raft::copy( + isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); + + rmm::device_uvector centroidsBuf(0, stream); + + // reset buffer to store the chosen centroid + centroidsBuf.resize(initialCentroid.size(), stream); + raft::copy(centroidsBuf.begin(), initialCentroid.data_handle(), initialCentroid.size(), stream); + + auto potentialCentroids = raft::make_device_matrix_view( + centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1)); + // <<< End of Step-1 >>> + + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + // L2 norm of X: ||x||^2 + auto L2NormX = raft::make_device_vector(handle, n_samples); + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm(L2NormX.data_handle(), + X.data_handle(), + X.extent(1), + X.extent(0), + raft::linalg::L2Norm, + true, + stream); + } + + auto minClusterDistance = raft::make_device_vector(handle, n_samples); + auto uniformRands = raft::make_device_vector(handle, n_samples); + + // <<< Step-2 >>>: psi <- phi_X (C) + auto clusterCost = raft::make_device_scalar(handle, 0); + + cuvs::cluster::kmeans::min_cluster_distance(handle, + X, + potentialCentroids, + minClusterDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + params.metric, + params.batch_samples, + params.batch_centroids, + workspace); + + // compute partial cluster cost from the samples in rank + cuvs::cluster::kmeans::cluster_cost( + handle, + minClusterDistance.view(), + workspace, + clusterCost.view(), + cuda::proclaim_return_type( + [] __device__(const DataT& a, const DataT& b) { return a + b; })); + + // compute total cluster cost by accumulating the partial cost from all the + // ranks + comm.allreduce( + clusterCost.data_handle(), clusterCost.data_handle(), 1, raft::comms::op_t::SUM, stream); + + DataT psi = 0; + raft::copy(&psi, clusterCost.data_handle(), 1, stream); + + // <<< End of Step-2 >>> + + ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, + "An error occurred in the distributed operation. This can result from " + "a failed rank"); + + // Scalable kmeans++ paper claims 8 rounds is sufficient + int niter = std::min(8, (int)ceil(log(psi))); + CUVS_LOG_KMEANS(handle, + "@Rank-%d:KMeans|| :phi - %f, max # of iterations for kmeans++ loop - " + "%d\n", + my_rank, + psi, + niter); + + // <<<< Step-3 >>> : for O( log(psi) ) times do + for (int iter = 0; iter < niter; ++iter) { + CUVS_LOG_KMEANS(handle, + "@Rank-%d:KMeans|| - Iteration %d: # potential centroids sampled - " + "%d\n", + my_rank, + iter, + potentialCentroids.extent(0)); + + cuvs::cluster::kmeans::min_cluster_distance(handle, + X, + potentialCentroids, + minClusterDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + params.metric, + params.batch_samples, + params.batch_centroids, + workspace); + + cuvs::cluster::kmeans::cluster_cost( + handle, + minClusterDistance.view(), + workspace, + clusterCost.view(), + cuda::proclaim_return_type( + [] __device__(const DataT& a, const DataT& b) { return a + b; })); + comm.allreduce( + clusterCost.data_handle(), clusterCost.data_handle(), 1, raft::comms::op_t::SUM, stream); + raft::copy(&psi, clusterCost.data_handle(), 1, stream); + ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, + "An error occurred in the distributed operation. This can result " + "from a failed rank"); + + // <<<< Step-4 >>> : Sample each point x in X independently and identify new + // potentialCentroids + raft::random::uniform( + handle, rng, uniformRands.data_handle(), uniformRands.extent(0), (DataT)0, (DataT)1); + cuvs::cluster::kmeans::SamplingOp select_op(psi, + params.oversampling_factor, + n_clusters, + uniformRands.data_handle(), + isSampleCentroid.data_handle()); + + rmm::device_uvector inRankCp(0, stream); + cuvs::cluster::kmeans::sample_centroids(handle, + X, + minClusterDistance.view(), + isSampleCentroid.view(), + select_op, + inRankCp, + workspace); + /// <<<< End of Step-4 >>>> + + int* nPtsSampledByRank; + RAFT_CUDA_TRY(cudaMallocHost(&nPtsSampledByRank, n_rank * sizeof(int))); + + /// <<<< Step-5 >>> : C = C U C' + // append the data in Cp from all ranks to the buffer holding the + // potentialCentroids + // RAFT_CUDA_TRY(cudaMemsetAsync(nPtsSampledByRank, 0, n_rank * sizeof(int), stream)); + std::fill(nPtsSampledByRank, nPtsSampledByRank + n_rank, 0); + nPtsSampledByRank[my_rank] = inRankCp.size() / n_features; + comm.allgather(&(nPtsSampledByRank[my_rank]), nPtsSampledByRank, 1, stream); + ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, + "An error occurred in the distributed operation. This can result " + "from a failed rank"); + + auto nPtsSampled = + thrust::reduce(thrust::host, nPtsSampledByRank, nPtsSampledByRank + n_rank, 0); + + // gather centroids from all ranks + std::vector sizes(n_rank); + thrust::transform( + thrust::host, nPtsSampledByRank, nPtsSampledByRank + n_rank, sizes.begin(), [&](int val) { + return val * n_features; + }); + + RAFT_CUDA_TRY_NO_THROW(cudaFreeHost(nPtsSampledByRank)); + + std::vector displs(n_rank); + thrust::exclusive_scan(thrust::host, sizes.begin(), sizes.end(), displs.begin()); + + centroidsBuf.resize(centroidsBuf.size() + nPtsSampled * n_features, stream); + comm.allgatherv(inRankCp.data(), + centroidsBuf.end() - nPtsSampled * n_features, + sizes.data(), + displs.data(), + stream); + + auto tot_centroids = potentialCentroids.extent(0) + nPtsSampled; + potentialCentroids = + raft::make_device_matrix_view(centroidsBuf.data(), tot_centroids, n_features); + /// <<<< End of Step-5 >>> + } /// <<<< Step-6 >>> + + CUVS_LOG_KMEANS(handle, + "@Rank-%d:KMeans||: # potential centroids sampled - %d\n", + my_rank, + potentialCentroids.extent(0)); + + if ((IndexT)potentialCentroids.extent(0) > (IndexT)n_clusters) { + // <<< Step-7 >>>: For x in C, set w_x to be the number of pts closest to X + // temporary buffer to store the sample count per cluster, destructor + // releases the resource + + auto weight = raft::make_device_vector(handle, potentialCentroids.extent(0)); + + cuvs::cluster::kmeans::count_samples_in_cluster( + handle, params, X, L2NormX.view(), potentialCentroids, workspace, weight.view()); + + // merge the local histogram from all ranks + comm.allreduce(weight.data_handle(), // sendbuff + weight.data_handle(), // recvbuff + weight.size(), // count + raft::comms::op_t::SUM, + stream); + + // <<< end of Step-7 >>> + + // Step-8: Recluster the weighted points in C into k clusters + // Note - reclustering step is duplicated across all ranks and with the same + // seed they should generate the same potentialCentroids + auto const_centroids = raft::make_device_matrix_view( + potentialCentroids.data_handle(), potentialCentroids.extent(0), potentialCentroids.extent(1)); + cuvs::cluster::kmeans::init_plus_plus( + handle, params, const_centroids, centroidsRawData, workspace); + + auto inertia = raft::make_host_scalar(0); + auto n_iter = raft::make_host_scalar(0); + auto weight_view = + raft::make_device_vector_view(weight.data_handle(), weight.extent(0)); + cuvs::cluster::kmeans::params params_copy = params; + params_copy.rng_state = default_params.rng_state; + + cuvs::cluster::kmeans::fit_main(handle, + params_copy, + const_centroids, + weight_view, + centroidsRawData, + inertia.view(), + n_iter.view(), + workspace); + + } else if ((IndexT)potentialCentroids.extent(0) < (IndexT)n_clusters) { + // supplement with random + auto n_random_clusters = n_clusters - potentialCentroids.extent(0); + CUVS_LOG_KMEANS(handle, + "[Warning!] KMeans||: found fewer than %d centroids during " + "initialization (found %d centroids, remaining %d centroids will be " + "chosen randomly from input samples)\n", + n_clusters, + potentialCentroids.extent(0), + n_random_clusters); + + // generate `n_random_clusters` centroids + cuvs::cluster::kmeans::params rand_params = params; + rand_params.rng_state = default_params.rng_state; + rand_params.init = cuvs::cluster::kmeans::params::InitMethod::Random; + rand_params.n_clusters = n_random_clusters; + initRandom(handle, rand_params, X, centroidsRawData); + + // copy centroids generated during kmeans|| iteration to the buffer + raft::copy(centroidsRawData.data_handle() + n_random_clusters * n_features, + potentialCentroids.data_handle(), + potentialCentroids.size(), + stream); + + } else { + // found the required n_clusters + raft::copy(centroidsRawData.data_handle(), + potentialCentroids.data_handle(), + potentialCentroids.size(), + stream); + } +} + +template +void checkWeights(const raft::resources& handle, + rmm::device_uvector& workspace, + raft::device_vector_view weight) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + rmm::device_scalar wt_aggr(stream); + + const auto& comm = raft::resource::get_comms(handle); + + auto n_samples = weight.extent(0); + size_t temp_storage_bytes = 0; + RAFT_CUDA_TRY(cub::DeviceReduce::Sum( + nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); + + workspace.resize(temp_storage_bytes, stream); + + RAFT_CUDA_TRY(cub::DeviceReduce::Sum( + workspace.data(), temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); + + comm.allreduce(wt_aggr.data(), // sendbuff + wt_aggr.data(), // recvbuff + 1, // count + raft::comms::op_t::SUM, + stream); + DataT wt_sum = wt_aggr.value(stream); + raft::resource::sync_stream(handle, stream); + + if (wt_sum != n_samples) { + CUVS_LOG_KMEANS(handle, + "[Warning!] KMeans: normalizing the user provided sample weights to " + "sum up to %d samples", + n_samples); + + DataT scale = n_samples / wt_sum; + raft::linalg::unaryOp( + weight.data_handle(), + weight.data_handle(), + weight.size(), + cuda::proclaim_return_type([=] __device__(const DataT& wt) { return wt * scale; }), + stream); + } +} + +template +void fit(const raft::resources& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter, + rmm::device_uvector& workspace) +{ + const auto& comm = raft::resource::get_comms(handle); + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + auto weight = raft::make_device_vector(handle, n_samples); + if (sample_weight) { + raft::copy(weight.data_handle(), sample_weight->data_handle(), n_samples, stream); + } else { + thrust::fill(raft::resource::get_thrust_policy(handle), + weight.data_handle(), + weight.data_handle() + weight.size(), + 1); + } + + // check if weights sum up to n_samples + checkWeights(handle, workspace, weight.view()); + + if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + // initializing with random samples from input dataset + CUVS_LOG_KMEANS(handle, + "KMeans.fit: initialize cluster centers by randomly choosing from the " + "input data.\n"); + initRandom(handle, params, X, centroids); + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + // default method to initialize is kmeans++ + CUVS_LOG_KMEANS(handle, "KMeans.fit: initialize cluster centers using k-means++ algorithm.\n"); + initKMeansPlusPlus(handle, params, X, centroids, workspace); + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { + CUVS_LOG_KMEANS(handle, + "KMeans.fit: initialize cluster centers from the ndarray array input " + "passed to init argument.\n"); + + } else { + THROW("unknown initialization method to select initial centers"); + } + + // stores (key, value) pair corresponding to each sample where + // - key is the index of nearest cluster + // - value is the distance to the nearest cluster + auto minClusterAndDistance = + raft::make_device_vector, IndexT>(handle, n_samples); + + // temporary buffer to store L2 norm of centroids or distance matrix, + // destructor releases the resource + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + // temporary buffer to store intermediate centroids, destructor releases the + // resource + auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); + + // temporary buffer to store the weights per cluster, destructor releases + // the resource + auto wtInCluster = raft::make_device_vector(handle, n_clusters); + + // L2 norm of X: ||x||^2 + auto L2NormX = raft::make_device_vector(handle, n_samples); + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm(L2NormX.data_handle(), + X.data_handle(), + X.extent(1), + X.extent(0), + raft::linalg::L2Norm, + true, + stream); + } + + DataT priorClusteringCost = 0; + for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { + CUVS_LOG_KMEANS(handle, + "KMeans.fit: Iteration-%d: fitting the model using the initialize " + "cluster centers\n", + n_iter[0]); + + auto const_centroids = raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)); + // computes minClusterAndDistance[0:n_samples) where + // minClusterAndDistance[i] is a pair where + // 'key' is index to an sample in 'centroids' (index of the nearest + // centroid) and 'value' is the distance between the sample 'X[i]' and the + // 'centroid[key]' + cuvs::cluster::kmeans::min_cluster_and_distance(handle, + X, + const_centroids, + minClusterAndDistance.view(), + L2NormX.view(), + L2NormBuf_OR_DistBuf, + params.metric, + params.batch_samples, + params.batch_centroids, + workspace); + + // Using TransformInputIteratorT to dereference an array of + // cub::KeyValuePair and converting them to just return the Key to be used + // in reduce_rows_by_key prims + KeyValueIndexOp conversion_op; + cub::TransformInputIterator, + raft::KeyValuePair*> + itr(minClusterAndDistance.data_handle(), conversion_op); + + workspace.resize(n_samples, stream); + + // Calculates weighted sum of all the samples assigned to cluster-i and + // store the result in newCentroids[i] + raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), + X.extent(1), + itr, + weight.data_handle(), + workspace.data(), + X.extent(0), + X.extent(1), + static_cast(n_clusters), + newCentroids.data_handle(), + stream); + + // Reduce weights by key to compute weight in each cluster + raft::linalg::reduce_cols_by_key(weight.data_handle(), + itr, + wtInCluster.data_handle(), + (IndexT)1, + (IndexT)weight.extent(0), + (IndexT)n_clusters, + stream); + + // merge the local histogram from all ranks + comm.allreduce(wtInCluster.data_handle(), // sendbuff + wtInCluster.data_handle(), // recvbuff + wtInCluster.size(), // count + raft::comms::op_t::SUM, + stream); + + // reduces newCentroids from all ranks + comm.allreduce(newCentroids.data_handle(), // sendbuff + newCentroids.data_handle(), // recvbuff + newCentroids.size(), // count + raft::comms::op_t::SUM, + stream); + + // Computes newCentroids[i] = newCentroids[i]/wtInCluster[i] where + // newCentroids[n_clusters x n_features] - 2D array, newCentroids[i] has + // sum of all the samples assigned to cluster-i + // wtInCluster[n_clusters] - 1D array, wtInCluster[i] contains # of + // samples in cluster-i. + // Note - when wtInCluster[i] is 0, newCentroid[i] is reset to 0 + + raft::linalg::matrixVectorOp( + newCentroids.data_handle(), + newCentroids.data_handle(), + wtInCluster.data_handle(), + newCentroids.extent(1), + newCentroids.extent(0), + true, + false, + cuda::proclaim_return_type([=] __device__(DataT mat, DataT vec) { + if (vec == 0) + return DataT(0); + else + return mat / vec; + }), + stream); + + // copy the centroids[i] to newCentroids[i] when wtInCluster[i] is 0 + cub::ArgIndexInputIterator itr_wt(wtInCluster.data_handle()); + raft::matrix::gather_if( + centroids.data_handle(), + centroids.extent(1), + centroids.extent(0), + itr_wt, + itr_wt, + wtInCluster.extent(0), + newCentroids.data_handle(), + cuda::proclaim_return_type( + [=] __device__(raft::KeyValuePair map) { // predicate + // copy when the # of samples in the cluster is 0 + if (map.value == 0) + return true; + else + return false; + }), + cuda::proclaim_return_type( + [=] __device__(raft::KeyValuePair map) { // map + return map.key; + }), + stream); + + // compute the squared norm between the newCentroids and the original + // centroids, destructor releases the resource + auto sqrdNorm = raft::make_device_scalar(handle, 1); + raft::linalg::mapThenSumReduce( + sqrdNorm.data_handle(), + newCentroids.size(), + cuda::proclaim_return_type([=] __device__(const DataT a, const DataT b) { + DataT diff = a - b; + return diff * diff; + }), + stream, + centroids.data_handle(), + newCentroids.data_handle()); + + DataT sqrdNormError = 0; + raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); + + raft::copy(centroids.data_handle(), newCentroids.data_handle(), newCentroids.size(), stream); + + bool done = false; + if (params.inertia_check) { + rmm::device_scalar> clusterCostD(stream); + + // calculate cluster cost phi_x(C) + cuvs::cluster::kmeans::cluster_cost( + handle, + minClusterAndDistance.view(), + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + cuda::proclaim_return_type>( + [] __device__(const raft::KeyValuePair& a, + const raft::KeyValuePair& b) { + raft::KeyValuePair res; + res.key = 0; + res.value = a.value + b.value; + return res; + })); + + // Cluster cost phi_x(C) from all ranks + comm.allreduce(&(clusterCostD.data()->value), + &(clusterCostD.data()->value), + 1, + raft::comms::op_t::SUM, + stream); + + DataT curClusteringCost = 0; + raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream); + + ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, + "An error occurred in the distributed operation. This can result " + "from a failed rank"); + ASSERT(curClusteringCost != (DataT)0.0, + "Too few points and centroids being found is getting 0 cost from " + "centers\n"); + + if (n_iter[0] > 0) { + DataT delta = curClusteringCost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; + } + priorClusteringCost = curClusteringCost; + } + + raft::resource::sync_stream(handle, stream); + if (sqrdNormError < params.tol) done = true; + + if (done) { + CUVS_LOG_KMEANS( + handle, "Threshold triggered after %d iterations. Terminating early.\n", n_iter[0]); + break; + } + } +} + +}; // namespace cuvs::cluster::kmeans::mg::detail diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 1d12142da..5e6d756cc 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -17,10 +17,12 @@ #include "detail/kmeans.cuh" #include "detail/kmeans_auto_find_k.cuh" +#include "kmeans_mg.hpp" #include #include #include #include +#include #include #include @@ -94,8 +96,13 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::kmeans_fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); + // use the mnmg kmeans fit if we have comms initialize, single gpu otherwise + if (raft::resource::comms_initialized(handle)) { + cuvs::cluster::kmeans::mg::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); + } else { + cuvs::cluster::kmeans::detail::kmeans_fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); + } } /** diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu new file mode 100644 index 000000000..4f193da09 --- /dev/null +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kmeans.cuh" +#include + +namespace cuvs::cluster::kmeans { + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} +} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 89862a46c..3888ae492 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,4 +30,16 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_mg_double.cu b/cpp/src/cluster/kmeans_fit_mg_double.cu new file mode 100644 index 000000000..15081dfba --- /dev/null +++ b/cpp/src/cluster/kmeans_fit_mg_double.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./detail/kmeans_mg.cuh" +#include "kmeans_mg.hpp" +#include + +namespace cuvs::cluster::kmeans::mg { + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); + + cuvs::cluster::kmeans::mg::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); +} + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); + + cuvs::cluster::kmeans::mg::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); +} +} // namespace cuvs::cluster::kmeans::mg diff --git a/cpp/src/cluster/kmeans_fit_mg_float.cu b/cpp/src/cluster/kmeans_fit_mg_float.cu new file mode 100644 index 000000000..54fbd6763 --- /dev/null +++ b/cpp/src/cluster/kmeans_fit_mg_float.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./detail/kmeans_mg.cuh" +#include "kmeans_mg.hpp" +#include + +namespace cuvs::cluster::kmeans::mg { + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); + + cuvs::cluster::kmeans::mg::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); +} + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); + + cuvs::cluster::kmeans::mg::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); +} +} // namespace cuvs::cluster::kmeans::mg diff --git a/cpp/src/cluster/kmeans_fit_predict_double.cu b/cpp/src/cluster/kmeans_fit_predict_double.cu new file mode 100644 index 000000000..28a1d70c0 --- /dev/null +++ b/cpp/src/cluster/kmeans_fit_predict_double.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kmeans.cuh" +#include + +namespace cuvs::cluster::kmeans { + +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) + +{ + cuvs::cluster::kmeans::fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} + +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) + +{ + cuvs::cluster::kmeans::fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} +} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_predict_float.cu b/cpp/src/cluster/kmeans_fit_predict_float.cu index f043f7624..be3652db5 100644 --- a/cpp/src/cluster/kmeans_fit_predict_float.cu +++ b/cpp/src/cluster/kmeans_fit_predict_float.cu @@ -32,4 +32,18 @@ void fit_predict(raft::resources const& handle, cuvs::cluster::kmeans::fit_predict( handle, params, X, sample_weight, centroids, labels, inertia, n_iter); } + +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) + +{ + cuvs::cluster::kmeans::fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_mg.hpp b/cpp/src/cluster/kmeans_mg.hpp new file mode 100644 index 000000000..34f38314a --- /dev/null +++ b/cpp/src/cluster/kmeans_mg.hpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include + +namespace cuvs::cluster::kmeans::mg { + +/** + * @brief MNMG kmeans fit + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X. + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); +} // namespace cuvs::cluster::kmeans::mg diff --git a/cpp/src/cluster/kmeans_predict_double.cu b/cpp/src/cluster/kmeans_predict_double.cu new file mode 100644 index 000000000..1fcc393ac --- /dev/null +++ b/cpp/src/cluster/kmeans_predict_double.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kmeans.cuh" +#include + +namespace cuvs::cluster::kmeans { + +void predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) + +{ + cuvs::cluster::kmeans::predict( + handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); +} + +void predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) + +{ + cuvs::cluster::kmeans::predict( + handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); +} +} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_predict_float.cu b/cpp/src/cluster/kmeans_predict_float.cu index d092152f1..b5f9f9e51 100644 --- a/cpp/src/cluster/kmeans_predict_float.cu +++ b/cpp/src/cluster/kmeans_predict_float.cu @@ -32,4 +32,17 @@ void predict(raft::resources const& handle, cuvs::cluster::kmeans::predict( handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } +void predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) + +{ + cuvs::cluster::kmeans::predict( + handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_transform_double.cu b/cpp/src/cluster/kmeans_transform_double.cu new file mode 100644 index 000000000..4a026812e --- /dev/null +++ b/cpp/src/cluster/kmeans_transform_double.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kmeans.cuh" +#include + +namespace cuvs::cluster::kmeans { + +void transform(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_matrix_view X_new) + +{ + cuvs::cluster::kmeans::transform(handle, params, X, centroids, X_new); +} +} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/core/c_api.cpp b/cpp/src/core/c_api.cpp index a75e5a1dd..cfbeed2d5 100644 --- a/cpp/src/core/c_api.cpp +++ b/cpp/src/core/c_api.cpp @@ -21,6 +21,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -83,10 +86,14 @@ extern "C" cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes) }); } -thread_local std::unique_ptr> pool_mr; +thread_local std::shared_ptr< + rmm::mr::owning_wrapper, + rmm::mr::device_memory_resource>> + pool_mr; extern "C" cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent, - int max_pool_size_percent) + int max_pool_size_percent, + bool managed) { return cuvs::core::translate_exceptions([=] { // Upstream memory resource needs to be a cuda_memory_resource @@ -95,10 +102,22 @@ extern "C" cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_per if (cuda_mr_casted == nullptr) { throw std::runtime_error("Current memory resource is not a cuda_memory_resource"); } + auto initial_size = rmm::percent_of_free_device_memory(initial_pool_size_percent); auto max_size = rmm::percent_of_free_device_memory(max_pool_size_percent); - pool_mr = std::make_unique>( - cuda_mr_casted, initial_size, max_size); + + auto mr = std::shared_ptr(); + if (managed) { + mr = std::static_pointer_cast( + std::make_shared()); + } else { + mr = std::static_pointer_cast( + std::make_shared()); + } + + pool_mr = + rmm::mr::make_owning_wrapper(mr, initial_size, max_size); + rmm::mr::set_current_device_resource(pool_mr.get()); }); } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh index 3107f0fa4..4fd194f6c 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -111,6 +111,8 @@ instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( cuvs::distance::detail::ops::l1_distance_op, int); instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( cuvs::distance::detail::ops::l2_exp_distance_op, int); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::l2_exp_distance_op, int64_t); instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( cuvs::distance::detail::ops::l2_unexp_distance_op, int); instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( @@ -124,5 +126,8 @@ instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo( int64_t, cuvs::distance::kernels::detail::rbf_fin_op); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( + cuvs::distance::detail::ops::l2_exp_distance_op, int64_t); + #undef instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo #undef instantiate_cuvs_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py index 1bd51aef9..d0913833f 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -15,7 +15,7 @@ # NOTE: this template is not perfectly formatted. Use pre-commit to get # everything in shape again. header = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -197,4 +197,30 @@ def arch_headers(archs): f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") -print("src/distance/detail/pairwise_matrix/dispatch_rbf.cu") + + print("src/distance/detail/pairwise_matrix/dispatch_rbf.cu") + +# L2 with int64_t indices for kmeans code +int64_t_op_instances = [ + dict( + path_prefix="l2_expanded", + OpT="cuvs::distance::detail::ops::l2_exp_distance_op", + archs = [60, 80], + )] + +for op in int64_t_op_instances: + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + + IdxT = "int64_t" + path = f"dispatch_{op['path_prefix']}_{DataT}_{AccT}_{OutT}_{IdxT}.cu" + with open(path, "w") as f: + f.write(header) + f.write(arch_headers(op["archs"])) + f.write(macro) + + OpT = op['OpT'] + FinOpT = "raft::identity_op" + f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + print(f"src/distance/detail/pairwise_matrix/{path}") diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu new file mode 100644 index 000000000..756739158 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include "dispatch_sm80.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_exp_distance_op, + double, + double, + double, + raft::identity_op, + int64_t); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu new file mode 100644 index 000000000..94910875c --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include "dispatch_sm80.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int64_t); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu index 1cb0ed8ae..3c8f25109 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/distance/distance-ext.cuh b/cpp/src/distance/distance-ext.cuh index e7fa30f03..8ce7ef690 100644 --- a/cpp/src/distance/distance-ext.cuh +++ b/cpp/src/distance/distance-ext.cuh @@ -244,6 +244,15 @@ instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::Linf); instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::LpUnexpanded); instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::RusselRaoExpanded); +instantiate_cuvs_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int64_t); +instantiate_cuvs_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int64_t); +instantiate_cuvs_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int64_t); +instantiate_cuvs_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int64_t); + #undef instantiate_cuvs_distance_distance_by_algo #undef instantiate_cuvs_distance_distance diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu index 72be93f10..c1d39f360 100644 --- a/cpp/src/distance/distance.cu +++ b/cpp/src/distance/distance.cu @@ -105,6 +105,16 @@ instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::Linf); instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::LpUnexpanded); instantiate_cuvs_distance_distance_by_algo(cuvs::distance::DistanceType::RusselRaoExpanded); +instantiate_cuvs_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int64_t); +instantiate_cuvs_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int64_t); + +instantiate_cuvs_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int64_t); +instantiate_cuvs_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int64_t); + #undef instantiate_cuvs_distance_distance_by_algo #undef instantiate_cuvs_distance_distance diff --git a/cpp/src/neighbors/cagra_build_half.cu b/cpp/src/neighbors/cagra_build_half.cu new file mode 100644 index 000000000..2aba1dada --- /dev/null +++ b/cpp/src/neighbors/cagra_build_half.cu @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cagra.cuh" +#include +#include + +namespace cuvs::neighbors::cagra { + +cuvs::neighbors::cagra::index build( + raft::resources const& handle, + const cuvs::neighbors::cagra::index_params& params, + raft::device_matrix_view dataset) +{ + return cuvs::neighbors::cagra::build(handle, params, dataset); +} + +cuvs::neighbors::cagra::index build( + raft::resources const& handle, + const cuvs::neighbors::cagra::index_params& params, + raft::host_matrix_view dataset) +{ + return cuvs::neighbors::cagra::build(handle, params, dataset); +} + +} // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/cagra_search_half.cu b/cpp/src/neighbors/cagra_search_half.cu new file mode 100644 index 000000000..d80f2bc00 --- /dev/null +++ b/cpp/src/neighbors/cagra_search_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cagra.cuh" +#include + +namespace cuvs::neighbors::cagra { + +#define CUVS_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + cuvs::neighbors::cagra::search_params const& params, \ + const cuvs::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + cuvs::neighbors::cagra::search(handle, params, index, queries, neighbors, distances); \ + } + +CUVS_INST_CAGRA_SEARCH(half, uint32_t); + +#undef CUVS_INST_CAGRA_SEARCH + +} // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64.cu b/cpp/src/neighbors/cagra_serialize_half.cu similarity index 53% rename from cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64.cu rename to cpp/src/neighbors/cagra_serialize_half.cu index 88167b843..92ebd9b71 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64.cu +++ b/cpp/src/neighbors/cagra_serialize_half.cu @@ -14,21 +14,12 @@ * limitations under the License. */ -/* - * NOTE: this file is generated by search_multi_cta_00_generate.py - * - * Make changes there and run in this directory: - * - * > python search_multi_cta_00_generate.py - * - */ +#include "cagra_serialize.cuh" + +#include -#include "search_multi_cta_inst.cuh" +namespace cuvs::neighbors::cagra { -namespace cuvs::neighbors::cagra::detail::multi_cta_search { -instantiate_kernel_selection(float, - uint64_t, - float, - cuvs::neighbors::filtering::none_cagra_sample_filter); +CUVS_INST_CAGRA_SERIALIZE(half); -} // namespace cuvs::neighbors::cagra::detail::multi_cta_search +} // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 4a927add5..e5495dc3e 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -33,8 +33,6 @@ #include // TODO: Fixme- this needs to be migrated -#include "../../ivf_pq/ivf_pq_build.cuh" -#include "../../ivf_pq/ivf_pq_search.cuh" #include "../../nn_descent.cuh" // TODO: This shouldn't be calling spatial/knn APIs @@ -156,8 +154,7 @@ void build_knn_graph( }(); RAFT_LOG_DEBUG("# Building IVF-PQ index %s", model_name.c_str()); - auto index = - cuvs::neighbors::ivf_pq::detail::build(res, pq.build_params, dataset); + auto index = cuvs::neighbors::ivf_pq::build(res, pq.build_params, dataset); // // search top (k + 1) neighbors @@ -169,7 +166,8 @@ void build_knn_graph( const auto num_queries = dataset.extent(0); // Use the same maximum batch size as the ivf_pq::search to avoid allocating more than needed. - using cuvs::neighbors::ivf_pq::detail::kMaxQueries; + constexpr uint32_t kMaxQueries = 4096; + // Heuristic: the build_knn_graph code should use only a fraction of the workspace memory; the // rest should be used by the ivf_pq::search. Here we say that the workspace size should be a good // multiple of what is required for the I/O batching below. diff --git a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh index 24cc2a22f..f86ed9ef6 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh @@ -194,6 +194,8 @@ void serialize_to_hnswlib(raft::resources const& res, auto data_elem = static_cast(host_dataset(i, j)); os.write(reinterpret_cast(&data_elem), sizeof(int)); } + } else { + RAFT_FAIL("Unsupported dataset type while saving CAGRA dataset to HNSWlib format"); } os.write(reinterpret_cast(&i), sizeof(std::size_t)); diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py index f8584c62e..aef31d161 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/compute_distance_00_generate.py @@ -66,8 +66,6 @@ half_uint32=("half", "uint32_t", "float"), int8_uint32=("int8_t", "uint32_t", "float"), uint8_uint32=("uint8_t", "uint32_t", "float"), - # float_uint64=("float", "uint64_t", "float"), - # half_uint64=("half", "uint64_t", "float"), ) metric_prefix = 'DistanceType::' diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py index 3153a3a9f..4e3983e3f 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py @@ -58,8 +58,6 @@ half_uint32=("half", "uint32_t", "float"), int8_uint32=("int8_t", "uint32_t", "float"), uint8_uint32=("uint8_t", "uint32_t", "float"), - float_uint64=("float", "uint64_t", "float"), - half_uint64=("half", "uint64_t", "float"), ) # knn for type_path, (data_t, idx_t, distance_t) in search_types.items(): diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py index e37ceb1fa..4693cd54d 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py @@ -60,8 +60,6 @@ half_uint32=("half", "uint32_t", "float"), int8_uint32=("int8_t", "uint32_t", "float"), uint8_uint32=("uint8_t", "uint32_t", "float"), - float_uint64=("float", "uint64_t", "float"), - half_uint64=("half", "uint64_t", "float"), ) # knn diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64.cu deleted file mode 100644 index 0ef5c366f..000000000 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by search_single_cta_00_generate.py - * - * Make changes there and run in this directory: - * - * > python search_single_cta_00_generate.py - * - */ - -#include "search_single_cta_inst.cuh" - -namespace cuvs::neighbors::cagra::detail::single_cta_search { -instantiate_kernel_selection(float, - uint64_t, - float, - cuvs::neighbors::filtering::none_cagra_sample_filter); - -} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index cf27bcde7..3aa1d7529 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -62,7 +62,10 @@ namespace cuvs::neighbors::detail { * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances */ -template +template void tiled_brute_force_knn(const raft::resources& handle, const ElementType* search, // size (m ,d) const ElementType* index, // size (n ,d) @@ -78,7 +81,8 @@ void tiled_brute_force_knn(const raft::resources& handle, size_t max_col_tile_size = 0, const DistanceT* precomputed_index_norms = nullptr, const DistanceT* precomputed_search_norms = nullptr, - const uint32_t* filter_bitmap = nullptr) + const uint32_t* filter_bitmap = nullptr, + DistanceEpilogue distance_epilogue = raft::identity_op()) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -207,7 +211,8 @@ void tiled_brute_force_knn(const raft::resources& handle, IndexType col = j + (idx % current_centroid_size); cuvs::distance::detail::ops::l2_exp_cutlass_op l2_op(sqrt); - return l2_op(row_norms[row], col_norms[col], dist[idx]); + auto val = l2_op(row_norms[row], col_norms[col], dist[idx]); + return distance_epilogue(val, row, col); }); } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); @@ -221,8 +226,22 @@ void tiled_brute_force_knn(const raft::resources& handle, IndexType row = i + (idx / current_centroid_size); IndexType col = j + (idx % current_centroid_size); auto val = DistanceT(1.0) - dist[idx] / DistanceT(row_norms[row] * col_norms[col]); - return val; + return distance_epilogue(val, row, col); }); + } else { + // if we're not l2 distance, and we have a distance epilogue - run it now + if constexpr (!std::is_same_v) { + auto distances_ptr = temp_distances.data(); + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(temp_distances.data(), + current_query_size * current_centroid_size), + [=] __device__(size_t idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + return distance_epilogue(distances_ptr[idx], row, col); + }); + } } if (filter_bitmap != nullptr) { diff --git a/cpp/src/neighbors/detail/reachability.cuh b/cpp/src/neighbors/detail/reachability.cuh new file mode 100644 index 000000000..903c6f1da --- /dev/null +++ b/cpp/src/neighbors/detail/reachability.cuh @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "./knn_brute_force.cuh" + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace cuvs::neighbors::detail::reachability { + +/** + * Extract core distances from KNN graph. This is essentially + * performing a knn_dists[:,min_pts] + * @tparam value_idx data type for integrals + * @tparam value_t data type for distance + * @tparam tpb block size for kernel + * @param[in] knn_dists knn distance array (size n * k) + * @param[in] min_samples this neighbor will be selected for core distances + * @param[in] n_neighbors the number of neighbors of each point in the knn graph + * @param[in] n number of samples + * @param[out] out output array (size n) + * @param[in] stream stream for which to order cuda operations + */ +template +void core_distances( + value_t* knn_dists, int min_samples, int n_neighbors, size_t n, value_t* out, cudaStream_t stream) +{ + ASSERT(n_neighbors >= min_samples, + "the size of the neighborhood should be greater than or equal to min_samples"); + + auto exec_policy = rmm::exec_policy(stream); + + auto indices = thrust::make_counting_iterator(0); + + thrust::transform(exec_policy, indices, indices + n, out, [=] __device__(value_idx row) { + return knn_dists[row * n_neighbors + (min_samples - 1)]; + }); +} + +/** + * Wraps the brute force knn API, to be used for both training and prediction + * @tparam value_idx data type for integrals + * @tparam value_t data type for distance + * @param[in] handle raft handle for resource reuse + * @param[in] X input data points (size m * n) + * @param[out] inds nearest neighbor indices (size n_search_items * k) + * @param[out] dists nearest neighbor distances (size n_search_items * k) + * @param[in] m number of rows in X + * @param[in] n number of columns in X + * @param[in] search_items array of items to search of dimensionality D (size n_search_items * n) + * @param[in] n_search_items number of rows in search_items + * @param[in] k number of nearest neighbors + * @param[in] metric distance metric to use + */ +template +void compute_knn(const raft::resources& handle, + const value_t* X, + value_idx* inds, + value_t* dists, + size_t m, + size_t n, + const value_t* search_items, + size_t n_search_items, + int k, + cuvs::distance::DistanceType metric) +{ + // perform knn + tiled_brute_force_knn(handle, X, search_items, m, n_search_items, n, k, dists, inds, metric); +} + +/* + @brief Internal function for CPU->GPU interop + to compute core_dists +*/ +template +void _compute_core_dists(const raft::resources& handle, + const value_t* X, + value_t* core_dists, + size_t m, + size_t n, + cuvs::distance::DistanceType metric, + int min_samples) +{ + RAFT_EXPECTS(metric == cuvs::distance::DistanceType::L2SqrtExpanded, + "Currently only L2 expanded distance is supported"); + + auto stream = raft::resource::get_cuda_stream(handle); + + rmm::device_uvector inds(min_samples * m, stream); + rmm::device_uvector dists(min_samples * m, stream); + + // perform knn + compute_knn(handle, X, inds.data(), dists.data(), m, n, X, m, min_samples, metric); + + // Slice core distances (distances to kth nearest neighbor) + core_distances(dists.data(), min_samples, min_samples, m, core_dists, stream); +} + +// Functor to post-process distances into reachability space +template +struct ReachabilityPostProcess { + DI value_t operator()(value_t value, value_idx row, value_idx col) const + { + return max(core_dists[col], max(core_dists[row], alpha * value)); + } + + const value_t* core_dists; + value_t alpha; +}; + +/** + * Given core distances, Fuses computations of L2 distances between all + * points, projection into mutual reachability space, and k-selection. + * @tparam value_idx + * @tparam value_t + * @param[in] handle raft handle for resource reuse + * @param[out] out_inds output indices array (size m * k) + * @param[out] out_dists output distances array (size m * k) + * @param[in] X input data points (size m * n) + * @param[in] m number of rows in X + * @param[in] n number of columns in X + * @param[in] k neighborhood size (includes self-loop) + * @param[in] core_dists array of core distances (size m) + */ +template +void mutual_reachability_knn_l2(const raft::resources& handle, + value_idx* out_inds, + value_t* out_dists, + const value_t* X, + size_t m, + size_t n, + int k, + value_t* core_dists, + value_t alpha) +{ + // Create a functor to postprocess distances into mutual reachability space + // Note that we can't use a lambda for this here, since we get errors like: + // `A type local to a function cannot be used in the template argument of the + // enclosing parent function (and any parent classes) of an extended __device__ + // or __host__ __device__ lambda` + auto epilogue = ReachabilityPostProcess{core_dists, alpha}; + + cuvs::neighbors::detail:: + tiled_brute_force_knn>( + handle, + X, + X, + m, + m, + n, + k, + out_dists, + out_inds, + cuvs::distance::DistanceType::L2SqrtExpanded, + 2.0, + 0, + 0, + nullptr, + nullptr, + nullptr, + epilogue); +} + +template +void mutual_reachability_graph(const raft::resources& handle, + const value_t* X, + size_t m, + size_t n, + cuvs::distance::DistanceType metric, + int min_samples, + value_t alpha, + value_idx* indptr, + value_t* core_dists, + raft::sparse::COO& out) +{ + RAFT_EXPECTS(metric == cuvs::distance::DistanceType::L2SqrtExpanded, + "Currently only L2 expanded distance is supported"); + + auto stream = raft::resource::get_cuda_stream(handle); + auto exec_policy = raft::resource::get_thrust_policy(handle); + + rmm::device_uvector coo_rows(min_samples * m, stream); + rmm::device_uvector inds(min_samples * m, stream); + rmm::device_uvector dists(min_samples * m, stream); + + // perform knn + compute_knn(handle, X, inds.data(), dists.data(), m, n, X, m, min_samples, metric); + + // Slice core distances (distances to kth nearest neighbor) + core_distances(dists.data(), min_samples, min_samples, m, core_dists, stream); + + /** + * Compute L2 norm + */ + mutual_reachability_knn_l2( + handle, inds.data(), dists.data(), X, m, n, min_samples, core_dists, (value_t)1.0 / alpha); + + // self-loops get max distance + auto coo_rows_counting_itr = thrust::make_counting_iterator(0); + thrust::transform(exec_policy, + coo_rows_counting_itr, + coo_rows_counting_itr + (m * min_samples), + coo_rows.data(), + [min_samples] __device__(value_idx c) -> value_idx { return c / min_samples; }); + + raft::sparse::linalg::symmetrize( + handle, coo_rows.data(), inds.data(), dists.data(), m, m, min_samples * m, out); + + raft::sparse::convert::sorted_coo_to_csr(out.rows(), out.nnz, indptr, m + 1, stream); + + // self-loops get max distance + auto transform_in = + thrust::make_zip_iterator(thrust::make_tuple(out.rows(), out.cols(), out.vals())); + + thrust::transform(exec_policy, + transform_in, + transform_in + out.nnz, + out.vals(), + [=] __device__(const thrust::tuple& tup) { + return thrust::get<0>(tup) == thrust::get<1>(tup) + ? std::numeric_limits::max() + : thrust::get<2>(tup); + }); +} + +} // namespace cuvs::neighbors::detail::reachability diff --git a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py index 878c7ee21..9b3083c3b 100644 --- a/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py +++ b/cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py @@ -57,6 +57,7 @@ types = dict( float_int64_t=("float", "int64_t"), + half_int64_t=("half", "int64_t"), int8_t_int64_t=("int8_t", "int64_t"), uint8_t_int64_t=("uint8_t", "int64_t"), ) diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_extend_half_int64_t.cu similarity index 55% rename from cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64.cu rename to cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_extend_half_int64_t.cu index dafb89cc3..2d7270957 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64.cu +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_extend_half_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,20 +15,21 @@ */ /* - * NOTE: this file is generated by search_multi_cta_00_generate.py + * NOTE: this file is generated by generate_ivf_pq.py * * Make changes there and run in this directory: * - * > python search_multi_cta_00_generate.py + * > python generate_ivf_pq.py * */ -#include "search_multi_cta_inst.cuh" +#include -namespace cuvs::neighbors::cagra::detail::multi_cta_search { -instantiate_kernel_selection(half, - uint64_t, - float, - cuvs::neighbors::filtering::none_cagra_sample_filter); +#include "ivf_pq_build_extend_inst.cuh" -} // namespace cuvs::neighbors::cagra::detail::multi_cta_search +namespace cuvs::neighbors::ivf_pq { +CUVS_INST_IVF_PQ_BUILD_EXTEND(half, int64_t); + +#undef CUVS_INST_IVF_PQ_BUILD_EXTEND + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu new file mode 100644 index 000000000..e5556e593 --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_half_int64_t.cu @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include + +#include "../ivf_pq_search.cuh" + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + cuvs::neighbors::ivf_pq::detail::search(handle, params, index, queries, neighbors, distances); \ + } +CUVS_INST_IVF_PQ_SEARCH(half, int64_t); + +#undef CUVS_INST_IVF_PQ_SEARCH + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_half_int64_t.cu b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_half_int64_t.cu new file mode 100644 index 000000000..5874fba6c --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_half_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by generate_ivf_pq.py + * + * Make changes there and run in this directory: + * + * > python generate_ivf_pq.py + * + */ + +#include + +#include "../ivf_pq_search.cuh" + +namespace cuvs::neighbors::ivf_pq { + +#define CUVS_INST_IVF_PQ_SEARCH_FILTER(T, IdxT) \ + void search_with_filtering( \ + raft::resources const& handle, \ + const cuvs::neighbors::ivf_pq::search_params& params, \ + cuvs::neighbors::ivf_pq::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + cuvs::neighbors::filtering::bitset_filter sample_filter) \ + { \ + cuvs::neighbors::ivf_pq::detail::search_with_filtering( \ + handle, params, index, queries, neighbors, distances, sample_filter); \ + } +CUVS_INST_IVF_PQ_SEARCH_FILTER(half, int64_t); + +#undef CUVS_INST_IVF_PQ_SEARCH_FILTER + +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/nn_descent_half.cu b/cpp/src/neighbors/nn_descent_half.cu new file mode 100644 index 000000000..587993031 --- /dev/null +++ b/cpp/src/neighbors/nn_descent_half.cu @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nn_descent.cuh" +#include + +namespace cuvs::neighbors::nn_descent { + +#define CUVS_INST_NN_DESCENT_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::device_matrix_view dataset) \ + ->cuvs::neighbors::nn_descent::index \ + { \ + return cuvs::neighbors::nn_descent::build(handle, params, dataset); \ + }; \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::host_matrix_view dataset) \ + ->cuvs::neighbors::nn_descent::index \ + { \ + return cuvs::neighbors::nn_descent::build(handle, params, dataset); \ + }; + +CUVS_INST_NN_DESCENT_BUILD(half, uint32_t); + +#undef CUVS_INST_NN_DESCENT_BUILD + +} // namespace cuvs::neighbors::nn_descent diff --git a/cpp/src/neighbors/reachability.cu b/cpp/src/neighbors/reachability.cu new file mode 100644 index 000000000..2e366106c --- /dev/null +++ b/cpp/src/neighbors/reachability.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "./detail/reachability.cuh" + +namespace cuvs::neighbors::reachability { + +void mutual_reachability_graph(const raft::resources& handle, + raft::device_matrix_view X, + int min_samples, + raft::device_vector_view indptr, + raft::device_vector_view core_dists, + raft::sparse::COO& out, + cuvs::distance::DistanceType metric, + float alpha) +{ + RAFT_EXPECTS(core_dists.extent(0) == static_cast(X.extent(0)), + "core_dists doesn't have expected size"); + RAFT_EXPECTS(indptr.extent(0) == static_cast(X.extent(0) + 1), + "indptr doesn't have expected size"); + + cuvs::neighbors::detail::reachability::mutual_reachability_graph( + handle, + X.data_handle(), + X.extent(0), + X.extent(1), + metric, + min_samples, + alpha, + indptr.data_handle(), + core_dists.data_handle(), + out); +} +} // namespace cuvs::neighbors::reachability diff --git a/cpp/src/selection/select_k_float_int32_t.cu b/cpp/src/selection/select_k_float_int32_t.cu new file mode 100644 index 000000000..66672a642 --- /dev/null +++ b/cpp/src/selection/select_k_float_int32_t.cu @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./select_k.cuh" + +instantiate_cuvs_selection_select_k(float, int); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index e04c39318..b81ef6bfa 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -139,6 +139,7 @@ if(BUILD_TESTS) NEIGHBORS_ANN_CAGRA_TEST PATH neighbors/ann_cagra/test_float_uint32_t.cu + neighbors/ann_cagra/test_half_uint32_t.cu neighbors/ann_cagra/test_int8_t_uint32_t.cu neighbors/ann_cagra/test_uint8_t_uint32_t.cu GPUS diff --git a/cpp/test/cluster/kmeans_mg.cu b/cpp/test/cluster/kmeans_mg.cu new file mode 100644 index 000000000..b9e06b2f1 --- /dev/null +++ b/cpp/test/cluster/kmeans_mg.cu @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include + +#include + +#define NCCLCHECK(cmd) \ + do { \ + ncclResult_t res = cmd; \ + if (res != ncclSuccess) { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(res)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace cuvs { + +template +struct KmeansInputs { + int n_row; + int n_col; + int n_clusters; + T tol; + bool weighted; +}; + +template +class KmeansTest : public ::testing::TestWithParam> { + protected: + KmeansTest() + : stream(handle.get_stream()), + d_labels(0, stream), + d_labels_ref(0, stream), + d_centroids(0, stream), + d_sample_weight(0, stream) + { + } + + void basicTest() + { + testparams = ::testing::TestWithParam>::GetParam(); + ncclComm_t nccl_comm; + NCCLCHECK(ncclCommInitAll(&nccl_comm, 1, {0})); + raft::comms::build_comms_nccl_only(&handle, nccl_comm, 1, 0); + + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.n_init = 5; + params.rng_state.seed = 1; + params.oversampling_factor = 1; + + auto stream = handle.get_stream(); + rmm::device_uvector X(n_samples * n_features, stream); + rmm::device_uvector labels(n_samples, stream); + + raft::random::make_blobs(handle, + X.data(), + labels.data(), + n_samples, + n_features, + params.n_clusters, + true, + nullptr, + nullptr, + 1.0, + false, + -10.0f, + 10.0f, + 1234ULL); + + d_labels.resize(n_samples, stream); + d_labels_ref.resize(n_samples, stream); + d_centroids.resize(params.n_clusters * n_features, stream); + + std::optional> d_sw = std::nullopt; + if (testparams.weighted) { + d_sample_weight.resize(n_samples, stream); + thrust::fill(thrust::cuda::par.on(stream), + d_sample_weight.data(), + d_sample_weight.data() + n_samples, + 1); + d_sw = raft::make_device_vector_view(d_sample_weight.data(), n_samples); + } + raft::copy(d_labels_ref.data(), labels.data(), n_samples, stream); + + handle.sync_stream(stream); + + T inertia = 0; + int n_iter = 0; + + auto X_view = raft::make_device_matrix_view(X.data(), n_samples, n_features); + auto centroids_view = + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + + cuvs::cluster::kmeans::fit(handle, + params, + X_view, + d_sw, + centroids_view, + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + + cuvs::cluster::kmeans::predict( + handle, + params, + X_view, + d_sw, + d_centroids.data(), + raft::make_device_vector_view(d_labels.data(), n_samples), + true, + raft::make_host_scalar_view(&inertia)); + score = raft::stats::adjusted_rand_index( + d_labels_ref.data(), d_labels.data(), n_samples, raft::resource::get_cuda_stream(handle)); + handle.sync_stream(stream); + + if (score < 0.99) { + std::cout << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream) + << std::endl; + std::cout << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream) + << std::endl; + std::cout << "score = " << score << std::endl; + } + ncclCommDestroy(nccl_comm); + } + + void SetUp() override { basicTest(); } + + protected: + raft::handle_t handle; + cudaStream_t stream; + KmeansInputs testparams; + rmm::device_uvector d_labels; + rmm::device_uvector d_labels_ref; + rmm::device_uvector d_centroids; + rmm::device_uvector d_sample_weight; + double score; + cuvs::cluster::kmeans::params params; +}; + +const std::vector> inputsf2 = {{1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, + {10000, 100, 50, 0.0001, true}, + {10000, 100, 50, 0.0001, false}}; + +const std::vector> inputsd2 = {{1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, + {10000, 100, 50, 0.0001, true}, + {10000, 100, 50, 0.0001, false}}; + +typedef KmeansTest KmeansTestF; +TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score >= 0.99); } + +typedef KmeansTest KmeansTestD; +TEST_P(KmeansTestD, Result) { ASSERT_TRUE(score >= 0.99); } + +INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestD, ::testing::ValuesIn(inputsd2)); + +} // end namespace cuvs diff --git a/cpp/test/core/c_api.c b/cpp/test/core/c_api.c index 27973c2dd..a3dae6004 100644 --- a/cpp/test/core/c_api.c +++ b/cpp/test/core/c_api.c @@ -33,34 +33,49 @@ int main() // Allocate memory void* ptr; - size_t bytes = 1024; - cuvsError_t alloc_error = cuvsRMMAlloc(res, &ptr, bytes); - if (alloc_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + size_t bytes = 1024; + cuvsError_t error = cuvsRMMAlloc(res, &ptr, bytes); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } // Free memory - cuvsError_t free_error = cuvsRMMFree(res, ptr, bytes); - if (free_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + error = cuvsRMMFree(res, ptr, bytes); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } // Enable pool memory resource - cuvsError_t pool_error = cuvsRMMPoolMemoryResourceEnable(10, 100); - if (pool_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + error = cuvsRMMPoolMemoryResourceEnable(10, 100, false); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } // Allocate memory again - void* ptr2; - cuvsError_t alloc_error_pool = cuvsRMMAlloc(res, &ptr2, 1024); - if (alloc_error_pool == CUVS_ERROR) { exit(EXIT_FAILURE); } + error = cuvsRMMAlloc(res, &ptr, 1024); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } // Free memory - cuvsError_t free_error_pool = cuvsRMMFree(res, ptr2, 1024); - if (free_error_pool == CUVS_ERROR) { exit(EXIT_FAILURE); } + error = cuvsRMMFree(res, ptr, 1024); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } // Reset pool memory resource - cuvsError_t reset_error = cuvsRMMMemoryResourceReset(); - if (reset_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + error = cuvsRMMMemoryResourceReset(); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Enable pool memory resource (managed) + error = cuvsRMMPoolMemoryResourceEnable(10, 100, true); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Allocate memory again + error = cuvsRMMAlloc(res, &ptr, 1024); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Free memory + error = cuvsRMMFree(res, ptr, 1024); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Reset pool memory resource + error = cuvsRMMMemoryResourceReset(); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } // Destroy resources - cuvsError_t destroy_error = cuvsResourcesDestroy(res); - if (destroy_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + error = cuvsResourcesDestroy(res); + if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } return 0; } diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64.cu b/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu similarity index 53% rename from cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64.cu rename to cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu index b96ed0b22..f03de69d2 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64.cu +++ b/cpp/test/neighbors/ann_cagra/test_half_uint32_t.cu @@ -14,21 +14,15 @@ * limitations under the License. */ -/* - * NOTE: this file is generated by search_single_cta_00_generate.py - * - * Make changes there and run in this directory: - * - * > python search_single_cta_00_generate.py - * - */ +#include + +#include "../ann_cagra.cuh" + +namespace cuvs::neighbors::cagra { -#include "search_single_cta_inst.cuh" +typedef AnnCagraTest AnnCagraTestF16_U32; +TEST_P(AnnCagraTestF16_U32, AnnCagra) { this->testCagra(); } -namespace cuvs::neighbors::cagra::detail::single_cta_search { -instantiate_kernel_selection(half, - uint64_t, - float, - cuvs::neighbors::filtering::none_cagra_sample_filter); +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF16_U32, ::testing::ValuesIn(inputs)); -} // namespace cuvs::neighbors::cagra::detail::single_cta_search +} // namespace cuvs::neighbors::cagra diff --git a/python/cuvs_bench/cuvs_bench/plot/__init__.py b/python/cuvs_bench/cuvs_bench/plot/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/cuvs_bench/cuvs_bench/plot/__main__.py b/python/cuvs_bench/cuvs_bench/plot/__main__.py new file mode 100644 index 000000000..93deb69c7 --- /dev/null +++ b/python/cuvs_bench/cuvs_bench/plot/__main__.py @@ -0,0 +1,617 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script is inspired by +# 1: https://github.com/erikbern/ann-benchmarks/blob/main/plot.py +# 2: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/utils.py # noqa: E501 +# 3: https://github.com/erikbern/ann-benchmarks/blob/main/ann_benchmarks/plotting/metrics.py # noqa: E501 +# License: https://github.com/rapidsai/cuvs/blob/branch-24.10/thirdparty/LICENSES/LICENSE.ann-benchmark # noqa: E501 + +import itertools +import os +from collections import OrderedDict + +import click +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +mpl.use("Agg") + +metrics = { + "k-nn": { + "description": "Recall", + "worst": float("-inf"), + "lim": [0.0, 1.03], + }, + "throughput": { + "description": "Queries per second (1/s)", + "worst": float("-inf"), + }, + "latency": { + "description": "Search Latency (s)", + "worst": float("inf"), + }, +} + + +def positive_int(value): + ivalue = int(value) + if ivalue <= 0: + raise click.BadParameter(f"{value} is not a positive integer") + return ivalue + + +def positive_float(value): + fvalue = float(value) + if fvalue <= 0: + raise click.BadParameter(f"{value} is not a positive float") + return fvalue + + +def generate_n_colors(n): + vs = np.linspace(0.3, 0.9, 7) + colors = [(0.9, 0.4, 0.4, 1.0)] + + def euclidean(a, b): + return sum((x - y) ** 2 for x, y in zip(a, b)) + + while len(colors) < n: + new_color = max( + itertools.product(vs, vs, vs), + key=lambda a: min(euclidean(a, b) for b in colors), + ) + colors.append(new_color + (1.0,)) + return colors + + +def create_linestyles(unique_algorithms): + colors = dict( + zip(unique_algorithms, generate_n_colors(len(unique_algorithms))) + ) + linestyles = dict( + (algo, ["--", "-.", "-", ":"][i % 4]) + for i, algo in enumerate(unique_algorithms) + ) + markerstyles = dict( + (algo, ["+", "<", "o", "*", "x"][i % 5]) + for i, algo in enumerate(unique_algorithms) + ) + faded = dict( + (algo, (r, g, b, 0.3)) for algo, (r, g, b, a) in colors.items() + ) + return dict( + ( + algo, + (colors[algo], faded[algo], linestyles[algo], markerstyles[algo]), + ) + for algo in unique_algorithms + ) + + +def create_plot_search( + all_data, + x_scale, + y_scale, + fn_out, + linestyles, + dataset, + k, + batch_size, + mode, + time_unit, + x_start, +): + xn = "k-nn" + xm, ym = (metrics[xn], metrics[mode]) + xm["lim"][0] = x_start + # Now generate each plot + handles = [] + labels = [] + plt.figure(figsize=(12, 9)) + + # Sorting by mean y-value helps aligning plots with labels + def mean_y(algo): + points = np.array(all_data[algo], dtype=object) + return -np.log(np.array(points[:, 3], dtype=np.float32)).mean() + + # Find range for logit x-scale + min_x, max_x = 1, 0 + for algo in sorted(all_data.keys(), key=mean_y): + points = np.array(all_data[algo], dtype=object) + xs = points[:, 2] + ys = points[:, 3] + min_x = min([min_x] + [x for x in xs if x > 0]) + max_x = max([max_x] + [x for x in xs if x < 1]) + color, faded, linestyle, marker = linestyles[algo] + (handle,) = plt.plot( + xs, + ys, + "-", + label=algo, + color=color, + ms=7, + mew=3, + lw=3, + marker=marker, + ) + handles.append(handle) + + labels.append(algo) + + ax = plt.gca() + y_description = ym["description"] + if mode == "latency": + y_description = y_description.replace("(s)", f"({time_unit})") + ax.set_ylabel(y_description) + ax.set_xlabel("Recall") + # Custom scales of the type --x-scale a3 + if x_scale[0] == "a": + alpha = float(x_scale[1:]) + + def fun(x): + return 1 - (1 - x) ** (1 / alpha) + + def inv_fun(x): + return 1 - (1 - x) ** alpha + + ax.set_xscale("function", functions=(fun, inv_fun)) + if alpha <= 3: + ticks = [inv_fun(x) for x in np.arange(0, 1.2, 0.2)] + plt.xticks(ticks) + if alpha > 3: + from matplotlib import ticker + + ax.xaxis.set_major_formatter(ticker.LogitFormatter()) + # plt.xticks(ticker.LogitLocator().tick_values(min_x, max_x)) + plt.xticks([0, 1 / 2, 1 - 1e-1, 1 - 1e-2, 1 - 1e-3, 1 - 1e-4, 1]) + # Other x-scales + else: + ax.set_xscale(x_scale) + ax.set_yscale(y_scale) + ax.set_title(f"{dataset} k={k} batch_size={batch_size}") + plt.gca().get_position() + # plt.gca().set_position([box.x0, box.y0, box.width * 0.8, box.height]) + ax.legend( + handles, + labels, + loc="center left", + bbox_to_anchor=(1, 0.5), + prop={"size": 9}, + ) + plt.grid(visible=True, which="major", color="0.65", linestyle="-") + plt.setp(ax.get_xminorticklabels(), visible=True) + + # Logit scale has to be a subset of (0,1) + if "lim" in xm and x_scale != "logit": + x0, x1 = xm["lim"] + plt.xlim(max(x0, 0), min(x1, 1)) + elif x_scale == "logit": + plt.xlim(min_x, max_x) + if "lim" in ym: + plt.ylim(ym["lim"]) + + # Workaround for bug https://github.com/matplotlib/matplotlib/issues/6789 + ax.spines["bottom"]._adjust_location() + + print(f"writing search output to {fn_out}") + plt.savefig(fn_out, bbox_inches="tight") + plt.close() + + +def create_plot_build( + build_results, search_results, linestyles, fn_out, dataset, k, batch_size +): + bt_80 = [0] * len(linestyles) + + bt_90 = [0] * len(linestyles) + + bt_95 = [0] * len(linestyles) + + bt_99 = [0] * len(linestyles) + + data = OrderedDict() + colors = OrderedDict() + + # Sorting by mean y-value helps aligning plots with labels + + def mean_y(algo): + points = np.array(search_results[algo], dtype=object) + return -np.log(np.array(points[:, 3], dtype=np.float32)).mean() + + for pos, algo in enumerate(sorted(search_results.keys(), key=mean_y)): + points = np.array(search_results[algo], dtype=object) + # x is recall, ls is algo_name, idxs is index_name + xs = points[:, 2] + ls = points[:, 0] + idxs = points[:, 1] + + len_80, len_90, len_95, len_99 = 0, 0, 0, 0 + for i in range(len(xs)): + if xs[i] >= 0.80 and xs[i] < 0.90: + bt_80[pos] = bt_80[pos] + build_results[(ls[i], idxs[i])][0][2] + len_80 = len_80 + 1 + elif xs[i] >= 0.9 and xs[i] < 0.95: + bt_90[pos] = bt_90[pos] + build_results[(ls[i], idxs[i])][0][2] + len_90 = len_90 + 1 + elif xs[i] >= 0.95 and xs[i] < 0.99: + bt_95[pos] = bt_95[pos] + build_results[(ls[i], idxs[i])][0][2] + len_95 = len_95 + 1 + elif xs[i] >= 0.99: + bt_99[pos] = bt_99[pos] + build_results[(ls[i], idxs[i])][0][2] + len_99 = len_99 + 1 + if len_80 > 0: + bt_80[pos] = bt_80[pos] / len_80 + if len_90 > 0: + bt_90[pos] = bt_90[pos] / len_90 + if len_95 > 0: + bt_95[pos] = bt_95[pos] / len_95 + if len_99 > 0: + bt_99[pos] = bt_99[pos] / len_99 + data[algo] = [ + bt_80[pos], + bt_90[pos], + bt_95[pos], + bt_99[pos], + ] + colors[algo] = linestyles[algo][0] + + index = [ + "@80% Recall", + "@90% Recall", + "@95% Recall", + "@99% Recall", + ] + + df = pd.DataFrame(data, index=index) + df.replace(0.0, np.nan, inplace=True) + df = df.dropna(how="all") + plt.figure(figsize=(12, 9)) + ax = df.plot.bar(rot=0, color=colors) + fig = ax.get_figure() + print(f"writing build output to {fn_out}") + plt.title( + "Average Build Time within Recall Range " + f"for k={k} batch_size={batch_size}" + ) + plt.suptitle(f"{dataset}") + plt.ylabel("Build Time (s)") + fig.savefig(fn_out) + + +def load_lines(results_path, result_files, method, index_key, mode, time_unit): + results = dict() + + for result_filename in result_files: + try: + with open(os.path.join(results_path, result_filename), "r") as f: + lines = f.readlines() + lines = lines[:-1] if lines[-1] == "\n" else lines + + if method == "build": + key_idx = [2] + elif method == "search": + y_idx = 3 if mode == "throughput" else 4 + key_idx = [2, y_idx] + + for line in lines[1:]: + split_lines = line.split(",") + + algo_name = split_lines[0] + index_name = split_lines[1] + + if index_key == "algo": + dict_key = algo_name + elif index_key == "index": + dict_key = (algo_name, index_name) + if dict_key not in results: + results[dict_key] = [] + to_add = [algo_name, index_name] + for key_i in key_idx: + to_add.append(float(split_lines[key_i])) + if ( + mode == "latency" + and time_unit != "s" + and method == "search" + ): + to_add[-1] = ( + to_add[-1] * (10**3) + if time_unit == "ms" + else to_add[-1] * (10**6) + ) + results[dict_key].append(to_add) + except Exception: + print( + f"An error occurred processing file {result_filename}. " + "Skipping..." + ) + + return results + + +def load_all_results( + dataset_path, + algorithms, + groups, + algo_groups, + k, + batch_size, + method, + index_key, + raw, + mode, + time_unit, +): + results_path = os.path.join(dataset_path, "result", method) + result_files = os.listdir(results_path) + if method == "build": + result_files = [ + result_file + for result_file in result_files + if ".csv" in result_file + ] + elif method == "search": + if raw: + suffix = ",raw" + else: + suffix = f",{mode}" + result_files = [ + result_file + for result_file in result_files + if f"{suffix}.csv" in result_file + ] + if len(result_files) == 0: + raise FileNotFoundError(f"No CSV result files found in {results_path}") + + if method == "search": + filter_k_bs = [] + for result_filename in result_files: + filename_split = result_filename.split(",") + if ( + int(filename_split[-3][1:]) == k + and int(filename_split[-2][2:]) == batch_size + ): + filter_k_bs.append(result_filename) + result_files = filter_k_bs + + algo_group_files = [ + result_filename.replace(".csv", "").split(",")[:2] + for result_filename in result_files + ] + algo_group_files = list(zip(*algo_group_files)) + + if len(algorithms) > 0: + final_results = [ + result_files[i] + for i in range(len(result_files)) + if (algo_group_files[0][i] in algorithms) + and (algo_group_files[1][i] in groups) + ] + else: + final_results = [ + result_files[i] + for i in range(len(result_files)) + if (algo_group_files[1][i] in groups) + ] + + if len(algo_groups) > 0: + split_algo_groups = [ + algo_group.split(".") for algo_group in algo_groups + ] + split_algo_groups = list(zip(*split_algo_groups)) + final_algo_groups = [ + result_files[i] + for i in range(len(result_files)) + if (algo_group_files[0][i] in split_algo_groups[0]) + and (algo_group_files[1][i] in split_algo_groups[1]) + ] + final_results = final_results + final_algo_groups + final_results = set(final_results) + + results = load_lines( + results_path, final_results, method, index_key, mode, time_unit + ) + + return results + + +@click.command() +@click.option("--dataset", default="glove-100-inner", help="Dataset to plot.") +@click.option( + "--dataset-path", + default=lambda: os.getenv( + "RAPIDS_DATASET_ROOT_DIR", os.path.join(os.getcwd(), "datasets/") + ), + help="Path to dataset folder.", +) +@click.option( + "--output-filepath", + default=os.getcwd(), + help="Directory where PNG will be saved.", +) +@click.option( + "--algorithms", + default=None, + help="Comma-separated list of named algorithms to plot. If `groups` and " + "`algo-groups` are both undefined, then group `base` is plotted by " + "default.", +) +@click.option( + "--groups", + default="base", + help="Comma-separated groups of parameters to plot.", +) +@click.option( + "--algo-groups", + help="Comma-separated . to plot. Example usage: " + '--algo-groups=raft_cagra.large,hnswlib.large".', +) +@click.option( + "-k", + "--count", + default=10, + type=positive_int, + help="The number of nearest neighbors to search for.", +) +@click.option( + "-bs", + "--batch-size", + default=10000, + type=positive_int, + help="Number of query vectors to use in each query trial.", +) +@click.option("--build", is_flag=True, help="Flag to indicate build mode.") +@click.option("--search", is_flag=True, help="Flag to indicate search mode.") +@click.option( + "--x-scale", + default="linear", + help="Scale to use when drawing the X-axis. Typically linear, " + "logit, or a2.", +) +@click.option( + "--y-scale", + type=click.Choice( + ["linear", "log", "symlog", "logit"], case_sensitive=False + ), + default="linear", + help="Scale to use when drawing the Y-axis.", +) +@click.option( + "--x-start", + default=0.8, + type=positive_float, + help="Recall values to start the x-axis from.", +) +@click.option( + "--mode", + type=click.Choice(["throughput", "latency"], case_sensitive=False), + default="throughput", + help="Search mode whose Pareto frontier is used on the Y-axis.", +) +@click.option( + "--time-unit", + type=click.Choice(["s", "ms", "us"], case_sensitive=False), + default="ms", + help="Time unit to plot when mode is latency.", +) +@click.option( + "--raw", + is_flag=True, + help="Show raw results (not just Pareto frontier) of the mode argument.", +) +def main( + dataset: str, + dataset_path: str, + output_filepath: str, + algorithms: str, + groups: str, + algo_groups: str, + count: int, + batch_size: int, + build: bool, + search: bool, + x_scale: str, + y_scale: str, + x_start: float, + mode: str, + time_unit: str, + raw: bool, +) -> None: + + args = locals() + + if args["algorithms"]: + algorithms = args["algorithms"].split(",") + else: + algorithms = [] + groups = args["groups"].split(",") + if args["algo_groups"]: + algo_groups = args["algo_groups"].split(",") + else: + algo_groups = [] + k = args["count"] + batch_size = args["batch_size"] + if not args["build"] and not args["search"]: + build = True + search = True + else: + build = args["build"] + search = args["search"] + + search_output_filepath = os.path.join( + args["output_filepath"], + f"search-{args['dataset']}-k{k}-batch_size{batch_size}.png", + ) + build_output_filepath = os.path.join( + args["output_filepath"], + f"build-{args['dataset']}-k{k}-batch_size{batch_size}.png", + ) + + search_results = load_all_results( + os.path.join(args["dataset_path"], args["dataset"]), + algorithms, + groups, + algo_groups, + k, + batch_size, + "search", + "algo", + args["raw"], + args["mode"], + args["time_unit"], + ) + linestyles = create_linestyles(sorted(search_results.keys())) + if search: + create_plot_search( + search_results, + args["x_scale"], + args["y_scale"], + search_output_filepath, + linestyles, + args["dataset"], + k, + batch_size, + args["mode"], + args["time_unit"], + args["x_start"], + ) + if build: + build_results = load_all_results( + os.path.join(args["dataset_path"], args["dataset"]), + algorithms, + groups, + algo_groups, + k, + batch_size, + "build", + "index", + args["raw"], + args["mode"], + args["time_unit"], + ) + create_plot_build( + build_results, + search_results, + linestyles, + build_output_filepath, + args["dataset"], + k, + batch_size, + ) + + +if __name__ == "__main__": + main() diff --git a/python/cuvs_bench/cuvs_bench/run/__main__.py b/python/cuvs_bench/cuvs_bench/run/__main__.py index b5d99a4bf..bf9f8586d 100644 --- a/python/cuvs_bench/cuvs_bench/run/__main__.py +++ b/python/cuvs_bench/cuvs_bench/run/__main__.py @@ -19,6 +19,7 @@ from typing import Optional import click +from data_export import convert_json_to_csv_build, convert_json_to_csv_search from run import run_benchmark @@ -138,6 +139,16 @@ "run execute the benchmarks but will not actually execute " "the command.", ) +@click.option( + "--data-export", + is_flag=True, + help="By default, the intermediate JSON outputs produced by " + "cuvs_bench.run to more easily readable CSV files is done " + "automatically, which are needed to build charts made by " + "cuvs_bench.plot. But if some of the benchmark runs failed or " + "were interrupted, use this option to convert those intermediate " + "files manually.", +) @click.option( "--raft-log-level", default="info", @@ -165,6 +176,7 @@ def main( search_mode: str, search_threads: Optional[str], dry_run: bool, + data_export: bool, raft_log_level: str, ) -> None: """ @@ -209,7 +221,11 @@ def main( """ - run_benchmark(**locals()) + if not data_export: + run_benchmark(**locals()) + + convert_json_to_csv_build(dataset, dataset_path) + convert_json_to_csv_search(dataset, dataset_path) if __name__ == "__main__": diff --git a/python/cuvs_bench/cuvs_bench/run/data_export.py b/python/cuvs_bench/cuvs_bench/run/data_export.py new file mode 100644 index 000000000..997dab500 --- /dev/null +++ b/python/cuvs_bench/cuvs_bench/run/data_export.py @@ -0,0 +1,326 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os +import traceback +import warnings + +import pandas as pd + +skip_build_cols = set( + [ + "algo_name", + "index_name", + "time", + "name", + "family_index", + "per_family_instance_index", + "run_name", + "run_type", + "repetitions", + "repetition_index", + "iterations", + "real_time", + "time_unit", + "index_size", + ] +) + +skip_search_cols = ( + set(["recall", "qps", "latency", "items_per_second", "Recall", "Latency"]) + | skip_build_cols +) + +metrics = { + "k-nn": { + "description": "Recall", + "worst": float("-inf"), + "lim": [0.0, 1.03], + }, + "throughput": { + "description": "Queries per second (1/s)", + "worst": float("-inf"), + }, + "latency": { + "description": "Search Latency (s)", + "worst": float("inf"), + }, +} + + +def read_json_files(dataset, dataset_path, method): + """ + Yield file paths, algo names, and loaded JSON data as pandas DataFrames. + + Parameters + ---------- + dataset : str + The name of the dataset. + dataset_path : str + The base path where datasets are stored. + method : str + The method subdirectory to search within (e.g., "build" or "search"). + + Yields + ------ + tuple + A tuple containing the file path, algorithm name, and the + DataFrame of JSON content. + """ + dir_path = os.path.join(dataset_path, dataset, "result", method) + for file in os.listdir(dir_path): + if file.endswith(".json"): + file_path = os.path.join(dir_path, file) + try: + with open(file_path, "r", encoding="ISO-8859-1") as f: + data = json.load(f) + df = pd.DataFrame(data["benchmarks"]) + algo_name = tuple(file.split(",")[:2]) + yield file_path, algo_name, df + except Exception as e: + print(f"Error processing file {file}: {e}. Skipping...") + traceback.print_exc() + + +def clean_algo_name(algo_name): + """ + Clean and format the algorithm name. + + Parameters + ---------- + algo_name : tuple + Tuple containing parts of the algorithm name. + + Returns + ------- + str + Cleaned algorithm name. + """ + + return algo_name[0] if "base" in algo_name[1] else "_".join(algo_name) + + +def write_csv(file, algo_name, df, extra_columns=None, skip_cols=None): + """ + Write a DataFrame to CSV with specified columns skipped. + + Parameters + ---------- + file : str + The path to the file to be written. + algo_name : str + The algorithm name to be included in the CSV. + df : pandas.DataFrame + The DataFrame containing the data to write. + extra_columns : list, optional + List of extra columns to add (default is None). + skip_cols : set, optional + Set of columns to skip when writing to CSV (default is None). + """ + df["name"] = df["name"].str.split("/").str[0] + write_data = pd.DataFrame( + { + "algo_name": [algo_name] * len(df), + "index_name": df["name"], + "time": df["real_time"], + } + ) + # Add extra columns if provided + if extra_columns: + for col in extra_columns: + write_data[col] = None + # Include columns not in skip list + for name in df: + if name not in skip_cols: + write_data[name] = df[name] + write_data.to_csv(file.replace(".json", ".csv"), index=False) + + +def convert_json_to_csv_build(dataset, dataset_path): + """ + Convert build JSON files to CSV format. + + Parameters + ---------- + dataset : str + The name of the dataset. + dataset_path : str + The base path where datasets are stored. + """ + for file, algo_name, df in read_json_files(dataset, dataset_path, "build"): + try: + algo_name = clean_algo_name(algo_name) + write_csv(file, algo_name, df, skip_cols=skip_build_cols) + except Exception as e: + print(f"Error processing build file {file}: {e}. Skipping...") + traceback.print_exc() + + +def append_build_data(write, build_file): + """ + Append build data to the search DataFrame. + + Parameters + ---------- + write : pandas.DataFrame + The DataFrame containing the search data to which build + data will be appended. + build_file : str + The file path to the build CSV file. + """ + if os.path.exists(build_file): + build_df = pd.read_csv(build_file) + write_ncols = len(write.columns) + # Initialize columns for build data + build_columns = [ + "build time", + "build threads", + "build cpu_time", + "build GPU", + ] + write = write.assign(**{col: None for col in build_columns}) + # Append additional columns if available + for col_name in build_df.columns[6:]: + write[col_name] = None + # Match build rows with search rows by index_name + for s_index, search_row in write.iterrows(): + for b_index, build_row in build_df.iterrows(): + if search_row["index_name"] == build_row["index_name"]: + write.iloc[s_index, write_ncols:] = build_row[2:].values + break + else: + warnings.warn( + f"Build CSV not found for {build_file}, build params not appended." + ) + + +def convert_json_to_csv_search(dataset, dataset_path): + """ + Convert search JSON files to CSV format. + + Parameters + ---------- + dataset : str + The name of the dataset. + dataset_path : str + The base path where datasets are stored. + """ + for file, algo_name, df in read_json_files( + dataset, dataset_path, "search" + ): + try: + build_file = os.path.join( + dataset_path, + dataset, + "result", + "build", + f"{','.join(algo_name)}.csv", + ) + algo_name = clean_algo_name(algo_name) + df["name"] = df["name"].str.split("/").str[0] + write_data = pd.DataFrame( + { + "algo_name": [algo_name] * len(df), + "index_name": df["name"], + "recall": df["Recall"], + "throughput": df["items_per_second"], + "latency": df["Latency"], + } + ) + # Append build data + append_build_data(write_data, build_file) + # Write search data and compute frontiers + write_data.to_csv(file.replace(".json", ",raw.csv"), index=False) + write_frontier(file, write_data, "throughput") + write_frontier(file, write_data, "latency") + except Exception as e: + print(f"Error processing search file {file}: {e}. Skipping...") + traceback.print_exc() + + +def create_pointset(data, xn, yn): + """ + Create a pointset by sorting and filtering data based on metrics. + + Parameters + ---------- + data : list + A list of data points. + xn : str + X-axis metric name. + yn : str + Y-axis metric name. + + Returns + ------- + list + Filtered list of data points sorted by x and y metrics. + """ + xm, ym = metrics[xn], metrics[yn] + rev_x, rev_y = (-1 if xm["worst"] < 0 else 1), ( + -1 if ym["worst"] < 0 else 1 + ) + # Sort data based on x and y metrics + data.sort(key=lambda t: (rev_y * t[4], rev_x * t[2])) + lines = [] + last_x = xm["worst"] + comparator = ( + (lambda xv, lx: xv > lx) if last_x < 0 else (lambda xv, lx: xv < lx) + ) + for d in data: + if comparator(d[2], last_x): + last_x = d[2] + lines.append(d) + return lines + + +def get_frontier(df, metric): + """ + Get the frontier of the data for a given metric. + + Parameters + ---------- + df : pandas.DataFrame + The DataFrame containing the data. + metric : str + The metric for which to compute the frontier. + + Returns + ------- + pandas.DataFrame + DataFrame containing the frontier points for the given metric. + """ + lines = create_pointset(df.values.tolist(), "k-nn", metric) + return pd.DataFrame(lines, columns=df.columns) + + +def write_frontier(file, write_data, metric): + """ + Write the frontier data to CSV for a given metric. + + Parameters + ---------- + file : str + The path to the file to write the frontier data. + write_data : pandas.DataFrame + The DataFrame containing the original data. + metric : str + The metric for which the frontier is computed + (e.g., "throughput", "latency"). + """ + frontier_data = get_frontier(write_data, metric) + frontier_data.to_csv(file.replace(".json", f",{metric}.csv"), index=False) diff --git a/python/cuvs_bench/cuvs_bench/run/run.py b/python/cuvs_bench/cuvs_bench/run/run.py index dbedcc183..a65d4b5fe 100644 --- a/python/cuvs_bench/cuvs_bench/run/run.py +++ b/python/cuvs_bench/cuvs_bench/run/run.py @@ -580,6 +580,7 @@ def run_benchmark( search_mode: str, search_threads: int, dry_run: bool, + data_export: bool, raft_log_level: int, ) -> None: """