-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch support for Join and Careduce Ops #869
Conversation
def join(axis, *tensors): | ||
# tensors could also be tuples, and in this case they don't have a ndim | ||
tensors = [torch.tensor(tensor) for tensor in tensors] | ||
view = op.view |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to remove this kwarg, so you can just raise NotImplementedErrror if the Op has it set in the outer dispatch function. #753
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #869 +/- ##
==========================================
+ Coverage 80.98% 81.01% +0.02%
==========================================
Files 169 169
Lines 47025 46894 -131
Branches 11501 11492 -9
==========================================
- Hits 38084 37989 -95
+ Misses 6724 6698 -26
+ Partials 2217 2207 -10
|
|
||
@pytorch_funcify.register(Prod) | ||
def pytorch_funcify_prod(op, **kwargs): | ||
dim = op.axis[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why axis[0]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op.axis
is a tuple, pytorch expects integers for dim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we need to change the logic, because it is possible for them to be tuples with more than one entry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do something like this:
for d in op.axis:
x=torch.prod(x, dim=d, keepdim=True) # To make sure constant shape, we can reshape in the end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you reduce in reversed order you don't have to worry about the keepdims. Sounds good, a bit surprising that they don't support multiple axes
tests/link/pytorch/test_elemwise.py
Outdated
@@ -57,6 +58,72 @@ def test_pytorch_elemwise(): | |||
compare_pytorch_and_py(fg, [[0.9, 0.9]]) | |||
|
|||
|
|||
@pytest.mark.parametrize("axis", [None, 0, 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we parametrize these tests with the reduce function? Since they all look the same, we can reduce a bunch of lines. Or at least separate only those that need numerical inputs from those that need boolean (all and any).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I would like to test axis = (1, 2), and have a_pt
be a tensor3, so that we cover the case with more than 1 axis, but not all of them.
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) | ||
def test_pytorch_any(axis): | ||
def test_pytorch_any_all(fn, axis): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_pytorch_any_all(fn, axis): | |
def test_pytorch_careduce_bool(fn, axis): |
tests/link/pytorch/test_elemwise.py
Outdated
@@ -58,8 +58,9 @@ def test_pytorch_elemwise(): | |||
compare_pytorch_and_py(fg, [[0.9, 0.9]]) | |||
|
|||
|
|||
@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min]) | |||
@pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is sufficient
@pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)]) | |
@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)]) |
for d in op.axis: | ||
x = torch.prod(x, dim=d, keepdim=True) | ||
return x.squeeze() | ||
for d in op.axis[::-1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More readable?
for d in op.axis[::-1]: | |
for d in sorted(op.axis, reverse=True): |
Same for the others
Description
Add pytorch support for common CARdeuce Ops (sum, prod, all, any, ...)
Related Issue
Checklist
Type of change