-
Notifications
You must be signed in to change notification settings - Fork 70
tma pointwise with broadcast #5555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: llu/pt3_auto2_bcast_utils_mechanical
Are you sure you want to change the base?
tma pointwise with broadcast #5555
Conversation
|
Review updated until commit ab08df1 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Break point validation
n_valid_dims >= 2 when break_point != 0, but this validation might be too strict or might not cover all valid cases. Need to verify this validation logic handles all expected scenarios correctly, especially for edge cases with broadcast dimensions. |
9447394 to
83a809e
Compare
|
!test |
1 similar comment
|
!test |
Greptile OverviewGreptile SummaryThis PR extends the TMA pointwise scheduler to handle broadcast domains by implementing a break point selection mechanism.
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant H as getPointwiseHeuristics
participant BP as getBreakPoint
participant S as schedulePointwise
participant C as commonPointwiseSchedule
H->>BP: getBreakPoint(fusion, prop, is_tma=true)
BP-->>H: BreakPointInfo (break_point, flip_grid_binding)
H->>H: Compute tma_domain_inner/outer
H->>H: Compute tma_tile_inner/outer
H-->>S: PointwiseParams (break_point, tma_tile_*)
S->>C: commonPointwiseSchedule(fusion, break_point)
alt break_point == 0
C->>C: Flatten all dims to [I]
else break_point > 0
C->>C: Merge to [lhs, rhs] structure
end
C-->>S: CommonScheduleInfo (reference_tv)
alt break_point == 0
S->>S: reference_tv->split(0, tma_domain_inner)
else break_point > 0
S->>S: Validate n_valid_dims >= 2
end
S->>S: Split for TMA tiles
S->>S: Parallelize TMA tensors (Bulk)
S->>S: Schedule non-TMA tensors
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
csrc/scheduler/pointwise_tma.cpp
Outdated
| // dimension. outer tile size: don't exceed the outer TMA dimension size Both | ||
| // Both are subject to hardware constraints of 256 elements per dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: comment has formatting issue - appears to be an incomplete edit
| // dimension. outer tile size: don't exceed the outer TMA dimension size Both | |
| // Both are subject to hardware constraints of 256 elements per dimension. | |
| // - Inner tile size: ensure at least 2 tiles in the inner TMA dimension | |
| // - Outer tile size: don't exceed the outer TMA dimension size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 4 comments
| if (pparams->break_point == 0) { | ||
| reference_tv->split(0, pparams->tma_domain_inner); | ||
| } else { | ||
| NVF_ERROR( | ||
| n_valid_dims >= 2, | ||
| "Required at least 2 valid dimensions for Tma scheduling, but got ", | ||
| n_valid_dims); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: when break_point > 0, only dimension validation is performed but no actual transformation. Verify the subsequent splits at lines 349-350 work correctly with the [lhs, rhs] structure from commonPointwiseSchedule
csrc/scheduler/pointwise_tma.cpp
Outdated
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: extra blank line
| result.flip_grid_binding = false; | ||
| } | ||
| } else { | ||
| // If TMA is used, priorize break if it saves transfered size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: 'priorize' should be 'prioritize'
| // If TMA is used, priorize break if it saves transfered size | |
| // If TMA is used, prioritize break if it saves transfered size |
| } | ||
| } else { | ||
| // If TMA is used, priorize break if it saves transfered size | ||
| // This ensures we break at broadcast dimensions, then we can optionally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: 'transfered' should be 'transferred'
| // This ensures we break at broadcast dimensions, then we can optionally | |
| // This ensures we break at broadcast dimensions, then we can optionally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
22a4ced to
350b75a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Move break point calculation logic into pointwise_utils: - Add pointwise_utils::getBreakPoint() to calculate optimal 2D break point - Add pointwise_utils::getBlockGridConfig() to compute block/grid dimensions - Refactor pointwise_non_tma.cpp to use new utility functions - Add BreakPointInfo and BlockGridConfig structs for return values This is a mechanical refactoring with no functional changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
61412c0 to
01924fa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Following #5553, part of #5366
TMA Pointwise Scheduler: Broadcast Domain Handling
(1) TMA Load vs. General Load (ldg/ld.global)
The current TMA pointwise scheduler does not use TMA load for inputs with concretized broadcast domains.
Example:
Given three inputs:
tv0[I1, B2] + tv1[B1, I2] + tv2[I1, I2], whereB1andB2are broadcast domains concretized toI1andI2:tv0andtv1will NOT be loaded with TMANote: This is a performance optimization. These inputs can be loaded with TMA, but only using a one-dimensional tile, as demonstrated in the newly added test.
(2) Break Point Selection
When broadcasts are present, the loop domain of the reference tv is merged to
[lhs, rhs]instead of flattening to a single dimension.Example:
Given
tv0[I1, B2] + tv1[B1, I2] + tv2[I1, I2], the break point is atpos-1, which separates[I1, I2]into[lhs, rhs].Break point selection differs between TMA and non-TMA versions:
Rationale: In the TMA version, we cannot safely merge broadcast and non-broadcast domains when creating 2D TMA domains and schedules, so we always break when broadcasts are present. See restrictions at #5556