Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Op (Slice - complex) | feat torchlib #2089

Merged

Conversation

titaiwangms
Copy link
Contributor

@titaiwangms titaiwangms requested a review from justinchuby March 5, 2025 23:12
Copy link

codecov bot commented Mar 5, 2025

Codecov Report

Attention: Patch coverage is 60.00000% with 2 lines in your changes missing coverage. Please review.

Project coverage is 72.25%. Comparing base (4c1cda2) to head (d8b42bc).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 60.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2089   +/-   ##
=======================================
  Coverage   72.25%   72.25%           
=======================================
  Files         217      217           
  Lines       29138    29143    +5     
  Branches     3462     3463    +1     
=======================================
+ Hits        21053    21058    +5     
  Misses       6954     6954           
  Partials     1131     1131           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think negative axes need to be handled like Xavier suggested? Tests didn’t error which was a little unexpected

@titaiwangms
Copy link
Contributor Author

Do you think negative axes need to be handled like Xavier suggested? Tests didn’t error which was a little unexpected

In this example,

import torch

class ComplexSliceModel(torch.nn.Module):
    def forward(self, x):
        # Convert input to a complex tensor
        x_complex = x.to(torch.complex64)
        # Apply a slice operation on the complex tensor
        return x_complex[:, :2]

model = ComplexSliceModel()
dummy_input = torch.randn(3, 4)

# Verify the model works as expected
print("Model output:", model(dummy_input))

# This call fails due to the slice op on a complex tensor.
torch.onnx.export(model, dummy_input, "complex_slice.onnx", dynamo=True)
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 4]"):
             # File: /home/titaiwang/pytorch/test_slice.py:6 in forward, code: x_complex = x.to(torch.complex64)
            _to_copy: "c64[3, 4]" = torch.ops.aten._to_copy.default(x, dtype = torch.complex64)
            _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float32);  x = _assert_tensor_metadata = None
            
             # File: /home/titaiwang/pytorch/test_slice.py:8 in forward, code: return x_complex[:, :2]
            slice_1: "c64[3, 4]" = torch.ops.aten.slice.Tensor(_to_copy, 0, 0, 9223372036854775807);  _to_copy = None
            slice_2: "c64[3, 2]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 2);  slice_1 = None
            return (slice_2,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)])
Range constraints: {}

It works because ExportedProgram use two slice ops to slice real and image respectively. I think it means the dim is respected?

@justinchuby
Copy link
Collaborator

What happens if you did

torch.ops.aten.slice.Tensor(slice_1, -1, 0, 2) ?

Granted this may be unlikely.

@titaiwangms
Copy link
Contributor Author

Added the dim adjustment.

@titaiwangms titaiwangms enabled auto-merge (squash) March 7, 2025 21:32
@titaiwangms titaiwangms merged commit 68d4b9f into microsoft:main Mar 7, 2025
22 of 29 checks passed
Comment on lines +7710 to +7712
if dim < 0:
# Account for the complex dimension in ONNX
dim = dim - 1
Copy link
Collaborator

@justinchuby justinchuby Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be a little confusing for users when they examine the graph. What do you think if we did dim += len(self.shape) - 1?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

[ONNX] slice complex tensor needs implementation
3 participants