-
Notifications
You must be signed in to change notification settings - Fork 69
[RFC] RaggedIterDomain for nested tensors #5550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit 8cfce69 Description
|
| Relevant files | |||
|---|---|---|---|
| Documentation |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests | ||||||||||||||||||
| 🔒 Security concerns No - This is a design document for adding new tensor domain support. No security implications identified. | ||||||||||||||||||
| ⚡ Recommended focus areas for review | ||||||||||||||||||
Design Completeness
|
| Property | Uniformity | Rationale |
|---|---|---|
| extent | VARIABLE | Core ragged characteristic - each component has different length |
| ParallelType | UNIFORM | GPU execution model requires consistent thread mapping |
| IterType | UNIFORM | All components perform same operation (iteration/reduction) |
| start | UNIFORM (=0) | Simplifies offset computation; all components start at 0 |
| is_rfactor_domain | UNIFORM | Reduction transformation applies uniformly |
The following properties are out of scope of this initial buildout:
is_padded_dimensionis_clustered_dimensionpadded_to_size
The constructor validates that:
- Uniform properties (ParallelType, IterType, start, is_rfactor_domain) are consistent across all nested domains
- Out-of-scope properties are not set (must be false/nullptr for all nested domains)
If any validation fails, an error is thrown.
6.4 Key Operations
Partition Operation (IterDomain-level)
The partition operation is the fundamental primitive for creating ragged dimensions. It splits a regular IterDomain into a batch IterDomain and a RaggedIterDomain based on variable-length segments defined by offsets.
TensorView API:
// TensorView API (csrc/ir/interface_nodes.h)
class TensorView : public Val {
// Partition dimension 'dim' using the provided offsets
// Returns new TensorView with partitioned dimension replaced by (batch_id, ragged_id)
TensorView* partition(int dim, TensorView* offsets);
};Example:
// Input: [token=325, hidden=512]
auto tokens = makeContigTensor(2);
// Partition into 3 experts with token counts [127, 0, 198]
// offsets = [0, 127, 127, 325]
auto partitioned = tokens->partition(/*dim=*/0, offsets_tv);
// Result: [expert=3, tokens_per_expert=[127,0,198], hidden=512]
// Dimension 0 is replaced by (expert_id, ragged_tokens_id)See Section 6.1 for detailed semantics of the underlying IterDomain::partition operation.
Creating Nested Tensors: asNested (Tensor-level)
The asNested operation is a tensor-level convenience operation (similar to reshape) that creates nested tensors from data and offset tensors. It is implemented using partition internally.
// In csrc/ops/alias.h
NVF_API TensorView* asNested(
TensorView* data, // Data tensor with contiguous ragged storage
TensorView* offsets, // Offset tensor [num_components + 1]
int64_t ragged_dim // Which dimension of data is ragged
);Usage Example:
// Input tensors
auto data_tv = ...; // [325, 512] - flattened ragged dimension
auto offsets_tv = ...; // [4] - offsets [0, 127, 127, 325]
// Create nested tensor
auto nested_tv = asNested(data_tv, offsets_tv, /*ragged_dim=*/0);
// Result: TensorView with shape [batch=3, ragged=[127,0,198], hidden=512]Semantics:
asNestedis a tensor-level operation (likereshape) that creates an output tensor with a ragged dimension- Implementation: Uses
partitionas a transform operation between the root and logical domains - Similar to how
reshapeuses splits/merges between root and logical domains - The offset tensor provides the boundaries for each component (see Section 6.7 for offset tensor format)
- Equivalent to calling
data->partition(ragged_dim, offsets)but provides a more intuitive API name
Merging Nested Tensors: asFlattened
To convert a nested tensor back to a regular flattened tensor, use the asFlattened operation:
// In csrc/ops/alias.h
NVF_API TensorView* asFlattened(
TensorView* nested_tensor, // Nested tensor to flatten
int64_t batch_dim, // Batch dimension index
int64_t ragged_dim // Ragged dimension index to flatten
);Usage Example:
// Input: Nested tensor [batch=3, ragged=[127,0,198], hidden=512]
auto nested_tv = ...;
// Flatten back to regular tensor
auto flattened_tv = asFlattened(nested_tv, /*batch_dim=*/0, /*ragged_dim=*/1);
// Result: TensorView with shape [token=325, hidden=512]Semantics:
asFlattenedis a tensor-level operation that flattens a ragged dimension back to regular- Internally, it uses
mergeas a transform operation between the root and logical domains - The batch and ragged dimensions are merged into a single regular dimension
- This is the inverse of
asNested
Typical Usage Pattern:
// Expert parallelism workflow:
// 1. Start with flattened tokens
auto tokens = ...; // [325, 512]
// 2. Create nested structure for expert processing
auto nested = asNested(tokens, expert_offsets, /*ragged_dim=*/0);
// [expert=3, tokens_per_expert=[127,0,198], hidden=512]
// 3. Process with experts (some operations on nested tensor)
auto processed = expert_processing(nested);
// 4. Flatten back to regular tensor
auto result = asFlattened(processed, /*batch_dim=*/0, /*ragged_dim=*/1);
// [325, 512]Transformations
Split: Split a regular IterDomain and merge with a RaggedIterDomain to create a new ragged structure.
auto split_result = IterDomain::split(ragged, 2);
auto outer = split_result.first; // extents = [2, 3, 1], ragged dimension
auto inner = split_result.second; // extent = 2, regular dimensionMerge: Merge a RaggedIterDomain with a regular IterDomain.
There are two types of merge operations involving RaggedIterDomain:
- Merge RaggedIterDomain with its batch dimension → Regular IterDomain (described in Section 6.1)
- Merge RaggedIterDomain with another regular IterDomain → RaggedIterDomain (two cases):
Case 2a: Merge ragged with non-ragged (element-wise product)
auto inner = IrBuilder::create<IterDomain>(0, 4); // extent 4
// Merge: ragged dimension becomes outer, inner becomes feature
auto merged = IterDomain::merge(ragged, inner);
// Result: RaggedIterDomain with nested extents [3*4, 5*4, 2*4] = [12, 20, 8]Case 2b: Merge regular IterDomain with RaggedIterDomain (reduction along dimension)
When the regular IterDomain corresponds to the outer dimension of the RaggedIterDomain's extent tensor, the merge reduces along that dimension:
// Input: RaggedIterDomain with 2D extent tensor shape [gpu=2, expert=4]
// extents[0] = [30, 0, 40, 30] (GPU 0)
// extents[1] = [25, 35, 25, 15] (GPU 1)
// Merge the 'gpu' dimension with this RaggedIterDomain
auto merged = IterDomain::merge(gpu_dim, ragged_dim);
// Result: RaggedIterDomain with 1D extent tensor shape [expert=4]
// merged_extents = [30+25, 0+35, 40+25, 30+15] = [55, 35, 65, 45]
// The extents are summed along the merged dimensionImplementation Detail: The reduction is defined using nvFuser's ReductionOp. Specifically, the new extent tensor is computed as:
// Reduce the extent tensor along dimension 0 (the merged dimension)
new_extents = sum(ragged_dim->extents(), {0});This creates a ReductionOp that sums the 2D extent tensor [gpu=2, expert=4] along dimension 0, producing a 1D extent tensor [expert=4].
Important Constraint: The total extent of the resulting RaggedIterDomain depends on the reduction result of the extent tensor. Since nvFuser assumes all IterDomain extents are known at kernel launch time, the extent reduction cannot be done within the same kernel that uses the merged RaggedIterDomain. Until the reduction is complete, we don't know the size of the ragged dimension, which blocks operations like memory allocation.
Handling Strategies (TBD): How this should be handled is still unclear. Possible approaches include:
-
Fusion Segmentation: Segment the fusion so that the extent reduction is done in a separate kernel before the main kernel is executed. This ensures the merged extents are available on the device before launching the kernel that needs them.
-
Conservative Allocation: Pre-allocate tensors large enough to accommodate the maximum possible extent. This is commonly done in optimized MoE implementations where an upper bound on token counts is known or enforced. This avoids the need for dynamic extent computation but may waste memory.
Implementation notes:
- Case 2a (element-wise): multiply each nested extent by non-ragged extent
- Case 2b (reduction): sum extents along the dimension corresponding to the merged IterDomain
- Split on ragged: split each nested IterDomain individually, creating new RaggedIterDomain with split components
- All transformations preserve uniform property requirements
Parallelization
The parallelize() method applies uniformly to all nested domains:
ragged->parallelize(ParallelType::TIDx);
// All nested domains now have ParallelType::TIDxSelect Operation (Optional - May Not Be Included in Initial Implementation)
The select operation extracts a specific component from a ragged dimension, converting it to a regular IterDomain. This is not critical for expert parallelism use cases and may be deferred to future work.
If implemented, it would work as follows:
- Select on the batch dimension to choose which component
- The ragged dimension automatically becomes a regular IterDomain with that component's extent
// Starting tensor: [batch=3, ragged=[3,5,2], feature=4]
auto tv = makeContigTensor(3); // Assume ragged dimension is at position 1
// Select batch component 1
auto selected = tv->select(/*batch_dim=*/0, /*index=*/1);
// Result: [ragged_extent=5, feature=4]
// The ragged dimension collapsed to a regular IterDomain with extent 5Implementation notes:
- Select on batch dimension causes the RaggedIterDomain to resolve to the corresponding nested IterDomain
- The nested IterDomain at the selected index replaces the RaggedIterDomain in the output TensorDomain
- This enables direct access to individual components, similar to PyTorch's
nested_tensor[i]indexing
Complete Examples
This section shows complete examples combining multiple operations for real-world use cases.
Example 1: Simple partitioning
// Starting with flattened tokens: [token=325, hidden=512]
auto tv = makeContigTensor(2);
// Partition into 3 experts with token counts [127, 0, 198]
// offsets = [0, 127, 127, 325]
auto partitioned = tv->partition(/*dim=*/0, offsets);
// Result: [expert=3, tokens_per_expert=[127,0,198], hidden=512]
// Dimension 0 is replaced by (expert_id, ragged_tokens_id)Example 2: Expert parallelism with partition and merge
// Input: Distributed tokens across D=2 GPUs, each GPU has S/D tokens
// Shape: [D=2, S/D=100, hidden=512]
// Total S=200 tokens evenly distributed: 100 tokens per GPU
auto tokens = makeContigTensor(3); // [2, 100, 512]
// Step 1: Partition S/D dimension by expert
// Tokens from each GPU are routed to E=4 experts with different counts per expert per GPU.
// expert_offsets: 2D tensor [D=2, E+1=5] with per-GPU offsets for E=4 experts:
// expert_offsets[0] = [0, 30, 30, 70, 100] - GPU 0: [30, 0, 40, 30] tokens per expert
// expert_offsets[1] = [0, 25, 60, 85, 100] - GPU 1: [25, 35, 25, 15] tokens per expert
tokens->partition(/*dim=*/1, expert_offsets);
tokens->setLoopAsLogical();
// Result: [gpu=2, expert=4, tokens_per_expert=[[30,0,40,30],[25,35,25,15]], hidden=512]
// Now we have nested ragged: outer gpu dimension, inner ragged tokens per expert
auto shuffled_tokens = set(tokens);
// logical: [gpu=2, expert=4, tokens_per_expert=[[30,0,40,30],[25,35,25,15]], hidden=512]
// Step 2: Shuffle to expert-first layout and distribute across GPUs
// The merge operation represents the shuffling that reorganizes from
// [gpu, expert, ragged] to expert-first layout. The actual implementation
// performs the communication to change the data layout.
shuffled_tokens->merge(0, 2);
// This creates: [expert=4, merged_ragged=[55,35,65,45], hidden=512]
// Where merged tokens per expert = sum across source GPUs:
// Expert 0: 30 (from GPU 0) + 25 (from GPU 1) = 55 tokens
// Expert 1: 0 (from GPU 0) + 35 (from GPU 1) = 35 tokens
// Expert 2: 40 (from GPU 0) + 25 (from GPU 1) = 65 tokens
// Expert 3: 30 (from GPU 0) + 15 (from GPU 1) = 45 tokens
//
// How the merged RaggedIterDomain's extents are computed:
// Before merge: tokens_per_expert RaggedIterDomain has 2D extent tensor shape [gpu=2, expert=4]
// extents[0] = [30, 0, 40, 30] (GPU 0)
// extents[1] = [25, 35, 25, 15] (GPU 1)
// The merge operation merges the 'gpu' dimension (dimension 0 of the TensorView) with the
// RaggedIterDomain. Since 'gpu' corresponds to the outer dimension of the extent tensor,
// we reduce along that dimension by summing extents:
// merged_extents[expert_i] = sum over gpus of extents[gpu][expert_i]
// merged_extents = [30+25, 0+35, 40+25, 30+15] = [55, 35, 65, 45]
// The resulting RaggedIterDomain has a 1D extent tensor with shape [expert=4]
//
// NOTE: This extent reduction must complete before any kernel that uses shuffled_tokens
// can be launched, as the total extent of the merged ragged dimension is not known until
// the reduction completes. This may require fusion segmentation.
// Then split experts across GPUs for parallel processing (2 experts per GPU)
shuffled_tokens->split(0, /*factor=*/2);
// Result: [gpu=2, expert_per_gpu=2, merged_ragged=[[55,35],[65,45]], hidden=512]
// GPU 0 processes experts 0-1 with [55, 35] tokens respectively
// GPU 1 processes experts 2-3 with [65, 45] tokens respectively
// Summary:
// - Input: [D=2, S/D=100] uniform tokens per GPU
// - After partition: [D=2, E=4, ragged] non-uniform tokens per (GPU, expert)
// - After merge+split: [D=2, E/D=2, ragged] expert-first, distributed for processing6.5 Indexing and Code Generation
Note: Indexing and code generation for ragged dimensions are not part of the initial scope of this design. The details below describe the intended eventual behavior but are intentionally left out for now. The initial implementation will focus on the IR representation and basic infrastructure.
Offset-Based Indexing (Future Work)
For ragged iteration, global indices will be computed as:
global_index = offset[component_idx] + local_index
Where component_idx is the batch index and local_index iterates from 0 to extent[component_idx]. The offset array provides the starting position of each component in contiguous storage.
Loop Structure (Future Work)
Generated CUDA code will follow this pattern:
for (int batch = 0; batch < num_components; batch++) {
int offset = offsets[batch];
int extent = extents[batch];
// Uniform parallelization (e.g., threadIdx.x)
for (int tid = threadIdx.x; tid < extent; tid += blockDim.x) {
int global_idx = offset + tid;
// Process element at global_idx
}
}Indexer Strategy (Future Work)
RaggedIterDomain will integrate with the IdModel-based indexing system. This requires extending IdModel to handle ragged dimensions, including new expression types for ragged transformations and modifications to ValGraph handling.
6.6 Memory Layout
Ragged data is stored contiguously in memory with components placed sequentially:
Ragged dimension with extents [3, 5, 2]:
Storage: [c0_0, c0_1, c0_2, c1_0, c1_1, c1_2, c1_3, c1_4, c2_0, c2_1]
└─ Component 0 ─┘ └─────── Component 1 ───────┘ └─ Comp 2 ─┘
Offsets: [0, 3, 8, 10]
- Component 0: indices [0, 1, 2]
- Component 1: indices [3, 4, 5, 6, 7]
- Component 2: indices [8, 9]
Properties:
- extent = 10
This layout enables efficient sequential access and avoids padding overhead.
6.7 Extent and Offset Tensor Management
In generated CUDA code, non-nested tensors are represented using the Tensor struct, which has extents as a property. For nested tensors, we cannot have extents of nested domains as they may be dynamically computed.
Consider the mixture of experts (MoE) use case where a kernel dynamically creates a nested tensor output:
# Input: tokens [num_tokens, hidden_dim], routing decisions per token
# Output: nested tensor where each component corresponds to tokens for one expert
# At kernel launch time:
# - Total number of tokens: KNOWN (e.g., 1024)
# - Number of experts: KNOWN (e.g., 8)
# - Tokens per expert: UNKNOWN (depends on routing computation inside kernel)
# Inside the kernel:
# 1. Compute routing: which tokens go to which expert
# 2. Count tokens per expert: [127, 0, 198, 64, 412, 89, 103, 31]
# 3. Reorder token data: group tokens by expert assignment
# 4. Write nested tensor output with ragged dimension
# Result: nested tensor [num_experts=8, ragged_tokens=[127,0,198,...], hidden_dim]Key Observation: The nested domain extents are computed inside the kernel and are not known at kernel launch time.
Implication: We cannot bundle extent/offset information with the nested tensor itself.
This problem can be addressed by managing extent/offset information as a separate tensor that can be computed dynamically on GPU and passed between kernels. That effectively means a logical nested tensor consists of two Vals: one tensor for the nested tensor itself and another tensor for the extent/offset information. More concretely, here's a fusion that creates a nested tensor with asNested as an output:
// User-defined Fusion
Fusion fusion;
FusionGuard fg(&fusion);
// User provides data and offsets as separate inputs
auto tv_data = TensorViewBuilder()
.ndims(2)
.shape({-1, 512}) // [total_tokens, hidden]
.dtype(DataType::Float)
.build();
fusion.addInput(tv_data);
auto tv_offsets = TensorViewBuilder()
.ndims(1)
.shape({9}) // [num_experts + 1]
.dtype(DataType::Int)
.build();
fusion.addInput(tv_offsets);
// User explicitly creates nested tensor
auto tv_nested = asNested(tv_data, tv_offsets, /*ragged_dim=*/0);
// tv_nested has shape [batch=8, ragged_tokens, hidden=512]
// Operations on the nested tensor
auto tv_result = some_operation(tv_nested);
fusion.addOutput(tv_result);The output tensor, tv_result, is a nested tensor. The extents of the nested domains are given as a fusion input, but in general, they are not known until the fusion is executed. Thus, if the nested tensor struct were defined like:
template <typename DT, int rank>
struct NestedTensor {
DT* ptr;
int64_t extents[rank];
int64_t nested_domain_extents[ragged_dimension_rank];
};The value of nested_domain_extents is not available until the completion of the kernel, which would block the launch of the subsequent kernel.
Instead, we would like the fusion to be defined as follows:
fusion.addInput(tv_data); // Original data input (unchanged)
fusion.addInput(tv_offsets); // Original offset input (unchanged)
auto tv_nested = asNested(tv_data, tv_offsets, /*ragged_dim=*/0);
auto tv_result = some_operation(tv_nested);
auto tv_result_offsets = /* extract/compute offset part of tv_result */;
fusion.addOutput(tv_result); // Data tensor output
fusion.addOutput(tv_result_offsets); // Offset tensor output (injected)Here, for tv_result we would use the same Tensor struct as the normal tensor. The offset tensor would be a 1D tensor with the ptr val referring to the vector holding the offsets on the device memory. In this case, there's nothing to block the launch of the subsequent kernel as the offset vector would remain on the device memory.
Since it is an implementation detail, the offset tensor should be hidden behind the nested tensor in the user-facing Fusion definition. When a user uses asNested to create a nested tensor, it should still create a single nested tensor Val, as illustrated in the first case above. The translation to the second pattern should be done automatically, e.g., by a new preseg pass.
Runtime Binding (Initial Version)
Challenge: How should ExpressionEvaluator bind a ragged TensorView to actual tensor objects at runtime?
Initial Approach: In this initial version, we take a simplified approach that avoids the complexity of custom tensor types:
-
Input Nested Tensors: Must be passed as two separate parameters:
- A normal
at::Tensorcontaining the flattened data - Another
at::Tensorparameter representing the offsets (or extents)
- A normal
-
Output Nested Tensors: Returned as two separate tensors:
- A normal
at::Tensorcontaining the flattened data - An offset (or extent) tensor
- A normal
This approach treats nested tensors as regular tensors plus separate offset/extent metadata at the API boundary. While the Fusion IR internally represents them as nested TensorViews with RaggedIterDomain, the runtime binding uses plain tensors.
Limitations: This approach is not sufficient for actual production use cases because:
- Users must manually manage the data/offset tensor pairs
- No type safety to ensure data and offsets stay synchronized
- Doesn't integrate with PyTorch's NestedTensor API
- More cumbersome API compared to a unified nested tensor type
Future Work: A more complete solution would involve:
- Creating a custom NestedTensor class for nvFuser runtime
- Adding it to PolymorphicValue for storage in KernelArgumentHolder
- Supporting direct binding to/from a unified nested tensor object
- This is left as future work beyond the initial implementation.
7. System Integration
IR Layer
Type System:
- Add
RaggedIterDomaintoValTypeenum incsrc/type.h - Update
DISPATCH_FOR_ALL_VALSmacro incsrc/dispatch.h
Class Implementation:
- Class declaration in
csrc/ir/internal_base_nodes.hafter IterDomain - Implementation in
csrc/ir/nodes.cppwith validation logic
Dispatch and Visitors:
- Add dispatch handlers in
csrc/dispatch.cpp - Update visitor patterns to traverse nested domains
Cloning and Printing:
- Implement cloning constructor with
NVFUSER_DEFINE_CLONE - Add printing support in
csrc/ir/iostream.cppshowing nested structure
Indexing Layer
Note: Not part of initial scope. Details intentionally left out for now.
IdModel Extensions (Future Work):
- Extend IdModel to handle RaggedIterDomain in ValGraph
- Modify
TensorIndexerto compute offset-based indices - Update loop promotion logic for ragged dimensions
Detection (Future Work):
- Detect RaggedIterDomain during graph building
- Route to ragged-aware indexing logic
Lowering and CodeGen
Note: Not part of initial scope. Details intentionally left out for now.
Device Lowering (Future Work):
- Handle RaggedIterDomain in allocation passes
- Generate offset array computations
- Create predicates for ragged bounds checking
Code Generation (Future Work):
- Emit nested loop structure with offset-based indexing
- Generate uniform parallelization across components
- Handle extent and offset lookups in generated CUDA code
8. Implementation Phases
Phase 1: Core Infrastructure (Initial Scope)
- Type system updates (ValType enum, dispatch macros)
- RaggedIterDomain class declaration and implementation
- Basic validation, accessors, printing
- Goal: Can create and inspect RaggedIterDomain instances
Phase 2: IdModel Integration (Future Work - Not in Initial Scope)
- Extend IdModel ValGraph to handle RaggedIterDomain
- Modify TensorIndexer for offset-based indexing
- Add new expression types for ragged operations
- Predicate generation for ragged bounds
- Goal: Can compile and execute simple ragged operations
Phase 3: Transformations (Initial Scope)
- Implement split operations on ragged dimensions
- Implement merge operations (ragged with regular IterDomain)
- Add parallelize override
- Additional operations (flatten, nest/unnest) as determined by expert parallelism requirements
- Goal: Can create and transform ragged dimensions
Phase 4: Full Integration (Future Work - Not in Initial Scope)
- TensorView integration (allow RaggedIterDomain in TensorDomain)
- Device lowering passes for ragged
- CUDA code generation with offset-based indexing
- Comprehensive end-to-end tests
- Goal: Production-ready ragged tensor support
9. Future Work
- Multi-level Nesting: Support for ragged within ragged dimensions (RaggedIterDomain containing other RaggedIterDomains as nested domains). While not required for the initial expert parallelism use case, this could enable more complex nested data structures in the future.
- Python Frontend: Expose RaggedIterDomain to Python API for direct construction
- Ragged-Aware Schedulers: Specialized pointwise, reduction, and matmul schedulers for ragged patterns
- Broadcast Operations: Support broadcasting to/from ragged dimensions
10. References
</details>
</td></tr>
</table>
</details>
<!-- BEGIN INTERNAL PR REVIEW PLACEHOLDER -->
<!-- END INTERNAL PR REVIEW PLACEHOLDER -->
<!-- BEGIN CI TEST RESULTS PLACEHOLDER -->
<!-- END CI TEST RESULTS PLACEHOLDER -->
- Add select operation section explaining how to extract individual components - Fix section numbering (removed gaps in section sequence) - Fix broken link syntax for PyTorch documentation - Clarify flatten operation as TBD based on expert parallelism requirements - Remove references to non-existent appendices - Standardize terminology to "nested domains" throughout - Clarify DID as "distributed device (multi-GPU) parallelization" - Remove metadata header (Status, Author, Date fields) - Update PyTorch description to "one dimension" instead of "one or more" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Greptile OverviewGreptile SummaryThis PR introduces a comprehensive RFC for Key Design Elements:
Issues Identified: Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant TensorView
participant IterDomain
participant RaggedIterDomain
Note over User,RaggedIterDomain: Partition Operation Flow
User->>TensorView: partition(position, offsets)
TensorView->>IterDomain: partition(domain, offsets)
IterDomain->>RaggedIterDomain: construct with nested domains and extents
RaggedIterDomain-->>IterDomain: return pair of domains
IterDomain-->>TensorView: partitioned tensor
TensorView-->>User: result with ragged dimension
Note over User,RaggedIterDomain: Merge Operation Flow
User->>TensorView: merge(batch pos, ragged pos)
TensorView->>IterDomain: merge(batch, ragged)
IterDomain->>RaggedIterDomain: compute extent sum
RaggedIterDomain-->>IterDomain: merged domain
IterDomain-->>TensorView: flattened tensor
TensorView-->>User: regular tensor result
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 2 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
| // This overrides IterDomain::parallelize and calls nested_domains[i]->parallelize(pt) for all nested domains | ||
| void parallelize(ParallelType pt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: The parallelize() method signature should be void parallelize(ParallelType pt) override; to properly override the base class method
| struct NestedTensor { | ||
| DT* ptr; | ||
| int64_t extents[rank]; | ||
| int64_t nested_domain_extents[ragged_dimension_rank]; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Inconsistent indentation - mix of tabs and spaces in the struct definition. Use consistent indentation (spaces preferred in C++ code)
|
|
||
| ```cpp | ||
| template <typename DT, int rank> | ||
| struct NestedTensor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you expect this to become a PolymorphicValue for nvFuser's runtime integration? KernelArgumentHolder only keeps PolymorphicValues at this moment.
I heard but never tried to subclass an at::Tensor. Would that be a good idea so we don't break too many assumptions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just one option to represent nested tensors in CUDA kernels, much as normal tensors are represented using the Tensor struct.
For the host runtime, a nested tensor with a RaggedIterDomain is still just a TensorView, so I don't think we would need to do anything special for nested tensors except for passing some more properties about the ragged ID.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At run time, do you expect ExpressionEvaluator to bind a ragged TensorView to a torch.nested._internal.nested_tensor.NestedTensor or just an at::Tensor?
If the former, NestedTensor appears to be in Python only. It's also unclear whether NestedTensor requires its extents or its offsets to be on CPU -- all the tutorials appear so.
If the latter, is a TensorView of logical domain [g, j?] bound to a 2D at::Tensor or 1D? If 2D, what's its size(1)? If 1D, how are we going to fix the existing assumption that tensor.dim() == std::ranges::distance(tv->getLogicalDomain() | TensorDomain::kNoReductions)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the PyTorch nested tensor requires the offsets always available on CPU, would that mean it is not suitable for our use cases? If so, should we just always view it as a normal tensor and somehow manage offsets on device memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Therefore, I suspect we need to write own NestedTensor class and add it to PolymorphicValue. A normal at::Tensor is not sufficient either -- we need the offsets/counts stored somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, now I see your concern. I'm planning to leave it unaddressed in the initial version. I think as long as flattened tensors are used as fusion inputs or outputs, this would not be a problem. Obviously, that's not ideal, but I'll leave it an issue for follow-up work. d8e9aae
…merge operations Major updates: - Add support for multi-level nesting (ragged within ragged) - Add support for multiple independent ragged dimensions per tensor - Introduce partition operation as core primitive for creating nested structures - Support 1D offsets for partitioning regular IterDomain - Support 2D offsets for partitioning RaggedIterDomain - Introduce merge operation as inverse of partition - Add asNested tensor-level operation (uses partition internally) - Add asFlattened tensor-level operation (uses merge internally) - Update expert parallelism examples with concrete use cases - Clarify all_to_all communication and rank-first to expert-first transformation - Emphasize distinction between inherent data properties vs scheduling decisions - Fix examples to always include batch IterDomains in TensorDomain Key design decisions: - partition and merge are IterDomain-level operations (like split/merge) - asNested and asFlattened are tensor-level operations (like reshape) - Offset tensors can be 1D or 2D depending on input type - Multi-level nesting enables expert parallelism with load balancing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Changes to Example 2: - Start with distributed input: [D=2, S/D=100, hidden=512] with uniform tokens per GPU - Use 2D offsets for partition: each GPU routes its tokens to E=4 experts non-uniformly - Clarify that merge operation represents the shuffling transformation - Merge creates expert-first layout: [E=4, ragged] with total tokens per expert - Split distributes experts across GPUs: [D=2, E/D=2, ragged] for processing - Combine shuffle and GPU distribution into single Step 2 Key clarifications: - Input is already distributed (not single flattened tensor) - First partition uses 2D offsets because each GPU has different token distribution - Merge represents shuffle operation that performs communication - Use 4 experts (not 3) to enable even split across 2 GPUs - Emphasize that merge changes data layout from [gpu, expert, ragged] to expert-first 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 3 comments
| ); | ||
| // ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Parameter type inconsistency: Line 198 shows TensorView* offsets but line 587 shows Val* offsets. The correct type should be TensorView* since offsets need to be a tensor (mentioned as "1D TensorView" in documentation at line 605).
| ); | |
| // ... | |
| Val* offsets // Offset values defining partition boundaries |
should be:
| ); | |
| // ... | |
| TensorView* offsets // Offset values defining partition boundaries |
| // performs the communication to change the data layout. | ||
| auto merged = IterDomain::merge(gpu_dim, tokens_per_expert); | ||
| // This creates: [expert=4, merged_ragged=[55,35,65,45]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Variable scope issue: gpu_dim is not defined in the accessible scope. After partitioning at line 633, the result by_expert has dimensions [gpu=2, expert=4, ragged, hidden=512], but gpu_dim and tokens_per_expert are not extracted from this result. Either extract these dimensions from by_expert first, or clarify this is pseudocode showing the conceptual operation rather than actual executable code.
| // Then split experts across GPUs for parallel processing (2 experts per GPU) | ||
| auto [gpu_out, expert_per_gpu] = IterDomain::split(expert_dim, /*factor=*/2); | ||
| // Result: [gpu=2, expert_per_gpu=2, merged_ragged=[[55,35],[65,45]], hidden=512] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Undefined variable expert_dim: This variable is not defined in the example. After the merge operation at line 640, the result should have dimensions related to the merged structure, but expert_dim is not extracted or defined. The split operation needs a valid IterDomain reference.
Changes: - Remove multi-level nesting from Goals (not required for expert parallelism) - Update PyTorch Restriction to only mention multiple independent ragged dimensions - Remove "Multi-Level Nesting Example" section - Remove multi-level nesting from Tensor Domain Structure features - Remove multi-level nesting from Nested Domains description - Rename "Multi-level partitioning" example to "Expert parallelism with partition and merge" - Update implementation notes to remove multi-level partitioning references - Add multi-level nesting to Future Work section Rationale: The expert parallelism example demonstrates that multi-level nesting (ragged within ragged) is not needed. The workflow uses multiple independent ragged dimensions with merge/partition operations, not nested RaggedIterDomains. This simplifies the initial implementation scope while still supporting the required use cases. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 2 comments
| // The merge operation represents the shuffling that reorganizes from | ||
| // [gpu, expert, ragged] to expert-first layout. The actual implementation | ||
| // performs the communication to change the data layout. | ||
| auto merged = IterDomain::merge(gpu_dim, tokens_per_expert); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Variables gpu_dim and tokens_per_expert are not extracted from the by_expert result. After the partition at line 563, you need to extract dimensions from the result before using them.
| auto merged = IterDomain::merge(gpu_dim, tokens_per_expert); | |
| // Extract dimensions from by_expert for subsequent operations | |
| auto gpu_dim = by_expert->axis(0); // gpu dimension | |
| auto tokens_per_expert = by_expert->axis(2); // ragged tokens dimension | |
| auto merged = IterDomain::merge(gpu_dim, tokens_per_expert); |
| // Expert 3: 30 (from GPU 0) + 15 (from GPU 1) = 45 tokens | ||
|
|
||
| // Then split experts across GPUs for parallel processing (2 experts per GPU) | ||
| auto [gpu_out, expert_per_gpu] = IterDomain::split(expert_dim, /*factor=*/2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Variable expert_dim is not defined. After the merge operation, you need to extract the expert dimension from the merged result.
| auto [gpu_out, expert_per_gpu] = IterDomain::split(expert_dim, /*factor=*/2); | |
| // Extract expert dimension from merged result | |
| auto expert_dim = merged->axis(0); // or appropriate axis index | |
| auto [gpu_out, expert_per_gpu] = IterDomain::split(expert_dim, /*factor=*/2); |
Changes: - Remove Case 2 from partition operation semantics in Section 6.1 - Simplify to only support partitioning regular IterDomain (Case 1) - Remove 2D offset requirement for partitioning RaggedIterDomain - Remove constraint about uniform K partitions across components Rationale: Since multi-level nesting is deferred to future work, we don't need to support partitioning a RaggedIterDomain to create nested ragged structures. The partition operation now only takes regular IterDomain as input with 1D offsets, which is sufficient for the expert parallelism use case. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
…ents Key changes: - Change internal representation from offsets to extents in RaggedIterDomain - Store extents_ internally, compute offsets on-demand via cumulative sum - Keep user-facing API using offsets for compatibility with PyTorch conventions - Mark indexing and code generation as out of scope for initial implementation - Clarify implementation phases: Phase 1 and 3 are initial scope, Phase 2 and 4 are future work - Add rationale for extent-based internal representation This minimalistic version focuses on IR representation and basic infrastructure, with detailed indexing/lowering intentionally deferred to future work. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
…eduction Add comprehensive documentation for merge operations involving RaggedIterDomain: 1. Clarify three types of merge operations: - Case 1: Merge RaggedIterDomain with batch dimension → Regular IterDomain - Case 2a: Merge ragged with non-ragged (element-wise product) → RaggedIterDomain - Case 2b: Merge regular with ragged (reduction along dimension) → RaggedIterDomain 2. Document Case 2b in detail: - When regular IterDomain corresponds to outer dimension of extent tensor - Merge reduces by summing extents along that dimension - Example: 2D extent tensor [gpu=2, expert=4] → 1D extent tensor [expert=4] 3. Add important constraint: - Extent reduction must complete before kernel launch - Total extent unknown until reduction completes - Blocks memory allocation and other operations 4. Document handling strategies: - Fusion segmentation: separate kernel for extent reduction - Conservative allocation: pre-allocate with maximum extent (common in optimized MoEs) 5. Enhanced expert-parallel example with detailed explanation of: - How merged extents are computed from 2D to 1D extent tensor - Note about segmentation requirement for extent reduction 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
Add section 6.7 "Runtime Binding (Initial Version)" addressing how ExpressionEvaluator should bind ragged TensorViews at runtime. Initial Approach: - Input nested tensors: passed as two separate at::Tensor parameters (data tensor + offset/extent tensor) - Output nested tensors: returned as two separate tensors - Treats nested tensors as regular tensors plus separate metadata at API boundary - Internal Fusion IR uses RaggedIterDomain, but runtime binding uses plain tensors Limitations documented: - Manual management of data/offset tensor pairs required - No type safety for data and offset synchronization - No integration with PyTorch's NestedTensor API - More cumbersome API than unified nested tensor type Future Work: - Custom NestedTensor class for nvFuser runtime - PolymorphicValue integration for KernelArgumentHolder - Direct binding to/from unified nested tensor objects This addresses the runtime binding question raised in PR review comments about ExpressionEvaluator binding and PolymorphicValue integration. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Clarify that the extent reduction in Case 2b merge is implemented using
nvFuser's standard ReductionOp:
```cpp
new_extents = sum(ragged_dim->extents(), {0});
```
This creates a ReductionOp that:
- Operates on the extent tensor from ragged_dim->extents()
- Reduces along dimension 0 (the merged dimension)
- Produces a new extent tensor with reduced dimensionality
- Example: 2D tensor [gpu=2, expert=4] → 1D tensor [expert=4]
This makes the implementation approach explicit and ties it to nvFuser's
existing reduction infrastructure.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 3 comments
| // expert_offsets: 2D tensor [D=2, E+1=5] with per-GPU offsets for E=4 experts: | ||
| // expert_offsets[0] = [0, 30, 30, 70, 100] - GPU 0: [30, 0, 40, 30] tokens per expert | ||
| // expert_offsets[1] = [0, 25, 60, 85, 100] - GPU 1: [25, 35, 25, 15] tokens per expert | ||
| tokens->partition(/*dim=*/1, expert_offsets); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: expert_offsets not defined as a variable. While described in comments (lines 595-597), for a complete code example, add variable declaration like auto expert_offsets = ...; before usage.
| **Implementation Detail**: The reduction is defined using nvFuser's `ReductionOp`. Specifically, the new extent tensor is computed as: | ||
| ```cpp | ||
| // Reduce the extent tensor along dimension 0 (the merged dimension) | ||
| new_extents = sum(ragged_dim->extents(), {0}); | ||
| ``` | ||
| This creates a `ReductionOp` that sums the 2D extent tensor `[gpu=2, expert=4]` along dimension 0, producing a 1D extent tensor `[expert=4]`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The ReductionOp implementation detail is helpful, but clarify whether ragged_dim->extents() returns a TensorView*. The comment should explicitly show the type: TensorView* extent_tensor = ragged_dim->extents(); for clarity.
| auto tv_nested = asNested(tv_data, tv_offsets, /*ragged_dim=*/0); | ||
| auto tv_result = some_operation(tv_nested); | ||
|
|
||
| auto tv_result_offsets = /* extract/compute offset part of tv_result */; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Comment says "extract/compute offset part of tv_result" but doesn't show the actual extraction logic. Consider adding: auto tv_result_offsets = extractOffsets(tv_result); or similar to make the conceptual transformation clearer.
Corrected the partition operation semantics to state that the offset tensor can be 1D, 2D, or N-D, not just 1D as previously stated. Changes: - Updated description: "TensorView (1D, 2D, or N-D) defining partition boundaries" - Clarified that offset tensor shape determines nesting structure - Added reference to Example 2 in Section 6.4 for concrete 2D offset example This aligns with the expert parallelism use case where 2D offset tensors are used (e.g., [gpu=2, expert+1=5] shape for partitioning tokens by GPU and expert). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
wujingyue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better!
| **Initial Approach**: In this initial version, we take a simplified approach that avoids the complexity of custom tensor types: | ||
|
|
||
| 1. **Input Nested Tensors**: Must be passed as **two separate parameters**: | ||
| - A normal `at::Tensor` containing the flattened data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you write some examples? The word "flattened" is slightly overloaded. I don't think you actually mean a 1D tensor here for things like [e, d, j1, h].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, suppose we use a nested tensor of [i, j], where j is a ragged iter domain and i corresponds to the normal iter domain produced when generating j by the partition op. Specifically, there should be a normal iter domain of k, that would be the producer of both i and j such that i, j = partition(k, offsets). In this case, we would require the fusion to use a 1D tensor of [k] and another 1D tensor of offsets as its inputs. Inside the fusion, we would define the nested tensor using asNested([k], offsets, 0).
Similarly, for outputs, instead of returning a nested tensor, we would return a tensor without any ragged iter domain by using asFlattened.
Both limitations should be addressed in follow-up work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So by "flattened" you do mean 1D? That's fine for the first version, but I'm sure you understand that won't work with GroupedMmaOp or EP because tensors in MoE are usually 2D or 3D, or even 4D if batched.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. It just merges back the partitioned sub iterdomains back to a standard iter domain. TensorDomain can be multidimensional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I understand you now: in the first version, we never expose a ragged TensorView as a segment input/output, so ExpressionEvaluator doesn't need to bind ragged TensorViews at all. In other words, each segment involving ragged TensorViews will have asNesteds at the beginning and asFlattened at the end.
That sounds OK. I don't know how much that buys us though -- it may not be harder to just change ExpressionEvaluator to bind a RaggedTensorView to a plain at::Tensor that's one-less rank (due to the RaggedIterDomain in logical).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe trivial, then that'd be good too. I just wanted to focus things inside a kernel as a first step.
|
|
||
| 1. **Input Nested Tensors**: Must be passed as **two separate parameters**: | ||
| - A normal `at::Tensor` containing the flattened data | ||
| - Another `at::Tensor` parameter representing the offsets (or extents) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tokens_per_expert and tokens_per_expert_per_device are actual TensorViews in the graph so they'll be bound to extent tensors anyway. A nested TensorView (i.e. a TensorView whose logical domain contains a RaggedIterDomain) doesn't need to be bound to two tensors. I think this will simplify the implementation a lot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what two tensors are referred to. RaggedIterDomain would have offsets as a TensorView. That's the only TensorView that would be connected from the ragged iter domain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bind a ragged TensorView to actual tensor objects at runtime?
By this, don't you mean one TensorView is bound to two tensor objects? That's what I was referring to and disagree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indirectly, yes. The RaggedIterDomain will be linked to the tensor object for the offsets. I don't think if we could do without having a link to the offset tensor from the ragged iter domain. We would need to maintain some connection to find its offsets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indirectly, yes.
Agreed.
I asked this question before knowing your plan to first hide ragged TensorVies from ExpressionEvaluator. Since I understood that, eventually, we should still maintain one-to-one between TVs and tensors. For example,
TV: [e, j1, h] => tensor: [s, h]
TV: j1's extent => tensor: [e]
Correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. I'm still unclear as there's no aten representation that could be used as is, but that'll be a topic for follow-up work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no aten representation that could be used as is
Why not?
TV: [e, j1, h] => tensor: [s, h]
TV: j1's extent => tensor: [e]
the two tensors on the right are existing, plain, dense at::Tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I just meant there's no single Aten type that could be used to represent an nvFuser nested tensor. As we discussed before, PyTorch's nested tensor is not sufficient for us as it would mean offsets are always copied between host and device.
| // Static method in IterDomain class | ||
| static std::pair<IterDomain*, RaggedIterDomain*> partition( | ||
| IterDomain* in, // Input IterDomain to partition | ||
| TensorView* offsets // Offset tensor defining partition boundaries |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may be missing information here. Consider MoE takes a batch of sequences. Therefore,
dispatch input TV:
s/d
/ \
[e, b, j1, h, d]
tokens_per_expert_per_device TV:
[b, d, e]
How do we know which dimension of tokens_per_expert_per_device holds the component extents? How do we know that the two bs should be mapped and the two ds should be mapped so we find the correct extent vector to partition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question. What I'm currently thinking about is to infer that from the TensorView domain. When this input TV, say input_tv, is partitioned, its domain should look like [b, d, s/d] (not sure about the order). Thus, when partition is called on this tensor, the call would look like input_tv->partition(/*dimensionto_partition=*/2, /*offsets=*/[b, d, e]). From this call, we would see that the b of the offset tensor would correspond to the b dimension of input_tv, etc. I think this would work without ambiguity, but I may be wrong, and if so, I'd make this call interface more explicit (e.g., specifying which dimension of the offset tensor correspond to which dimension of the tensor to partition).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would work without ambiguity, but I may be wrong,
I believe it's ambiguous -- b and d may have the same size or even symbolic sizes. PairwiseLogicalDomainMap would have no way to tell how they are mapped.
One potential way to fix this is to enforce a certain logical order, similar to what torch.matmul does. For example, input_tv's logical can be [b, d, e, j1] and tokens_per_expert_per_device can be [b, d, e].
However, another problem is that given [b, d, e, j1] how do we know which dimension represents the number of components. We can't rely on the partition/split IterDomain because it won't appear in all TensorViews. This is related to my other comment: https://github.com/NVIDIA/Fuser/pull/5550/files#r2591385758
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One potential way to fix this is to enforce a certain logical order, similar to what torch.matmul does. For example, input_tv's logical can be [b, d, e, j1] and tokens_per_expert_per_device can be [b, d, e].
That's what I'm planning to do. Positional matching should give us enough information.
However, another problem is that given [b, d, e, j1] how do we know which dimension represents the number of components. We can't rely on the partition/split IterDomain because it won't appear in all TensorViews
The exact graph can be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Positional matching should give us enough information.
Sounds good!
The exact graph can be used.
We could -- it's kind of expensive because it has to analyze across TensorViews. Can we simply have RaggedIterDomain point to the IterDomain (in the same TensorView) that represents the number of components?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the cost is concerning, but technically, for example, it should be possible to reshape the component dimension and split it to two dimensions, and merge them with the ragged dimension later, so simple tracking may not be enough. Probably not a concern for initial simple use cases, so I may proceed with simpler methods like what you suggested.
| class TensorView : public Val { | ||
| // Partition dimension 'dim' using the provided offsets | ||
| // Returns new TensorView with partitioned dimension replaced by (batch_id, ragged_id) | ||
| TensorView* partition(int dim, TensorView* offsets); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is partition a member method of TensorView instead of TensorView* partition(TensorView* in, int dim, TensorView* offsets) as other TensorView operations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I understood why -- you consider partition to be an IterDomain op.
Since a RaggedIterDomain's extent is a TensorView, we can instead simply use a regular split as the IterDomain operation, e.g., (IterDomain*, RaggedIterDomain*) = split(IterDomain*). This is more convenient when we have a series of pointwise operations applied on a nested TensorView -- we don't need each TensorView to have a partition from root to logical. This way, PairwiseLogicalDomainMap can keep mapping in's logical and out's root.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's meant to be an iter domain op, much like split and merge. To me, it's different enough to have a separate name than split.
This is more convenient when we have a series of pointwise operations applied on a nested TensorView -- we don't need each TensorView to have a partition from root to logical.
I think that'd be the case no matter if it's partition or split. What am I missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to say partition or split, the IterDomain op, should not appear in every nested TenesorView.
s
/ \ split
in: [e, j1]
|
| cos
v
out: [e, j1] # no `s` in the root domain
This way, we can do the following merging reshape. Note j2's definition is a merge not a partition/split, despite that j2's extent is an extent TensorView.
s
/ \ split
[e, j1]
|
| cos
v
[e, j1]
|
| reshape
v
e, j1
\ /
[j2]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was probably unclear:
we don't need each TensorView to have a partition from root to logical
This is true. It's an iter domain op, so just having it in one TensorView should be sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great -- I think we are on the same page.
IIUC, Partition, an IterDomain op, takes one IterDomain and produces two IterDomains. It does not take a TensorView -- that's embedded (as extent) in the second output which is a RaggedIterDomain. This is why we can reuse Split but I'm fine with a different op. Btw, do you need a different op for merging a regular IterDomain and a RaggedIterDomain and call it maybe combine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that feels a little inconsistent -- a new name for production but overloading the existing name for the reverse operation. My feeling is that split and partition have different input parameters: a scalar factor for split and an offset tensor for partition, but merge and combine would have the same set of inputs, i.e., two IterDomain pointers.
For the sake of consistency, I may introduce combine instead of extending merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a scalar factor for split and an offset tensor for partition
There's the API and there's the op itself. IIUC, that offset_tensor parameter is for API only -- it'll be embedded as the extent of a RaggedIterDomain so the Partition op doesn't keep the offset tensor in Expr::inputs(). Therefore, TensorView::partition, as the API, makes a lot of sense to me. But Partition as an IterDomain op is a bit wasteful -- it could just be a Split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be, I may change my plan. That said, I'd start with its own class since it would work more naturally with OptOutDispatcher and I'd expect we would usually do something different for partitioning than splitting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 3 comments
| Val* extent() const override; | ||
| // This overrides IterDomain::parallelize and calls nested_domains[i]->parallelize(pt) for all nested domains | ||
| void parallelize(ParallelType pt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Missing override keyword in method signature
| void parallelize(ParallelType pt); | |
| void parallelize(ParallelType pt) override; |
|
|
||
| **Semantics:** | ||
| - Input: Regular IterDomain with total extent N | ||
| - Offsets: TensorView (can be 1D, 2D, or N-D) defining partition boundaries. For the simple 1D case with K partitions: `[0, offset_1, ..., offset_K=N]`. For multi-dimensional cases, the offset tensor shape determines the structure of the extent tensor in the resulting RaggedIterDomain (see Example 2 in Section 6.4 for a concrete expert parallelism example with 2D offsets). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Clarify offset tensor shape for multi-dimensional cases: The text mentions "offset tensor shape determines the structure of the extent tensor" but doesn't explain the precise relationship. For a 2D offset tensor [D, E+1], does it produce a 2D extent tensor [D, E], or is the extent structure different?
| struct NestedTensor { | ||
| DT* ptr; | ||
| int64_t extents[rank]; | ||
| int64_t nested_domain_extents[ragged_dimension_rank]; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Inconsistent indentation - mix of tabs and spaces in the struct definition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
| ```cpp | ||
| // Batch of sequences with variable lengths | ||
| auto batch = IrBuilder::create<IterDomain>(0, 3); | ||
| auto ragged_seq = IrBuilder::create<RaggedIterDomain>({seq0, seq1, seq2}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To save your work, I think for GroupedMmaOp and EP all we need is to asNested to "split" a dense tensor into a nested tensor. This API to create RaggedIterDomain from sequences is likely unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated: ce7886f
|
|
||
| #### Select Operation | ||
|
|
||
| The `select` operation extracts a specific component from a ragged dimension, converting it to a regular IterDomain. This requires two steps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto. I don't need this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated: ce7886f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 5 comments
| auto split_result = IterDomain::split(ragged, 2); | ||
| auto outer = split_result.first; // extents = [2, 3, 1], ragged dimension | ||
| auto inner = split_result.second; // extent = 2, regular dimension | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The split example shows splitting a RaggedIterDomain named ragged with factor 2, but doesn't show what the input extents are. For clarity, specify the input: e.g., "Given ragged with extents [3, 5, 2], splitting by factor 2 produces..."
| ```cpp | ||
| template <typename DT, int rank> | ||
| struct NestedTensor { | ||
| DT* ptr; | ||
| int64_t extents[rank]; | ||
| int64_t nested_domain_extents[ragged_dimension_rank]; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The NestedTensor struct definition uses inconsistent indentation (tabs in some lines). Consider standardizing to spaces for better readability.
| // This overrides IterDomain::parallelize and calls nested_domains[i]->parallelize(pt) for all nested domains | ||
| void parallelize(ParallelType pt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Missing override keyword in the method signature. Should be void parallelize(ParallelType pt) override; to properly override the base class method.
|
|
||
| **Semantics:** | ||
| - Input: Regular IterDomain with total extent N | ||
| - Offsets: TensorView (can be 1D, 2D, or N-D) defining partition boundaries. For the simple 1D case with K partitions: `[0, offset_1, ..., offset_K=N]`. For multi-dimensional cases, the offset tensor shape determines the structure of the extent tensor in the resulting RaggedIterDomain (see Example 2 in Section 6.4 for a concrete expert parallelism example with 2D offsets). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Clarify the relationship between offset tensor shape and extent tensor structure. The text mentions "the offset tensor shape determines the structure of the extent tensor" but the precise mapping isn't clear. For a 2D offset tensor [D, E+1], does it produce a 2D extent tensor [D, E] (by taking differences along the last dimension)?
| auto tv_nested = asNested(tv_data, tv_offsets, /*ragged_dim=*/0); | ||
| auto tv_result = some_operation(tv_nested); | ||
|
|
||
| auto tv_result_offsets = /* extract/compute offset part of tv_result */; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The comment "extract/compute offset part of tv_result" is vague. Add a concrete example of how offsets are extracted or computed, e.g., auto tv_result_offsets = extractOffsets(tv_result); or describe the mechanism more explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
This is a design document for a new type of iteration domain called
RaggedIterDomain. The motivation is to support nested tensors for expert parallelism.