Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
d87e6d7
Initial introduction of RaggedIterDomain
naoyam Dec 12, 2025
77c6a07
Merge remote-tracking branch 'origin/main' into raggediterdomain_init…
naoyam Dec 12, 2025
f16fc4d
cleanup
naoyam Dec 12, 2025
23d55f1
fix
naoyam Dec 12, 2025
8392332
fix
naoyam Dec 12, 2025
787dfec
unit test
naoyam Dec 12, 2025
a0b40a3
cleanup
naoyam Dec 12, 2025
dbdd917
Fix IterVisitor
naoyam Dec 12, 2025
cdbd81e
cleanup
naoyam Dec 12, 2025
d4c8d7f
WIP: partition
naoyam Dec 12, 2025
9575a13
Partition expr
naoyam Dec 13, 2025
a054ae0
TensorView::partition
naoyam Dec 13, 2025
69dbe0f
cleanup
naoyam Dec 13, 2025
db3b359
Merge remote-tracking branch 'origin/main' into raggediterdomain_part…
naoyam Dec 13, 2025
2348dde
cleanup
naoyam Dec 13, 2025
7090b9c
WIP: asNested
naoyam Dec 13, 2025
b07e285
cleanup
naoyam Dec 13, 2025
a2c504b
asNested
naoyam Dec 15, 2025
b1d8cf4
warpdim
naoyam Dec 15, 2025
201c148
Make sure RaggedIterDomain is propagated to output tensors
naoyam Dec 17, 2025
9e0b161
Extend ops to be aware with RaggediterDomain
naoyam Dec 17, 2025
60a2dd5
RaggedIterDomain and reduction
naoyam Dec 17, 2025
566d63d
WIP
naoyam Dec 18, 2025
144b206
WIP
naoyam Dec 18, 2025
e2efe75
cleanup
naoyam Dec 18, 2025
0b68d6b
cleanup
naoyam Dec 18, 2025
8a73bb2
cleanup
naoyam Dec 18, 2025
550e0c5
Merge branch 'raggediterdomain_partition' into raggediterdomain-asnested
naoyam Dec 18, 2025
82bd85e
Merge remote-tracking branch 'origin/main' into raggediterdomain-asne…
naoyam Dec 18, 2025
f215f07
Use extents as a parameter
naoyam Dec 18, 2025
5b99432
Merge remote-tracking branch 'origin/main' into raggediterdomain-asne…
naoyam Dec 18, 2025
2dd9287
Merge branch 'raggediterdomain-asnested' into raggediterdomain_clone
naoyam Dec 18, 2025
c3aebec
combine
naoyam Dec 19, 2025
a22bb1f
Add tests
naoyam Dec 19, 2025
f521c38
WIP
naoyam Dec 19, 2025
8d0d9cb
don't hold component ID in RaggedIterDomain
naoyam Dec 19, 2025
67aac1b
Add design doc
naoyam Dec 19, 2025
3a80926
license
naoyam Dec 19, 2025
f75ecb6
Merge branch 'main' into raggediterdomain-asnested
naoyam Jan 7, 2026
8aa854e
feedback
naoyam Jan 7, 2026
72ae14f
fix
naoyam Jan 7, 2026
85d48df
Merge branch 'raggediterdomain-asnested' into raggediterdomain_clone
naoyam Jan 7, 2026
5f86d9c
Merge branch 'main' into raggediterdomain_clone
naoyam Jan 7, 2026
bf5b627
Merge remote-tracking branch 'origin/main' into raggediterdomain_clone
naoyam Jan 7, 2026
bec4c09
Merge remote-tracking branch 'origin/main' into raggediterdomain_clone
naoyam Jan 7, 2026
4d8acab
cleanup
naoyam Jan 9, 2026
3b082ba
cleanup
naoyam Jan 9, 2026
72dbc41
Merge branch 'raggediterdomain_clone' into ragged_combine
naoyam Jan 13, 2026
5002407
expand doc
naoyam Jan 13, 2026
be0e2ea
cleanup
naoyam Jan 13, 2026
cf346ad
WIP
naoyam Jan 13, 2026
bee8691
Merge branch 'ragged_combine' into raggediterdomain_nd_partition
naoyam Jan 13, 2026
19737a8
Partition with multi-dim extents
naoyam Jan 14, 2026
5a1e325
WIP
naoyam Jan 14, 2026
1400baf
cleanup
naoyam Jan 14, 2026
d0a359b
format
naoyam Jan 14, 2026
cbe1850
Error check
naoyam Jan 14, 2026
0abf6d2
Empty commit
naoyam Jan 15, 2026
6b83715
Merge branch 'main' into raggediterdomain_nd_partition
naoyam Jan 16, 2026
aec9385
cleanup
naoyam Jan 16, 2026
1fde60b
cleanup
naoyam Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions csrc/ir/internal_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,20 +1025,28 @@ std::pair<IterDomain*, RaggedIterDomain*> RaggedIterDomain::partition(
"partition: extents must have Index type, got ",
extents->dtype());

