Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements shape Ops and MakeVector in PyTorch #926

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.shape
# isort: on
18 changes: 17 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from functools import singledispatch
from types import NoneType

import torch

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector


@singledispatch
Expand All @@ -15,6 +16,11 @@ def pytorch_typify(data, dtype=None, **kwargs):
return torch.as_tensor(data, dtype=dtype)


@pytorch_typify.register(NoneType)
def pytorch_typify_None(data, **kwargs):
return None


@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
Expand Down Expand Up @@ -116,3 +122,13 @@ def eye(N, M, k):
return zeros

return eye


@pytorch_funcify.register(MakeVector)
def pytorch_funcify_MakeVector(op, **kwargs):
torch_dtype = getattr(torch, op.dtype)

def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector
52 changes: 52 additions & 0 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast


@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))

return reshape


@pytorch_funcify.register(Shape)
def pytorch_funcify_Shape(op, **kwargs):
def shape(x):
return x.shape

Check warning on line 18 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L17-L18

Added lines #L17 - L18 were not covered by tests

return shape

Check warning on line 20 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L20

Added line #L20 was not covered by tests


@pytorch_funcify.register(Shape_i)
def pytorch_funcify_Shape_i(op, **kwargs):
i = op.i

def shape_i(x):
return torch.tensor(x.shape[i])

Check warning on line 28 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L28

Added line #L28 was not covered by tests

return shape_i


@pytorch_funcify.register(SpecifyShape)
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape):
if expected is None:
continue
if actual != expected:
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")

Check warning on line 41 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L41

Added line #L41 was not covered by tests
return x

return specifyshape


@pytorch_funcify.register(Unbroadcast)
def pytorch_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x

Check warning on line 50 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L49-L50

Added lines #L49 - L50 were not covered by tests

return unbroadcast

Check warning on line 52 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L52

Added line #L52 was not covered by tests
7 changes: 7 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,10 @@ def test_eye(dtype):
for _M in range(1, 6):
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))


def test_pytorch_MakeVector():
x = ptb.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])
13 changes: 1 addition & 12 deletions tests/link/pytorch/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,7 @@ def test_pytorch_CumOp(axis, dtype):
compare_pytorch_and_py(fgraph, [test_value])


@pytest.mark.parametrize(
"axis, repeats",
[
(0, (1, 2, 3)),
(1, (3, 3)),
pytest.param(
None,
3,
marks=pytest.mark.xfail(reason="Reshape not implemented"),
),
],
)
@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])
def test_pytorch_Repeat(axis, repeats):
a = pt.matrix("a", dtype="float64")

Expand Down
61 changes: 61 additions & 0 deletions tests/link/pytorch/test_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np

import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py


def test_pytorch_shape_ops():
x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [], must_be_device_array=False)

x = Shape_i(1)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [], must_be_device_array=False)


def test_pytorch_specify_shape():
in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None))
x_fg = FunctionGraph([in_pt], [x])
compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])

# When used to assert two arrays have similar shapes
in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape)
x_fg = FunctionGraph([in_pt, shape_pt], [x])
compare_pytorch_and_py(
x_fg,
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)


def test_pytorch_Reshape_constant():
a = vector("a")
x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])


def test_pytorch_Reshape_dynamic():
a = vector("a")
shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])


def test_pytorch_unbroadcast():
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])
11 changes: 1 addition & 10 deletions tests/link/pytorch/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,7 @@


@pytest.mark.parametrize("func", (sort, argsort))
@pytest.mark.parametrize(
"axis",
[
pytest.param(0),
pytest.param(1),
pytest.param(
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
),
],
)
@pytest.mark.parametrize("axis", [0, 1, None])
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
Expand Down
Loading