Skip to content

Conversation

@gramalingam
Copy link
Collaborator

Update aten_index_put implementation in torchlib.

Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
@codecov
Copy link

codecov bot commented Nov 18, 2025

Codecov Report

❌ Patch coverage is 69.11765% with 21 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.06%. Comparing base (5583f96) to head (3641a1e).
⚠️ Report is 1 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 69.11% 16 Missing and 5 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2712      +/-   ##
==========================================
- Coverage   70.11%   70.06%   -0.05%     
==========================================
  Files         226      226              
  Lines       27230    27266      +36     
  Branches     2748     2754       +6     
==========================================
+ Hits        19092    19105      +13     
- Misses       7193     7210      +17     
- Partials      945      951       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
@titaiwangms
Copy link
Contributor

titaiwangms commented Dec 1, 2025

@xadupre It looks like not merged yet?

@justinchuby justinchuby enabled auto-merge (squash) December 9, 2025 04:19
Signed-off-by: Justin Chu <[email protected]>
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Dec 11, 2025
@justinchuby
Copy link
Collaborator

@titaiwangms @xadupre PTAL

@justinchuby justinchuby merged commit a3883a6 into main Dec 11, 2025
32 of 33 checks passed
@justinchuby justinchuby deleted the rama/index-put branch December 11, 2025 20:54
@justinchuby
Copy link
Collaborator

User: Explain this implementation in detail, with examples

GitHub Copilot: I'll explain the aten_index_put implementation in detail, which handles PyTorch's advanced indexing assignment operation.

Overview

aten_index_put implements the operation tensor[indices] = values in PyTorch. It updates a tensor at specified indices with new values, optionally accumulating (adding) rather than replacing.

Key Concepts

Advanced Indexing: When indices are tensors (not scalars), PyTorch uses "advanced indexing" which has special broadcasting and dimension placement rules.

Implementation Breakdown

1. Index Normalization (Lines 4541-4543)

self_rank = len(self.shape)
if len(indices) < self_rank:
    indices = list(indices) + [None] * (self_rank - len(indices))

Example: If self has shape [7, 3, 4, 5] and indices = [idx1, idx2], it becomes [idx1, idx2, None, None].

2. Identifying Index Types (Lines 4548-4565)

def is_advanced_index(index):
    return index is not None

advanced_indices: list[int] = []  # Positions with tensor indices
none_indices: list[int] = []      # Positions with None (slices)

Example: For tensor[idx1, :, idx2, :]indices = [idx1, None, idx2, None]:

  • advanced_indices = [0, 2]
  • none_indices = [1, 3]

3. Broadcasting Advanced Indices (Lines 4570-4600)

When multiple advanced indices exist, they must be broadcast to a common shape:

# Broadcast advanced indices to a common shape
advanced_index_rank = max(len(indices[i].shape) for i in advanced_indices)
# ... padding and Max operation to get common shape

Example:

# tensor shape: [7, 3, 4, 5]
# indices = [idx1, None, idx2, None]
# idx1.shape = [2, 3], idx2.shape = [2, 1]
# After broadcasting: both become [2, 3]

4. Creating ONNX ScatterND Index (Lines 4602-4618)

ONNX's ScatterND requires indices in a specific format - all advanced indices concatenated along a new last dimension:

advanced_index_values = [op.Unsqueeze(indices[i], minus_one) for i in advanced_indices]
onnx_index = op.Concat(*advanced_index_values, axis=-1)

Example:

# idx1.shape = [2, 3], idx2.shape = [2, 3]
# After unsqueeze: [2, 3, 1] and [2, 3, 1]
# After concat: [2, 3, 2]  (last dim = number of advanced indices)

5. Checking Contiguity (Lines 4620-4623)

Determines if advanced indices appear consecutively:

contiguous = True
if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices):
    contiguous = False

Example:

  • [0, 1, 2] → contiguous ✓
  • [0, 2, 4] → not contiguous ✗

6. Dimension Permutation (Lines 4625-4627)

