Skip to content

Commit

Permalink
Parametrized dtype in tests for Eye Op in PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
twaclaw committed Jul 7, 2024
1 parent daa86c4 commit 08d23ab
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,16 @@ def test_pytorch_Join():
)


def test_eye():
@pytest.mark.parametrize(
"dtype",
["int64", config.floatX],
)
def test_eye(dtype):
N = scalar("N", dtype="int64")
M = scalar("M", dtype="int64")
k = scalar("k", dtype="int64")

out = eye(N, M, k, dtype="float32")
out = eye(N, M, k, dtype=dtype)

fn = function([N, M, k], out, mode=pytorch_mode)

Expand Down

0 comments on commit 08d23ab

Please sign in to comment.