Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] torch.where(x) - default overload - is actually not implemented #1971

Merged
merged 2 commits into from
Dec 17, 2024

Conversation

Bludator
Copy link
Contributor

@Bludator Bludator commented Dec 1, 2024

As the title said the overload is not implemented in aten_where. It should be decomposed into nonzero function by pytorch.

Now it throws error as there is not enough parameters. Minimal reproducible example:

import torch

class Model(torch.nn.Module):
	def forward(self, x):
		return torch.where(x)

torch.onnx.export(Model(), (torch.tensor([0, 1, 2, 0, 3]),), dynamo=True)
<class 'ValueError'>: Required parameter 'self' is not provided. Signature: pkg.onnxscript.torch_lib::aten_where(condition: T_condition, self: TTensor, other: TTensor) -> (TTensor) where T_condition=BOOL, TTensor=INT8 | FLOAT16 | INT16 | INT32 | UINT8 | FLOAT | BOOL | COMPLEX128 | BFLOAT16 | COMPLEX64 | DOUBLE | INT64. Args: (SymbolicTensor('x', type=Tensor(INT64), shape=[5], producer=None, index=None),). Kwargs: {}.

As for the tests I would have thought it is handled by the ops_test.py but apparently it is not.


As a side note, the pylint is somehow broken for this file (at least).

@titaiwangms titaiwangms self-requested a review December 3, 2024 17:07
@titaiwangms
Copy link
Contributor

I hit NotImplementedError though. It's not decomposed.

@titaiwangms
Copy link
Contributor

titaiwangms commented Dec 3, 2024

Based on this PR: pytorch/pytorch#21798
The real fix is to implement:

def aten_nonzero_numpy(self: TensorType) -> TensorType:
and register aten::where under it.

Something like:

# aten::where.default is the same as aten::nonzero_numpy
# https://github.com/pytorch/pytorch/issues/21798
@torch_op("aten::where", trace_only=True)
def aten_nonzero_numpy(self: TensorType) -> TensorType:
    """nonzero_numpy(Tensor self) -> Tensor[]"""
    nonzero = aten_nonzero(self)
    return aten_unbind(nonzero, dim=1)

And enable its test in ops_test.py.

referenced https://github.com/pytorch/pytorch/blob/9125e9119cee131d7f839c105d79e8036b44e92b/torch/onnx/symbolic_opset13.py#L297

@titaiwangms titaiwangms added the topic: torch_lib Related to the torch/aten function lib in development label Dec 3, 2024
@Bludator
Copy link
Contributor Author

Bludator commented Dec 3, 2024

The odd thing is it works for me by just removing it. Let me just install fresh venv and try again

Copy link

codecov bot commented Dec 3, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 74.72%. Comparing base (0aed232) to head (815c418).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1971      +/-   ##
==========================================
- Coverage   74.73%   74.72%   -0.02%     
==========================================
  Files         273      273              
  Lines       29329    29329              
  Branches     3367     3367              
==========================================
- Hits        21920    21916       -4     
- Misses       6368     6371       +3     
- Partials     1041     1042       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Bludator
Copy link
Contributor Author

Bludator commented Dec 4, 2024

You are right, on nightly build of torch the decompositions does not work but on the stable version it do... 😮‍💨

I think that it is better if it is decomposed by the pytorch, as there should be decomposition of where(x) to the nonzero in any way.
Something like:

if a is None and b is None:
    return torch.nonzero(pred, as_tuple=True)

here, which is actually where the NotImplementedError originates

@Bludator
Copy link
Contributor Author

Bludator commented Dec 16, 2024

@titaiwangms with this PR it works without any other change. I have no idea why as it seems to me there is no decomposition that would handle it but it works.

But I would probably make PR from my previous comment in any way.

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just looking at this again. Nice!

@titaiwangms
Copy link
Contributor

@titaiwangms with this PR it works without any other change. I have no idea why as it seems to me there is no decomposition that would handle it but it works.

But I would probably make PR from my previous comment in any way.

aten.where.default was overwritten by the second decomp table leading to a different result.

@titaiwangms titaiwangms enabled auto-merge (squash) December 17, 2024 00:19
@titaiwangms titaiwangms merged commit f0769c3 into microsoft:main Dec 17, 2024
21 of 41 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

2 participants