|
4 | 4 |
|
5 | 5 | import arrayfire_wrapper.dtypes as dtype |
6 | 6 | import arrayfire_wrapper.lib as wrapper |
7 | | - |
8 | | -dtype_map = { |
9 | | - "int16": dtype.s16, |
10 | | - "int32": dtype.s32, |
11 | | - "int64": dtype.s64, |
12 | | - "uint8": dtype.u8, |
13 | | - "uint16": dtype.u16, |
14 | | - "uint32": dtype.u32, |
15 | | - "uint64": dtype.u64, |
16 | | - "float16": dtype.f16, |
17 | | - "float32": dtype.f32, |
18 | | - # 'float64': dtype.f64, |
19 | | - # 'complex64': dtype.c64, |
20 | | - "complex32": dtype.c32, |
21 | | - "bool": dtype.b8, |
22 | | - "s16": dtype.s16, |
23 | | - "s32": dtype.s32, |
24 | | - "s64": dtype.s64, |
25 | | - "u8": dtype.u8, |
26 | | - "u16": dtype.u16, |
27 | | - "u32": dtype.u32, |
28 | | - "u64": dtype.u64, |
29 | | - "f16": dtype.f16, |
30 | | - "f32": dtype.f32, |
31 | | - # 'f64': dtype.f64, |
32 | | - "c32": dtype.c32, |
33 | | - # 'c64': dtype.c64, |
34 | | - "b8": dtype.b8, |
35 | | -} |
| 7 | +from tests.utility_functions import check_type_supported, get_all_types |
36 | 8 |
|
37 | 9 |
|
38 | 10 | @pytest.mark.parametrize( |
|
45 | 17 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
46 | 18 | ], |
47 | 19 | ) |
48 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 20 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
49 | 21 | def test_and_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
50 | 22 | """Test and_ operation between two arrays of the same shape""" |
| 23 | + check_type_supported(dtype_name) |
51 | 24 | lhs = wrapper.randu(shape, dtype_name) |
52 | 25 | rhs = wrapper.randu(shape, dtype_name) |
53 | 26 |
|
@@ -89,17 +62,10 @@ def test_and_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
89 | 62 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
90 | 63 | ], |
91 | 64 | ) |
92 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 65 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
93 | 66 | def test_bitand_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
94 | 67 | """Test bitand operation between two arrays of the same shape""" |
95 | | - if ( |
96 | | - dtype_name == dtype.c32 |
97 | | - or dtype_name == dtype.c64 |
98 | | - or dtype_name == dtype.f32 |
99 | | - or dtype_name == dtype.f64 |
100 | | - or dtype_name == dtype.f16 |
101 | | - ): |
102 | | - pytest.skip() |
| 68 | + check_type_supported(dtype_name) |
103 | 69 | lhs = wrapper.randu(shape, dtype_name) |
104 | 70 | rhs = wrapper.randu(shape, dtype_name) |
105 | 71 |
|
@@ -136,17 +102,10 @@ def test_bitand_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
136 | 102 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
137 | 103 | ], |
138 | 104 | ) |
139 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 105 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
140 | 106 | def test_bitnot_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
141 | 107 | """Test bitnot operation between two arrays of the same shape""" |
142 | | - if ( |
143 | | - dtype_name == dtype.c32 |
144 | | - or dtype_name == dtype.c64 |
145 | | - or dtype_name == dtype.f32 |
146 | | - or dtype_name == dtype.f64 |
147 | | - or dtype_name == dtype.f16 |
148 | | - ): |
149 | | - pytest.skip() |
| 108 | + check_type_supported(dtype_name) |
150 | 109 | out = wrapper.randu(shape, dtype_name) |
151 | 110 |
|
152 | 111 | result = wrapper.bitnot(out) |
@@ -181,17 +140,10 @@ def test_bitnot_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
181 | 140 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
182 | 141 | ], |
183 | 142 | ) |
184 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 143 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
185 | 144 | def test_bitor_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
186 | 145 | """Test bitor operation between two arrays of the same shape""" |
187 | | - if ( |
188 | | - dtype_name == dtype.c32 |
189 | | - or dtype_name == dtype.c64 |
190 | | - or dtype_name == dtype.f32 |
191 | | - or dtype_name == dtype.f64 |
192 | | - or dtype_name == dtype.f16 |
193 | | - ): |
194 | | - pytest.skip() |
| 146 | + check_type_supported(dtype_name) |
195 | 147 | lhs = wrapper.randu(shape, dtype_name) |
196 | 148 | rhs = wrapper.randu(shape, dtype_name) |
197 | 149 |
|
@@ -228,17 +180,10 @@ def test_bitor_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
228 | 180 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
229 | 181 | ], |
230 | 182 | ) |
231 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 183 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
232 | 184 | def test_bitxor_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
233 | 185 | """Test bitxor operation between two arrays of the same shape""" |
234 | | - if ( |
235 | | - dtype_name == dtype.c32 |
236 | | - or dtype_name == dtype.c64 |
237 | | - or dtype_name == dtype.f32 |
238 | | - or dtype_name == dtype.f64 |
239 | | - or dtype_name == dtype.f16 |
240 | | - ): |
241 | | - pytest.skip() |
| 186 | + check_type_supported(dtype_name) |
242 | 187 | lhs = wrapper.randu(shape, dtype_name) |
243 | 188 | rhs = wrapper.randu(shape, dtype_name) |
244 | 189 |
|
@@ -275,17 +220,10 @@ def test_bitxor_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
275 | 220 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
276 | 221 | ], |
277 | 222 | ) |
278 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 223 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
279 | 224 | def test_eq_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
280 | 225 | """Test eq operation between two arrays of the same shape""" |
281 | | - if ( |
282 | | - dtype_name == dtype.c32 |
283 | | - or dtype_name == dtype.c64 |
284 | | - or dtype_name == dtype.f32 |
285 | | - or dtype_name == dtype.f64 |
286 | | - or dtype_name == dtype.f16 |
287 | | - ): |
288 | | - pytest.skip() |
| 226 | + check_type_supported(dtype_name) |
289 | 227 | lhs = wrapper.randu(shape, dtype_name) |
290 | 228 | rhs = wrapper.randu(shape, dtype_name) |
291 | 229 |
|
@@ -322,9 +260,10 @@ def test_eq_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
322 | 260 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
323 | 261 | ], |
324 | 262 | ) |
325 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 263 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
326 | 264 | def test_ge_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
327 | 265 | """Test >= operation between two arrays of the same shape""" |
| 266 | + check_type_supported(dtype_name) |
328 | 267 | lhs = wrapper.randu(shape, dtype_name) |
329 | 268 | rhs = wrapper.randu(shape, dtype_name) |
330 | 269 |
|
@@ -361,9 +300,10 @@ def test_ge_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
361 | 300 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
362 | 301 | ], |
363 | 302 | ) |
364 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 303 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
365 | 304 | def test_gt_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
366 | 305 | """Test > operation between two arrays of the same shape""" |
| 306 | + check_type_supported(dtype_name) |
367 | 307 | lhs = wrapper.randu(shape, dtype_name) |
368 | 308 | rhs = wrapper.randu(shape, dtype_name) |
369 | 309 |
|
@@ -400,9 +340,10 @@ def test_gt_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
400 | 340 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
401 | 341 | ], |
402 | 342 | ) |
403 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 343 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
404 | 344 | def test_le_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
405 | 345 | """Test <= operation between two arrays of the same shape""" |
| 346 | + check_type_supported(dtype_name) |
406 | 347 | lhs = wrapper.randu(shape, dtype_name) |
407 | 348 | rhs = wrapper.randu(shape, dtype_name) |
408 | 349 |
|
@@ -439,9 +380,10 @@ def test_le_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
439 | 380 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
440 | 381 | ], |
441 | 382 | ) |
442 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 383 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
443 | 384 | def test_lt_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
444 | 385 | """Test < operation between two arrays of the same shape""" |
| 386 | + check_type_supported(dtype_name) |
445 | 387 | lhs = wrapper.randu(shape, dtype_name) |
446 | 388 | rhs = wrapper.randu(shape, dtype_name) |
447 | 389 |
|
@@ -478,9 +420,10 @@ def test_lt_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
478 | 420 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
479 | 421 | ], |
480 | 422 | ) |
481 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 423 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
482 | 424 | def test_neq_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
483 | 425 | """Test not equal operation between two arrays of the same shape""" |
| 426 | + check_type_supported(dtype_name) |
484 | 427 | lhs = wrapper.randu(shape, dtype_name) |
485 | 428 | rhs = wrapper.randu(shape, dtype_name) |
486 | 429 |
|
@@ -517,17 +460,10 @@ def test_neq_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
517 | 460 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
518 | 461 | ], |
519 | 462 | ) |
520 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 463 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
521 | 464 | def test_not_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
522 | 465 | """Test not operation between two arrays of the same shape""" |
523 | | - if ( |
524 | | - dtype_name == dtype.c32 |
525 | | - or dtype_name == dtype.c64 |
526 | | - or dtype_name == dtype.f32 |
527 | | - or dtype_name == dtype.f64 |
528 | | - or dtype_name == dtype.f16 |
529 | | - ): |
530 | | - pytest.skip() |
| 466 | + check_type_supported(dtype_name) |
531 | 467 | out = wrapper.randu(shape, dtype_name) |
532 | 468 |
|
533 | 469 | result = wrapper.not_(out) |
@@ -562,9 +498,10 @@ def test_not_shapes_invalid(invdtypes: dtype.Dtype) -> None: |
562 | 498 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), |
563 | 499 | ], |
564 | 500 | ) |
565 | | -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 501 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
566 | 502 | def test_or_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None: |
567 | 503 | """Test or operation between two arrays of the same shape""" |
| 504 | + check_type_supported(dtype_name) |
568 | 505 | lhs = wrapper.randu(shape, dtype_name) |
569 | 506 | rhs = wrapper.randu(shape, dtype_name) |
570 | 507 |
|
|
0 commit comments