2
2
3
3
import pytest
4
4
5
- import arrayfire_wrapper .dtypes as dtypes
6
5
import arrayfire_wrapper .lib as wrapper
6
+ from arrayfire_wrapper .dtypes import (
7
+ Dtype ,
8
+ c32 ,
9
+ c64 ,
10
+ c_api_value_to_dtype ,
11
+ f16 ,
12
+ f32 ,
13
+ f64 ,
14
+ s16 ,
15
+ s32 ,
16
+ s64 ,
17
+ u8 ,
18
+ u16 ,
19
+ u32 ,
20
+ u64 ,
21
+ )
7
22
8
23
invalid_shape = (
9
24
random .randint (1 , 10 ),
14
29
)
15
30
16
31
32
+ types = [s16 , s32 , s64 , u8 , u16 , u32 , u64 , f16 , f32 , f64 , c32 , c64 ]
33
+
34
+
17
35
@pytest .mark .parametrize (
18
36
"shape" ,
19
37
[
27
45
def test_constant_shape (shape : tuple ) -> None :
28
46
"""Test if constant creates an array with the correct shape."""
29
47
number = 5.0
30
- dtype = dtypes . s16
48
+ dtype = s16
31
49
32
50
result = wrapper .constant (number , shape , dtype )
33
51
@@ -46,9 +64,9 @@ def test_constant_shape(shape: tuple) -> None:
46
64
)
47
65
def test_constant_complex_shape (shape : tuple ) -> None :
48
66
"""Test if constant_complex creates an array with the correct shape."""
49
- dtype = dtypes . c32
67
+ dtype = c32
50
68
51
- dtype = dtypes . c32
69
+ dtype = c32
52
70
rand_array = wrapper .randu ((1 , 1 ), dtype )
53
71
number = wrapper .get_scalar (rand_array , dtype )
54
72
@@ -71,7 +89,7 @@ def test_constant_complex_shape(shape: tuple) -> None:
71
89
)
72
90
def test_constant_long_shape (shape : tuple ) -> None :
73
91
"""Test if constant_long creates an array with the correct shape."""
74
- dtype = dtypes . s64
92
+ dtype = s64
75
93
rand_array = wrapper .randu ((1 , 1 ), dtype )
76
94
number = wrapper .get_scalar (rand_array , dtype )
77
95
@@ -93,7 +111,7 @@ def test_constant_long_shape(shape: tuple) -> None:
93
111
)
94
112
def test_constant_ulong_shape (shape : tuple ) -> None :
95
113
"""Test if constant_ulong creates an array with the correct shape."""
96
- dtype = dtypes . u64
114
+ dtype = u64
97
115
rand_array = wrapper .randu ((1 , 1 ), dtype )
98
116
number = wrapper .get_scalar (rand_array , dtype )
99
117
@@ -109,15 +127,15 @@ def test_constant_shape_invalid() -> None:
109
127
"""Test if constant handles a shape with greater than 4 dimensions"""
110
128
with pytest .raises (TypeError ):
111
129
number = 5.0
112
- dtype = dtypes . s16
130
+ dtype = s16
113
131
114
132
wrapper .constant (number , invalid_shape , dtype )
115
133
116
134
117
135
def test_constant_complex_shape_invalid () -> None :
118
136
"""Test if constant_complex handles a shape with greater than 4 dimensions"""
119
137
with pytest .raises (TypeError ):
120
- dtype = dtypes . c32
138
+ dtype = c32
121
139
rand_array = wrapper .randu ((1 , 1 ), dtype )
122
140
number = wrapper .get_scalar (rand_array , dtype )
123
141
@@ -128,7 +146,7 @@ def test_constant_complex_shape_invalid() -> None:
128
146
def test_constant_long_shape_invalid () -> None :
129
147
"""Test if constant_long handles a shape with greater than 4 dimensions"""
130
148
with pytest .raises (TypeError ):
131
- dtype = dtypes . s64
149
+ dtype = s64
132
150
rand_array = wrapper .randu ((1 , 1 ), dtype )
133
151
number = wrapper .get_scalar (rand_array , dtype )
134
152
@@ -139,7 +157,7 @@ def test_constant_long_shape_invalid() -> None:
139
157
def test_constant_ulong_shape_invalid () -> None :
140
158
"""Test if constant_ulong handles a shape with greater than 4 dimensions"""
141
159
with pytest .raises (TypeError ):
142
- dtype = dtypes . u64
160
+ dtype = u64
143
161
rand_array = wrapper .randu ((1 , 1 ), dtype )
144
162
number = wrapper .get_scalar (rand_array , dtype )
145
163
@@ -148,50 +166,47 @@ def test_constant_ulong_shape_invalid() -> None:
148
166
149
167
150
168
@pytest .mark .parametrize (
151
- "dtype_index " ,
152
- [ i for i in range ( 13 )] ,
169
+ "dtype " ,
170
+ types ,
153
171
)
154
- def test_constant_dtype (dtype_index : int ) -> None :
172
+ def test_constant_dtype (dtype : Dtype ) -> None :
155
173
"""Test if constant creates an array with the correct dtype."""
156
- if dtype_index in [1 , 3 ] or (dtype_index == 2 and not wrapper .get_dbl_support ()):
174
+ if dtype in [c32 , c64 ] or (dtype == f64 and not wrapper .get_dbl_support ()):
157
175
pytest .skip ()
158
176
159
- dtype = dtypes .c_api_value_to_dtype (dtype_index )
160
-
161
177
rand_array = wrapper .randu ((1 , 1 ), dtype )
162
178
value = wrapper .get_scalar (rand_array , dtype )
163
179
shape = (2 , 2 )
164
180
if isinstance (value , (int , float )):
165
181
result = wrapper .constant (value , shape , dtype )
166
- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
182
+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
167
183
else :
168
184
pytest .skip ()
169
185
170
186
171
187
@pytest .mark .parametrize (
172
- "dtype_index " ,
173
- [ i for i in range ( 13 )] ,
188
+ "dtype " ,
189
+ types ,
174
190
)
175
- def test_constant_complex_dtype (dtype_index : int ) -> None :
191
+ def test_constant_complex_dtype (dtype : Dtype ) -> None :
176
192
"""Test if constant_complex creates an array with the correct dtype."""
177
- if dtype_index not in [1 , 3 ] or (dtype_index == 3 and not wrapper .get_dbl_support ()):
193
+ if dtype not in [c32 , c64 ] or (dtype == c64 and not wrapper .get_dbl_support ()):
178
194
pytest .skip ()
179
195
180
- dtype = dtypes .c_api_value_to_dtype (dtype_index )
181
196
rand_array = wrapper .randu ((1 , 1 ), dtype )
182
197
value = wrapper .get_scalar (rand_array , dtype )
183
198
shape = (2 , 2 )
184
199
185
200
if isinstance (value , (int , float , complex )):
186
201
result = wrapper .constant_complex (value , shape , dtype )
187
- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
202
+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
188
203
else :
189
204
pytest .skip ()
190
205
191
206
192
207
def test_constant_long_dtype () -> None :
193
208
"""Test if constant_long creates an array with the correct dtype."""
194
- dtype = dtypes . s64
209
+ dtype = s64
195
210
196
211
rand_array = wrapper .randu ((1 , 1 ), dtype )
197
212
value = wrapper .get_scalar (rand_array , dtype )
@@ -200,14 +215,14 @@ def test_constant_long_dtype() -> None:
200
215
if isinstance (value , (int , float )):
201
216
result = wrapper .constant_long (value , shape , dtype )
202
217
203
- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
218
+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
204
219
else :
205
220
pytest .skip ()
206
221
207
222
208
223
def test_constant_ulong_dtype () -> None :
209
224
"""Test if constant_ulong creates an array with the correct dtype."""
210
- dtype = dtypes . u64
225
+ dtype = u64
211
226
212
227
rand_array = wrapper .randu ((1 , 1 ), dtype )
213
228
value = wrapper .get_scalar (rand_array , dtype )
@@ -216,6 +231,6 @@ def test_constant_ulong_dtype() -> None:
216
231
if isinstance (value , (int , float )):
217
232
result = wrapper .constant_ulong (value , shape , dtype )
218
233
219
- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
234
+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
220
235
else :
221
236
pytest .skip ()
0 commit comments