Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Oct 3, 2025

add non_all_reduce version of cluster reduction

Usage: The 2nd reduction in cross entropy loss to compute log-sum-exp doesn't require all reduce.
Algorithm:
(1) All warps do warp reduce and write result to last block's shared memory
(2) last block waits until all data are received, then its warp-0 finish the reduction.
(3) After reduction warp-0 in last block of this cluster has the valid result.
last block is used instead of first block to keep consistent with grid reduction.

@github-actions
Copy link

github-actions bot commented Oct 3, 2025

Review updated until commit 5c6dd06

Description

  • Add non-all-reduce cluster reduction support

  • Update validation logic for scalar outputs

  • Extend tests for both all-reduce and reduce

  • Refactor cluster reduction kernel for flexibility


Changes walkthrough 📝

Relevant files
Enhancement
4 files
codegen.cpp
Pass is_all_reduce flag to template args                                 
+2/-0     
cluster.cu
Unified clusterReduce with is_all_reduce                                 
+132/-35
cluster_test_kernels.cu
Template kernel for all/reduce modes                                         
+45/-15 
cluster_test_helper.h
Update launch function signature                                                 
+4/-2     
Bug fix
2 files
index.cpp
Remove all-reduce restriction in lowering                               
+1/-5     
kernel_ir.cpp
Remove is_all_reduce assertion constraint                               
+0/-1     
Tests
3 files
cluster_test_helper.cpp
Add is_all_reduce support in validation                                   
+32/-14 
test_cluster_device_func.cpp
Add non-all-reduce test cases                                                       
+51/-6   
test_cluster.cpp
Add SimpleFusionNotAllReduce test                                               
+51/-1   

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Barrier Setup

The setupBarrierExpectTX function is called conditionally in the reduce path based on my_block_rank, but only by warp 0. This could lead to inconsistent barrier initialization if not all participating blocks properly set up the barrier.

  setupBarrierExpectTX<cluster_size, warps_per_block, T>(
      barrier_smem_addr, warp_idx);
}
Output Handling

In the clusterReduceTestKernel, when is_all_reduce is false, only the first thread of the last block writes to output[0]. This assumes that the output tensor has at least one element, but this is not validated.

output[0] = result;

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl liqiangxl marked this pull request as ready for review October 6, 2025 13:02
template <
int CLUSTER_SIZE,
int WARPS_PER_BLOCK,
bool is_all_reduce,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I know the runtime functions don't follow the style guide very well, but this just looks inconsistent as the first two are all caps. Looks like the guide suggests to use the same naming as normal parameters, so we should probably use cluster_size and warps_per_block.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed.

@naoyam
Copy link
Collaborator

naoyam commented Oct 6, 2025

Can you make sure a proper predicate is generated for the output of the reduction? We need to generate a predicate that masks off all blocks except the last one. There's logic for grid reductions, and I think that should just work since this also looks like a grid reduction.

@liqiangxl
Copy link
Collaborator Author

Can you make sure a proper predicate is generated for the output of the reduction? We need to generate a predicate that masks off all blocks except the last one. There's logic for grid reductions, and I think that should just work since this also looks like a grid reduction.

Yes, output is correctly predicated. Initially I used block-0 to finish the reduction, test captures error and I switched to use last block to reuse the predicate.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl liqiangxl merged commit 7e66407 into main Oct 8, 2025
55 checks passed
@liqiangxl liqiangxl deleted the llu/add_non_all_reduce_cluster_reduction branch October 8, 2025 12:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants