Skip to content

Commit 527b807

Browse files
author
Chaluvadi
committedMar 12, 2024·
readability changes pt.2
1 parent 1f8d475 commit 527b807

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed
 

‎tests/test_constants.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030

3131

32-
types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]
32+
all_types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]
3333

3434

3535
@pytest.mark.parametrize(
@@ -66,7 +66,6 @@ def test_constant_complex_shape(shape: tuple) -> None:
6666
"""Test if constant_complex creates an array with the correct shape."""
6767
dtype = c32
6868

69-
dtype = c32
7069
rand_array = wrapper.randu((1, 1), dtype)
7170
number = wrapper.get_scalar(rand_array, dtype)
7271

@@ -167,11 +166,11 @@ def test_constant_ulong_shape_invalid() -> None:
167166

168167
@pytest.mark.parametrize(
169168
"dtype",
170-
types,
169+
all_types,
171170
)
172171
def test_constant_dtype(dtype: Dtype) -> None:
173172
"""Test if constant creates an array with the correct dtype."""
174-
if dtype in [c32, c64] or (dtype == f64 and not wrapper.get_dbl_support()):
173+
if is_cmplx_type(dtype) or not is_system_supported(dtype):
175174
pytest.skip()
176175

177176
rand_array = wrapper.randu((1, 1), dtype)
@@ -186,11 +185,11 @@ def test_constant_dtype(dtype: Dtype) -> None:
186185

187186
@pytest.mark.parametrize(
188187
"dtype",
189-
types,
188+
all_types,
190189
)
191190
def test_constant_complex_dtype(dtype: Dtype) -> None:
192191
"""Test if constant_complex creates an array with the correct dtype."""
193-
if dtype not in [c32, c64] or (dtype == c64 and not wrapper.get_dbl_support()):
192+
if not is_cmplx_type(dtype) or not is_system_supported(dtype):
194193
pytest.skip()
195194

196195
rand_array = wrapper.randu((1, 1), dtype)
@@ -234,3 +233,14 @@ def test_constant_ulong_dtype() -> None:
234233
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
235234
else:
236235
pytest.skip()
236+
237+
238+
def is_cmplx_type(dtype: Dtype) -> bool:
239+
return dtype == c32 or dtype == c64
240+
241+
242+
def is_system_supported(dtype: Dtype) -> bool:
243+
if dtype in [f64, c64] and not wrapper.get_dbl_support():
244+
return False
245+
246+
return True

0 commit comments

Comments
 (0)