Skip to content

Commit 1f8d475

Browse files
author
Chaluvadi
committed
Readability changes to cosntants tests
1 parent f90ef61 commit 1f8d475

File tree

1 file changed

+42
-27
lines changed

1 file changed

+42
-27
lines changed

tests/test_constants.py

+42-27
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,23 @@
22

33
import pytest
44

5-
import arrayfire_wrapper.dtypes as dtypes
65
import arrayfire_wrapper.lib as wrapper
6+
from arrayfire_wrapper.dtypes import (
7+
Dtype,
8+
c32,
9+
c64,
10+
c_api_value_to_dtype,
11+
f16,
12+
f32,
13+
f64,
14+
s16,
15+
s32,
16+
s64,
17+
u8,
18+
u16,
19+
u32,
20+
u64,
21+
)
722

823
invalid_shape = (
924
random.randint(1, 10),
@@ -14,6 +29,9 @@
1429
)
1530

1631

32+
types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]
33+
34+
1735
@pytest.mark.parametrize(
1836
"shape",
1937
[
@@ -27,7 +45,7 @@
2745
def test_constant_shape(shape: tuple) -> None:
2846
"""Test if constant creates an array with the correct shape."""
2947
number = 5.0
30-
dtype = dtypes.s16
48+
dtype = s16
3149

3250
result = wrapper.constant(number, shape, dtype)
3351

@@ -46,9 +64,9 @@ def test_constant_shape(shape: tuple) -> None:
4664
)
4765
def test_constant_complex_shape(shape: tuple) -> None:
4866
"""Test if constant_complex creates an array with the correct shape."""
49-
dtype = dtypes.c32
67+
dtype = c32
5068

51-
dtype = dtypes.c32
69+
dtype = c32
5270
rand_array = wrapper.randu((1, 1), dtype)
5371
number = wrapper.get_scalar(rand_array, dtype)
5472

@@ -71,7 +89,7 @@ def test_constant_complex_shape(shape: tuple) -> None:
7189
)
7290
def test_constant_long_shape(shape: tuple) -> None:
7391
"""Test if constant_long creates an array with the correct shape."""
74-
dtype = dtypes.s64
92+
dtype = s64
7593
rand_array = wrapper.randu((1, 1), dtype)
7694
number = wrapper.get_scalar(rand_array, dtype)
7795

@@ -93,7 +111,7 @@ def test_constant_long_shape(shape: tuple) -> None:
93111
)
94112
def test_constant_ulong_shape(shape: tuple) -> None:
95113
"""Test if constant_ulong creates an array with the correct shape."""
96-
dtype = dtypes.u64
114+
dtype = u64
97115
rand_array = wrapper.randu((1, 1), dtype)
98116
number = wrapper.get_scalar(rand_array, dtype)
99117

@@ -109,15 +127,15 @@ def test_constant_shape_invalid() -> None:
109127
"""Test if constant handles a shape with greater than 4 dimensions"""
110128
with pytest.raises(TypeError):
111129
number = 5.0
112-
dtype = dtypes.s16
130+
dtype = s16
113131

114132
wrapper.constant(number, invalid_shape, dtype)
115133

116134

117135
def test_constant_complex_shape_invalid() -> None:
118136
"""Test if constant_complex handles a shape with greater than 4 dimensions"""
119137
with pytest.raises(TypeError):
120-
dtype = dtypes.c32
138+
dtype = c32
121139
rand_array = wrapper.randu((1, 1), dtype)
122140
number = wrapper.get_scalar(rand_array, dtype)
123141

@@ -128,7 +146,7 @@ def test_constant_complex_shape_invalid() -> None:
128146
def test_constant_long_shape_invalid() -> None:
129147
"""Test if constant_long handles a shape with greater than 4 dimensions"""
130148
with pytest.raises(TypeError):
131-
dtype = dtypes.s64
149+
dtype = s64
132150
rand_array = wrapper.randu((1, 1), dtype)
133151
number = wrapper.get_scalar(rand_array, dtype)
134152

@@ -139,7 +157,7 @@ def test_constant_long_shape_invalid() -> None:
139157
def test_constant_ulong_shape_invalid() -> None:
140158
"""Test if constant_ulong handles a shape with greater than 4 dimensions"""
141159
with pytest.raises(TypeError):
142-
dtype = dtypes.u64
160+
dtype = u64
143161
rand_array = wrapper.randu((1, 1), dtype)
144162
number = wrapper.get_scalar(rand_array, dtype)
145163

@@ -148,50 +166,47 @@ def test_constant_ulong_shape_invalid() -> None:
148166

149167

150168
@pytest.mark.parametrize(
151-
"dtype_index",
152-
[i for i in range(13)],
169+
"dtype",
170+
types,
153171
)
154-
def test_constant_dtype(dtype_index: int) -> None:
172+
def test_constant_dtype(dtype: Dtype) -> None:
155173
"""Test if constant creates an array with the correct dtype."""
156-
if dtype_index in [1, 3] or (dtype_index == 2 and not wrapper.get_dbl_support()):
174+
if dtype in [c32, c64] or (dtype == f64 and not wrapper.get_dbl_support()):
157175
pytest.skip()
158176

159-
dtype = dtypes.c_api_value_to_dtype(dtype_index)
160-
161177
rand_array = wrapper.randu((1, 1), dtype)
162178
value = wrapper.get_scalar(rand_array, dtype)
163179
shape = (2, 2)
164180
if isinstance(value, (int, float)):
165181
result = wrapper.constant(value, shape, dtype)
166-
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
182+
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
167183
else:
168184
pytest.skip()
169185

170186

171187
@pytest.mark.parametrize(
172-
"dtype_index",
173-
[i for i in range(13)],
188+
"dtype",
189+
types,
174190
)
175-
def test_constant_complex_dtype(dtype_index: int) -> None:
191+
def test_constant_complex_dtype(dtype: Dtype) -> None:
176192
"""Test if constant_complex creates an array with the correct dtype."""
177-
if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()):
193+
if dtype not in [c32, c64] or (dtype == c64 and not wrapper.get_dbl_support()):
178194
pytest.skip()
179195

180-
dtype = dtypes.c_api_value_to_dtype(dtype_index)
181196
rand_array = wrapper.randu((1, 1), dtype)
182197
value = wrapper.get_scalar(rand_array, dtype)
183198
shape = (2, 2)
184199

185200
if isinstance(value, (int, float, complex)):
186201
result = wrapper.constant_complex(value, shape, dtype)
187-
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
202+
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
188203
else:
189204
pytest.skip()
190205

191206

192207
def test_constant_long_dtype() -> None:
193208
"""Test if constant_long creates an array with the correct dtype."""
194-
dtype = dtypes.s64
209+
dtype = s64
195210

196211
rand_array = wrapper.randu((1, 1), dtype)
197212
value = wrapper.get_scalar(rand_array, dtype)
@@ -200,14 +215,14 @@ def test_constant_long_dtype() -> None:
200215
if isinstance(value, (int, float)):
201216
result = wrapper.constant_long(value, shape, dtype)
202217

203-
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
218+
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
204219
else:
205220
pytest.skip()
206221

207222

208223
def test_constant_ulong_dtype() -> None:
209224
"""Test if constant_ulong creates an array with the correct dtype."""
210-
dtype = dtypes.u64
225+
dtype = u64
211226

212227
rand_array = wrapper.randu((1, 1), dtype)
213228
value = wrapper.get_scalar(rand_array, dtype)
@@ -216,6 +231,6 @@ def test_constant_ulong_dtype() -> None:
216231
if isinstance(value, (int, float)):
217232
result = wrapper.constant_ulong(value, shape, dtype)
218233

219-
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
234+
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
220235
else:
221236
pytest.skip()

0 commit comments

Comments
 (0)