Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch support for Join and Careduce Ops #869

Merged
merged 11 commits into from
Jul 4, 2024
13 changes: 12 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join


@singledispatch
Expand Down Expand Up @@ -89,3 +89,14 @@ def arange(start, stop, step):
return torch.arange(start, stop, step, dtype=dtype)

return arange


@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]

return torch.cat(tensors, dim=axis)

return join
59 changes: 59 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad


Expand Down Expand Up @@ -37,6 +38,64 @@
return dimshuffle


@pytorch_funcify.register(Sum)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
def pytorch_funcify_sum(op, **kwargs):
def torch_sum(x):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
return torch.sum(x, dim=op.axis)

Check warning on line 44 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L44

Added line #L44 was not covered by tests

return torch_sum


@pytorch_funcify.register(All)
def pytorch_funcify_all(op, **kwargs):
dim = op.axis

def torch_all(x):
return torch.all(x, dim=dim)

Check warning on line 54 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L54

Added line #L54 was not covered by tests

return torch_all


@pytorch_funcify.register(Prod)
def pytorch_funcify_prod(op, **kwargs):
dim = op.axis[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why axis[0]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op.axis is a tuple, pytorch expects integers for dim

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we need to change the logic, because it is possible for them to be tuples with more than one entry

Copy link
Contributor Author

@HarshvirSandhu HarshvirSandhu Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do something like this:

for d in op.axis:
     x=torch.prod(x, dim=d, keepdim=True) # To make sure constant shape, we can reshape in the end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you reduce in reversed order you don't have to worry about the keepdims. Sounds good, a bit surprising that they don't support multiple axes


def torch_prod(x):
return torch.prod(x, dim=dim)

Check warning on line 64 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L64

Added line #L64 was not covered by tests

return torch_prod


@pytorch_funcify.register(Any)
def pytorch_funcify_any(op, **kwargs):
dim = op.axis

def torch_any(x):
return torch.any(x, dim=dim)

Check warning on line 74 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L74

Added line #L74 was not covered by tests

return torch_any


@pytorch_funcify.register(Max)
def pytorch_funcify_max(op, **kwargs):
dim = op.axis[0]

def torch_max(x):
return torch.max(x, dim=dim).values

Check warning on line 84 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L84

Added line #L84 was not covered by tests

return torch_max


@pytorch_funcify.register(Min)
def pytorch_funcify_min(op, **kwargs):
dim = op.axis[0]

def torch_min(x):
return torch.min(x, dim=dim).values

Check warning on line 94 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L94

Added line #L94 was not covered by tests

return torch_min


@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis
Expand Down
42 changes: 41 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -13,7 +14,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor.type import scalar, vector
from pytensor.tensor.type import matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -235,3 +236,42 @@ def test_arange():
FunctionGraph([start, stop, step], [out]),
[np.array(1), np.array(10), np.array(2)],
)


def test_pytorch_Join():
a = matrix("a")
b = matrix("b")

x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
],
)

x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
67 changes: 67 additions & 0 deletions tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
Expand Down Expand Up @@ -57,6 +58,72 @@ def test_pytorch_elemwise():
compare_pytorch_and_py(fg, [[0.9, 0.9]])


@pytest.mark.parametrize("axis", [None, 0, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we parametrize these tests with the reduce function? Since they all look the same, we can reduce a bunch of lines. Or at least separate only those that need numerical inputs from those that need boolean (all and any).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I would like to test axis = (1, 2), and have a_pt be a tensor3, so that we cover the case with more than 1 axis, but not all of them.

def test_pytorch_sum(axis):
a_pt = matrix("a")
test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX)

x = pt.math.sum(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_pytorch_all(axis):
a_pt = matrix("a")
test_value = np.array([[True, False, True], [False, True, True]])

x = ptm.all(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("axis", [0, 1])
def test_pytorch_prod(axis):
a_pt = matrix("a")
test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX)

x = ptm.prod(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_pytorch_any(axis):
a_pt = matrix("a")
test_value = np.array([[True, False, True], [False, True, True]])

x = ptm.any(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("axis", [0, 1])
def test_pytorch_max(axis):
a_pt = matrix("a")
test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX)

x = ptm.max(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("axis", [0, 1])
def test_pytorch_min(axis):
a_pt = matrix("a")
test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX)

x = ptm.min(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis, dtype):
Expand Down
Loading