-
Notifications
You must be signed in to change notification settings - Fork 68
Fix inlining positions with constrained ops #5317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
!test |
Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
| {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); | ||
| for (const auto loop_id : tv->getLoopDomain()) { | ||
| if (std::ranges::find(all_constrained_ids, loop_id) != | ||
| all_constrained_ids.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to exclude constrained_logical
For a manual scheduling, we could have logical domain and loop domain share IDs and this would artificially exclude that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify your question? Exclude from inlining? Or exclude from uninlinable_ids?
They should be included in all_constrained_ids, so any loop ID, no matter if it's also a logical ID, should be included in uninlinable_ids. So, no matter if it's logical or not, all constrained loop IDs are excluded from inlining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah you are right. For some reason I read getAllValsBetween and figured that the dependencies wouldn't be included.
268 // Grab all values that exist between and including provided
269 // vals. Returned values are topologicaly ordered, and unique.
270 NVF_API static std::vector<Val*> getAllValsBetween(
271 const std::unordered_set<Val*>& dependencies,
272 const std::vector<Val*>& of);
| {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); | ||
| for (const auto loop_id : tv->getLoopDomain()) { | ||
| if (std::ranges::find(all_constrained_ids, loop_id) != | ||
| all_constrained_ids.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah you are right. For some reason I read getAllValsBetween and figured that the dependencies wouldn't be included.
268 // Grab all values that exist between and including provided
269 // vals. Returned values are topologicaly ordered, and unique.
270 NVF_API static std::vector<Val*> getAllValsBetween(
271 const std::unordered_set<Val*>& dependencies,
272 const std::vector<Val*>& of);
This was an oversight when the greedy scheduler was extended with batching. The uninlinable IDs need to be loop IDs, whereas before this PR they stayed at the logical domain. Generated code didn't result in any error because the loop IDs are either parallelized with TIDx or Group, but still limiting the inlining position to the left of constrained IDs should make more sense.