-
Notifications
You must be signed in to change notification settings - Fork 95
Update aten_index_put implementation #2712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2712 +/- ##
==========================================
- Coverage 70.11% 70.06% -0.05%
==========================================
Files 226 226
Lines 27230 27266 +36
Branches 2748 2754 +6
==========================================
+ Hits 19092 19105 +13
- Misses 7193 7210 +17
- Partials 945 951 +6 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
|
@xadupre It looks like not merged yet? |
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
|
@titaiwangms @xadupre PTAL |
|
User: Explain this implementation in detail, with examples GitHub Copilot: I'll explain the Overview
Key ConceptsAdvanced Indexing: When indices are tensors (not scalars), PyTorch uses "advanced indexing" which has special broadcasting and dimension placement rules. Implementation Breakdown1. Index Normalization (Lines 4541-4543)self_rank = len(self.shape)
if len(indices) < self_rank:
indices = list(indices) + [None] * (self_rank - len(indices))Example: If 2. Identifying Index Types (Lines 4548-4565)def is_advanced_index(index):
return index is not None
advanced_indices: list[int] = [] # Positions with tensor indices
none_indices: list[int] = [] # Positions with None (slices)Example: For
3. Broadcasting Advanced Indices (Lines 4570-4600)When multiple advanced indices exist, they must be broadcast to a common shape: # Broadcast advanced indices to a common shape
advanced_index_rank = max(len(indices[i].shape) for i in advanced_indices)
# ... padding and Max operation to get common shapeExample: # tensor shape: [7, 3, 4, 5]
# indices = [idx1, None, idx2, None]
# idx1.shape = [2, 3], idx2.shape = [2, 1]
# After broadcasting: both become [2, 3]4. Creating ONNX ScatterND Index (Lines 4602-4618)ONNX's advanced_index_values = [op.Unsqueeze(indices[i], minus_one) for i in advanced_indices]
onnx_index = op.Concat(*advanced_index_values, axis=-1)Example: # idx1.shape = [2, 3], idx2.shape = [2, 3]
# After unsqueeze: [2, 3, 1] and [2, 3, 1]
# After concat: [2, 3, 2] (last dim = number of advanced indices)5. Checking Contiguity (Lines 4620-4623)Determines if advanced indices appear consecutively: contiguous = True
if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices):
contiguous = FalseExample:
6. Dimension Permutation (Lines 4625-4627)ONNX perm = advanced_indices + none_indices # Bring advanced indices to front
transposed = op.Transpose(self, perm=perm)Example: # Original: self[dim0, dim1, dim2, dim3]
# indices pattern: [idx, None, idx, None]
# perm = [0, 2, 1, 3] (advanced first: 0,2 then none: 1,3)
# Result: transposed[dim0, dim2, dim1, dim3]7. Values Transposition (Lines 4629-4648)If indices are contiguous, if contiguous:
# values must match where PyTorch places the advanced indexing dimensions
values_perm = (
advanced_index_replacement_positions # Where advanced dims go
+ initial_none_index_positions # Leading None dims
+ final_none_index_positions # Trailing None dims
)
values = op.Transpose(values, perm=values_perm)Example: # For tensor[:, idx1, idx2, :] where idx1, idx2 broadcast to [2, 3]
# target_shape = [2, 3, orig_dim0, orig_dim3]
# values needs to be [orig_dim0, 2, 3, orig_dim3] initially
# Transpose to [2, 3, orig_dim0, orig_dim3]8. ScatterND Operation (Lines 4652-4654)updated = op.ScatterND(
transposed, onnx_index, expanded_values,
reduction="add" if accumulate else None
)
9. Inverse Transpose (Lines 4658-4661)Restore original dimension order: inverse_perm = [0] * self_rank
for i, p in enumerate(perm):
inverse_perm[p] = i
result = op.Transpose(updated, perm=inverse_perm)Complete Example# PyTorch operation:
# tensor[idx1, :, idx2] = values
# Given:
tensor.shape = [7, 4, 5, 6]
idx1 = torch.tensor([[0, 1], [2, 3]]) # shape [2, 2]
idx2 = torch.tensor([[1], [2]]) # shape [2, 1]
values.shape = [2, 2, 4, 6] # matches result shape
# Processing:
# 1. indices = [idx1, None, idx2, None]
# 2. advanced_indices = [0, 2], none_indices = [1, 3]
# 3. Broadcast idx1, idx2 to [2, 2]
# 4. onnx_index.shape = [2, 2, 2] (positions for dims 0 and 2)
# 5. perm = [0, 2, 1, 3] (advanced first)
# 6. transposed.shape = [7, 5, 4, 6]
# 7. target_shape = [2, 2, 4, 6]
# 8. ScatterND updates transposed
# 9. Inverse transpose back to [7, 4, 5, 6]Why This Complexity?PyTorch's advanced indexing has nuanced rules about dimension ordering that don't map directly to ONNX
User: Contrast and compare the implementation with _aten_index_onnx GitHub Copilot: I'll contrast Read , lines 4339 to 4430 Read , lines 4430 to 4490 Comparison:
|
| Aspect | _aten_index_onnx |
aten_index_put |
|---|---|---|
| Operation | result = tensor[indices] (read/gather) |
tensor[indices] = values (write/scatter) |
| ONNX Op | GatherND |
ScatterND |
| Returns | New tensor with selected values | Modified tensor |
Key Similarities
1. Index Normalization
Both pad indices to match tensor rank:
# _aten_index_onnx: implicit (handled by caller)
# aten_index_put: explicit
if len(indices) < self_rank:
indices = list(indices) + [None] * (self_rank - len(indices))2. Index Broadcasting
Both broadcast advanced indices to common shape:
# Both do:
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)
# Then expand/broadcast each index to this shape3. Dimension Reordering
Both move advanced indices to front for ONNX compatibility:
# _aten_index_onnx:
reordered_positions = sorted(range(len(indices)),
key=lambda i: (indices[i] is None, i))
# aten_index_put:
perm = advanced_indices + none_indices # Simpler, same effect4. Contiguity Check
Both check if advanced indices are consecutive to determine dimension placement:
# Both use: _has_none_in_middle() or similar logic
# This affects final transpose operationKey Differences
1. Operation Direction
_aten_index_onnx - Extracts values:
self = op.GatherND(self, final_index, batch_dims=0)
# Returns subset of original tensoraten_index_put - Inserts values:
updated = op.ScatterND(transposed, onnx_index, expanded_values,
reduction="add" if accumulate else None)
# Returns full tensor with updates2. Values Handling
_aten_index_onnx: No values input
- Only needs to know where to read from
aten_index_put: Must handle values tensor
# Lines 4629-4648: Complex values transposition logic
if contiguous:
values_perm = (advanced_index_replacement_positions
+ initial_none_index_positions
+ final_none_index_positions)
values = op.Transpose(values, perm=values_perm)
expanded_values = op.Expand(values, target_shape)3. Shape Computation
_aten_index_onnx:
- Output shape = broadcasted index shape + remaining None dimensions
- Simpler calculation
aten_index_put:
# Lines 4603-4607: Must compute target shape for values expansion
none_indices_constant = op.Constant(value_ints=none_indices)
none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0)
target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0)4. Inverse Transpose
_aten_index_onnx: One-way transpose
# Transpose input, gather, potentially transpose result
return op.Transpose(self, perm=perm)aten_index_put: Round-trip transpose
# Must transpose back to original dimension order
inverse_perm = [0] * self_rank
for i, p in enumerate(perm):
inverse_perm[p] = i
result = op.Transpose(updated, perm=inverse_perm)5. Accumulation Option
_aten_index_onnx: N/A
aten_index_put: Supports accumulation
updated = op.ScatterND(..., reduction="add" if accumulate else None)
# Can add to existing values instead of replacingExample Comparison
# Setup
tensor = torch.randn(7, 4, 5, 6)
idx1 = torch.tensor([[0, 1], [2, 3]]) # [2, 2]
idx2 = torch.tensor([[1], [2]]) # [2, 1]
# READ (_aten_index_onnx):
result = tensor[idx1, :, idx2]
# Steps:
# 1. Transpose: [7, 5, 4, 6] (advanced indices first: dims 0,2)
# 2. Broadcast indices to [2, 2]
# 3. GatherND → [2, 2, 4, 6] (reads values at those positions)
# 4. Transpose to [4, 2, 2, 6] (if contiguous, move dims to middle)
# WRITE (aten_index_put):
tensor[idx1, :, idx2] = values # values.shape = [4, 2, 2, 6]
# Steps:
# 1. Transpose tensor: [7, 5, 4, 6]
# 2. Broadcast indices to [2, 2]
# 3. Transpose values: [2, 2, 4, 6] (match scatter format)
# 4. ScatterND → writes values at those positions
# 5. Inverse transpose: [7, 4, 5, 6] (restore original order)Complexity Comparison
| Aspect | _aten_index_onnx |
aten_index_put |
|---|---|---|
| Lines of code | ~100 lines | ~160 lines |
| Transposes | 1-2 | 2-3 (includes inverse) |
| Shape calculations | Simpler | More complex (target_shape) |
| Values handling | None | Extensive (transpose + expand) |
Why More Complexity in aten_index_put?
- Bidirectional operation: Must preserve original tensor structure
- Values tensor: Additional input requiring careful alignment
- Shape preservation: Output shape = input shape (not reduced)
- Accumulation: Optional addition instead of replacement
The read operation (_aten_index_onnx) is fundamentally simpler because it only extracts data, while write operation (aten_index_put) must carefully place values back into the correct positions while preserving tensor structure.
Update aten_index_put implementation in torchlib.