diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index df8eb569bc9..479b53caa84 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -1025,20 +1025,28 @@ std::pair RaggedIterDomain::partition( "partition: extents must have Index type, got ", extents->dtype()); - const auto& extents_domain = extents->getLogicalDomain(); - NVF_ERROR_EQ( - extents_domain.size(), - 1, - "partition: extents tensor must be 1D, got ", - extents_domain.size(), - "D tensor. Multi-dimensional extents not yet supported."); + // Filter out reduction dimensions from extents tensor + auto extents_no_reduction = + extents->getLogicalDomain() | TensorDomain::kNoReductions; + auto extents_ndim = std::ranges::distance(extents_no_reduction); + NVF_ERROR_GT( + extents_ndim, + 0, + "partition: extents tensor must have at least one non-reduction " + "dimension"); auto container = in->container(); // Create component IterDomain - // Component extent = number of components = length of extents tensor + // Component extent = number of components = size of last dimension of extents + // For 1D extents [K]: component_extent = K + // For 2D extents [D, K]: component_extent = K (last dim) + // For N-D extents [..., K]: component_extent = K (last dim) + // The outer dimensions of extents correspond to outer dimensions of the + // tensor being partitioned, allowing non-uniform partitions across instances. auto zero = container->zeroVal(DataType::Index); - auto component_extent = extents_domain.at(0)->extent(); + auto component_extent = + (*std::ranges::prev(extents_no_reduction.end()))->extent(); auto component_id = IterDomainBuilder(zero, component_extent) .parallel_type(ParallelType::Serial) .iter_type(IterType::Iteration) @@ -1121,12 +1129,18 @@ IterDomain* RaggedIterDomain::combine( TensorView* extents_tv = ragged->extents(); NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null"); - // It is still assumed the extents tensor is just 1D + // Filter out reduction dimensions before checking + auto extents_no_reduction = + extents_tv->getLogicalDomain() | TensorDomain::kNoReductions; + // Multi-dimensional extents are not yet supported in combine + auto extents_ndim = std::ranges::distance(extents_no_reduction); NVF_ERROR_EQ( - std::ranges::distance( - extents_tv->getLogicalDomain() | TensorDomain::kNoReductions), + extents_ndim, 1, - "Unexpected rank of extent tensor: ", + "combine: Multi-dimensional extents are not yet supported. ", + "Expected 1D extents tensor, got ", + extents_ndim, + "D extents: ", extents_tv->toString()); auto container = component->container(); diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 2d1872563e5..67f66b68782 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -486,15 +486,17 @@ class NVF_API RaggedIterDomain : public IterDomain { //! Creates a component IterDomain and a RaggedIterDomain based on extents //! //! \param in Input IterDomain to partition (must be regular IterDomain) - //! \param extents Extents tensor defining the size of each component (must be - //! 1D) - //! Shape: [num_components], values: [extent0, extent1, ..., - //! extent(n-1)] - //! \return Pair of (component_id, ragged_id) - //! component_id: IterDomain with extent = num_components - //! ragged_id: RaggedIterDomain with the provided extents + //! \param extents Extents tensor defining the size of each component + //! 1D example: Shape [num_components], values [extent0, extent1, ...] + //! 2D example: Shape [outer_dim, num_components], e.g., [num_gpus, + //! num_experts] For N-D extents, the last dimension defines the number + //! of components, and outer dimensions correspond to outer dimensions + //! of the tensor. //! - //! TODO: Support multi-dimensional extents for nested ragged structures + //! \return Pair of (component_id, ragged_id) + //! component_id: IterDomain with extent = num_components (from last + //! dim of extents) ragged_id: RaggedIterDomain with N-D extents + //! tensor static std::pair partition( IterDomain* in, TensorView* extents); diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index d5b1990fa80..24336211b2e 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1316,13 +1316,6 @@ TensorView* asNested( NVF_ERROR(data != nullptr, "asNested: data tensor is null"); NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); - // Only 1D extents tensors are currently supported - NVF_ERROR_EQ( - std::ranges::distance( - extents->getLogicalDomain() | TensorDomain::kNoReductions), - 1, - "asNested currently only supports 1D extents tensors"); - NVF_CHECK( !data->domain()->hasRaggedIterDomain(), "Multiple level of nesting is not supported: ", @@ -1341,6 +1334,33 @@ TensorView* asNested( ragged_dim = wrapDim(ragged_dim, inp_logical_size); + // Filter out reduction dimensions from extents tensor + auto extents_no_reduction = + extents->getLogicalDomain() | TensorDomain::kNoReductions; + auto extents_ndim = std::ranges::distance(extents_no_reduction); + NVF_ERROR_GT( + extents_ndim, + 0, + "asNested: extents tensor must have at least one non-reduction " + "dimension"); + + // Validate shape correspondence for multi-dimensional extents + // For N-D extents, outer dimensions must match outer dimensions of input + // tensor Rule: extents.ndim - 1 == ragged_dim (except 1D extents which are + // always valid). + if (extents_ndim > 1) { + NVF_ERROR_EQ( + extents_ndim - 1, + ragged_dim, + "asNested: Multi-dimensional extents require shape ", + "[d0, d1, ..., d(axis-1), num_components]. ", + "Got ", + extents_ndim, + "D extents for partitioning axis ", + ragged_dim); + } + // Note: 1D extents are always valid for any axis (uniform partition) + // Partition the specified dimension in root domain // This replaces one IterDomain with (component_id, ragged_id) auto [component_id, ragged_id] = diff --git a/tests/cpp/test_ragged_iter_domain.cpp b/tests/cpp/test_ragged_iter_domain.cpp index 6e2b88b6690..4f974a3f86c 100644 --- a/tests/cpp/test_ragged_iter_domain.cpp +++ b/tests/cpp/test_ragged_iter_domain.cpp @@ -278,12 +278,7 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { fusion.addInput(float_extents); EXPECT_THROW(RaggedIterDomain::partition(input_id, float_extents), nvfError); - // Test 4: Multi-dimensional extents should fail - auto extents_2d = makeSymbolicTensor(2, DataType::Index); - fusion.addInput(extents_2d); - EXPECT_THROW(RaggedIterDomain::partition(input_id, extents_2d), nvfError); - - // Test 5: Non-Iteration IterType should fail + // Test 4: Non-Iteration IterType should fail auto reduction_id = IterDomainBuilder( fusion.zeroVal(), IrBuilder::create(10L, DataType::Index)) @@ -291,7 +286,7 @@ TEST_F(RaggedIterDomainTest, PartitionValidation) { .build(); EXPECT_THROW(RaggedIterDomain::partition(reduction_id, extents), nvfError); - // Test 6: Cannot partition RaggedIterDomain + // Test 5: Cannot partition RaggedIterDomain auto extents2 = makeSymbolicTensor(1, DataType::Index); fusion.addInput(extents2); auto ragged_id = IrBuilder::create( @@ -638,19 +633,44 @@ TEST_F(RaggedIterDomainTest, AsNestedValidationNullExtents) { EXPECT_THROW(asNested(data, nullptr, 0), nvfError); } -// asNested validation - multi-dimensional extents (not yet supported) -TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimExtents) { +// asNested with 2D extents (should partition axis 1, not axis 0) +TEST_F(RaggedIterDomainTest, AsNested2DOffsets) { Fusion fusion; FusionGuard fg(&fusion); - auto data = makeSymbolicTensor(2, DataType::Float); + // Create a 3D TensorView: [D=2, tokens=100, hidden=512] + // This represents 2 GPUs, each with tokens, and hidden dimension + auto data = makeSymbolicTensor(3, DataType::Float); fusion.addInput(data); - // 2D extents should fail (only 1D supported currently) + // Create 2D extents [D=2, num_experts=4] + // This represents per-GPU token counts for experts auto extents_2d = makeSymbolicTensor(2, DataType::Index); fusion.addInput(extents_2d); - EXPECT_THROW(asNested(data, extents_2d, 0), nvfError); + // Create nested tensor partitioning dimension 1 (tokens) + auto nested = asNested(data, extents_2d, 1); + + fusion.addOutput(nested); + + // Verify the output dimensions + // Should be: [D=2, component=4, ragged, hidden=512] + EXPECT_EQ(nested->nDims(), 4); + + // First axis is unchanged (D=2) + EXPECT_TRUE(nested->axis(0)->isStrictlyA()); + + // Second axis is component (num_experts from last dim) + EXPECT_TRUE(nested->axis(1)->isStrictlyA()); + EXPECT_FALSE(nested->axis(1)->isA()); + + // Third axis is ragged with 2D extents + EXPECT_TRUE(nested->axis(2)->isA()); + auto ragged_id = nested->axis(2)->as(); + EXPECT_EQ(ragged_id->extents()->nDims(), 2); + + // Fourth axis is the original hidden dimension + EXPECT_TRUE(nested->axis(3)->isStrictlyA()); } TEST_F(RaggedIterDomainTest, LoadStoreWithNestedTensor) { @@ -1136,4 +1156,21 @@ TEST_F(RaggedIterDomainTest, PadRaggedDimensionError) { EXPECT_THROW(pad(nested, pad_widths), nvfError); } +// Test asNested invalid shape +TEST_F(RaggedIterDomainTest, AsNestedInvalidShape) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a 3D TensorView: [D=2, S=100, hidden=512] + auto tokens = makeSymbolicTensor(3, DataType::Float); + fusion.addInput(tokens); + + // Create 3D extents (wrong dimensionality for axis 1) + auto extents_3d = makeSymbolicTensor(3, DataType::Index); + fusion.addInput(extents_3d); + + // This should throw: 3D extents for axis 1 requires extents.ndim - 1 == 1 + EXPECT_THROW(asNested(tokens, extents_3d, 1), nvfError); +} + } // namespace nvfuser