Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNMG ANN #231

Merged
merged 67 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
cc1e45a
SNMG ANN
viclafargue Jul 18, 2024
279c345
nccl_clique as header
viclafargue Jul 18, 2024
b10d01d
update linking, build system and conda env
viclafargue Jul 18, 2024
d178155
Answered review
viclafargue Jul 19, 2024
4bc9d9c
Merge branch 'branch-24.08' into snmg-ann
viclafargue Jul 19, 2024
1459248
Apply review
viclafargue Jul 22, 2024
f3a65fc
Answer reviews + small changes
viclafargue Jul 25, 2024
ee2dcc3
Adding documentation
viclafargue Jul 26, 2024
5236cc2
Merge branch 'branch-24.08' into snmg-ann
viclafargue Jul 26, 2024
60bd621
removing unnecessary omp barriers
viclafargue Jul 29, 2024
17f62d2
int64_t change
viclafargue Jul 30, 2024
f523251
tree reduction merge implementation
viclafargue Jul 30, 2024
3e79a44
tree merge solidification
viclafargue Jul 31, 2024
d4cabe0
Adding bench code
viclafargue Aug 6, 2024
37f9755
Merge branch 'branch-24.08' into snmg-ann
viclafargue Aug 6, 2024
504b0c3
Auto max throughput for replicated search
viclafargue Aug 9, 2024
2d0a950
improve batching
viclafargue Aug 20, 2024
169eb15
branch-24.10 merge
viclafargue Sep 6, 2024
686f81d
answering reviews 1
viclafargue Sep 6, 2024
c8d3864
Updating params
viclafargue Sep 9, 2024
51291d8
iface free functions
viclafargue Sep 9, 2024
80cf875
free functions
viclafargue Sep 10, 2024
d60e583
NCCL clique from RAFT handle
viclafargue Sep 18, 2024
3419dfa
load balancing mechanism
viclafargue Sep 19, 2024
7970fdc
Merge branch 'branch-24.10' into snmg-ann
viclafargue Sep 19, 2024
6a220b5
update doc
viclafargue Sep 19, 2024
c5e955f
moving iface struct
viclafargue Sep 23, 2024
60fbef1
include fix
viclafargue Sep 23, 2024
5ea9b9b
small fixes
viclafargue Sep 24, 2024
8b0c8c7
RAFT handle update
viclafargue Sep 26, 2024
bcf97c9
RAFT handle update
viclafargue Sep 26, 2024
9418f7e
smallSearchBatchSize as constexpr
viclafargue Sep 27, 2024
fa457f4
Merge branch 'branch-24.10' into snmg-ann
viclafargue Sep 30, 2024
dc2ccdd
add half type
viclafargue Sep 30, 2024
ed68cd8
fix bench
viclafargue Sep 30, 2024
9e659c4
Update build system
viclafargue Oct 2, 2024
f3bc98a
update iface to only expose device-only search function
viclafargue Oct 2, 2024
d9a83e5
Adding replicated search mode (load-balancer and round-robin)
viclafargue Oct 2, 2024
e6a73c6
CAGRA bench consolidation
viclafargue Oct 2, 2024
d68f572
Adding --mg to conda recipes
viclafargue Oct 2, 2024
6a673c3
resolving merge conflict
viclafargue Oct 2, 2024
55fbb36
enable multi-GPU by default, add a CMake option to control it
jameslamb Oct 2, 2024
5649a49
empty commit to re-trigger CI
jameslamb Oct 2, 2024
a208d49
Merge branch 'branch-24.10' into snmg-ann
jameslamb Oct 2, 2024
e0c232a
revert CUVS_EXPLICIT_INSTANTIATE_ONLY re-introduction
jameslamb Oct 2, 2024
1a5a2f2
Merge branch 'snmg-ann' of github.com:viclafargue/cuvs into snmg-ann
jameslamb Oct 2, 2024
fef0fc9
Removing std comms
cjnolet Oct 2, 2024
c028dca
Remove UCP
cjnolet Oct 2, 2024
a43c4f9
Adding nccl to rapids_build
cjnolet Oct 2, 2024
3b2feb7
add back NCCL dependency, pin to NCCL>=2.19
jameslamb Oct 2, 2024
d77a4e9
Revert "Removing std comms"
cjnolet Oct 3, 2024
4af2c2e
Renaming comms source file
cjnolet Oct 3, 2024
cecb372
Merge branch 'snmg-ann' of github.com:viclafargue/cuvs into snmg-ann
cjnolet Oct 3, 2024
f7a73fd
Merge branch 'branch-24.10' into snmg-ann
cjnolet Oct 3, 2024
ceb6287
Adding ucp to cmakelists
cjnolet Oct 3, 2024
ce37b71
Merge branch 'snmg-ann' of github.com:viclafargue/cuvs into snmg-ann
cjnolet Oct 3, 2024
1f0f5e9
MOre renames
cjnolet Oct 3, 2024
cb8ed0c
Adding libucxx
cjnolet Oct 3, 2024
fe5b6f8
Adding ucxx
cjnolet Oct 3, 2024
e257282
Adding to run time
cjnolet Oct 3, 2024
b6cb776
Adding libucxx to libcuvs y
cjnolet Oct 3, 2024
ac26507
use raw nccl calls
viclafargue Oct 3, 2024
4a10a6c
Removing ucp from cmake
cjnolet Oct 3, 2024
c9515d5
changing serialization path and disabling sharded mode testing
viclafargue Oct 3, 2024
d77704c
round robin check improvment + temporary disable of CAGRA
viclafargue Oct 3, 2024
c2c810c
Merge branch 'branch-24.10' into snmg-ann
viclafargue Oct 3, 2024
4e7398a
fix merge
viclafargue Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ ivf_pq_index

# cuvs_bench
datasets/
/*.json
/*.json
9 changes: 8 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ARGS=$*
# scripts, and that this script resides in the repo dir!
REPODIR=$(cd $(dirname $0); pwd)

VALIDARGS="clean libcuvs python rust docs tests bench-ann examples --uninstall -v -g -n --compile-static-lib --allgpuarch --no-nvtx --show_depr_warn --incl-cache-stats --time -h"
VALIDARGS="clean libcuvs python rust docs tests bench-ann examples --uninstall -v -g -n --compile-static-lib --allgpuarch --mg --no-nvtx --show_depr_warn --incl-cache-stats --time -h"
HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<tool>] [--limit-tests=<targets>] [--limit-bench-ann=<targets>] [--build-metrics=<filename>]
where <target> is:
clean - remove all existing build artifacts and configuration (start over)
Expand All @@ -40,6 +40,7 @@ HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<to
--limit-tests - semicolon-separated list of test executables to compile (e.g. NEIGHBORS_TEST;CLUSTER_TEST)
--limit-bench-ann - semicolon-separated list of ann benchmark executables to compute (e.g. HNSWLIB_ANN_BENCH;RAFT_IVF_PQ_ANN_BENCH)
--allgpuarch - build for all supported GPU architectures
--mg - build MG features
--no-nvtx - disable nvtx (profiling markers), but allow enabling it in downstream projects
--show_depr_warn - show cmake deprecation warnings
--build-metrics - filename for generating build metrics report for libcuvs
Expand All @@ -65,6 +66,7 @@ CMAKE_LOG_LEVEL=""
VERBOSE_FLAG=""
BUILD_ALL_GPU_ARCH=0
BUILD_TESTS=ON
BUILD_MG_ALGOS=OFF
BUILD_TYPE=Release
COMPILE_LIBRARY=OFF
INSTALL_TARGET=install
Expand Down Expand Up @@ -261,6 +263,10 @@ if hasArg --allgpuarch; then
BUILD_ALL_GPU_ARCH=1
fi

if hasArg --mg; then
BUILD_MG_ALGOS=ON
fi

if hasArg tests || (( ${NUMARGS} == 0 )); then
BUILD_TESTS=ON
CMAKE_TARGET="${CMAKE_TARGET};${TEST_TARGETS}"
Expand Down Expand Up @@ -353,6 +359,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libcuvs || hasArg docs || hasArg tests || has
-DBUILD_C_TESTS=${BUILD_TESTS} \
-DBUILD_CUVS_BENCH=${BUILD_CUVS_BENCH} \
-DBUILD_CPU_ONLY=${BUILD_CPU_ONLY} \
-DBUILD_MG_ALGOS=${BUILD_MG_ALGOS} \
-DCMAKE_MESSAGE_LOG_LEVEL=${CMAKE_LOG_LEVEL} \
${CACHE_ARGS} \
${EXTRA_CMAKE_ARGS}
Expand Down
39 changes: 39 additions & 0 deletions cpp/CMakeLists.txt
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,24 @@ target_compile_options(
"$<$<COMPILE_LANGUAGE:CUDA>:${CUVS_CUDA_FLAGS}>"
)

if(BUILD_MG_ALGOS)
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
set(CUVS_MG_ALGOS
src/neighbors/mg/mg_flat_float_int64_t.cu
src/neighbors/mg/mg_flat_int8_t_int64_t.cu
src/neighbors/mg/mg_flat_uint8_t_int64_t.cu
src/neighbors/mg/mg_pq_float_int64_t.cu
src/neighbors/mg/mg_pq_half_int64_t.cu
src/neighbors/mg/mg_pq_int8_t_int64_t.cu
src/neighbors/mg/mg_pq_uint8_t_int64_t.cu
src/neighbors/mg/mg_cagra_float_uint32_t.cu
src/neighbors/mg/mg_cagra_half_uint32_t.cu
src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu
src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu
src/neighbors/mg/std_comms.cu
src/neighbors/mg/omp_checks.cu
)
endif()

add_library(
cuvs SHARED
src/cluster/kmeans_balanced_fit_float.cu
Expand Down Expand Up @@ -358,6 +376,17 @@ add_library(
src/neighbors/cagra_serialize_half.cu
src/neighbors/cagra_serialize_int8.cu
src/neighbors/cagra_serialize_uint8.cu
src/neighbors/iface/iface_cagra_float_uint32_t.cu
src/neighbors/iface/iface_cagra_half_uint32_t.cu
src/neighbors/iface/iface_cagra_int8_t_uint32_t.cu
src/neighbors/iface/iface_cagra_uint8_t_uint32_t.cu
src/neighbors/iface/iface_flat_float_int64_t.cu
src/neighbors/iface/iface_flat_int8_t_int64_t.cu
src/neighbors/iface/iface_flat_uint8_t_int64_t.cu
src/neighbors/iface/iface_pq_float_int64_t.cu
src/neighbors/iface/iface_pq_half_int64_t.cu
src/neighbors/iface/iface_pq_int8_t_int64_t.cu
src/neighbors/iface/iface_pq_uint8_t_int64_t.cu
src/neighbors/detail/cagra/cagra_build.cpp
src/neighbors/detail/cagra/topk_for_cagra/topk.cu
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/hnsw.cpp>
Expand Down Expand Up @@ -421,8 +450,13 @@ add_library(
src/selection/select_k_half_uint32_t.cu
src/stats/silhouette_score.cu
src/stats/trustworthiness_score.cu
${CUVS_MG_ALGOS}
)

if(BUILD_MG_ALGOS)
target_compile_definitions(cuvs PUBLIC CUVS_BUILD_MG_ALGOS)
endif()

target_compile_definitions(cuvs PRIVATE "CUVS_EXPLICIT_INSTANTIATE_ONLY")

target_compile_options(
Expand Down Expand Up @@ -454,11 +488,16 @@ if(NOT BUILD_CPU_ONLY)
${CUVS_CUSPARSE_DEPENDENCY} ${CUVS_CURAND_DEPENDENCY}
)

if(BUILD_MG_ALGOS)
set(CUVS_COMMS_DEPENDENCY ucp ucs ucxx nccl)
endif()

# Keep cuVS as lightweight as possible. Only CUDA libs and rmm should be used in global target.
target_link_libraries(
cuvs
PUBLIC rmm::rmm raft::raft ${CUVS_CTK_MATH_DEPENDENCIES}
PRIVATE nvidia::cutlass::cutlass $<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX> cuvs-cagra-search
${CUVS_COMMS_DEPENDENCY}
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
)
endif()

Expand Down
18 changes: 18 additions & 0 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ option(CUVS_ANN_BENCH_USE_CUVS_BRUTE_FORCE "Include cuVS brute force knn in benc
option(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB "Include cuVS CAGRA with HNSW search in benchmark" ON)
option(CUVS_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON)
option(CUVS_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" OFF)
option(CUVS_ANN_BENCH_USE_CUVS_MG "Include cuVS ann mg algorithm in benchmark" ${BUILD_MG_ALGOS})
option(CUVS_ANN_BENCH_SINGLE_EXE
"Make a single executable with benchmark as shared library modules" OFF
)
Expand All @@ -55,6 +56,7 @@ if(BUILD_CPU_ONLY)
set(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB OFF)
set(CUVS_ANN_BENCH_USE_GGNN OFF)
set(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE OFF)
set(CUVS_ANN_BENCH_USE_CUVS_MG OFF)
else()
set(CUVS_FAISS_ENABLE_GPU ON)
endif()
Expand All @@ -66,6 +68,7 @@ if(CUVS_ANN_BENCH_USE_CUVS_IVF_PQ
OR CUVS_ANN_BENCH_USE_CUVS_CAGRA
OR CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB
OR CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE
OR CUVS_ANN_BENCH_USE_CUVS_MG
)
set(CUVS_ANN_BENCH_USE_CUVS ON)
endif()
Expand Down Expand Up @@ -247,6 +250,21 @@ if(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB)
)
endif()

if(CUVS_ANN_BENCH_USE_CUVS_MG)
ConfigureAnnBench(
NAME
CUVS_MG
PATH
src/cuvs/cuvs_benchmark.cu
$<$<BOOL:${CUVS_ANN_BENCH_USE_CUVS_MG}>:src/cuvs/cuvs_mg_ivf_flat.cu>
$<$<BOOL:${CUVS_ANN_BENCH_USE_CUVS_MG}>:src/cuvs/cuvs_mg_ivf_pq.cu>
$<$<BOOL:${CUVS_ANN_BENCH_USE_CUVS_MG}>:src/cuvs/cuvs_mg_cagra.cu>
LINKS
cuvs
nccl
)
endif()

message("CUVS_FAISS_TARGETS: ${CUVS_FAISS_TARGETS}")
message("CUDAToolkit_LIBRARY_DIR: ${CUDAToolkit_LIBRARY_DIR}")
if(CUVS_ANN_BENCH_USE_FAISS_CPU_FLAT)
Expand Down
18 changes: 15 additions & 3 deletions cpp/bench/ann/src/cuvs/cuvs_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,18 @@ extern template class cuvs::bench::cuvs_cagra<uint8_t, uint32_t>;
extern template class cuvs::bench::cuvs_cagra<int8_t, uint32_t>;
#endif

#ifdef CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT
#ifdef CUVS_ANN_BENCH_USE_CUVS_MG
#include "cuvs_ivf_flat_wrapper.h"
#include "cuvs_mg_ivf_flat_wrapper.h"

#include "cuvs_ivf_pq_wrapper.h"
#include "cuvs_mg_ivf_pq_wrapper.h"

#include "cuvs_cagra_wrapper.h"
#include "cuvs_mg_cagra_wrapper.h"
#endif

#if defined(CUVS_ANN_BENCH_USE_CUVS_IVF_FLAT) || defined(CUVS_ANN_BENCH_USE_CUVS_MG)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename cuvs::bench::cuvs_ivf_flat<T, IdxT>::build_param& param)
Expand All @@ -64,7 +75,7 @@ void parse_search_param(const nlohmann::json& conf,
#endif

#if defined(CUVS_ANN_BENCH_USE_CUVS_IVF_PQ) || defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA) || \
defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB)
defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB) || defined(CUVS_ANN_BENCH_USE_CUVS_MG)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename cuvs::bench::cuvs_ivf_pq<T, IdxT>::build_param& param)
Expand Down Expand Up @@ -130,7 +141,8 @@ void parse_search_param(const nlohmann::json& conf,
}
#endif

#if defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA) || defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB)
#if defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA) || defined(CUVS_ANN_BENCH_USE_CUVS_CAGRA_HNSWLIB) || \
defined(CUVS_ANN_BENCH_USE_CUVS_MG)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf, cuvs::neighbors::nn_descent::index_params& param)
{
Expand Down
89 changes: 89 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,43 @@

namespace cuvs::bench {

#ifdef CUVS_ANN_BENCH_USE_CUVS_MG
void add_distribution_mode(cuvs::neighbors::mg::distribution_mode* dist_mode,
const nlohmann::json& conf)
{
if (conf.contains("distribution_mode")) {
std::string distribution_mode = conf.at("distribution_mode");
if (distribution_mode == "replicated") {
*dist_mode = cuvs::neighbors::mg::distribution_mode::REPLICATED;
} else if (distribution_mode == "sharded") {
*dist_mode = cuvs::neighbors::mg::distribution_mode::SHARDED;
} else {
throw std::runtime_error("invalid value for distribution_mode");
}
} else {
// default
*dist_mode = cuvs::neighbors::mg::distribution_mode::SHARDED;
}
};

void add_merge_mode(cuvs::neighbors::mg::sharded_merge_mode* merge_mode, const nlohmann::json& conf)
{
if (conf.contains("merge_mode")) {
std::string sharded_merge_mode = conf.at("merge_mode");
if (sharded_merge_mode == "tree_merge") {
*merge_mode = cuvs::neighbors::mg::sharded_merge_mode::TREE_MERGE;
} else if (sharded_merge_mode == "merge_on_root_rank") {
*merge_mode = cuvs::neighbors::mg::sharded_merge_mode::MERGE_ON_ROOT_RANK;
} else {
throw std::runtime_error("invalid value for merge_mode");
}
} else {
// default
*merge_mode = cuvs::neighbors::mg::sharded_merge_mode::TREE_MERGE;
}
};
#endif

template <typename T>
auto create_algo(const std::string& algo_name,
const std::string& distance,
Expand Down Expand Up @@ -71,6 +108,32 @@ auto create_algo(const std::string& algo_name,
parse_build_param<T, uint32_t>(conf, param);
a = std::make_unique<cuvs::bench::cuvs_cagra<T, uint32_t>>(metric, dim, param);
}
#endif
#ifdef CUVS_ANN_BENCH_USE_CUVS_MG
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, uint8_t> ||
std::is_same_v<T, int8_t>) {
if (algo_name == "raft_mg_ivf_flat" || algo_name == "cuvs_mg_ivf_flat") {
typename cuvs::bench::cuvs_mg_ivf_flat<T, int64_t>::build_param param;
parse_build_param<T, int64_t>(conf, param);
add_distribution_mode(&param.mode, conf);
a = std::make_unique<cuvs::bench::cuvs_mg_ivf_flat<T, int64_t>>(metric, dim, param);
}
}

if (algo_name == "raft_mg_ivf_pq" || algo_name == "cuvs_mg_ivf_pq") {
typename cuvs::bench::cuvs_mg_ivf_pq<T, int64_t>::build_param param;
parse_build_param<T, int64_t>(conf, param);
add_distribution_mode(&param.mode, conf);
a = std::make_unique<cuvs::bench::cuvs_mg_ivf_pq<T, int64_t>>(metric, dim, param);
}

if (algo_name == "raft_mg_cagra" || algo_name == "cuvs_mg_cagra") {
typename cuvs::bench::cuvs_mg_cagra<T, uint32_t>::build_param param;
parse_build_param<T, uint32_t>(conf, param);
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
add_distribution_mode(&param.mode, conf);
a = std::make_unique<cuvs::bench::cuvs_mg_cagra<T, uint32_t>>(metric, dim, param);
}

#endif

if (!a) { throw std::runtime_error("invalid algo: '" + algo_name + "'"); }
Expand Down Expand Up @@ -113,6 +176,32 @@ auto create_search_param(const std::string& algo_name, const nlohmann::json& con
return param;
}
#endif
#ifdef CUVS_ANN_BENCH_USE_CUVS_MG
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, uint8_t> ||
std::is_same_v<T, int8_t>) {
if (algo_name == "raft_mg_ivf_flat" || algo_name == "cuvs_mg_ivf_flat") {
auto param =
std::make_unique<typename cuvs::bench::cuvs_mg_ivf_flat<T, int64_t>::search_param>();
parse_search_param<T, int64_t>(conf, *param);
add_merge_mode(&param->merge_mode, conf);
return param;
}
}

if (algo_name == "raft_mg_ivf_pq" || algo_name == "cuvs_mg_ivf_pq") {
auto param = std::make_unique<typename cuvs::bench::cuvs_mg_ivf_pq<T, int64_t>::search_param>();
parse_search_param<T, int64_t>(conf, *param);
add_merge_mode(&param->merge_mode, conf);
return param;
}

if (algo_name == "raft_mg_cagra" || algo_name == "cuvs_mg_cagra") {
auto param = std::make_unique<typename cuvs::bench::cuvs_mg_cagra<T, uint32_t>::search_param>();
parse_search_param<T, uint32_t>(conf, *param);
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
add_merge_mode(&param->merge_mode, conf);
return param;
}
#endif

// else
throw std::runtime_error("invalid algo: '" + algo_name + "'");
Expand Down
38 changes: 18 additions & 20 deletions cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ class cuvs_cagra : public algo<T>, public algo_gpu {
std::optional<float> ivf_pq_refine_rate = std::nullopt;
std::optional<cuvs::neighbors::ivf_pq::index_params> ivf_pq_build_params = std::nullopt;
std::optional<cuvs::neighbors::ivf_pq::search_params> ivf_pq_search_params = std::nullopt;

void prepare_build_params(const raft::extent_2d<IdxT>& dataset_extents)
{
if (algo == CagraBuildAlgo::kIvfPq) {
auto pq_params = cuvs::neighbors::cagra::graph_build_params::ivf_pq_params(
dataset_extents, cagra_params.metric);
if (ivf_pq_build_params) { pq_params.build_params = *ivf_pq_build_params; }
if (ivf_pq_search_params) { pq_params.search_params = *ivf_pq_search_params; }
if (ivf_pq_refine_rate) { pq_params.refinement_rate = *ivf_pq_refine_rate; }
cagra_params.graph_build_params = pq_params;
} else if (algo == CagraBuildAlgo::kNnDescent) {
auto nn_params = cuvs::neighbors::cagra::graph_build_params::nn_descent_params(
cagra_params.intermediate_graph_degree);
if (nn_descent_params) { nn_params = *nn_descent_params; }
cagra_params.graph_build_params = nn_params;
}
}
};

cuvs_cagra(Metric metric, int dim, const build_param& param, int concurrent_searches = 1)
Expand Down Expand Up @@ -168,28 +185,9 @@ template <typename T, typename IdxT>
void cuvs_cagra<T, IdxT>::build(const T* dataset, size_t nrow)
{
auto dataset_extents = raft::make_extents<IdxT>(nrow, dimension_);
index_params_.prepare_build_params(dataset_extents);

auto& params = index_params_.cagra_params;

if (index_params_.algo == CagraBuildAlgo::kIvfPq) {
auto pq_params =
cuvs::neighbors::cagra::graph_build_params::ivf_pq_params(dataset_extents, params.metric);
if (index_params_.ivf_pq_build_params) {
pq_params.build_params = *index_params_.ivf_pq_build_params;
}
if (index_params_.ivf_pq_search_params) {
pq_params.search_params = *index_params_.ivf_pq_search_params;
}
if (index_params_.ivf_pq_refine_rate) {
pq_params.refinement_rate = *index_params_.ivf_pq_refine_rate;
}
params.graph_build_params = pq_params;
} else if (index_params_.algo == CagraBuildAlgo::kNnDescent) {
auto nn_params = cuvs::neighbors::cagra::graph_build_params::nn_descent_params(
params.intermediate_graph_degree);
if (index_params_.nn_descent_params) { nn_params = *index_params_.nn_descent_params; }
params.graph_build_params = nn_params;
}
auto dataset_view_host =
raft::make_mdspan<const T, IdxT, raft::row_major, true, false>(dataset, dataset_extents);
auto dataset_view_device =
Expand Down
23 changes: 23 additions & 0 deletions cpp/bench/ann/src/cuvs/cuvs_mg_cagra.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuvs_mg_cagra_wrapper.h"

namespace cuvs::bench {
template class cuvs_mg_cagra<float, uint32_t>;
template class cuvs_mg_cagra<half, uint32_t>;
template class cuvs_mg_cagra<uint8_t, uint32_t>;
template class cuvs_mg_cagra<int8_t, uint32_t>;
} // namespace cuvs::bench
Loading
Loading