Skip to content

Commit

Permalink
use raw nccl calls
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Oct 3, 2024
1 parent b6cb776 commit ac26507
Show file tree
Hide file tree
Showing 15 changed files with 63 additions and 72 deletions.
1 change: 0 additions & 1 deletion conda/environments/all_cuda-118_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dependencies:
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- librmm==24.10.*,>=0.0.0a0
- libucxx==0.40.*,>=0.0.0a0
- make
- nccl>=2.19
- ninja
Expand Down
1 change: 0 additions & 1 deletion conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dependencies:
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- librmm==24.10.*,>=0.0.0a0
- libucxx==0.40.*,>=0.0.0a0
- make
- nccl>=2.19
- ninja
Expand Down
1 change: 0 additions & 1 deletion conda/environments/all_cuda-125_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ dependencies:
- libcusolver-dev
- libcusparse-dev
- librmm==24.10.*,>=0.0.0a0
- libucxx==0.40.*,>=0.0.0a0
- make
- nccl>=2.19
- ninja
Expand Down
1 change: 0 additions & 1 deletion conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ dependencies:
- libcusolver-dev
- libcusparse-dev
- librmm==24.10.*,>=0.0.0a0
- libucxx==0.40.*,>=0.0.0a0
- make
- nccl>=2.19
- ninja
Expand Down
1 change: 0 additions & 1 deletion conda/environments/bench_ann_cuda-118_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ dependencies:
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- librmm==24.10.*,>=0.0.0a0
- libucxx==0.40.*,>=0.0.0a0
- matplotlib
- nccl>=2.19
- ninja
Expand Down
1 change: 0 additions & 1 deletion conda/environments/bench_ann_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ dependencies:
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- librmm==24.10.*,>=0.0.0a0
- libucxx==0.40.*,>=0.0.0a0
- matplotlib
- nccl>=2.19
- ninja
Expand Down
1 change: 0 additions & 1 deletion conda/environments/bench_ann_cuda-125_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ dependencies:
- libcusolver-dev
- libcusparse-dev
- librmm==24.10.*,>=0.0.0a0
- libucxx==0.40.*,>=0.0.0a0
- matplotlib
- nccl>=2.19
- ninja
Expand Down
1 change: 0 additions & 1 deletion conda/environments/bench_ann_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ dependencies:
- libcusolver-dev
- libcusparse-dev
- librmm==24.10.*,>=0.0.0a0
- libucxx==0.40.*,>=0.0.0a0
- matplotlib
- nccl>=2.19
- ninja
Expand Down
3 changes: 0 additions & 3 deletions conda/recipes/libcuvs/conda_build_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,3 @@ cuda11_cuda_profiler_api_host_version:

cuda11_cuda_profiler_api_run_version:
- ">=11.4.240,<12"
-
ucxx_version:
- "0.40.*"
6 changes: 0 additions & 6 deletions conda/recipes/libcuvs/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ outputs:
- librmm ={{ minor_version }}
- libraft-headers ={{ minor_version }}
- nccl {{ nccl_version }}
- libucxx {{ ucxx_version }}
- distributed-ucxx {{ ucxx_version }}
- cuda-version ={{ cuda_version }}
{% if cuda_major == "11" %}
- cuda-profiler-api {{ cuda11_cuda_profiler_api_host_version }}
Expand Down Expand Up @@ -274,8 +272,6 @@ outputs:
- librmm ={{ minor_version }}
- libraft-headers ={{ minor_version }}
- nccl {{ nccl_version }}
- libucxx {{ ucxx_version }}
- distributed-ucxx {{ ucxx_version }}
- {{ pin_subpackage('libcuvs', exact=True) }}
- cuda-version ={{ cuda_version }}
{% if cuda_major == "11" %}
Expand Down Expand Up @@ -307,8 +303,6 @@ outputs:
- libcusolver
- libcusparse
{% endif %}
- libucxx {{ ucxx_version }}
- distributed-ucxx {{ ucxx_version }}
- {{ pin_subpackage('libcuvs', exact=True) }}
about:
home: https://rapids.ai/
Expand Down
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ if(BUILD_MG_ALGOS)
src/neighbors/mg/mg_cagra_half_uint32_t.cu
src/neighbors/mg/mg_cagra_int8_t_uint32_t.cu
src/neighbors/mg/mg_cagra_uint8_t_uint32_t.cu
src/neighbors/mg/std_comms.cpp
src/neighbors/mg/omp_checks.cpp
src/neighbors/mg/nccl_comm.cpp
)
endif()

Expand Down
90 changes: 54 additions & 36 deletions cpp/src/neighbors/mg/mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique,
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
auto& ann_if = index.ann_interfaces_[rank];
const auto& comms = resource::get_comms(dev_res);
RAFT_CUDA_TRY(cudaSetDevice(dev_id));

if (rank == clique.root_rank_) { // root rank
Expand All @@ -242,21 +241,25 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique,
dev_res, ann_if, search_params, query_partition, d_neighbors, d_distances);

