Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch support for Join and Careduce Ops #869

Merged
merged 11 commits into from
Jul 4, 2024
18 changes: 9 additions & 9 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
@pytorch_funcify.register(Sum)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
def pytorch_funcify_sum(op, **kwargs):
def torch_sum(x):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
return torch.sum(x, dim=op.axis)

Check warning on line 44 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L44

Added line #L44 was not covered by tests

return torch_sum

Expand All @@ -49,7 +49,7 @@
@pytorch_funcify.register(All)
def pytorch_funcify_all(op, **kwargs):
def torch_all(x):
return torch.all(x, dim=op.axis)

Check warning on line 52 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L52

Added line #L52 was not covered by tests

return torch_all

Expand All @@ -58,11 +58,11 @@
def pytorch_funcify_prod(op, **kwargs):
def torch_prod(x):
if isinstance(op.axis, tuple):
for d in op.axis:
x = torch.prod(x, dim=d, keepdim=True)
return x.squeeze()
for d in op.axis[::-1]:
Copy link
Member

@ricardoV94 ricardoV94 Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More readable?

Suggested change
for d in op.axis[::-1]:
for d in sorted(op.axis, reverse=True):

Same for the others

x = torch.prod(x, dim=d)
return x

Check warning on line 63 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L62-L63

Added lines #L62 - L63 were not covered by tests
else:
return torch.prod(x.flatten(), dim=0)

Check warning on line 65 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L65

Added line #L65 was not covered by tests

return torch_prod

Expand All @@ -70,7 +70,7 @@
@pytorch_funcify.register(Any)
def pytorch_funcify_any(op, **kwargs):
def torch_any(x):
return torch.any(x, dim=op.axis)

Check warning on line 73 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L73

Added line #L73 was not covered by tests

return torch_any

Expand All @@ -79,11 +79,11 @@
def pytorch_funcify_max(op, **kwargs):
def torch_max(x):
if isinstance(op.axis, tuple):
for d in op.axis:
x = torch.max(x, dim=d, keepdim=True).values
return x.squeeze()
for d in op.axis[::-1]:
x = torch.max(x, dim=d).values
return x

Check warning on line 84 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L83-L84

Added lines #L83 - L84 were not covered by tests
else:
return torch.max(x.flatten(), dim=0).values

Check warning on line 86 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L86

Added line #L86 was not covered by tests

return torch_max

Expand All @@ -92,11 +92,11 @@
def pytorch_funcify_min(op, **kwargs):
def torch_min(x):
if isinstance(op.axis, tuple):
for d in op.axis:
x = torch.min(x, dim=d, keepdim=True).values
return x.squeeze()
for d in op.axis[::-1]:
x = torch.min(x, dim=d).values
return x

Check warning on line 97 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L96-L97

Added lines #L96 - L97 were not covered by tests
else:
return torch.min(x.flatten(), dim=0).values

Check warning on line 99 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L99

Added line #L99 was not covered by tests

return torch_min

Expand Down
105 changes: 6 additions & 99 deletions tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def test_pytorch_elemwise():
compare_pytorch_and_py(fg, [[0.9, 0.9]])


@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min])
@pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is sufficient

Suggested change
@pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)])

def test_pytorch_sum(axis):
def test_pytorch_careduce(fn, axis):
a_pt = tensor3("a")
test_value = np.array(
[
Expand All @@ -79,113 +80,19 @@ def test_pytorch_sum(axis):
]
).astype(config.floatX)

x = pt.math.sum(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_pytorch_all(axis):
a_pt = matrix("a")
test_value = np.array([[True, False, True], [False, True, True]])

x = ptm.all(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (1, 2)])
def test_pytorch_prod(axis):
a_pt = tensor3("a")
test_value = np.array(
[
[
[1, 1, 1, 1],
[2, 2, 2, 2],
],
[
[3, 3, 3, 3],
[
4,
4,
4,
4,
],
],
]
).astype(config.floatX)

x = ptm.prod(a_pt, axis=axis)
x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("fn", [ptm.any, ptm.all])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
def test_pytorch_any(axis):
def test_pytorch_any_all(fn, axis):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_pytorch_any_all(fn, axis):
def test_pytorch_careduce_bool(fn, axis):

a_pt = matrix("a")
test_value = np.array([[True, False, True], [False, True, True]])

x = ptm.any(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("axis", [None, 0, 1, (1, -1)])
def test_pytorch_max(axis):
a_pt = tensor3("a")
test_value = np.array(
[
[
[1, 1, 1, 1],
[2, 2, 2, 2],
],
[
[3, 3, 3, 3],
[
4,
4,
4,
4,
],
],
]
).astype(config.floatX)

x = ptm.max(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("axis", [None, 0, 1, (1, -1)])
def test_pytorch_min(axis):
a_pt = tensor3("a")
test_value = np.array(
[
[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]],
[
[4, 4, 4, 4],
[
5,
5,
5,
5,
],
[
6,
6,
6,
6,
],
],
]
).astype(config.floatX)

x = ptm.min(a_pt, axis=axis)
x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])
Expand Down
Loading