ONNX ScatterND requires advanced indices to be first:

perm = advanced_indices + none_indices  # Bring advanced indices to front
transposed = op.Transpose(self, perm=perm)

Example:

# Original: self[dim0, dim1, dim2, dim3]
# indices pattern: [idx, None, idx, None]
# perm = [0, 2, 1, 3]  (advanced first: 0,2 then none: 1,3)
# Result: transposed[dim0, dim2, dim1, dim3]

7. Values Transposition (Lines 4629-4648)

If indices are contiguous, values may need reordering to match PyTorch's behavior:

if contiguous:
    # values must match where PyTorch places the advanced indexing dimensions
    values_perm = (
        advanced_index_replacement_positions  # Where advanced dims go
        + initial_none_index_positions        # Leading None dims
        + final_none_index_positions          # Trailing None dims
    )
    values = op.Transpose(values, perm=values_perm)

Example:

# For tensor[:, idx1, idx2, :] where idx1, idx2 broadcast to [2, 3]
# target_shape = [2, 3, orig_dim0, orig_dim3]
# values needs to be [orig_dim0, 2, 3, orig_dim3] initially
# Transpose to [2, 3, orig_dim0, orig_dim3]

8. ScatterND Operation (Lines 4652-4654)

updated = op.ScatterND(
    transposed, onnx_index, expanded_values, 
    reduction="add" if accumulate else None
)
  • accumulate=True → adds values to existing ones
  • accumulate=False → replaces existing values

9. Inverse Transpose (Lines 4658-4661)

Restore original dimension order:

inverse_perm = [0] * self_rank
for i, p in enumerate(perm):
    inverse_perm[p] = i
result = op.Transpose(updated, perm=inverse_perm)

Complete Example

# PyTorch operation:
# tensor[idx1, :, idx2] = values

# Given:
tensor.shape = [7, 4, 5, 6]
idx1 = torch.tensor([[0, 1], [2, 3]])  # shape [2, 2]
idx2 = torch.tensor([[1], [2]])         # shape [2, 1]
values.shape = [2, 2, 4, 6]             # matches result shape

# Processing:
# 1. indices = [idx1, None, idx2, None]
# 2. advanced_indices = [0, 2], none_indices = [1, 3]
# 3. Broadcast idx1, idx2 to [2, 2]
# 4. onnx_index.shape = [2, 2, 2]  (positions for dims 0 and 2)
# 5. perm = [0, 2, 1, 3] (advanced first)
# 6. transposed.shape = [7, 5, 4, 6]
# 7. target_shape = [2, 2, 4, 6]
# 8. ScatterND updates transposed
# 9. Inverse transpose back to [7, 4, 5, 6]

Why This Complexity?

PyTorch's advanced indexing has nuanced rules about dimension ordering that don't map directly to ONNX ScatterND. The implementation handles:

  • Broadcasting multiple index tensors
  • Correct dimension placement (contiguous vs non-contiguous indices)
  • Converting PyTorch's indexing semantics to ONNX's ScatterND format

User: Contrast and compare the implementation with _aten_index_onnx

GitHub Copilot: I'll contrast _aten_index_onnx (reading/indexing) with aten_index_put (writing/assignment).

Read , lines 4339 to 4430

Read , lines 4430 to 4490

Comparison: _aten_index_onnx (Read) vs aten_index_put (Write)

Both implement PyTorch's advanced indexing but for opposite operations. Here's a detailed contrast:

Purpose

Aspect _aten_index_onnx aten_index_put
Operation result = tensor[indices] (read/gather) tensor[indices] = values (write/scatter)
ONNX Op GatherND ScatterND
Returns New tensor with selected values Modified tensor

Key Similarities

1. Index Normalization

Both pad indices to match tensor rank:

# _aten_index_onnx: implicit (handled by caller)
# aten_index_put: explicit
if len(indices) < self_rank:
    indices = list(indices) + [None] * (self_rank - len(indices))

2. Index Broadcasting

Both broadcast advanced indices to common shape:

# Both do:
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)
# Then expand/broadcast each index to this shape

