Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/multidevice/cuda_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace nvfuser {

enum class P2pProtocol { Get, Put };

P2pProtocol getP2pProtocol();

std::ostream& operator<<(std::ostream& os, P2pProtocol protocol);

// Returns the prescribed P2P protocol based on NVFUSER_ENABLE option
Expand Down
91 changes: 27 additions & 64 deletions tests/cpp/test_multidevice_host_ir_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <host_ir/evaluator.h>
#include <host_ir/host_ir.h>
#include <ir/utils.h>
#include <multidevice/cuda_p2p.h>
#include <ops/all_ops.h>
#include <tests/cpp/multidevice.h>

Expand Down Expand Up @@ -1094,10 +1095,13 @@ TEST_F(

TEST_F(
RingAllgatherOverlapTest,
DISABLED_RingAllgatherBasedPipeliningHostIRImplementationCudaIpc) {
RingAllgatherBasedPipeliningHostIRImplementationCudaIpc) {
if (communicator_->size() == 1) {
GTEST_SKIP() << "Skipping test for single device";
}
if (getP2pProtocol() == P2pProtocol::Put) {
GTEST_SKIP() << "Skipping test for CudaIpc P2pProtocol::Put";
}

auto hic = std::make_unique<hir::HostIrContainer>();
FusionGuard::setCurFusion(hic.get());
Expand Down Expand Up @@ -1141,10 +1145,13 @@ TEST_F(
CircularBufferLoopStage::NotApplicable,
/*circular_buffer_loop_stage_depth=*/0);

auto* stream_index =
mod(add(i, j), IrBuilder::create<Val>(params.number_of_streams));
auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(
IrBuilder::create<hir::Stream>(stream_index));
auto* num_streams = IrBuilder::create<Val>(params.number_of_streams);
auto* curr_stream_index = mod(add(i, j), num_streams);
auto* next_stream_index = mod(add(i, add(j, step_j)), num_streams);
auto* set_curr_stream = IrBuilder::create<hir::SetCurrentStream>(
IrBuilder::create<hir::Stream>(curr_stream_index));
auto* set_next_stream = IrBuilder::create<hir::SetCurrentStream>(
IrBuilder::create<hir::Stream>(next_stream_index));

auto* my_device_index_val = IrBuilder::create<Val>(my_device_index_);
auto* number_of_steps_per_ring_val =
Expand Down Expand Up @@ -1194,33 +1201,34 @@ TEST_F(
std::vector<P2PCommunication*> grouped_communications = {send, recv};
auto share_mem_handles = IrBuilder::create<hir::ShareMemHandles>(
std::move(grouped_communications));
auto* wait_send = IrBuilder::create<hir::Wait>(send);
auto* wait_recv = IrBuilder::create<hir::Wait>(recv);

auto* comm_cond = ne(j, sub(stop_j, hic->oneVal()));
auto* comm_predicate = IrBuilder::create<kir::Predicate>(comm_cond);
auto* if_not_last_ring_step_post_comms =
IrBuilder::create<kir::IfThenElse>(comm_predicate);

// Nonblocking--just signals the buffer is ready for the get transfer
if_not_last_ring_step_post_comms->thenBody().push_back(send);
if_not_last_ring_step_post_comms->thenBody().push_back(set_next_stream);
// Block in recvPost on the next stream to do the get transfer
if_not_last_ring_step_post_comms->thenBody().push_back(recv);
if_not_last_ring_step_post_comms->thenBody().push_back(set_curr_stream);

auto* cond = ne(j, hic->zeroVal());
auto* wait_predicate = IrBuilder::create<kir::Predicate>(cond);
auto* if_not_first_ring_step_wait =
IrBuilder::create<kir::IfThenElse>(wait_predicate);
if_not_first_ring_step_wait->thenBody().push_back(wait_send);
if_not_first_ring_step_wait->thenBody().push_back(wait_recv);
// For the get protocol, recvWait is a NOP
// At the same time, sendWait will block waiting for the buffer to be
// IpcSemaphore::kReady but since on this stream we recvPosted the buffer last
// iteration, when that finishes it will be marked kReady anyways. So waiting
// for it to be kReady is unnecessary

std::vector<Expr*> loop_j_body = {
set_stream,
set_curr_stream,
tmp1->definition(),
tmp2->definition(),
tmp3->definition(),
tva_j_curr_slice->definition(),
tva_j_next_slice->definition(),
tvc_j->definition(),
share_mem_handles,
if_not_first_ring_step_wait,
if_not_last_ring_step_post_comms,
mm};
for (Expr* expr : loop_j_body) {
Expand All @@ -1230,30 +1238,6 @@ TEST_F(

hic->pushBackTopLevelExprs(for_loop_i);

// Synchronize all streams
auto* i_stream =
IrBuilder::create<Val>(DataType::Index); // running index of the for-loop
auto* start_stream = hic->zeroVal();
auto* stop_stream =
IrBuilder::create<Val>(params.number_of_streams, DataType::Index);
auto* step_stream = hic->oneVal();
auto* for_loop_stream = IrBuilder::create<kir::ForLoop>(
/*IterDomain=*/makeContigConcreteTensor({params.number_of_streams})
->axis(0),
/*index=*/i_stream,
start_stream,
stop_stream,
step_stream,
/*vectorize=*/false,
/*vectorize_shift=*/nullptr,
/*unroll_required=*/false,
CircularBufferLoopStage::NotApplicable,
/*circular_buffer_loop_stage_depth=*/0);
auto* sync_stream = IrBuilder::create<hir::Synchronize>(
IrBuilder::create<hir::Stream>(i_stream));
for_loop_stream->body().push_back(sync_stream);
hic->pushBackTopLevelExprs(for_loop_stream);

hic->addOutput(tmp1);
hic->addOutput(tmp2);
hic->addOutput(tmp3);
Expand All @@ -1274,8 +1258,11 @@ TEST_F(
{tvc_unsharded, tc_unsharded_}};

hie.runWithInput(std::move(inputs));

cudaDeviceSynchronize();
communicator_->barrier();
validate();
cudaDeviceSynchronize();
communicator_->barrier();
}
}

Expand Down Expand Up @@ -1405,30 +1392,6 @@ TEST_F(

hic->pushBackTopLevelExprs(for_loop_i);

// Synchronize all streams
auto* i_stream =
IrBuilder::create<Val>(DataType::Index); // running index of the for-loop
auto* start_stream = hic->zeroVal();
auto* stop_stream =
IrBuilder::create<Val>(params.number_of_streams, DataType::Index);
auto* step_stream = hic->oneVal();
auto* for_loop_stream = IrBuilder::create<kir::ForLoop>(
/*IterDomain=*/makeContigConcreteTensor({params.number_of_streams})
->axis(0),
/*index=*/i_stream,
start_stream,
stop_stream,
step_stream,
/*vectorize=*/false,
/*vectorize_shift=*/nullptr,
/*unroll_required=*/false,
CircularBufferLoopStage::NotApplicable,
/*circular_buffer_loop_stage_depth=*/0);
auto* sync_stream = IrBuilder::create<hir::Synchronize>(
IrBuilder::create<hir::Stream>(i_stream));
for_loop_stream->body().push_back(sync_stream);
hic->pushBackTopLevelExprs(for_loop_stream);

hic->addOutput(tmp1);
hic->addOutput(tmp2);
hic->addOutput(tmp3);
Expand Down
Loading