diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 58e8f1ee000..ea13cd111e4 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -4508,6 +4508,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { .append(std::to_string(blocks_per_cluster)); template_args.arg("/*warps_per_block=*/") .append(std::to_string(warps_per_block)); + template_args.arg("/*is_all_reduce=*/") + .append(cluster_reduction->isAllreduce() ? "true" : "false"); ArgumentBuilder func_args; func_args.arg(gen(output)); diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index fdee8ac3925..e208778ece6 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -675,10 +675,6 @@ void IndexLowering::handleClusterReduction( Val* in) { NVF_ERROR(ir_utils::isTvOp(rop)); - // cluster reduction is only supported for all-reduce - NVF_ERROR( - rop->isAllreduce(), "Cluster reduction is only supported for all-reduce"); - // Get mbarrier allocated during allocation pass auto cluster_mbarrier_tv = GpuLower::current()->clusterReductionMBarrier(); NVF_CHECK( @@ -701,7 +697,7 @@ void IndexLowering::handleClusterReduction( rop->getReductionOpType(), lowerSrcIndex(rop->init(), rop->out()), mbarrier_addr, - /*is_all_reduce=*/true); + rop->isAllreduce()); pushBack(cluster_reduction); GpuLower::current()->propagateExprInfo(rop, back()); diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index cbdf436653e..4c138610fa6 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -2108,7 +2108,6 @@ ClusterReductionOp::ClusterReductionOp( NVF_ERROR( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); - NVF_ERROR(is_all_reduce, "ClusterReductionOp only supports all-reduce"); addInput(mbarrier); } diff --git a/runtime/cluster.cu b/runtime/cluster.cu index 6954c199de6..606885784cb 100644 --- a/runtime/cluster.cu +++ b/runtime/cluster.cu @@ -139,15 +139,93 @@ __device__ __forceinline__ T warpReduce(T val, Func reduction_op) { return reduce_val; } -// Cluster reduction in x direction +// Helper function to perform final reduction from shared memory buffer +template +__device__ __forceinline__ T finalBufferReduce( + T init, + T* reduction_buffer, + uint32_t lane_idx, + Func reduction_op) { + T block_reduce_val = init; + constexpr int num_iter = (warps_per_block * cluster_size + 31) / 32; +#pragma unroll + for (int i = 0; i < num_iter; i++) { + int idx = lane_idx + i * 32; + if (idx < cluster_size * warps_per_block) { + reduction_op(block_reduce_val, reduction_buffer[idx]); + } + } + return warpReduce(block_reduce_val, reduction_op); +} + +// Helper function to setup barrier with expected transfer bytes +template +__device__ __forceinline__ void setupBarrierExpectTX( + uint32_t barrier_smem_addr, + uint32_t warp_idx) { + if (warp_idx == 0 && Hopper::electSync(4294967295U)) { + uint32_t expected_bytes = warps_per_block * cluster_size * sizeof(T); + mbarrier::arriveExpectTX(barrier_smem_addr, expected_bytes); + } +} + +// Helper function to store warp reduction result to distributed shared memory +template +__device__ __forceinline__ void storeWarpResult( + T warp_sum, + uint32_t my_block_rank, + uint32_t warp_idx, + uint32_t peer_cta_rank_in_cluster, + T* reduction_buffer, + uint32_t barrier_smem_addr) { + uint32_t buffer_offset = my_block_rank * warps_per_block + warp_idx; + uint32_t buffer_addr = toSmem(&reduction_buffer[buffer_offset]); + storeSharedRemote( + warp_sum, buffer_addr, barrier_smem_addr, peer_cta_rank_in_cluster); +} + +// Unified cluster reduction function supporting both all-reduce and reduce +// operations +// +// Template Parameters: +// cluster_size: Number of CTAs in the cluster (e.g., 2, 4, 8) +// warps_per_block: Number of warps per block (e.g. 4, 8, 16) +// is_all_reduce: true for all-reduce (all blocks get result), false for +// reduce (only last block gets result) T: Data type (float, double, etc.) +// Func: Reduction operator (e.g., AddOp, MaxOp) +// // Algorithm: -// 1. Each warp does a warp reduction -// 2. All warps async store reduction results to its and clustered CTA's shared -// memories -// 3. All warps read from its CTA's shared memory and do a warp reduction +// 1. Each warp performs a warp reduction +// 2. All warps async store reduction results to distributed shared memory +// - If is_all_reduce=true: store to all CTAs in cluster +// - If is_all_reduce=false: store only to the last block in cluster +// 3. Finish reduction with a warp reduction +// - If is_all_reduce=true: all blocks participate and get the result +// - If is_all_reduce=false: only warp-0 in the last block computes the final +// result +// +// Usage Examples: +// // All-reduce: all blocks get the result +// clusterReduce<2, 4, true>(result, input, 0.0f, barrier_addr, buffer, +// AddOp()); +// +// // Reduce: only last block gets the result +// clusterReduce<2, 4, false>(result, input, 0.0f, barrier_addr, buffer, +// AddOp()); +// +// Requirements: +// - barrier_smem_addr: Initialized mbarrier in shared memory +// - reduction_buffer: Shared memory buffer of size [cluster_size * +// warps_per_block] +// // TODO: we can represent this cluster reduction in fusion IR after we have new // parallel types to represent warp reduction. -template +template < + int cluster_size, + int warps_per_block, + bool is_all_reduce, + typename T, + typename Func> __device__ __forceinline__ void clusterReduce( T& res, T inp, @@ -164,41 +242,60 @@ __device__ __forceinline__ void clusterReduce( // 1. Perform warp reduction T warp_sum = warpReduce(thread_val, reduction_op); - // 2. All warps store their results to distributed shared memory - // Each warp uses N threads to write to N CTAs, e.g. thread-i write to CTA-i - // Buffer layout: reduction_buffer[CLUSTER_SIZE][WARPS_PER_BLOCK] - if (warp_idx == 0 && Hopper::electSync(4294967295U)) { - uint32_t expected_bytes = WARPS_PER_BLOCK * CLUSTER_SIZE * sizeof(T); - mbarrier::arriveExpectTX(barrier_smem_addr, expected_bytes); - } - if (lane_idx < CLUSTER_SIZE) { - uint32_t peer_cta_rank_in_cluster = lane_idx; - uint32_t buffer_offset = my_block_rank * WARPS_PER_BLOCK + warp_idx; - uint32_t buffer_addr = toSmem(&reduction_buffer[buffer_offset]); - storeSharedRemote( - warp_sum, buffer_addr, barrier_smem_addr, peer_cta_rank_in_cluster); + // 2. Store warp reduction results to distributed shared memory + if constexpr (is_all_reduce) { + // All-reduce: Each warp uses N threads to write to N CTAs, e.g. thread-i + // write to CTA-i Buffer layout: + // reduction_buffer[cluster_size][warps_per_block] + setupBarrierExpectTX( + barrier_smem_addr, warp_idx); + if (lane_idx < cluster_size) { + storeWarpResult( + warp_sum, + my_block_rank, + warp_idx, + /*peer_cta_rank_in_cluster=*/lane_idx, + reduction_buffer, + barrier_smem_addr); + } + } else { + // Reduce: Each warp selects a thread to store warp reduction result to + // shared memory of the last block in cluster + if (my_block_rank == cluster_size - 1) { + setupBarrierExpectTX( + barrier_smem_addr, warp_idx); + } + if (Hopper::electSync(4294967295U)) { + storeWarpResult( + warp_sum, + my_block_rank, + warp_idx, + /*peer_cta_rank_in_cluster=*/cluster_size - 1, + reduction_buffer, + barrier_smem_addr); + } } - // mbarrier is not repeatedly used, parity phase is set to 0. Otherwise, - // should flip parity phase, e.g. when used in persistent CTA kernels. - mbarrier::waitParity(barrier_smem_addr, 0); + if constexpr (is_all_reduce) { + // All-reduce: All blocks participate in final reduction + // mbarrier is not repeatedly used, parity phase is set to 0. Otherwise, + // should flip parity phase, e.g. when used in persistent CTA kernels. + mbarrier::waitParity(barrier_smem_addr, 0); - // 3. Each CTA has a copy of the warp reduction results from all warps in the - // cluster - // Finish reduction with a warp reduction - T block_reduce_val = init; - constexpr int num_iter = (WARPS_PER_BLOCK * CLUSTER_SIZE + 31) / 32; -#pragma unroll - for (int i = 0; i < num_iter; i++) { - int idx = lane_idx + i * 32; - if (idx < CLUSTER_SIZE * WARPS_PER_BLOCK) { - reduction_op(block_reduce_val, reduction_buffer[idx]); + // 3. Each CTA has a copy of the warp reduction results from all warps in + // the cluster. Finish reduction with a warp reduction + res = finalBufferReduce( + init, reduction_buffer, lane_idx, reduction_op); + } else { + // Reduce: only warp-0 in the last block is required to finish the reduction + if (my_block_rank == cluster_size - 1 && warp_idx == 0) { + mbarrier::waitParity(barrier_smem_addr, 0); + res = finalBufferReduce( + init, reduction_buffer, lane_idx, reduction_op); } } - // 4. Each CTA performs a warp reduction on its shared memory - // Get final result using warp reduction - res = warpReduce(block_reduce_val, reduction_op); } + #endif // Arch 90 } // namespace cluster } // namespace nvf diff --git a/tests/cpp/cluster_runtime_test/cluster_test_helper.cpp b/tests/cpp/cluster_runtime_test/cluster_test_helper.cpp index 58043929052..0792607910e 100644 --- a/tests/cpp/cluster_runtime_test/cluster_test_helper.cpp +++ b/tests/cpp/cluster_runtime_test/cluster_test_helper.cpp @@ -52,25 +52,43 @@ void validateClusterStoreResult( void validateClusterReduceResult( at::Tensor input_tensor, - at::Tensor output_tensor) { - const int64_t in_dim = static_cast(input_tensor.dim()); - const int64_t out_dim = static_cast(output_tensor.dim()); - ASSERT_TRUE(in_dim == out_dim); - - const int64_t in_numel = input_tensor.numel(); - const int64_t out_numel = output_tensor.numel(); - ASSERT_TRUE(in_numel == out_numel); - + at::Tensor output_tensor, + bool is_all_reduce, + int threads_per_block) { auto input_cpu = input_tensor.cpu(); auto output_cpu = output_tensor.cpu(); - // Expect every element to be the global sum of the input tensor + // Expect the result to be the global sum of the input tensor const double expected_scalar = input_cpu.sum().item(); - auto expected = at::empty_like(input_cpu); - expected.fill_(expected_scalar); - ASSERT_TRUE(at::allclose(output_cpu, expected, /*rtol=*/1e-6, /*atol=*/1e-7)) - << "Cluster reduce validation failed: output is not the global sum"; + if (is_all_reduce) { + // All-reduce: output should be full tensor with every element containing + // the global sum + const int64_t in_dim = static_cast(input_tensor.dim()); + const int64_t out_dim = static_cast(output_tensor.dim()); + ASSERT_TRUE(in_dim == out_dim); + + const int64_t in_numel = input_tensor.numel(); + const int64_t out_numel = output_tensor.numel(); + ASSERT_TRUE(in_numel == out_numel); + + auto expected = at::empty_like(input_cpu); + expected.fill_(expected_scalar); + + ASSERT_TRUE( + at::allclose(output_cpu, expected, /*rtol=*/1e-6, /*atol=*/1e-7)) + << "Cluster all-reduce validation failed: output is not the global sum"; + } else { + // Reduce: output should be a single scalar containing the global sum + ASSERT_TRUE(output_tensor.numel() == 1) + << "Reduce output should be a scalar (single element), got " + << output_tensor.numel() << " elements"; + + const double actual_scalar = output_cpu.item(); + ASSERT_TRUE(std::abs(actual_scalar - expected_scalar) < 1e-7) + << "Cluster reduce validation failed: output is not the global sum. " + << "Expected: " << expected_scalar << ", Got: " << actual_scalar; + } } } // namespace nvfuser diff --git a/tests/cpp/cluster_runtime_test/cluster_test_helper.h b/tests/cpp/cluster_runtime_test/cluster_test_helper.h index b5f98c02d4e..2a1de1b1e5c 100644 --- a/tests/cpp/cluster_runtime_test/cluster_test_helper.h +++ b/tests/cpp/cluster_runtime_test/cluster_test_helper.h @@ -22,11 +22,13 @@ void validateClusterStoreResult( at::Tensor output_tensor, int cluster_size); -template +template void launchClusterReduceTestKernel(T* input, T* output); void validateClusterReduceResult( at::Tensor input_tensor, - at::Tensor output_tensor); + at::Tensor output_tensor, + bool is_all_reduce, + int threads_per_block = 128); } // namespace nvfuser diff --git a/tests/cpp/cluster_runtime_test/cluster_test_kernels.cu b/tests/cpp/cluster_runtime_test/cluster_test_kernels.cu index c7832fec338..b3f4e994285 100644 --- a/tests/cpp/cluster_runtime_test/cluster_test_kernels.cu +++ b/tests/cpp/cluster_runtime_test/cluster_test_kernels.cu @@ -85,7 +85,12 @@ struct AddOp { }; // Reduce BLOCK_SIZE x CLUSTER_SIZE values across a cluster of CLUSTER_SIZE CTAs -template +template < + typename T, + int BLOCK_SIZE, + int CLUSTER_SIZE, + int WARPS_PER_BLOCK, + bool is_all_reduce> __global__ void clusterReduceTestKernel(T* input, T* output) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -101,16 +106,27 @@ __global__ void clusterReduceTestKernel(T* input, T* output) { nvf::cluster::clusterSync(); T result; - nvf::cluster::clusterReduce>( - result, - value, - static_cast(0), - mbarrier_addr, - reduction_buffer, - AddOp()); - - // After clusterReduce, each thread in each block has the same reduced value - output[global_tid] = result; + nvf::cluster:: + clusterReduce>( + result, + value, + static_cast(0), + mbarrier_addr, + reduction_buffer, + AddOp()); + + if constexpr (is_all_reduce) { + // All-reduce: each thread writes the result to its corresponding output + // element + output[global_tid] = result; + } else { + // Reduce: only the first thread of the last block writes the scalar result + constexpr uint32_t last_block_rank = CLUSTER_SIZE - 1; + uint32_t my_block_rank = nvf::cluster::blockIdInCluster().x; + if (my_block_rank == last_block_rank && threadIdx.x == 0) { + output[0] = result; + } + } #endif } @@ -141,7 +157,7 @@ void launchStoreSharedRemoteTestKernel(T* input, T* output) { output)); } -template +template void launchClusterReduceTestKernel(T* input, T* output) { constexpr int WARPS_PER_BLOCK = (BLOCK_SIZE + 31) / 32; @@ -157,9 +173,15 @@ void launchClusterReduceTestKernel(T* input, T* output) { config.attrs = &cluster_attr; config.numAttrs = 1; + // Use the unified kernel for both all-reduce and reduce cases NVFUSER_CUDA_RT_SAFE_CALL(cudaLaunchKernelEx( &config, - clusterReduceTestKernel, + clusterReduceTestKernel< + T, + BLOCK_SIZE, + CLUSTER_SIZE, + WARPS_PER_BLOCK, + is_all_reduce>, input, output)); } @@ -173,11 +195,19 @@ template void launchStoreSharedRemoteTestKernel( double* input, double* output); -template void launchClusterReduceTestKernel( +template void launchClusterReduceTestKernel( + float* input, + float* output); + +template void launchClusterReduceTestKernel( + double* input, + double* output); + +template void launchClusterReduceTestKernel( float* input, float* output); -template void launchClusterReduceTestKernel( +template void launchClusterReduceTestKernel( double* input, double* output); } // namespace nvfuser diff --git a/tests/cpp/cluster_runtime_test/test_cluster_device_func.cpp b/tests/cpp/cluster_runtime_test/test_cluster_device_func.cpp index b669aeb8c03..782d497881a 100644 --- a/tests/cpp/cluster_runtime_test/test_cluster_device_func.cpp +++ b/tests/cpp/cluster_runtime_test/test_cluster_device_func.cpp @@ -69,7 +69,7 @@ TEST_F(ClusterDeviceFuncTest, BasicStoreSharedRemoteDouble) { } // Cluster reduction test for float -TEST_F(ClusterDeviceFuncTest, ClusterReduceFloat) { +TEST_F(ClusterDeviceFuncTest, ClusterReduceFloatAllReduce) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); constexpr int num_blocks = 2; constexpr int threads_per_block = 128; @@ -84,14 +84,14 @@ TEST_F(ClusterDeviceFuncTest, ClusterReduceFloat) { {total_elements}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0)); - launchClusterReduceTestKernel( + launchClusterReduceTestKernel( input_tensor.data_ptr(), output_tensor.data_ptr()); - validateClusterReduceResult(input_tensor, output_tensor); + validateClusterReduceResult(input_tensor, output_tensor, true); } // Cluster reduction test for double -TEST_F(ClusterDeviceFuncTest, ClusterReduceDouble) { +TEST_F(ClusterDeviceFuncTest, ClusterReduceDoubleAllReduce) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); constexpr int num_blocks = 2; constexpr int threads_per_block = 128; @@ -106,9 +106,54 @@ TEST_F(ClusterDeviceFuncTest, ClusterReduceDouble) { {total_elements}, at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0)); - launchClusterReduceTestKernel( + launchClusterReduceTestKernel( input_tensor.data_ptr(), output_tensor.data_ptr()); - validateClusterReduceResult(input_tensor, output_tensor); + validateClusterReduceResult(input_tensor, output_tensor, true); } + +// Cluster reduction test for float - returns single scalar +TEST_F(ClusterDeviceFuncTest, ClusterReduceFloatNotAllReduce) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + constexpr int num_blocks = 2; + constexpr int threads_per_block = 128; + constexpr int total_elements = num_blocks * threads_per_block; + + auto input_tensor = at::arange( + 1.0f, + total_elements + 1.0f, + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0)); + + // Output is now a single scalar value + auto output_scalar = at::empty( + {1}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0)); + + launchClusterReduceTestKernel( + input_tensor.data_ptr(), output_scalar.data_ptr()); + + validateClusterReduceResult(input_tensor, output_scalar, false); +} + +// Cluster reduction test for double - returns single scalar +TEST_F(ClusterDeviceFuncTest, ClusterReduceDoubleNotAllReduce) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + constexpr int num_blocks = 2; + constexpr int threads_per_block = 128; + constexpr int total_elements = num_blocks * threads_per_block; + + auto input_tensor = at::arange( + 1.0, + total_elements + 1.0, + at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0)); + + // Output is now a single scalar value + auto output_scalar = at::empty( + {1}, at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0)); + + launchClusterReduceTestKernel( + input_tensor.data_ptr(), output_scalar.data_ptr()); + + validateClusterReduceResult(input_tensor, output_scalar, false); +} + } // namespace nvfuser diff --git a/tests/cpp/test_cluster.cpp b/tests/cpp/test_cluster.cpp index 2f77b4f4c2c..c9e211d6392 100644 --- a/tests/cpp/test_cluster.cpp +++ b/tests/cpp/test_cluster.cpp @@ -32,7 +32,7 @@ class ClusterReductionTest } }; -TEST_P(ClusterReductionTest, ManualScheduledSimpleFusion) { +TEST_P(ClusterReductionTest, SimpleFusionAllReduce) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -86,6 +86,56 @@ TEST_P(ClusterReductionTest, ManualScheduledSimpleFusion) { EXPECT_TRUE(ke.compiledKernel()->kernel()->summary().has_cluster_reduction); testValidate(&unscheduled_fusion_copy, outputs, {t0}); } + +TEST_P(ClusterReductionTest, SimpleFusionNotAllReduce) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + const int64_t vect = 2, bdimx = 128, serial = 2; + const auto [cluster_size, dtype] = GetParam(); + const int64_t reduction_size = vect * bdimx * serial * cluster_size; + auto tv0 = makeContigTensor(2, dtype); + fusion.addInput(tv0); + auto tv1 = set(tv0); + const DataType compute_dtype = + (dtype == DataType::Double) ? DataType::Double : DataType::Float; + tv1 = maybeCastOp(compute_dtype, tv1); + auto tv2 = sum(tv1, {1}); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + auto unscheduled_fusion_copy = fusion; + // [I, R] + tv2->split(1, vect); + // [I, R/vect, vect] + tv2->split(1, bdimx); + // [I, R/vect/bdimx, bdimx, vect] + tv2->split(1, serial, false); + // [I, serial, R/vect/bdimx/serial, bdimx, vect] + // [BIDy, Serial, BIDx(cluster), TIDx, Vectorize or Serial] + tv2->axis(-2)->parallelize(ParallelType::TIDx); + tv2->axis(-3)->parallelize(ParallelType::BIDx); + // set clustered blocks to use cluster reduction + tv2->axis(-3)->setClusteredBlocks(); + tv2->axis(0)->parallelize(ParallelType::BIDy); + + auto reference = tv2->rFactor({-1, -4}); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + scheduler_utils::parallelizeAllLike(reference); + + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t0 = at::randn({256, reduction_size}, options); + + KernelExecutor ke; + ke.compile(fusion_ptr.get(), {t0}); + auto outputs = ke.run({t0}); + EXPECT_TRUE(ke.compiledKernel()->kernel()->summary().has_cluster_reduction); + testValidate(&unscheduled_fusion_copy, outputs, {t0}); +} INSTANTIATE_TEST_SUITE_P( , ClusterReductionTest,