Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch support for Join and Careduce Ops #869

Merged
merged 11 commits into from
Jul 4, 2024
19 changes: 18 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join


@singledispatch
Expand Down Expand Up @@ -89,3 +89,20 @@
return torch.arange(start, stop, step, dtype=dtype)

return arange


@pytorch_funcify.register(Join)
def jax_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
view = op.view

Check warning on line 99 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L99

Added line #L99 was not covered by tests
Copy link
Member

@ricardoV94 ricardoV94 Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to remove this kwarg, so you can just raise NotImplementedErrror if the Op has it set in the outer dispatch function. #753

if (view != -1) and all(
tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :]
):
return tensors[view]

Check warning on line 103 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L103

Added line #L103 was not covered by tests

else:
return torch.cat(tensors, dim=axis)

return join
9 changes: 9 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Sum


@pytorch_funcify.register(Elemwise)
Expand Down Expand Up @@ -34,3 +35,11 @@
return res

return dimshuffle


@pytorch_funcify.register(Sum)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
def pytorch_funcify_careduce(op, **kwargs):
def torch_sum(x):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
return torch.sum(x, dim=op.axis)

Check warning on line 43 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L43

Added line #L43 was not covered by tests

return torch_sum
42 changes: 41 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -13,7 +14,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor.type import scalar, vector
from pytensor.tensor.type import matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -235,3 +236,42 @@ def test_arange():
FunctionGraph([start, stop, step], [out]),
[np.array(1), np.array(10), np.array(2)],
)


def test_pytorch_Join():
a = matrix("a")
b = matrix("b")

x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
],
)

x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
10 changes: 10 additions & 0 deletions tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,13 @@ def test_pytorch_elemwise():

fg = FunctionGraph([x], [out])
compare_pytorch_and_py(fg, [[0.9, 0.9]])


def test_pytorch_sum():
a_pt = vector("a")
test_value = np.r_[1, 2, 3].astype(config.floatX)

x = pt.math.sum(a_pt, axis=None)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])
Loading