3. Dimension Reordering

Both move advanced indices to front for ONNX compatibility:

# _aten_index_onnx:
reordered_positions = sorted(range(len(indices)), 
                             key=lambda i: (indices[i] is None, i))

# aten_index_put:
perm = advanced_indices + none_indices  # Simpler, same effect

4. Contiguity Check

Both check if advanced indices are consecutive to determine dimension placement:

# Both use: _has_none_in_middle() or similar logic
# This affects final transpose operation

Key Differences

1. Operation Direction

_aten_index_onnx - Extracts values:

self = op.GatherND(self, final_index, batch_dims=0)
# Returns subset of original tensor

aten_index_put - Inserts values:

updated = op.ScatterND(transposed, onnx_index, expanded_values,
                       reduction="add" if accumulate else None)
# Returns full tensor with updates

2. Values Handling

_aten_index_onnx: No values input

  • Only needs to know where to read from

aten_index_put: Must handle values tensor

# Lines 4629-4648: Complex values transposition logic
if contiguous:
    values_perm = (advanced_index_replacement_positions 
                   + initial_none_index_positions 
                   + final_none_index_positions)
    values = op.Transpose(values, perm=values_perm)
expanded_values = op.Expand(values, target_shape)

3. Shape Computation

_aten_index_onnx:

  • Output shape = broadcasted index shape + remaining None dimensions
  • Simpler calculation

aten_index_put:

# Lines 4603-4607: Must compute target shape for values expansion
none_indices_constant = op.Constant(value_ints=none_indices)
none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0)
target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0)

4. Inverse Transpose

_aten_index_onnx: One-way transpose

# Transpose input, gather, potentially transpose result
return op.Transpose(self, perm=perm)

aten_index_put: Round-trip transpose

# Must transpose back to original dimension order
inverse_perm = [0] * self_rank
for i, p in enumerate(perm):
    inverse_perm[p] = i
result = op.Transpose(updated, perm=inverse_perm)

5. Accumulation Option

_aten_index_onnx: N/A

aten_index_put: Supports accumulation

updated = op.ScatterND(..., reduction="add" if accumulate else None)
# Can add to existing values instead of replacing

Example Comparison

# Setup
tensor = torch.randn(7, 4, 5, 6)
idx1 = torch.tensor([[0, 1], [2, 3]])  # [2, 2]
idx2 = torch.tensor([[1], [2]])         # [2, 1]

# READ (_aten_index_onnx):
result = tensor[idx1, :, idx2]
# Steps:
# 1. Transpose: [7, 5, 4, 6] (advanced indices first: dims 0,2)
# 2. Broadcast indices to [2, 2]
# 3. GatherND → [2, 2, 4, 6] (reads values at those positions)
# 4. Transpose to [4, 2, 2, 6] (if contiguous, move dims to middle)

# WRITE (aten_index_put):
tensor[idx1, :, idx2] = values  # values.shape = [4, 2, 2, 6]
# Steps:
# 1. Transpose tensor: [7, 5, 4, 6]
# 2. Broadcast indices to [2, 2]
# 3. Transpose values: [2, 2, 4, 6] (match scatter format)
# 4. ScatterND → writes values at those positions
# 5. Inverse transpose: [7, 4, 5, 6] (restore original order)

Complexity Comparison

Aspect _aten_index_onnx aten_index_put
Lines of code ~100 lines ~160 lines
Transposes 1-2 2-3 (includes inverse)
Shape calculations Simpler More complex (target_shape)
Values handling None Extensive (transpose + expand)

Why More Complexity in aten_index_put?

  1. Bidirectional operation: Must preserve original tensor structure
  2. Values tensor: Additional input requiring careful alignment
  3. Shape preservation: Output shape = input shape (not reduced)
  4. Accumulation: Optional addition instead of replacement

The read operation (_aten_index_onnx) is fundamentally simpler because it only extracts data, while write operation (aten_index_put) must carefully place values back into the correct positions while preserving tensor structure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

5 participants