-
Notifications
You must be signed in to change notification settings - Fork 75
Combine for RaggedIterDomain #5716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d87e6d7
77c6a07
f16fc4d
23d55f1
8392332
787dfec
a0b40a3
dbdd917
cdbd81e
d4c8d7f
9575a13
a054ae0
69dbe0f
db3b359
2348dde
7090b9c
b07e285
a2c504b
b1d8cf4
201c148
9e0b161
60a2dd5
566d63d
144b206
e2efe75
0b68d6b
8a73bb2
550e0c5
82bd85e
f215f07
5b99432
2dd9287
c3aebec
a22bb1f
f521c38
8d0d9cb
67aac1b
3a80926
f75ecb6
8aa854e
72ae14f
85d48df
5f86d9c
bf5b627
bec4c09
4d8acab
3b082ba
72dbc41
5002407
be0e2ea
05a6201
d2b5384
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1052,6 +1052,107 @@ std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition( | |||||
| return {component_id, ragged_id}; | ||||||
| } | ||||||
|
|
||||||
| IterDomain* RaggedIterDomain::combine( | ||||||
| IterDomain* component, | ||||||
| RaggedIterDomain* ragged) { | ||||||
| NVF_ERROR(component != nullptr, "combine: component IterDomain is null"); | ||||||
| NVF_ERROR(ragged != nullptr, "combine: ragged IterDomain is null"); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| !component->isA<RaggedIterDomain>(), | ||||||
| "combine: component must be a regular IterDomain, got RaggedIterDomain: ", | ||||||
| component->toString()); | ||||||
|
|
||||||
| // Validate that component and ragged have compatible properties | ||||||
| NVF_ERROR_EQ( | ||||||
| component->getParallelType(), | ||||||
| ParallelType::Serial, | ||||||
| "Combining parallelized IterDomain not supported: ", | ||||||
| component->toString()); | ||||||
|
|
||||||
| NVF_ERROR_EQ( | ||||||
| ragged->getParallelType(), | ||||||
| ParallelType::Serial, | ||||||
| "Combining parallelized RaggedIterDomain not supported: ", | ||||||
| ragged->toString()); | ||||||
|
|
||||||
| NVF_ERROR_EQ( | ||||||
| component->getIterType(), | ||||||
| IterType::Iteration, | ||||||
| "combine: only IterType::Iteration is supported for component, got ", | ||||||
| component->getIterType(), | ||||||
| " for IterDomain: ", | ||||||
| component->toString()); | ||||||
|
|
||||||
| NVF_ERROR_EQ( | ||||||
| ragged->getIterType(), | ||||||
| IterType::Iteration, | ||||||
| "combine: only IterType::Iteration is supported for ragged, got ", | ||||||
| ragged->getIterType(), | ||||||
| " for RaggedIterDomain: ", | ||||||
| ragged->toString()); | ||||||
|
|
||||||
| // Validate component-ragged pairing when Partition definition is available | ||||||
| // (Option 3 of doc/dev/ragged_iter_domain_combine_design_doc.md). | ||||||
| // Only validate when the RaggedIterDomain has a direct Partition definition. | ||||||
| // After propagation (e.g., set() operations), the definition may be nullptr, | ||||||
| // in which case we trust the user to provide the correct component. | ||||||
| if (ragged->definition() != nullptr && | ||||||
| ragged->definition()->isA<Partition>()) { | ||||||
| auto* partition = ragged->definition()->as<Partition>(); | ||||||
| IterDomain* expected_component = partition->component(); | ||||||
|
|
||||||
| NVF_ERROR( | ||||||
| component == expected_component, | ||||||
| "combine: component mismatch. The provided component does not match ", | ||||||
| "the component from the Partition that created this " | ||||||
| "RaggedIterDomain.\n", | ||||||
| " Provided component: ", | ||||||
| component->toString(), | ||||||
| "\n", | ||||||
| " Expected component: ", | ||||||
| expected_component->toString()); | ||||||
| } | ||||||
| // If no Partition definition (after set, in segmented fusion, or external | ||||||
| // input), trust the user and proceed without validation | ||||||
|
|
||||||
| // The combined extent is the sum of all extents in the ragged dimension | ||||||
| // For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents) | ||||||
| TensorView* extents_tv = ragged->extents(); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The type already says it. Also, in the context of RaggedIterDomain, extents has to be a TensorView. |
||||||
| NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null"); | ||||||
|
|
||||||
| // It is still assumed the extents tensor is just 1D | ||||||
| NVF_ERROR_EQ( | ||||||
| std::ranges::distance( | ||||||
| extents_tv->getLogicalDomain() | TensorDomain::kNoReductions), | ||||||
| 1, | ||||||
| "Unexpected rank of extent tensor: ", | ||||||
| extents_tv->toString()); | ||||||
|
|
||||||
| auto container = component->container(); | ||||||
| auto zero = container->zeroVal(DataType::Index); | ||||||
|
|
||||||
| // Create a symbolic extent for the combined IterDomain | ||||||
| // This represents the sum of all ragged extents, i.e., | ||||||
| // sum(extents_tv, {0}). We could use the sum output as the extent | ||||||
| // but we would need to extract the scalar value out of the 0-dim | ||||||
| // tensor. For now, we leave it as a symbolic Val. | ||||||
| Val* combined_extent = | ||||||
| IrBuilder::createInContainer<Val>(container, DataType::Index); | ||||||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| // Create the combined IterDomain with the symbolic extent | ||||||
| IterDomain* combined_id = IterDomainBuilder(zero, combined_extent) | ||||||
| .parallel_type(ParallelType::Serial) | ||||||
| .iter_type(IterType::Iteration) | ||||||
| .build(); | ||||||
|
|
||||||
| // Create the Combine expression linking component + ragged -> combined | ||||||
| IrBuilder::createInContainer<Combine>( | ||||||
| container, combined_id, component, ragged); | ||||||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| return combined_id; | ||||||
| } | ||||||
|
|
||||||
| TensorDomain::TensorDomain( | ||||||
| IrBuilderPasskey passkey, | ||||||
| std::vector<IterDomain*> logical_domain, | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2708,6 +2708,33 @@ std::string Partition::toInlineString(int indent_size) const { | |
|
|
||
| NVFUSER_DEFINE_CLONE_AND_CREATE(Partition) | ||
|
|
||
| Combine::Combine( | ||
| IrBuilderPasskey passkey, | ||
| IterDomain* out, | ||
| IterDomain* component, | ||
wujingyue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| RaggedIterDomain* ragged) | ||
| : Expr(passkey) { | ||
| addOutput(out); | ||
| addInput(component); | ||
| addInput(ragged); | ||
| } | ||
|
|
||
| std::string Combine::toString(int indent_size) const { | ||
| std::stringstream ss; | ||
| ss << "Combine: "; | ||
| ss << "component: " << component()->toString(); | ||
| ss << " + ragged: " << ragged()->toString(); | ||
| ss << " -> " << out()->toString(); | ||
| ss << "\n"; | ||
| return ss.str(); | ||
| } | ||
|
|
||
| std::string Combine::toInlineString(int indent_size) const { | ||
| NVF_CHECK(false, "Combine can not be printed inline"); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not? toString seems to be one line.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually am not quite sure why, but our convention is that inline printing seems to be only for scalar values. For example, |
||
| } | ||
|
|
||
| NVFUSER_DEFINE_CLONE_AND_CREATE(Combine) | ||
|
|
||
| Swizzle::Swizzle( | ||
| IrBuilderPasskey passkey, | ||
| IterDomain* out_x, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.