const auto& extents_domain = extents->getLogicalDomain();
NVF_ERROR_EQ(
extents_domain.size(),
1,
"partition: extents tensor must be 1D, got ",
extents_domain.size(),
"D tensor. Multi-dimensional extents not yet supported.");
// Filter out reduction dimensions from extents tensor
auto extents_no_reduction =
extents->getLogicalDomain() | TensorDomain::kNoReductions;
auto extents_ndim = std::ranges::distance(extents_no_reduction);
NVF_ERROR_GT(
extents_ndim,
0,
"partition: extents tensor must have at least one non-reduction "
"dimension");

auto container = in->container();

// Create component IterDomain
// Component extent = number of components = length of extents tensor
// Component extent = number of components = size of last dimension of extents
// For 1D extents [K]: component_extent = K
// For 2D extents [D, K]: component_extent = K (last dim)
// For N-D extents [..., K]: component_extent = K (last dim)
// The outer dimensions of extents correspond to outer dimensions of the
// tensor being partitioned, allowing non-uniform partitions across instances.
auto zero = container->zeroVal(DataType::Index);
auto component_extent = extents_domain.at(0)->extent();
auto component_extent =
(*std::ranges::prev(extents_no_reduction.end()))->extent();
auto component_id = IterDomainBuilder(zero, component_extent)
.parallel_type(ParallelType::Serial)
.iter_type(IterType::Iteration)
Expand Down Expand Up @@ -1121,12 +1129,18 @@ IterDomain* RaggedIterDomain::combine(
TensorView* extents_tv = ragged->extents();
NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null");

// It is still assumed the extents tensor is just 1D
// Filter out reduction dimensions before checking
auto extents_no_reduction =
extents_tv->getLogicalDomain() | TensorDomain::kNoReductions;
// Multi-dimensional extents are not yet supported in combine
auto extents_ndim = std::ranges::distance(extents_no_reduction);
NVF_ERROR_EQ(
std::ranges::distance(
extents_tv->getLogicalDomain() | TensorDomain::kNoReductions),
extents_ndim,
1,
"Unexpected rank of extent tensor: ",
"combine: Multi-dimensional extents are not yet supported. ",
"Expected 1D extents tensor, got ",
extents_ndim,
"D extents: ",
extents_tv->toString());

auto container = component->container();
Expand Down
18 changes: 10 additions & 8 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,17 @@ class NVF_API RaggedIterDomain : public IterDomain {
//! Creates a component IterDomain and a RaggedIterDomain based on extents
//!
//! \param in Input IterDomain to partition (must be regular IterDomain)
//! \param extents Extents tensor defining the size of each component (must be
//! 1D)
//! Shape: [num_components], values: [extent0, extent1, ...,
//! extent(n-1)]
//! \return Pair of (component_id, ragged_id)
//! component_id: IterDomain with extent = num_components
//! ragged_id: RaggedIterDomain with the provided extents
//! \param extents Extents tensor defining the size of each component
//! 1D example: Shape [num_components], values [extent0, extent1, ...]
//! 2D example: Shape [outer_dim, num_components], e.g., [num_gpus,
//! num_experts] For N-D extents, the last dimension defines the number
//! of components, and outer dimensions correspond to outer dimensions
//! of the tensor.
//!
//! TODO: Support multi-dimensional extents for nested ragged structures
//! \return Pair of (component_id, ragged_id)
//! component_id: IterDomain with extent = num_components (from last
//! dim of extents) ragged_id: RaggedIterDomain with N-D extents
//! tensor
static std::pair<IterDomain*, RaggedIterDomain*> partition(
IterDomain* in,
TensorView* extents);
Expand Down
34 changes: 27 additions & 7 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1316,13 +1316,6 @@ TensorView* asNested(
NVF_ERROR(data != nullptr, "asNested: data tensor is null");
NVF_ERROR(extents != nullptr, "asNested: extents tensor is null");

// Only 1D extents tensors are currently supported
NVF_ERROR_EQ(
std::ranges::distance(
extents->getLogicalDomain() | TensorDomain::kNoReductions),
1,
"asNested currently only supports 1D extents tensors");

NVF_CHECK(
!data->domain()->hasRaggedIterDomain(),
"Multiple level of nesting is not supported: ",
Expand All @@ -1341,6 +1334,33 @@ TensorView* asNested(

ragged_dim = wrapDim(ragged_dim, inp_logical_size);

// Filter out reduction dimensions from extents tensor
auto extents_no_reduction =
extents->getLogicalDomain() | TensorDomain::kNoReductions;
auto extents_ndim = std::ranges::distance(extents_no_reduction);
NVF_ERROR_GT(
extents_ndim,
0,
"asNested: extents tensor must have at least one non-reduction "
"dimension");

// Validate shape correspondence for multi-dimensional extents
// For N-D extents, outer dimensions must match outer dimensions of input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you touched PairwiseLogicalDomainMap yet? These "batch" dimensions should be mapped somehow so sharding propagation and sharding decomposition can work.

// tensor Rule: extents.ndim - 1 == ragged_dim (except 1D extents which are
// always valid).
if (extents_ndim > 1) {
NVF_ERROR_EQ(
extents_ndim - 1,
ragged_dim,
"asNested: Multi-dimensional extents require shape ",
"[d0, d1, ..., d(axis-1), num_components]. ",
"Got ",
extents_ndim,
"D extents for partitioning axis ",
ragged_dim);
}
// Note: 1D extents are always valid for any axis (uniform partition)

// Partition the specified dimension in root domain
// This replaces one IterDomain with (component_id, ragged_id)
auto [component_id, ragged_id] =
Expand Down
61 changes: 49 additions & 12 deletions tests/cpp/test_ragged_iter_domain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,20 +278,15 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) {
fusion.addInput(float_extents);
EXPECT_THROW(RaggedIterDomain::partition(input_id, float_extents), nvfError);

// Test 4: Multi-dimensional extents should fail
auto extents_2d = makeSymbolicTensor(2, DataType::Index);
fusion.addInput(extents_2d);
EXPECT_THROW(RaggedIterDomain::partition(input_id, extents_2d), nvfError);

// Test 5: Non-Iteration IterType should fail
// Test 4: Non-Iteration IterType should fail
auto reduction_id =
IterDomainBuilder(
fusion.zeroVal(), IrBuilder::create<Val>(10L, DataType::Index))
.iter_type(IterType::Reduction)
.build();
EXPECT_THROW(RaggedIterDomain::partition(reduction_id, extents), nvfError);

// Test 6: Cannot partition RaggedIterDomain
// Test 5: Cannot partition RaggedIterDomain
auto extents2 = makeSymbolicTensor(1, DataType::Index);
fusion.addInput(extents2);
auto ragged_id = IrBuilder::create<RaggedIterDomain>(
Expand Down Expand Up @@ -638,19 +633,44 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationNullExtents) {
EXPECT_THROW(asNested(data, nullptr, 0), nvfError);
}

// asNested validation - multi-dimensional extents (not yet supported)
TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimExtents) {
// asNested with 2D extents (should partition axis 1, not axis 0)
TEST_F(RaggedIterDomainTest, AsNested2DOffsets) {
Fusion fusion;
FusionGuard fg(&fusion);

auto data = makeSymbolicTensor(2, DataType::Float);
// Create a 3D TensorView: [D=2, tokens=100, hidden=512]
// This represents 2 GPUs, each with tokens, and hidden dimension
auto data = makeSymbolicTensor(3, DataType::Float);
fusion.addInput(data);

// 2D extents should fail (only 1D supported currently)
// Create 2D extents [D=2, num_experts=4]
// This represents per-GPU token counts for experts
auto extents_2d = makeSymbolicTensor(2, DataType::Index);
fusion.addInput(extents_2d);

EXPECT_THROW(asNested(data, extents_2d, 0), nvfError);
// Create nested tensor partitioning dimension 1 (tokens)
auto nested = asNested(data, extents_2d, 1);

fusion.addOutput(nested);

// Verify the output dimensions
// Should be: [D=2, component=4, ragged, hidden=512]
EXPECT_EQ(nested->nDims(), 4);

// First axis is unchanged (D=2)
EXPECT_TRUE(nested->axis(0)->isStrictlyA<IterDomain>());

// Second axis is component (num_experts from last dim)
EXPECT_TRUE(nested->axis(1)->isStrictlyA<IterDomain>());
EXPECT_FALSE(nested->axis(1)->isA<RaggedIterDomain>());

// Third axis is ragged with 2D extents
EXPECT_TRUE(nested->axis(2)->isA<RaggedIterDomain>());
auto ragged_id = nested->axis(2)->as<RaggedIterDomain>();
EXPECT_EQ(ragged_id->extents()->nDims(), 2);

// Fourth axis is the original hidden dimension
EXPECT_TRUE(nested->axis(3)->isStrictlyA<IterDomain>());
}

TEST_F(RaggedIterDomainTest, LoadStoreWithNestedTensor) {
Expand Down Expand Up @@ -1136,4 +1156,21 @@ TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) {
EXPECT_THROW(pad(nested, pad_widths), nvfError);
}

// Test asNested invalid shape
TEST_F(RaggedIterDomainTest, AsNestedInvalidShape) {
Fusion fusion;
FusionGuard fg(&fusion);

// Create a 3D TensorView: [D=2, S=100, hidden=512]
auto tokens = makeSymbolicTensor(3, DataType::Float);
fusion.addInput(tokens);

// Create 3D extents (wrong dimensionality for axis 1)
auto extents_3d = makeSymbolicTensor(3, DataType::Index);
fusion.addInput(extents_3d);

// This should throw: 3D extents for axis 1 requires extents.ndim - 1 == 1
EXPECT_THROW(asNested(tokens, extents_3d, 1), nvfError);
}

} // namespace nvfuser