Skip to content

Commit

Permalink
Reworked tests in implementation of Shape in PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
twaclaw committed Jul 16, 2024
1 parent 4d4abd7 commit 0e455fd
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 17 deletions.
4 changes: 1 addition & 3 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
shape = node.inputs[1]

def reshape(x, shape=shape):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))

return reshape
Expand Down
3 changes: 2 additions & 1 deletion tests/link/pytorch/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def test_pytorch_CumOp(axis, dtype):
pytest.param(
None,
3,
marks=pytest.mark.xfail(reason="Reshape not implemented"),
marks=pytest.mark.xfail(reason="Issue in Elemwise"),
# TODO: add reference to issue
),
],
)
Expand Down
15 changes: 2 additions & 13 deletions tests/link/pytorch/test_shape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np

import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
Expand Down Expand Up @@ -46,27 +45,17 @@ def test_pytorch_Reshape_constant():
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])


def test_pytorch_Reshape_shape_graph_input():
def test_pytorch_Reshape_dynamic():
a = vector("a")
shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])


def test_pytorch_compile_ops():
x = DeepCopyOp()(pt.as_tensor_variable(1.1))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])

def test_pytorch_unbroadcast():
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])

x = ViewOp()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])

0 comments on commit 0e455fd

Please sign in to comment.