Skip to content

Commit 352dcbf

Browse files
authored
Combine for RaggedIterDomain (#5716)
This PR introduces the combine operation as discussed in the RaggedIterDomain [design doc](https://github.com/NVIDIA/Fuser/blob/main/doc/dev/ragged_iter_domain_design_doc.md). One design decision that I changed from the original design doc is about detecting and validating component iter domains. Previously, I was thinking about using the exact graph to find the corresponding component iter domain for a given ragged iter domain (e.g., #5550 (comment)). However, it won't work, for example, when a fusion is segmented and a segment does not have the corresponding `Partition` expr for a `RaggedIterDomain`. For example, when a tensor is used as an input for `asNested`, followed by some other operations, if the fusion is segmented after some operations, the latter segment won't be able to see the `asNested` and the `Partition` operations as they don't exist in the segment. This could be alleviated by providing an exact graph for the whole complete fusion, but more fundamentally, if a fusion has a nested tensor as an input, there doesn't seem to be any reasonable way to attach a `Partition` expr. See doc/dev/ragged_iter_domain_combine_design_doc.md‎ for detailed discussions. At this moment, I decided to not worry too much about the validation and assume the correctness is guaranteed by the user. Note that partitioning is still limited to 1D extents. Multi-dim offsets will be the next step of this series of RPs. * Update * Tracking which iter domains correspond to which extent iter domains seems to be actually necessary for supporting combine with ragged iter domains produced by multi-dim extent tensors. I'll revisit this as part of multi-dim combine work, but my current plan is to take Option 4 as described in the design doc.
1 parent 36c6cfa commit 352dcbf

File tree

7 files changed

+733
-0
lines changed

7 files changed

+733
-0
lines changed

csrc/dispatch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class Val;
116116
f(ScanOp); \
117117
f(Merge); \
118118
f(Partition); \
119+
f(Combine); \
119120
f(Swizzle); \
120121
f(Swizzle2D); \
121122
f(Resize); \

csrc/ir/internal_base_nodes.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,107 @@ std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition(
10521052
return {component_id, ragged_id};
10531053
}
10541054

1055+
IterDomain* RaggedIterDomain::combine(
1056+
IterDomain* component,
1057+
RaggedIterDomain* ragged) {
1058+
NVF_ERROR(component != nullptr, "combine: component IterDomain is null");
1059+
NVF_ERROR(ragged != nullptr, "combine: ragged IterDomain is null");
1060+
1061+
NVF_ERROR(
1062+
!component->isA<RaggedIterDomain>(),
1063+
"combine: component must be a regular IterDomain, got RaggedIterDomain: ",
1064+
component->toString());
1065+
1066+
// Validate that component and ragged have compatible properties
1067+
NVF_ERROR_EQ(
1068+
component->getParallelType(),
1069+
ParallelType::Serial,
1070+
"Combining parallelized IterDomain not supported: ",
1071+
component->toString());
1072+
1073+
NVF_ERROR_EQ(
1074+
ragged->getParallelType(),
1075+
ParallelType::Serial,
1076+
"Combining parallelized RaggedIterDomain not supported: ",
1077+
ragged->toString());
1078+
1079+
NVF_ERROR_EQ(
1080+
component->getIterType(),
1081+
IterType::Iteration,
1082+
"combine: only IterType::Iteration is supported for component, got ",
1083+
component->getIterType(),
1084+
" for IterDomain: ",
1085+
component->toString());
1086+
1087+
NVF_ERROR_EQ(
1088+
ragged->getIterType(),
1089+
IterType::Iteration,
1090+
"combine: only IterType::Iteration is supported for ragged, got ",
1091+
ragged->getIterType(),
1092+
" for RaggedIterDomain: ",
1093+
ragged->toString());
1094+
1095+
// Validate component-ragged pairing when Partition definition is available
1096+
// (Option 3 of doc/dev/ragged_iter_domain_combine_design_doc.md).
1097+
// Only validate when the RaggedIterDomain has a direct Partition definition.
1098+
// After propagation (e.g., set() operations), the definition may be nullptr,
1099+
// in which case we trust the user to provide the correct component.
1100+
if (ragged->definition() != nullptr &&
1101+
ragged->definition()->isA<Partition>()) {
1102+
auto* partition = ragged->definition()->as<Partition>();
1103+
IterDomain* expected_component = partition->component();
1104+
1105+
NVF_ERROR(
1106+
component == expected_component,
1107+
"combine: component mismatch. The provided component does not match ",
1108+
"the component from the Partition that created this "
1109+
"RaggedIterDomain.\n",
1110+
" Provided component: ",
1111+
component->toString(),
1112+
"\n",
1113+
" Expected component: ",
1114+
expected_component->toString());
1115+
}
1116+
// If no Partition definition (after set, in segmented fusion, or external
1117+
// input), trust the user and proceed without validation
1118+
1119+
// The combined extent is the sum of all extents in the ragged dimension
1120+
// For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents)
1121+
TensorView* extents_tv = ragged->extents();
1122+
NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null");
1123+
1124+
// It is still assumed the extents tensor is just 1D
1125+
NVF_ERROR_EQ(
1126+
std::ranges::distance(
1127+
extents_tv->getLogicalDomain() | TensorDomain::kNoReductions),
1128+
1,
1129+
"Unexpected rank of extent tensor: ",
1130+
extents_tv->toString());
1131+
1132+
auto container = component->container();
1133+
auto zero = container->zeroVal(DataType::Index);
1134+
1135+
// Create a symbolic extent for the combined IterDomain
1136+
// This represents the sum of all ragged extents, i.e.,
1137+
// sum(extents_tv, {0}). We could use the sum output as the extent
1138+
// but we would need to extract the scalar value out of the 0-dim
1139+
// tensor. For now, we leave it as a symbolic Val.
1140+
Val* combined_extent =
1141+
IrBuilder::createInContainer<Val>(container, DataType::Index);
1142+
1143+
// Create the combined IterDomain with the symbolic extent
1144+
IterDomain* combined_id = IterDomainBuilder(zero, combined_extent)
1145+
.parallel_type(ParallelType::Serial)
1146+
.iter_type(IterType::Iteration)
1147+
.build();
1148+
1149+
// Create the Combine expression linking component + ragged -> combined
1150+
IrBuilder::createInContainer<Combine>(
1151+
container, combined_id, component, ragged);
1152+
1153+
return combined_id;
1154+
}
1155+
10551156
TensorDomain::TensorDomain(
10561157
IrBuilderPasskey passkey,
10571158
std::vector<IterDomain*> logical_domain,

csrc/ir/internal_base_nodes.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,22 @@ class NVF_API RaggedIterDomain : public IterDomain {
499499
IterDomain* in,
500500
TensorView* extents);
501501

502+
//! Combine a component IterDomain with a RaggedIterDomain to flatten
503+
//! This is the inverse of partition, creating a regular IterDomain
504+
//!
505+
//! \param component Component IterDomain (extent = num_components)
506+
//! \param ragged RaggedIterDomain with variable extents per component
507+
//! \return Regular IterDomain with extent = sum of all component extents
508+
//!
509+
//! This operation flattens the ragged structure back into a single dimension.
510+
//! Example: component extent=3, ragged extents=[127, 0, 198]
511+
//! -> output extent = 325 (= 127 + 0 + 198)
512+
//!
513+
//! Note: We use "combine" instead of "merge" to differentiate from the
514+
//! regular IterDomain::merge operation which only works with regular
515+
//! IterDomains.
516+
static IterDomain* combine(IterDomain* component, RaggedIterDomain* ragged);
517+
502518
//! Override cloneWithoutRFactor to preserve RaggedIterDomain type
503519
IterDomain* cloneWithoutRFactor(bool map_with_original = false) override;
504520

csrc/ir/internal_nodes.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2708,6 +2708,33 @@ std::string Partition::toInlineString(int indent_size) const {
27082708

27092709
NVFUSER_DEFINE_CLONE_AND_CREATE(Partition)
27102710

2711+
Combine::Combine(
2712+
IrBuilderPasskey passkey,
2713+
IterDomain* out,
2714+
IterDomain* component,
2715+
RaggedIterDomain* ragged)
2716+
: Expr(passkey) {
2717+
addOutput(out);
2718+
addInput(component);
2719+
addInput(ragged);
2720+
}
2721+
2722+
std::string Combine::toString(int indent_size) const {
2723+
std::stringstream ss;
2724+
ss << "Combine: ";
2725+
ss << "component: " << component()->toString();
2726+
ss << " + ragged: " << ragged()->toString();
2727+
ss << " -> " << out()->toString();
2728+
ss << "\n";
2729+
return ss.str();
2730+
}
2731+
2732+
std::string Combine::toInlineString(int indent_size) const {
2733+
NVF_CHECK(false, "Combine can not be printed inline");
2734+
}
2735+
2736+
NVFUSER_DEFINE_CLONE_AND_CREATE(Combine)
2737+
27112738
Swizzle::Swizzle(
27122739
IrBuilderPasskey passkey,
27132740
IterDomain* out_x,

csrc/ir/internal_nodes.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,6 +1910,44 @@ class NVF_API Partition : public Expr {
19101910
}
19111911
};
19121912

1913+
//! Combine a component IterDomain with a RaggedIterDomain to flatten
1914+
//! This is the inverse of Partition, merging component and ragged dimensions
1915+
//! into a single regular IterDomain
1916+
class NVF_API Combine : public Expr {
1917+
public:
1918+
using Expr::Expr;
1919+
1920+
Combine(
1921+
IrBuilderPasskey,
1922+
IterDomain* out,
1923+
IterDomain* component,
1924+
RaggedIterDomain* ragged);
1925+
1926+
NVFUSER_DECLARE_CLONE_AND_CREATE
1927+
1928+
const char* getOpString() const override {
1929+
return "Combine";
1930+
}
1931+
1932+
std::string toString(int indent_size = 0) const override;
1933+
std::string toInlineString(int indent_size = 0) const override;
1934+
1935+
//! Output IterDomain (combined/flattened dimension)
1936+
IterDomain* out() const {
1937+
return output(0)->as<IterDomain>();
1938+
}
1939+
1940+
//! Component dimension input (extent = num_components)
1941+
IterDomain* component() const {
1942+
return input(0)->as<IterDomain>();
1943+
}
1944+
1945+
//! Ragged dimension input (variable extents per component)
1946+
RaggedIterDomain* ragged() const {
1947+
return input(1)->as<RaggedIterDomain>();
1948+
}
1949+
};
1950+
19131951
class Swizzle : public Expr {
19141952
public:
19151953
using Expr::Expr;

0 commit comments

Comments
 (0)