Skip to content

Commit

Permalink
Pytorch support for Join and Careduce Ops (#869)
Browse files Browse the repository at this point in the history
  • Loading branch information
HarshvirSandhu authored Jul 4, 2024
1 parent df769f6 commit e57e25b
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 3 deletions.
13 changes: 12 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join


@singledispatch
Expand Down Expand Up @@ -89,3 +89,14 @@ def arange(start, stop, step):
return torch.arange(start, stop, step, dtype=dtype)

return arange


@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]

return torch.cat(tensors, dim=axis)

return join
64 changes: 64 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad


Expand Down Expand Up @@ -37,6 +38,69 @@ def dimshuffle(x):
return dimshuffle


@pytorch_funcify.register(Sum)
def pytorch_funcify_sum(op, **kwargs):
def torch_sum(x):
return torch.sum(x, dim=op.axis)

return torch_sum


@pytorch_funcify.register(All)
def pytorch_funcify_all(op, **kwargs):
def torch_all(x):
return torch.all(x, dim=op.axis)

return torch_all


@pytorch_funcify.register(Prod)
def pytorch_funcify_prod(op, **kwargs):
def torch_prod(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.prod(x, dim=d)
return x
else:
return torch.prod(x.flatten(), dim=0)

return torch_prod


@pytorch_funcify.register(Any)
def pytorch_funcify_any(op, **kwargs):
def torch_any(x):
return torch.any(x, dim=op.axis)

return torch_any


@pytorch_funcify.register(Max)
def pytorch_funcify_max(op, **kwargs):
def torch_max(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.max(x, dim=d).values
return x
else:
return torch.max(x.flatten(), dim=0).values

return torch_max


@pytorch_funcify.register(Min)
def pytorch_funcify_min(op, **kwargs):
def torch_min(x):
if isinstance(op.axis, tuple):
for d in sorted(op.axis, reverse=True):
x = torch.min(x, dim=d).values
return x
else:
return torch.min(x.flatten(), dim=0).values

return torch_min


@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis
Expand Down
42 changes: 41 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -13,7 +14,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor.type import scalar, vector
from pytensor.tensor.type import matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -235,3 +236,42 @@ def test_arange():
FunctionGraph([start, stop, step], [out]),
[np.array(1), np.array(10), np.array(2)],
)


def test_pytorch_Join():
a = matrix("a")
b = matrix("b")

x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
],
)

x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
43 changes: 42 additions & 1 deletion tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import pytest

import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py


Expand Down Expand Up @@ -57,6 +58,46 @@ def test_pytorch_elemwise():
compare_pytorch_and_py(fg, [[0.9, 0.9]])


@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)])
def test_pytorch_careduce(fn, axis):
a_pt = tensor3("a")
test_value = np.array(
[
[
[1, 1, 1, 1],
[2, 2, 2, 2],
],
[
[3, 3, 3, 3],
[
4,
4,
4,
4,
],
],
]
).astype(config.floatX)

x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("fn", [ptm.any, ptm.all])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
def test_pytorch_any_all(fn, axis):
a_pt = matrix("a")
test_value = np.array([[True, False, True], [False, True, True]])

x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])

compare_pytorch_and_py(x_fg, [test_value])


@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis, dtype):
Expand Down

0 comments on commit e57e25b

Please sign in to comment.