Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch support for Join and Careduce Ops #869

Merged
merged 11 commits into from
Jul 4, 2024
13 changes: 12 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join


@singledispatch
Expand Down Expand Up @@ -89,3 +89,14 @@ def arange(start, stop, step):
return torch.arange(start, stop, step, dtype=dtype)

return arange


@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]

return torch.cat(tensors, dim=axis)

return join
64 changes: 64 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad


Expand Down Expand Up @@ -37,6 +38,69 @@
return dimshuffle


@pytorch_funcify.register(Sum)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
def pytorch_funcify_sum(op, **kwargs):
def torch_sum(x):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
return torch.sum(x, dim=op.axis)

Check warning on line 44 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L44

Added line #L44 was not covered by tests

return torch_sum


@pytorch_funcify.register(All)
def pytorch_funcify_all(op, **kwargs):
def torch_all(x):
return torch.all(x, dim=op.axis)

Check warning on line 52 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L52

Added line #L52 was not covered by tests

return torch_all


@pytorch_funcify.register(Prod)
def pytorch_funcify_prod(op, **kwargs):
def torch_prod(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.prod(x, dim=d)
return x

Check warning on line 63 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L62-L63

Added lines #L62 - L63 were not covered by tests
else:
return torch.prod(x.flatten(), dim=0)

Check warning on line 65 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L65

Added line #L65 was not covered by tests

return torch_prod


@pytorch_funcify.register(Any)
def pytorch_funcify_any(op, **kwargs):
def torch_any(x):
return torch.any(x, dim=op.axis)

Check warning on line 73 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L73

Added line #L73 was not covered by tests

return torch_any


@pytorch_funcify.register(Max)
def pytorch_funcify_max(op, **kwargs):
def torch_max(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.max(x, dim=d).values
return x

Check warning on line 84 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L83-L84

Added lines #L83 - L84 were not covered by tests
else:
return torch.max(x.flatten(), dim=0).values

Check warning on line 86 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L86

Added line #L86 was not covered by tests

return torch_max


@pytorch_funcify.register(Min)
def pytorch_funcify_min(op, **kwargs):
def torch_min(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.min(x, dim=d).values
return x

Check warning on line 97 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L96-L97

Added lines #L96 - L97 were not covered by tests
else:
return torch.min(x.flatten(), dim=0).values

Check warning on line 99 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L99

Added line #L99 was not covered by tests

return torch_min


@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis
Expand Down
42 changes: 41 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -13,7 +14,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor.type import scalar, vector
from pytensor.tensor.type import matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -235,3 +236,42 @@ def test_arange():
FunctionGraph([start, stop, step], [out]),
[np.array(1), np.array(10), np.array(2)],
)


def test_pytorch_Join():
a = matrix("a")
b = matrix("b")

x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
],
)

x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
43 changes: 42 additions & 1 deletion tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import pytest

import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py


Expand Down Expand Up @@ -57,6 +58,46 @@ def test_pytorch_elemwise():
compare_pytorch_and_py(fg, [[0.9, 0.9]])


@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)])
def test_pytorch_careduce(fn, axis):
a_pt = tensor3("a")
test_value = np.array(
[
[
[1, 1, 1, 1],
[2, 2, 2, 2],
],
[
[3, 3, 3, 3],
[
4,
4,
4,
4,
],
],
]
).astype(config.floatX)

x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("fn", [ptm.any, ptm.all])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
def test_pytorch_any_all(fn, axis):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_pytorch_any_all(fn, axis):
def test_pytorch_careduce_bool(fn, axis):

a_pt = matrix("a")
test_value = np.array([[True, False, True], [False, True, True]])

x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis, dtype):
Expand Down
Loading