-
Notifications
You must be signed in to change notification settings - Fork 75
Prevent TMA usage for tensors with broadcast dimensions #5735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
!test |
|
Review updated until commit 5bc0cb8 Description
|
| Relevant files | |||
|---|---|---|---|
| Bug fix |
| ||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Early return logic
|
Greptile SummaryPrevents TMA (Tensor Memory Accelerator) usage for input tensors with broadcast dimensions in their logical domain by adding an early validation check in Key changes:
Rationale: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Scheduler as Pointwise Scheduler
participant Heuristics as getPointwiseHeuristics
participant Check as isTvSuitableForTma
participant Bits as getInputBitsPerElement
participant Domain as TensorView Logical Domain
Scheduler->>Heuristics: Request TMA scheduling
Heuristics->>Heuristics: Determine break point
Note over Heuristics: Step 0: break point computed
Heuristics->>Bits: Calculate bits_per_element
Bits->>Check: Check each input TV
Check->>Domain: Get logical domain
Domain-->>Check: Return IterDomains
Check->>Check: Check for broadcast dimensions
alt Has broadcast dimension
Check-->>Bits: Return false (not suitable)
Note over Check: NEW: Reject TVs with<br/>broadcast in logical domain
else No broadcast dimension
Check-->>Bits: Return true (suitable)
end
Bits-->>Heuristics: Return total bits_per_element
alt bits_per_element == 0
Heuristics-->>Scheduler: Return nullptr (no TMA)
Note over Heuristics: NEW: Early validation<br/>before Step 1
else bits_per_element > 0
Heuristics->>Heuristics: Step 1: Compute TMA domain
Heuristics->>Heuristics: Step 2: Determine elements per CTA
Heuristics->>Heuristics: Step 3+: Configure tiles
Heuristics-->>Scheduler: Return TMA params
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 1 comment
| if (bits_per_element == 0) { | ||
| return nullptr; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: redundant check - bits_per_element == 0 is already checked at line 125-127, so this condition will never be true
| if (bits_per_element == 0) { | |
| return nullptr; | |
| } | |
| // bits_per_element already validated at line 125-127 |
naoyam
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When a fusion has a tensor with a broadcast, does that cause the fusion to not use TMA at all? Or is it just the broadcast tensor that is affected?
What about if an input tensor doesn't have a broadcast ID but is broadcast inside the fusion?
| executor_cache.fusion(), out_tensors, {t0, t1}, __LINE__, __FILE__); | ||
| } | ||
|
|
||
| TEST_F(TmaPointwiseTestF, OuterDimOne) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just leave a quick comment what each of the tests is meant to validate?
Just the broadcast tv, added a new test
It's already addressed. Current scheduler checks |
|
!test |
naoyam
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
When input tv has broadcast dimensions, it may cause one of the following issues: