Skip to content

Commit f90ef61

Browse files
sakchalChaluvadi
and
Chaluvadi
authored
Added unit tests for the diagonal function (#23)
* added diagonal tests * added random tests * modified random tests * fixed formatting issues for random tests * reformatted manage_array file * reformatted diag tests, manage array --------- Co-authored-by: Chaluvadi <[email protected]>
1 parent a657583 commit f90ef61

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed

arrayfire_wrapper/lib/create_and_modify_array/manage_array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def get_scalar(arr: AFArray, dtype: Dtype, /) -> int | float | complex | bool |
166166
out = dtype.c_type()
167167
call_from_clib(get_scalar.__name__, ctypes.pointer(out), arr)
168168
if dtype == c32 or dtype == c64:
169-
return complex(out[0], out[1]) # type: ignore
169+
return complex(out[0], out[1]) # type: ignore
170170
else:
171171
return cast(int | float | complex | bool | None, out.value)
172172

tests/test_diag.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pytest
2+
3+
import arrayfire_wrapper.dtypes as dtypes
4+
import arrayfire_wrapper.lib as wrapper
5+
6+
7+
@pytest.mark.parametrize("diagonal_shape", [(2,), (10,), (100,), (1000,)])
8+
def test_diagonal_shape(diagonal_shape: tuple) -> None:
9+
"""Test if diagonal array is keeping the shape of the passed into the input array"""
10+
in_arr = wrapper.constant(1, diagonal_shape, dtypes.s16)
11+
diag_array = wrapper.diag_create(in_arr, 0)
12+
13+
extracted_diagonal = wrapper.diag_extract(diag_array, 0)
14+
15+
assert wrapper.get_dims(extracted_diagonal)[0 : len(diagonal_shape)] == diagonal_shape # noqa: E203
16+
17+
18+
@pytest.mark.parametrize("diagonal_shape", [(2,), (10,), (100,), (1000,)])
19+
def test_diagonal_val(diagonal_shape: tuple) -> None:
20+
"""Test if diagonal array is keeping the same value as that of the values passed into the input array"""
21+
dtype = dtypes.s16
22+
in_arr = wrapper.constant(1, diagonal_shape, dtype)
23+
diag_array = wrapper.diag_create(in_arr, 0)
24+
25+
extracted_diagonal = wrapper.diag_extract(diag_array, 0)
26+
27+
assert wrapper.get_scalar(extracted_diagonal, dtype) == wrapper.get_scalar(in_arr, dtype)
28+
29+
30+
@pytest.mark.parametrize(
31+
"diagonal_shape",
32+
[
33+
(10, 10, 10),
34+
(100, 100, 100, 100),
35+
],
36+
)
37+
def test_invalid_diagonal(diagonal_shape: tuple) -> None:
38+
"""Test if an invalid diagonal shape is being properly handled"""
39+
with pytest.raises(RuntimeError):
40+
in_arr = wrapper.constant(1, diagonal_shape, dtypes.s16)
41+
diag_array = wrapper.diag_create(in_arr, 0)
42+
43+
wrapper.diag_extract(diag_array, 0)

tests/test_random.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.dtypes as dtypes
6+
import arrayfire_wrapper.lib as wrapper
7+
8+
9+
@pytest.mark.parametrize(
10+
"shape",
11+
[
12+
(),
13+
(random.randint(1, 10), 1),
14+
(random.randint(1, 10), random.randint(1, 10)),
15+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
16+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
17+
],
18+
)
19+
def test_randu_shape(shape: tuple) -> None:
20+
"""Test if randu function creates an array with the correct shape."""
21+
dtype = dtypes.s16
22+
23+
result = wrapper.randu(shape, dtype)
24+
25+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
26+
27+
28+
@pytest.mark.parametrize(
29+
"shape",
30+
[
31+
(),
32+
(random.randint(1, 10), 1),
33+
(random.randint(1, 10), random.randint(1, 10)),
34+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
35+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
36+
],
37+
)
38+
def test_random_uniform_shape(shape: tuple) -> None:
39+
"""Test if rand uniform function creates an array with the correct shape."""
40+
dtype = dtypes.s16
41+
engine = wrapper.create_random_engine(100, 10)
42+
43+
result = wrapper.random_uniform(shape, dtype, engine)
44+
45+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
46+
47+
48+
@pytest.mark.parametrize(
49+
"shape",
50+
[
51+
(),
52+
(random.randint(1, 10), 1),
53+
(random.randint(1, 10), random.randint(1, 10)),
54+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
55+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
56+
],
57+
)
58+
def test_randn_shape(shape: tuple) -> None:
59+
"""Test if randn function creates an array with the correct shape."""
60+
dtype = dtypes.f32
61+
62+
result = wrapper.randn(shape, dtype)
63+
64+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
65+
66+
67+
@pytest.mark.parametrize(
68+
"shape",
69+
[
70+
(),
71+
(random.randint(1, 10), 1),
72+
(random.randint(1, 10), random.randint(1, 10)),
73+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
74+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
75+
],
76+
)
77+
def test_random_normal_shape(shape: tuple) -> None:
78+
"""Test if random normal function creates an array with the correct shape."""
79+
dtype = dtypes.f32
80+
engine = wrapper.create_random_engine(100, 10)
81+
82+
result = wrapper.random_normal(shape, dtype, engine)
83+
84+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
85+
86+
87+
@pytest.mark.parametrize(
88+
"engine_index",
89+
[100, 200, 300],
90+
)
91+
def test_create_random_engine(engine_index: int) -> None:
92+
engine = wrapper.create_random_engine(engine_index, 10)
93+
94+
engine_type = wrapper.random_engine_get_type(engine)
95+
96+
assert engine_type == engine_index
97+
98+
99+
@pytest.mark.parametrize(
100+
"invalid_index",
101+
[random.randint(301, 600), random.randint(301, 600), random.randint(301, 600)],
102+
)
103+
def test_invalid_random_engine(invalid_index: int) -> None:
104+
"Test if invalid engine types are properly handled"
105+
with pytest.raises(RuntimeError):
106+
107+
invalid_engine = wrapper.create_random_engine(invalid_index, 10)
108+
109+
engine_type = wrapper.random_engine_get_type(invalid_engine)
110+
111+
assert engine_type == invalid_engine

0 commit comments

Comments
 (0)