Skip to content

Commit cdc9098

Browse files
sakchalChaluvadi
and
Chaluvadi
authored
unit tests for iotafunction (#19)
* unit tests for iotafunction * Fixed import formatting and automatic checks --------- Co-authored-by: Chaluvadi <[email protected]>
1 parent d8caaeb commit cdc9098

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

tests/test_iota.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import random
2+
3+
import numpy as np
4+
import pytest
5+
6+
import arrayfire_wrapper.dtypes as dtypes
7+
import arrayfire_wrapper.lib as wrapper
8+
9+
10+
@pytest.mark.parametrize(
11+
"shape",
12+
[
13+
(),
14+
(random.randint(1, 10), 1),
15+
(random.randint(1, 10), random.randint(1, 10)),
16+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
17+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
18+
],
19+
)
20+
def test_iota_shape(shape: tuple) -> None:
21+
"""Test if identity creates an array with the correct shape"""
22+
dtype = dtypes.s16
23+
t_shape = (1, 1)
24+
25+
result = wrapper.iota(shape, t_shape, dtype)
26+
27+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
28+
29+
30+
def test_iota_invalid_shape() -> None:
31+
"""Test if iota handles a shape with greater than 4 dimensions"""
32+
with pytest.raises(TypeError) as excinfo:
33+
invalid_shape = (
34+
random.randint(1, 10),
35+
random.randint(1, 10),
36+
random.randint(1, 10),
37+
random.randint(1, 10),
38+
random.randint(1, 10),
39+
)
40+
dtype = dtypes.s16
41+
t_shape = ()
42+
43+
wrapper.iota(invalid_shape, t_shape, dtype)
44+
45+
assert f"CShape.__init__() takes from 1 to 5 positional arguments but {len(invalid_shape) + 1} were given" in str(
46+
excinfo.value
47+
)
48+
49+
50+
@pytest.mark.parametrize(
51+
"t_shape",
52+
[
53+
(1,),
54+
(random.randint(1, 10), 1),
55+
(random.randint(1, 10), random.randint(1, 10)),
56+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
57+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
58+
],
59+
)
60+
def test_iota_tshape(t_shape: tuple) -> None:
61+
"""Test if iota properly uses t_shape to change the size of the array and result in correct dimensions"""
62+
shape = np.array([2, 2])
63+
dtype = dtypes.s64
64+
65+
if len(shape.shape) < len(t_shape):
66+
shape = np.append(shape, np.ones(len(t_shape) - len(shape), dtype=int))
67+
68+
result_shape = shape * t_shape
69+
70+
result = wrapper.iota(tuple(shape), t_shape, dtype)
71+
72+
result_dims = tuple(int(value) for value in wrapper.get_dims(result))
73+
74+
assert (result_dims[0 : len(result_shape)] == result_shape).all() # noqa: E203
75+
76+
77+
@pytest.mark.parametrize(
78+
"t_shape",
79+
[
80+
(0,),
81+
(-1, -1),
82+
],
83+
)
84+
def test_iota_tshape_zero(t_shape: tuple) -> None:
85+
"""Test it iota properly handles negative or zero t_shapes"""
86+
with pytest.raises(RuntimeError):
87+
shape = (2, 2)
88+
89+
dtype = dtypes.s16
90+
91+
wrapper.iota(shape, t_shape, dtype)
92+
93+
94+
def test_iota_tshape_invalid() -> None:
95+
"""Test it iota properly handles a tshape with greater than 4 dimensions"""
96+
with pytest.raises(TypeError):
97+
shape = (2, 2)
98+
invalid_tshape = (
99+
random.randint(1, 10),
100+
random.randint(1, 10),
101+
random.randint(1, 10),
102+
random.randint(1, 10),
103+
random.randint(1, 10),
104+
)
105+
dtype = dtypes.s16
106+
107+
wrapper.iota(shape, invalid_tshape, dtype)
108+
109+
110+
@pytest.mark.parametrize(
111+
"dtype_index",
112+
[i for i in range(13)],
113+
)
114+
def test_iota_dtype(dtype_index: int) -> None:
115+
"""Test if iota creates an array with the correct dtype"""
116+
if (dtype_index in [1, 4]) or (dtype_index in [2, 3] and not wrapper.get_dbl_support()):
117+
pytest.skip()
118+
119+
shape = (5, 5)
120+
t_shape = (2, 2)
121+
dtype = dtypes.c_api_value_to_dtype(dtype_index)
122+
123+
result = wrapper.iota(shape, t_shape, dtype)
124+
125+
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype

0 commit comments

Comments
 (0)