-
Notifications
You must be signed in to change notification settings - Fork 74
Combine for RaggedIterDomain #5716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
|
!test |
We should still be able to validate using the exact graph on the complete fusion. Correct? |
wujingyue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM otherwise
|
|
||
| // The combined extent is the sum of all extents in the ragged dimension | ||
| // For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents) | ||
| TensorView* extents_tv = ragged->extents(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| TensorView* extents_tv = ragged->extents(); | |
| TensorView* extents = ragged->extents(); |
The type already says it. Also, in the context of RaggedIterDomain, extents has to be a TensorView.
| } | ||
|
|
||
| std::string Combine::toInlineString(int indent_size) const { | ||
| NVF_CHECK(false, "Combine can not be printed inline"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not? toString seems to be one line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually am not quite sure why, but our convention is that inline printing seems to be only for scalar values. For example, Split::toInlineString isn't supported either. It isn't just whether it can be printed in a single line. It's more like if it can be recursively called.
Yes. Actually, I'm considering changing this design for supporting multi-dim combine. I realized we would indeed need to know which iter domains correspond to which extent iter domains for partial combine like the shuffle pattern in expert parallelism. Please consider the validation part is a TODO task. I'll be likely to address that in a later PR. |
7f34288 to
d2b5384
Compare
|
!test |
👍 |
This PR introduces the combine operation as discussed in the RaggedIterDomain design doc.
One design decision that I changed from the original design doc is about detecting and validating component iter domains. Previously, I was thinking about using the exact graph to find the corresponding component iter domain for a given ragged iter domain (e.g., #5550 (comment)). However, it won't work, for example, when a fusion is segmented and a segment does not have the corresponding
Partitionexpr for aRaggedIterDomain. For example, when a tensor is used as an input forasNested, followed by some other operations, if the fusion is segmented after some operations, the latter segment won't be able to see theasNestedand thePartitionoperations as they don't exist in the segment. This could be alleviated by providing an exact graph for the whole complete fusion, but more fundamentally, if a fusion has a nested tensor as an input, there doesn't seem to be any reasonable way to attach aPartitionexpr.See doc/dev/ragged_iter_domain_combine_design_doc.md for detailed discussions. At this moment, I decided to not worry too much about the validation and assume the correctness is guaranteed by the user.
Note that partitioning is still limited to 1D extents. Multi-dim offsets will be the next step of this series of RPs.
Tracking which iter domains correspond to which extent iter domains seems to be actually necessary for supporting combine with ragged iter domains produced by multi-dim extent tensors. I'll revisit this as part of multi-dim combine work, but my current plan is to take Option 4 as described in the design doc.