-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torchlib] torch.where(x) - default overload - is actually not implemented #1971
Conversation
I hit NotImplementedError though. It's not decomposed. |
Based on this PR: pytorch/pytorch#21798
aten::where under it.
Something like: # aten::where.default is the same as aten::nonzero_numpy
# https://github.com/pytorch/pytorch/issues/21798
@torch_op("aten::where", trace_only=True)
def aten_nonzero_numpy(self: TensorType) -> TensorType:
"""nonzero_numpy(Tensor self) -> Tensor[]"""
nonzero = aten_nonzero(self)
return aten_unbind(nonzero, dim=1) And enable its test in ops_test.py. |
The odd thing is it works for me by just removing it. Let me just install fresh venv and try again |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1971 +/- ##
==========================================
- Coverage 74.73% 74.72% -0.02%
==========================================
Files 273 273
Lines 29329 29329
Branches 3367 3367
==========================================
- Hits 21920 21916 -4
- Misses 6368 6371 +3
- Partials 1041 1042 +1 ☔ View full report in Codecov by Sentry. |
You are right, on nightly build of torch the decompositions does not work but on the stable version it do... 😮💨 I think that it is better if it is decomposed by the pytorch, as there should be decomposition of where(x) to the nonzero in any way. if a is None and b is None:
return torch.nonzero(pred, as_tuple=True) here, which is actually where the NotImplementedError originates |
@titaiwangms with this PR it works without any other change. I have no idea why as it seems to me there is no decomposition that would handle it but it works. But I would probably make PR from my previous comment in any way. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just looking at this again. Nice!
|
As the title said the overload is not implemented in
aten_where
. It should be decomposed intononzero
function by pytorch.Now it throws error as there is not enough parameters. Minimal reproducible example:
As for the tests I would have thought it is handled by the ops_test.py but apparently it is not.
As a side note, the
pylint
is somehow broken for this file (at least).