Skip to content

Commit e404ab7

Browse files
sakchalChaluvadi
and
Chaluvadi
authored
Add constant tests (#17)
* added constant tests * Fixed automatic checks - styling, import order, lint, static type checking * Corrected redundant checks within tests * Fixed import formatting and automatic checks * Added back in complex tests, made fix to get_scalar function in manage_array --------- Co-authored-by: Chaluvadi <[email protected]>
1 parent d64894d commit e404ab7

File tree

2 files changed

+226
-2
lines changed

2 files changed

+226
-2
lines changed

arrayfire_wrapper/lib/create_and_modify_array/manage_array.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import cast
33

44
from arrayfire_wrapper.defines import AFArray, ArrayBuffer, CDimT, CShape
5-
from arrayfire_wrapper.dtypes import Dtype
5+
from arrayfire_wrapper.dtypes import Dtype, c32, c64
66
from arrayfire_wrapper.lib._utility import call_from_clib
77

88

@@ -165,7 +165,10 @@ def get_scalar(arr: AFArray, dtype: Dtype, /) -> int | float | complex | bool |
165165
"""
166166
out = dtype.c_type()
167167
call_from_clib(get_scalar.__name__, ctypes.pointer(out), arr)
168-
return cast(int | float | complex | bool | None, out.value)
168+
if dtype == c32 or dtype == c64:
169+
return complex(out[0], out[1]) # type: ignore
170+
else:
171+
return cast(int | float | complex | bool | None, out.value)
169172

170173

171174
def get_type(arr: AFArray, /) -> int:

tests/test_constants.py

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.dtypes as dtypes
6+
import arrayfire_wrapper.lib as wrapper
7+
8+
invalid_shape = (
9+
random.randint(1, 10),
10+
random.randint(1, 10),
11+
random.randint(1, 10),
12+
random.randint(1, 10),
13+
random.randint(1, 10),
14+
)
15+
16+
17+
@pytest.mark.parametrize(
18+
"shape",
19+
[
20+
(),
21+
(random.randint(1, 10), 1),
22+
(random.randint(1, 10), random.randint(1, 10)),
23+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
24+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
25+
],
26+
)
27+
def test_constant_shape(shape: tuple) -> None:
28+
"""Test if constant creates an array with the correct shape."""
29+
number = 5.0
30+
dtype = dtypes.s16
31+
32+
result = wrapper.constant(number, shape, dtype)
33+
34+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
35+
36+
37+
@pytest.mark.parametrize(
38+
"shape",
39+
[
40+
(),
41+
(random.randint(1, 10), 1),
42+
(random.randint(1, 10), random.randint(1, 10)),
43+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
44+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
45+
],
46+
)
47+
def test_constant_complex_shape(shape: tuple) -> None:
48+
"""Test if constant_complex creates an array with the correct shape."""
49+
dtype = dtypes.c32
50+
51+
dtype = dtypes.c32
52+
rand_array = wrapper.randu((1, 1), dtype)
53+
number = wrapper.get_scalar(rand_array, dtype)
54+
55+
if isinstance(number, (complex)):
56+
result = wrapper.constant_complex(number, shape, dtype)
57+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
58+
else:
59+
pytest.skip()
60+
61+
62+
@pytest.mark.parametrize(
63+
"shape",
64+
[
65+
(),
66+
(random.randint(1, 10), 1),
67+
(random.randint(1, 10), random.randint(1, 10)),
68+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
69+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
70+
],
71+
)
72+
def test_constant_long_shape(shape: tuple) -> None:
73+
"""Test if constant_long creates an array with the correct shape."""
74+
dtype = dtypes.s64
75+
rand_array = wrapper.randu((1, 1), dtype)
76+
number = wrapper.get_scalar(rand_array, dtype)
77+
78+
if isinstance(number, (int, float)):
79+
result = wrapper.constant_long(number, shape, dtype)
80+
81+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
82+
83+
84+
@pytest.mark.parametrize(
85+
"shape",
86+
[
87+
(),
88+
(random.randint(1, 10), 1),
89+
(random.randint(1, 10), random.randint(1, 10)),
90+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
91+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
92+
],
93+
)
94+
def test_constant_ulong_shape(shape: tuple) -> None:
95+
"""Test if constant_ulong creates an array with the correct shape."""
96+
dtype = dtypes.u64
97+
rand_array = wrapper.randu((1, 1), dtype)
98+
number = wrapper.get_scalar(rand_array, dtype)
99+
100+
if isinstance(number, (int, float)):
101+
result = wrapper.constant_ulong(number, shape, dtype)
102+
103+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
104+
else:
105+
pytest.skip()
106+
107+
108+
def test_constant_shape_invalid() -> None:
109+
"""Test if constant handles a shape with greater than 4 dimensions"""
110+
with pytest.raises(TypeError):
111+
number = 5.0
112+
dtype = dtypes.s16
113+
114+
wrapper.constant(number, invalid_shape, dtype)
115+
116+
117+
def test_constant_complex_shape_invalid() -> None:
118+
"""Test if constant_complex handles a shape with greater than 4 dimensions"""
119+
with pytest.raises(TypeError):
120+
dtype = dtypes.c32
121+
rand_array = wrapper.randu((1, 1), dtype)
122+
number = wrapper.get_scalar(rand_array, dtype)
123+
124+
if isinstance(number, (int, float, complex)):
125+
wrapper.constant_complex(number, invalid_shape, dtype)
126+
127+
128+
def test_constant_long_shape_invalid() -> None:
129+
"""Test if constant_long handles a shape with greater than 4 dimensions"""
130+
with pytest.raises(TypeError):
131+
dtype = dtypes.s64
132+
rand_array = wrapper.randu((1, 1), dtype)
133+
number = wrapper.get_scalar(rand_array, dtype)
134+
135+
if isinstance(number, (int, float)):
136+
wrapper.constant_long(number, invalid_shape, dtype)
137+
138+
139+
def test_constant_ulong_shape_invalid() -> None:
140+
"""Test if constant_ulong handles a shape with greater than 4 dimensions"""
141+
with pytest.raises(TypeError):
142+
dtype = dtypes.u64
143+
rand_array = wrapper.randu((1, 1), dtype)
144+
number = wrapper.get_scalar(rand_array, dtype)
145+
146+
if isinstance(number, (int, float)):
147+
wrapper.constant_ulong(number, invalid_shape, dtype)
148+
149+
150+
@pytest.mark.parametrize(
151+
"dtype_index",
152+
[i for i in range(13)],
153+
)
154+
def test_constant_dtype(dtype_index: int) -> None:
155+
"""Test if constant creates an array with the correct dtype."""
156+
if dtype_index in [1, 3] or (dtype_index == 2 and not wrapper.get_dbl_support()):
157+
pytest.skip()
158+
159+
dtype = dtypes.c_api_value_to_dtype(dtype_index)
160+
161+
rand_array = wrapper.randu((1, 1), dtype)
162+
value = wrapper.get_scalar(rand_array, dtype)
163+
shape = (2, 2)
164+
if isinstance(value, (int, float)):
165+
result = wrapper.constant(value, shape, dtype)
166+
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
167+
else:
168+
pytest.skip()
169+
170+
171+
@pytest.mark.parametrize(
172+
"dtype_index",
173+
[i for i in range(13)],
174+
)
175+
def test_constant_complex_dtype(dtype_index: int) -> None:
176+
"""Test if constant_complex creates an array with the correct dtype."""
177+
if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()):
178+
pytest.skip()
179+
180+
dtype = dtypes.c_api_value_to_dtype(dtype_index)
181+
rand_array = wrapper.randu((1, 1), dtype)
182+
value = wrapper.get_scalar(rand_array, dtype)
183+
shape = (2, 2)
184+
185+
if isinstance(value, (int, float, complex)):
186+
result = wrapper.constant_complex(value, shape, dtype)
187+
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
188+
else:
189+
pytest.skip()
190+
191+
192+
def test_constant_long_dtype() -> None:
193+
"""Test if constant_long creates an array with the correct dtype."""
194+
dtype = dtypes.s64
195+
196+
rand_array = wrapper.randu((1, 1), dtype)
197+
value = wrapper.get_scalar(rand_array, dtype)
198+
shape = (2, 2)
199+
200+
if isinstance(value, (int, float)):
201+
result = wrapper.constant_long(value, shape, dtype)
202+
203+
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
204+
else:
205+
pytest.skip()
206+
207+
208+
def test_constant_ulong_dtype() -> None:
209+
"""Test if constant_ulong creates an array with the correct dtype."""
210+
dtype = dtypes.u64
211+
212+
rand_array = wrapper.randu((1, 1), dtype)
213+
value = wrapper.get_scalar(rand_array, dtype)
214+
shape = (2, 2)
215+
216+
if isinstance(value, (int, float)):
217+
result = wrapper.constant_ulong(value, shape, dtype)
218+
219+
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
220+
else:
221+
pytest.skip()

0 commit comments

Comments
 (0)