Skip to content

Commit

Permalink
Implements shape and MakeVector Ops in PyTorch
Browse files Browse the repository at this point in the history
- Shape
- Shape_i
- Reshape
- SpecifyShape
- Unbroadcast

- MakeVector
  • Loading branch information
twaclaw committed Jul 12, 2024
1 parent a6b9585 commit bf50423
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 12 deletions.
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.shape
# isort: on
16 changes: 14 additions & 2 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector


@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
return torch.as_tensor(data, dtype=dtype)
if data is not None:
return torch.as_tensor(data, dtype=dtype)
return None


@singledispatch
Expand Down Expand Up @@ -116,3 +118,13 @@ def eye(N, M, k):
return zeros

return eye


@pytorch_funcify.register(MakeVector)
def pytorch_funcify_MakeVector(op, **kwargs):
torch_dtype = getattr(torch, op.dtype)

def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector
54 changes: 54 additions & 0 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast


@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
shape = node.inputs[1]

def reshape(x, shape=shape):
return torch.reshape(x, tuple(shape))

return reshape


@pytorch_funcify.register(Shape)
def pytorch_funcify_Shape(op, **kwargs):
def shape(x):
return x.shape

Check warning on line 20 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L19-L20

Added lines #L19 - L20 were not covered by tests

return shape

Check warning on line 22 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L22

Added line #L22 was not covered by tests


@pytorch_funcify.register(Shape_i)
def pytorch_funcify_Shape_i(op, **kwargs):
i = op.i

def shape_i(x):
return x.shape[i]

Check warning on line 30 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L30

Added line #L30 was not covered by tests

return shape_i


@pytorch_funcify.register(SpecifyShape)
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape):
if expected is None:
continue
if actual != expected:
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")

Check warning on line 43 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L43

Added line #L43 was not covered by tests
return x

return specifyshape


@pytorch_funcify.register(Unbroadcast)
def pytorch_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x

Check warning on line 52 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L51-L52

Added lines #L51 - L52 were not covered by tests

return unbroadcast

Check warning on line 54 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L54

Added line #L54 was not covered by tests
7 changes: 7 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,10 @@ def test_eye(dtype):
for _M in range(1, 6):
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))


def test_pytorch_MakeVector():
x = ptb.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])
72 changes: 72 additions & 0 deletions tests/link/pytorch/test_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np

import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py


def test_pytorch_shape_ops():
x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [], must_be_device_array=False)

x = Shape_i(1)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [], must_be_device_array=False)


def test_pytorch_specify_shape():
in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None))
x_fg = FunctionGraph([in_pt], [x])
compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])

# When used to assert two arrays have similar shapes
in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape)
x_fg = FunctionGraph([in_pt, shape_pt], [x])
compare_pytorch_and_py(
x_fg,
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)


def test_pytorch_Reshape_constant():
a = vector("a")
x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])


def test_pytorch_Reshape_shape_graph_input():
a = vector("a")
shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])


def test_pytorch_compile_ops():
x = DeepCopyOp()(pt.as_tensor_variable(1.1))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])

x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])

x = ViewOp()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])
11 changes: 1 addition & 10 deletions tests/link/pytorch/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,7 @@


@pytest.mark.parametrize("func", (sort, argsort))
@pytest.mark.parametrize(
"axis",
[
pytest.param(0),
pytest.param(1),
pytest.param(
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
),
],
)
@pytest.mark.parametrize("axis", [0, 1, None])
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
Expand Down

0 comments on commit bf50423

Please sign in to comment.