-
Notifications
You must be signed in to change notification settings - Fork 69
add non_all_reduce version of cluster reduction #5319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Review updated until commit 5c6dd06 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test |
|
!test |
|
!test |
|
!test |
| template < | ||
| int CLUSTER_SIZE, | ||
| int WARPS_PER_BLOCK, | ||
| bool is_all_reduce, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I know the runtime functions don't follow the style guide very well, but this just looks inconsistent as the first two are all caps. Looks like the guide suggests to use the same naming as normal parameters, so we should probably use cluster_size and warps_per_block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed.
|
Can you make sure a proper predicate is generated for the output of the reduction? We need to generate a predicate that masks off all blocks except the last one. There's logic for grid reductions, and I think that should just work since this also looks like a grid reduction. |
Yes, output is correctly predicated. Initially I used |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
!test |
add non_all_reduce version of cluster reduction
Usage: The 2nd reduction in cross entropy loss to compute
log-sum-expdoesn't require all reduce.Algorithm:
(1) All warps do warp reduce and write result to
last block's shared memory(2)
last blockwaits until all data are received, then itswarp-0finish the reduction.(3) After reduction
warp-0inlast blockof this cluster has the valid result.last blockis used instead offirst blockto keep consistent with grid reduction.