// wait for other ranks
comms.group_start();
ncclGroupStart();
for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) {
if (from_rank == clique.root_rank_) continue;

batch_offset = from_rank * part_size;
comms.device_recv(in_neighbors.data_handle() + batch_offset,
part_size,
from_rank,
resource::get_cuda_stream(dev_res));
comms.device_recv(in_distances.data_handle() + batch_offset,
part_size,
from_rank,
resource::get_cuda_stream(dev_res));
ncclRecv(in_neighbors.data_handle() + batch_offset,
part_size * sizeof(IdxT),
ncclUint8,
from_rank,
clique.nccl_comms_[rank],
resource::get_cuda_stream(dev_res));
ncclRecv(in_distances.data_handle() + batch_offset,
part_size * sizeof(float),
ncclUint8,
from_rank,
clique.nccl_comms_[rank],
resource::get_cuda_stream(dev_res));
}
comms.group_end();
ncclGroupEnd();
resource::sync_stream(dev_res);
} else { // non-root ranks
auto d_neighbors = raft::make_device_matrix<IdxT, int64_t, row_major>(
Expand All @@ -267,16 +270,20 @@ void sharded_search_with_direct_merge(const raft::comms::nccl_clique& clique,
dev_res, ann_if, search_params, query_partition, d_neighbors.view(), d_distances.view());

// send results to root rank
comms.group_start();
comms.device_send(d_neighbors.data_handle(),
part_size,
clique.root_rank_,
resource::get_cuda_stream(dev_res));
comms.device_send(d_distances.data_handle(),
part_size,
clique.root_rank_,
resource::get_cuda_stream(dev_res));
comms.group_end();
ncclGroupStart();
ncclSend(d_neighbors.data_handle(),
part_size * sizeof(IdxT),
ncclUint8,
clique.root_rank_,
clique.nccl_comms_[rank],
resource::get_cuda_stream(dev_res));
ncclSend(d_distances.data_handle(),
part_size * sizeof(float),
ncclUint8,
clique.root_rank_,
clique.nccl_comms_[rank],
resource::get_cuda_stream(dev_res));
ncclGroupEnd();
resource::sync_stream(dev_res);
}
}
Expand Down Expand Up @@ -345,7 +352,6 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique,
int dev_id = clique.device_ids_[rank];
const raft::device_resources& dev_res = clique.device_resources_[rank];
auto& ann_if = index.ann_interfaces_[rank];
const auto& comms = resource::get_comms(dev_res);
RAFT_CUDA_TRY(cudaSetDevice(dev_id));

int64_t part_size = n_rows_of_current_batch * n_neighbors;
Expand Down Expand Up @@ -381,31 +387,43 @@ void sharded_search_with_tree_merge(const raft::comms::nccl_clique& clique,
while (remaining > 1) {
bool received_something = false;
int64_t offset = radix / 2;
comms.group_start();
ncclGroupStart();
if (rank % radix == 0) // This is one of the receivers
{
int other_id = rank + offset;
if (other_id < index.num_ranks_) // Make sure someone's sending anything
{
comms.device_recv(tmp_neighbors.data_handle() + part_size,
part_size,
other_id,
resource::get_cuda_stream(dev_res));
comms.device_recv(tmp_distances.data_handle() + part_size,
part_size,
other_id,
resource::get_cuda_stream(dev_res));
ncclRecv(tmp_neighbors.data_handle() + part_size,
part_size * sizeof(IdxT),
ncclUint8,
other_id,
clique.nccl_comms_[rank],
resource::get_cuda_stream(dev_res));
ncclRecv(tmp_distances.data_handle() + part_size,
part_size * sizeof(float),
ncclUint8,
other_id,
clique.nccl_comms_[rank],
resource::get_cuda_stream(dev_res));
received_something = true;
}
} else if (rank % radix == offset) // This is one of the senders
{
int other_id = rank - offset;
comms.device_send(
tmp_neighbors.data_handle(), part_size, other_id, resource::get_cuda_stream(dev_res));
comms.device_send(
tmp_distances.data_handle(), part_size, other_id, resource::get_cuda_stream(dev_res));
ncclSend(tmp_neighbors.data_handle(),
part_size * sizeof(IdxT),
ncclUint8,
other_id,
clique.nccl_comms_[rank],
resource::get_cuda_stream(dev_res));
ncclSend(tmp_distances.data_handle(),
part_size * sizeof(float),
ncclUint8,
other_id,
clique.nccl_comms_[rank],
resource::get_cuda_stream(dev_res));
}
comms.group_end();
ncclGroupEnd();

remaining = (remaining + 1) / 2;
radix *= 2;
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/neighbors/mg/nccl_comm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <nccl.h>
#include <raft/core/resources.hpp>

namespace raft::comms {
void build_comms_nccl_only(raft::resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank)
{
}
} // namespace raft::comms
17 changes: 0 additions & 17 deletions cpp/src/neighbors/mg/std_comms.cpp

This file was deleted.

1 change: 0 additions & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ dependencies:
- c-compiler
- cxx-compiler
- nccl>=2.19
- libucxx==0.40.*,>=0.0.0a0
specific:
- output_types: conda
matrices:
Expand Down

0 comments on commit ac26507

Please sign in to comment.