Skip to content

Commit d8caaeb

Browse files
sakchalChaluvadi
and
Chaluvadi
authored
Added tests for identity matrix function (#18)
* Added tests for identity matrix function * Fixed test_identity_invalid_shape() method * Fixed import formatting, black and flake8 automatic checks --------- Co-authored-by: Chaluvadi <[email protected]>
1 parent e404ab7 commit d8caaeb

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

tests/test_identity.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.dtypes as dtypes
6+
import arrayfire_wrapper.lib as wrapper
7+
8+
9+
@pytest.mark.parametrize(
10+
"shape",
11+
[
12+
(),
13+
(random.randint(1, 10), 1),
14+
(random.randint(1, 10), random.randint(1, 10)),
15+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
16+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
17+
],
18+
)
19+
def test_identity_shape(shape: tuple) -> None:
20+
"""Test if identity creates an array with the correct shape"""
21+
dtype = dtypes.s16
22+
23+
result = wrapper.identity(shape, dtype)
24+
25+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
26+
27+
28+
def test_identity_invalid_shape() -> None:
29+
"""Test if identity handles a shape with greater than 4 dimensions"""
30+
with pytest.raises(TypeError) as excinfo:
31+
invalid_shape = (
32+
random.randint(1, 10),
33+
random.randint(1, 10),
34+
random.randint(1, 10),
35+
random.randint(1, 10),
36+
random.randint(1, 10),
37+
)
38+
dtype = dtypes.s16
39+
40+
wrapper.identity(invalid_shape, dtype)
41+
42+
assert f"CShape.__init__() takes from 1 to 5 positional arguments but {len(invalid_shape) + 1} were given" in str(
43+
excinfo.value
44+
)
45+
46+
47+
def test_identity_nonsquare_shape() -> None:
48+
dtype = dtypes.s16
49+
shape = (5, 6)
50+
51+
result = wrapper.identity(shape, dtype)
52+
53+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
54+
55+
56+
@pytest.mark.parametrize(
57+
"dtype_index",
58+
[i for i in range(13)],
59+
)
60+
def test_identity_dtype(dtype_index: int) -> None:
61+
"""Test if identity creates an array with the correct dtype"""
62+
if dtype_index in [2, 3] and not wrapper.get_dbl_support():
63+
pytest.skip()
64+
65+
shape = (5, 5)
66+
dtype = dtypes.c_api_value_to_dtype(dtype_index)
67+
68+
result = wrapper.identity(shape, dtype)
69+
70+
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype

0 commit comments

Comments
 (0)