Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4508,6 +4508,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
.append(std::to_string(blocks_per_cluster));
template_args.arg("/*warps_per_block=*/")
.append(std::to_string(warps_per_block));
template_args.arg("/*is_all_reduce=*/")
.append(cluster_reduction->isAllreduce() ? "true" : "false");

ArgumentBuilder func_args;
func_args.arg(gen(output));
Expand Down
6 changes: 1 addition & 5 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,10 +675,6 @@ void IndexLowering::handleClusterReduction(
Val* in) {
NVF_ERROR(ir_utils::isTvOp(rop));

// cluster reduction is only supported for all-reduce
NVF_ERROR(
rop->isAllreduce(), "Cluster reduction is only supported for all-reduce");

// Get mbarrier allocated during allocation pass
auto cluster_mbarrier_tv = GpuLower::current()->clusterReductionMBarrier();
NVF_CHECK(
Expand All @@ -701,7 +697,7 @@ void IndexLowering::handleClusterReduction(
rop->getReductionOpType(),
lowerSrcIndex(rop->init(), rop->out()),
mbarrier_addr,
/*is_all_reduce=*/true);
rop->isAllreduce());

pushBack(cluster_reduction);
GpuLower::current()->propagateExprInfo(rop, back());
Expand Down
1 change: 0 additions & 1 deletion csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2108,7 +2108,6 @@ ClusterReductionOp::ClusterReductionOp(
NVF_ERROR(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
NVF_ERROR(is_all_reduce, "ClusterReductionOp only supports all-reduce");
addInput(mbarrier);
}

Expand Down
167 changes: 132 additions & 35 deletions runtime/cluster.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,93 @@ __device__ __forceinline__ T warpReduce(T val, Func reduction_op) {
return reduce_val;
}

// Cluster reduction in x direction
// Helper function to perform final reduction from shared memory buffer
template <int cluster_size, int warps_per_block, typename T, typename Func>
__device__ __forceinline__ T finalBufferReduce(
T init,
T* reduction_buffer,
uint32_t lane_idx,
Func reduction_op) {
T block_reduce_val = init;
constexpr int num_iter = (warps_per_block * cluster_size + 31) / 32;
#pragma unroll
for (int i = 0; i < num_iter; i++) {
int idx = lane_idx + i * 32;
if (idx < cluster_size * warps_per_block) {
reduction_op(block_reduce_val, reduction_buffer[idx]);
}
}
return warpReduce(block_reduce_val, reduction_op);
}

// Helper function to setup barrier with expected transfer bytes
template <int cluster_size, int warps_per_block, typename T>
__device__ __forceinline__ void setupBarrierExpectTX(
uint32_t barrier_smem_addr,
uint32_t warp_idx) {
if (warp_idx == 0 && Hopper::electSync(4294967295U)) {
uint32_t expected_bytes = warps_per_block * cluster_size * sizeof(T);
mbarrier::arriveExpectTX(barrier_smem_addr, expected_bytes);
}
}

// Helper function to store warp reduction result to distributed shared memory
template <int warps_per_block, typename T>
__device__ __forceinline__ void storeWarpResult(
T warp_sum,
uint32_t my_block_rank,
uint32_t warp_idx,
uint32_t peer_cta_rank_in_cluster,
T* reduction_buffer,
uint32_t barrier_smem_addr) {
uint32_t buffer_offset = my_block_rank * warps_per_block + warp_idx;
uint32_t buffer_addr = toSmem(&reduction_buffer[buffer_offset]);
storeSharedRemote<T>(
warp_sum, buffer_addr, barrier_smem_addr, peer_cta_rank_in_cluster);
}

// Unified cluster reduction function supporting both all-reduce and reduce
// operations
//
// Template Parameters:
// cluster_size: Number of CTAs in the cluster (e.g., 2, 4, 8)
// warps_per_block: Number of warps per block (e.g. 4, 8, 16)
// is_all_reduce: true for all-reduce (all blocks get result), false for
// reduce (only last block gets result) T: Data type (float, double, etc.)
// Func: Reduction operator (e.g., AddOp, MaxOp)
//
// Algorithm:
// 1. Each warp does a warp reduction
// 2. All warps async store reduction results to its and clustered CTA's shared
// memories
// 3. All warps read from its CTA's shared memory and do a warp reduction
// 1. Each warp performs a warp reduction
// 2. All warps async store reduction results to distributed shared memory
// - If is_all_reduce=true: store to all CTAs in cluster
// - If is_all_reduce=false: store only to the last block in cluster
// 3. Finish reduction with a warp reduction
// - If is_all_reduce=true: all blocks participate and get the result
// - If is_all_reduce=false: only warp-0 in the last block computes the final
// result
//
// Usage Examples:
// // All-reduce: all blocks get the result
// clusterReduce<2, 4, true>(result, input, 0.0f, barrier_addr, buffer,
// AddOp<float>());
//
// // Reduce: only last block gets the result
// clusterReduce<2, 4, false>(result, input, 0.0f, barrier_addr, buffer,
// AddOp<float>());
//
// Requirements:
// - barrier_smem_addr: Initialized mbarrier in shared memory
// - reduction_buffer: Shared memory buffer of size [cluster_size *
// warps_per_block]
//
// TODO: we can represent this cluster reduction in fusion IR after we have new
// parallel types to represent warp reduction.
template <int CLUSTER_SIZE, int WARPS_PER_BLOCK, typename T, typename Func>
template <
int cluster_size,
int warps_per_block,
bool is_all_reduce,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I know the runtime functions don't follow the style guide very well, but this just looks inconsistent as the first two are all caps. Looks like the guide suggests to use the same naming as normal parameters, so we should probably use cluster_size and warps_per_block.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed.

typename T,
typename Func>
__device__ __forceinline__ void clusterReduce(
T& res,
T inp,
Expand All @@ -164,41 +242,60 @@ __device__ __forceinline__ void clusterReduce(
// 1. Perform warp reduction
T warp_sum = warpReduce(thread_val, reduction_op);

// 2. All warps store their results to distributed shared memory
// Each warp uses N threads to write to N CTAs, e.g. thread-i write to CTA-i
// Buffer layout: reduction_buffer[CLUSTER_SIZE][WARPS_PER_BLOCK]
if (warp_idx == 0 && Hopper::electSync(4294967295U)) {
uint32_t expected_bytes = WARPS_PER_BLOCK * CLUSTER_SIZE * sizeof(T);
mbarrier::arriveExpectTX(barrier_smem_addr, expected_bytes);
}
if (lane_idx < CLUSTER_SIZE) {
uint32_t peer_cta_rank_in_cluster = lane_idx;
uint32_t buffer_offset = my_block_rank * WARPS_PER_BLOCK + warp_idx;
uint32_t buffer_addr = toSmem(&reduction_buffer[buffer_offset]);
storeSharedRemote<T>(
warp_sum, buffer_addr, barrier_smem_addr, peer_cta_rank_in_cluster);
// 2. Store warp reduction results to distributed shared memory
if constexpr (is_all_reduce) {
// All-reduce: Each warp uses N threads to write to N CTAs, e.g. thread-i
// write to CTA-i Buffer layout:
// reduction_buffer[cluster_size][warps_per_block]
setupBarrierExpectTX<cluster_size, warps_per_block, T>(
barrier_smem_addr, warp_idx);
if (lane_idx < cluster_size) {
storeWarpResult<warps_per_block>(
warp_sum,
my_block_rank,
warp_idx,
/*peer_cta_rank_in_cluster=*/lane_idx,
reduction_buffer,
barrier_smem_addr);
}
} else {
// Reduce: Each warp selects a thread to store warp reduction result to
// shared memory of the last block in cluster
if (my_block_rank == cluster_size - 1) {
setupBarrierExpectTX<cluster_size, warps_per_block, T>(
barrier_smem_addr, warp_idx);
}
if (Hopper::electSync(4294967295U)) {
storeWarpResult<warps_per_block>(
warp_sum,
my_block_rank,
warp_idx,
/*peer_cta_rank_in_cluster=*/cluster_size - 1,
reduction_buffer,
barrier_smem_addr);
}
}

// mbarrier is not repeatedly used, parity phase is set to 0. Otherwise,
// should flip parity phase, e.g. when used in persistent CTA kernels.
mbarrier::waitParity(barrier_smem_addr, 0);
if constexpr (is_all_reduce) {
// All-reduce: All blocks participate in final reduction
// mbarrier is not repeatedly used, parity phase is set to 0. Otherwise,
// should flip parity phase, e.g. when used in persistent CTA kernels.
mbarrier::waitParity(barrier_smem_addr, 0);

// 3. Each CTA has a copy of the warp reduction results from all warps in the
// cluster
// Finish reduction with a warp reduction
T block_reduce_val = init;
constexpr int num_iter = (WARPS_PER_BLOCK * CLUSTER_SIZE + 31) / 32;
#pragma unroll
for (int i = 0; i < num_iter; i++) {
int idx = lane_idx + i * 32;
if (idx < CLUSTER_SIZE * WARPS_PER_BLOCK) {
reduction_op(block_reduce_val, reduction_buffer[idx]);
// 3. Each CTA has a copy of the warp reduction results from all warps in
// the cluster. Finish reduction with a warp reduction
res = finalBufferReduce<cluster_size, warps_per_block>(
init, reduction_buffer, lane_idx, reduction_op);
} else {
// Reduce: only warp-0 in the last block is required to finish the reduction
if (my_block_rank == cluster_size - 1 && warp_idx == 0) {
mbarrier::waitParity(barrier_smem_addr, 0);
res = finalBufferReduce<cluster_size, warps_per_block>(
init, reduction_buffer, lane_idx, reduction_op);
}
}
// 4. Each CTA performs a warp reduction on its shared memory
// Get final result using warp reduction
res = warpReduce(block_reduce_val, reduction_op);
}

#endif // Arch 90
} // namespace cluster
} // namespace nvf
46 changes: 32 additions & 14 deletions tests/cpp/cluster_runtime_test/cluster_test_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,43 @@ void validateClusterStoreResult(

void validateClusterReduceResult(
at::Tensor input_tensor,
at::Tensor output_tensor) {
const int64_t in_dim = static_cast<int64_t>(input_tensor.dim());
const int64_t out_dim = static_cast<int64_t>(output_tensor.dim());
ASSERT_TRUE(in_dim == out_dim);

const int64_t in_numel = input_tensor.numel();
const int64_t out_numel = output_tensor.numel();
ASSERT_TRUE(in_numel == out_numel);

at::Tensor output_tensor,
bool is_all_reduce,
int threads_per_block) {
auto input_cpu = input_tensor.cpu();
auto output_cpu = output_tensor.cpu();

// Expect every element to be the global sum of the input tensor
// Expect the result to be the global sum of the input tensor
const double expected_scalar = input_cpu.sum().item<double>();
auto expected = at::empty_like(input_cpu);
expected.fill_(expected_scalar);

ASSERT_TRUE(at::allclose(output_cpu, expected, /*rtol=*/1e-6, /*atol=*/1e-7))
<< "Cluster reduce validation failed: output is not the global sum";
if (is_all_reduce) {
// All-reduce: output should be full tensor with every element containing
// the global sum
const int64_t in_dim = static_cast<int64_t>(input_tensor.dim());
const int64_t out_dim = static_cast<int64_t>(output_tensor.dim());
ASSERT_TRUE(in_dim == out_dim);

const int64_t in_numel = input_tensor.numel();
const int64_t out_numel = output_tensor.numel();
ASSERT_TRUE(in_numel == out_numel);

auto expected = at::empty_like(input_cpu);
expected.fill_(expected_scalar);

ASSERT_TRUE(
at::allclose(output_cpu, expected, /*rtol=*/1e-6, /*atol=*/1e-7))
<< "Cluster all-reduce validation failed: output is not the global sum";
} else {
// Reduce: output should be a single scalar containing the global sum
ASSERT_TRUE(output_tensor.numel() == 1)
<< "Reduce output should be a scalar (single element), got "
<< output_tensor.numel() << " elements";

const double actual_scalar = output_cpu.item<double>();
ASSERT_TRUE(std::abs(actual_scalar - expected_scalar) < 1e-7)
<< "Cluster reduce validation failed: output is not the global sum. "
<< "Expected: " << expected_scalar << ", Got: " << actual_scalar;
}
}

} // namespace nvfuser
6 changes: 4 additions & 2 deletions tests/cpp/cluster_runtime_test/cluster_test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ void validateClusterStoreResult(
at::Tensor output_tensor,
int cluster_size);

template <typename T, int BLOCK_SIZE, int CLUSTER_SIZE>
template <typename T, int BLOCK_SIZE, int CLUSTER_SIZE, bool is_all_reduce>
void launchClusterReduceTestKernel(T* input, T* output);

void validateClusterReduceResult(
at::Tensor input_tensor,
at::Tensor output_tensor);
at::Tensor output_tensor,
bool is_all_reduce,
int threads_per_block = 128);

} // namespace nvfuser
Loading