|
4 | 4 |
|
5 | 5 | import arrayfire_wrapper.dtypes as dtype
|
6 | 6 | import arrayfire_wrapper.lib as wrapper
|
7 |
| - |
8 |
| -dtype_map = { |
9 |
| - "int16": dtype.s16, |
10 |
| - "int32": dtype.s32, |
11 |
| - "int64": dtype.s64, |
12 |
| - "uint8": dtype.u8, |
13 |
| - "uint16": dtype.u16, |
14 |
| - "uint32": dtype.u32, |
15 |
| - "uint64": dtype.u64, |
16 |
| - "float16": dtype.f16, |
17 |
| - "float32": dtype.f32, |
18 |
| - # 'float64': dtype.f64, |
19 |
| - # 'complex64': dtype.c64, |
20 |
| - "complex32": dtype.c32, |
21 |
| - "bool": dtype.b8, |
22 |
| - "s16": dtype.s16, |
23 |
| - "s32": dtype.s32, |
24 |
| - "s64": dtype.s64, |
25 |
| - "u8": dtype.u8, |
26 |
| - "u16": dtype.u16, |
27 |
| - "u32": dtype.u32, |
28 |
| - "u64": dtype.u64, |
29 |
| - "f16": dtype.f16, |
30 |
| - "f32": dtype.f32, |
31 |
| - # 'f64': dtype.f64, |
32 |
| - "c32": dtype.c32, |
33 |
| - # 'c64': dtype.c64, |
34 |
| - "b8": dtype.b8, |
35 |
| -} |
| 7 | +from tests.utility_functions import check_type_supported, get_all_types |
36 | 8 |
|
37 | 9 |
|
38 | 10 | @pytest.mark.parametrize(
|
|
45 | 17 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
46 | 18 | ],
|
47 | 19 | )
|
48 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 20 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
49 | 21 | def test_and_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
50 | 22 | """Test and_ operation between two arrays of the same shape"""
|
| 23 | + check_type_supported(dtype_name) |
51 | 24 | lhs = wrapper.randu(shape, dtype_name)
|
52 | 25 | rhs = wrapper.randu(shape, dtype_name)
|
53 | 26 |
|
@@ -89,17 +62,10 @@ def test_and_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
89 | 62 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
90 | 63 | ],
|
91 | 64 | )
|
92 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 65 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
93 | 66 | def test_bitand_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
94 | 67 | """Test bitand operation between two arrays of the same shape"""
|
95 |
| - if ( |
96 |
| - dtype_name == dtype.c32 |
97 |
| - or dtype_name == dtype.c64 |
98 |
| - or dtype_name == dtype.f32 |
99 |
| - or dtype_name == dtype.f64 |
100 |
| - or dtype_name == dtype.f16 |
101 |
| - ): |
102 |
| - pytest.skip() |
| 68 | + check_type_supported(dtype_name) |
103 | 69 | lhs = wrapper.randu(shape, dtype_name)
|
104 | 70 | rhs = wrapper.randu(shape, dtype_name)
|
105 | 71 |
|
@@ -136,17 +102,10 @@ def test_bitand_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
136 | 102 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
137 | 103 | ],
|
138 | 104 | )
|
139 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 105 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
140 | 106 | def test_bitnot_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
141 | 107 | """Test bitnot operation between two arrays of the same shape"""
|
142 |
| - if ( |
143 |
| - dtype_name == dtype.c32 |
144 |
| - or dtype_name == dtype.c64 |
145 |
| - or dtype_name == dtype.f32 |
146 |
| - or dtype_name == dtype.f64 |
147 |
| - or dtype_name == dtype.f16 |
148 |
| - ): |
149 |
| - pytest.skip() |
| 108 | + check_type_supported(dtype_name) |
150 | 109 | out = wrapper.randu(shape, dtype_name)
|
151 | 110 |
|
152 | 111 | result = wrapper.bitnot(out)
|
@@ -181,17 +140,10 @@ def test_bitnot_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
181 | 140 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
182 | 141 | ],
|
183 | 142 | )
|
184 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 143 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
185 | 144 | def test_bitor_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
186 | 145 | """Test bitor operation between two arrays of the same shape"""
|
187 |
| - if ( |
188 |
| - dtype_name == dtype.c32 |
189 |
| - or dtype_name == dtype.c64 |
190 |
| - or dtype_name == dtype.f32 |
191 |
| - or dtype_name == dtype.f64 |
192 |
| - or dtype_name == dtype.f16 |
193 |
| - ): |
194 |
| - pytest.skip() |
| 146 | + check_type_supported(dtype_name) |
195 | 147 | lhs = wrapper.randu(shape, dtype_name)
|
196 | 148 | rhs = wrapper.randu(shape, dtype_name)
|
197 | 149 |
|
@@ -228,17 +180,10 @@ def test_bitor_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
228 | 180 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
229 | 181 | ],
|
230 | 182 | )
|
231 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 183 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
232 | 184 | def test_bitxor_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
233 | 185 | """Test bitxor operation between two arrays of the same shape"""
|
234 |
| - if ( |
235 |
| - dtype_name == dtype.c32 |
236 |
| - or dtype_name == dtype.c64 |
237 |
| - or dtype_name == dtype.f32 |
238 |
| - or dtype_name == dtype.f64 |
239 |
| - or dtype_name == dtype.f16 |
240 |
| - ): |
241 |
| - pytest.skip() |
| 186 | + check_type_supported(dtype_name) |
242 | 187 | lhs = wrapper.randu(shape, dtype_name)
|
243 | 188 | rhs = wrapper.randu(shape, dtype_name)
|
244 | 189 |
|
@@ -275,17 +220,10 @@ def test_bitxor_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
275 | 220 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
276 | 221 | ],
|
277 | 222 | )
|
278 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 223 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
279 | 224 | def test_eq_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
280 | 225 | """Test eq operation between two arrays of the same shape"""
|
281 |
| - if ( |
282 |
| - dtype_name == dtype.c32 |
283 |
| - or dtype_name == dtype.c64 |
284 |
| - or dtype_name == dtype.f32 |
285 |
| - or dtype_name == dtype.f64 |
286 |
| - or dtype_name == dtype.f16 |
287 |
| - ): |
288 |
| - pytest.skip() |
| 226 | + check_type_supported(dtype_name) |
289 | 227 | lhs = wrapper.randu(shape, dtype_name)
|
290 | 228 | rhs = wrapper.randu(shape, dtype_name)
|
291 | 229 |
|
@@ -322,9 +260,10 @@ def test_eq_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
322 | 260 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
323 | 261 | ],
|
324 | 262 | )
|
325 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 263 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
326 | 264 | def test_ge_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
327 | 265 | """Test >= operation between two arrays of the same shape"""
|
| 266 | + check_type_supported(dtype_name) |
328 | 267 | lhs = wrapper.randu(shape, dtype_name)
|
329 | 268 | rhs = wrapper.randu(shape, dtype_name)
|
330 | 269 |
|
@@ -361,9 +300,10 @@ def test_ge_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
361 | 300 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
362 | 301 | ],
|
363 | 302 | )
|
364 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 303 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
365 | 304 | def test_gt_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
366 | 305 | """Test > operation between two arrays of the same shape"""
|
| 306 | + check_type_supported(dtype_name) |
367 | 307 | lhs = wrapper.randu(shape, dtype_name)
|
368 | 308 | rhs = wrapper.randu(shape, dtype_name)
|
369 | 309 |
|
@@ -400,9 +340,10 @@ def test_gt_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
400 | 340 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
401 | 341 | ],
|
402 | 342 | )
|
403 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 343 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
404 | 344 | def test_le_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
405 | 345 | """Test <= operation between two arrays of the same shape"""
|
| 346 | + check_type_supported(dtype_name) |
406 | 347 | lhs = wrapper.randu(shape, dtype_name)
|
407 | 348 | rhs = wrapper.randu(shape, dtype_name)
|
408 | 349 |
|
@@ -439,9 +380,10 @@ def test_le_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
439 | 380 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
440 | 381 | ],
|
441 | 382 | )
|
442 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 383 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
443 | 384 | def test_lt_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
444 | 385 | """Test < operation between two arrays of the same shape"""
|
| 386 | + check_type_supported(dtype_name) |
445 | 387 | lhs = wrapper.randu(shape, dtype_name)
|
446 | 388 | rhs = wrapper.randu(shape, dtype_name)
|
447 | 389 |
|
@@ -478,9 +420,10 @@ def test_lt_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
478 | 420 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
479 | 421 | ],
|
480 | 422 | )
|
481 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 423 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
482 | 424 | def test_neq_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
483 | 425 | """Test not equal operation between two arrays of the same shape"""
|
| 426 | + check_type_supported(dtype_name) |
484 | 427 | lhs = wrapper.randu(shape, dtype_name)
|
485 | 428 | rhs = wrapper.randu(shape, dtype_name)
|
486 | 429 |
|
@@ -517,17 +460,10 @@ def test_neq_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
517 | 460 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
518 | 461 | ],
|
519 | 462 | )
|
520 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 463 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
521 | 464 | def test_not_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
522 | 465 | """Test not operation between two arrays of the same shape"""
|
523 |
| - if ( |
524 |
| - dtype_name == dtype.c32 |
525 |
| - or dtype_name == dtype.c64 |
526 |
| - or dtype_name == dtype.f32 |
527 |
| - or dtype_name == dtype.f64 |
528 |
| - or dtype_name == dtype.f16 |
529 |
| - ): |
530 |
| - pytest.skip() |
| 466 | + check_type_supported(dtype_name) |
531 | 467 | out = wrapper.randu(shape, dtype_name)
|
532 | 468 |
|
533 | 469 | result = wrapper.not_(out)
|
@@ -562,9 +498,10 @@ def test_not_shapes_invalid(invdtypes: dtype.Dtype) -> None:
|
562 | 498 | (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
|
563 | 499 | ],
|
564 | 500 | )
|
565 |
| -@pytest.mark.parametrize("dtype_name", dtype_map.values()) |
| 501 | +@pytest.mark.parametrize("dtype_name", get_all_types()) |
566 | 502 | def test_or_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
|
567 | 503 | """Test or operation between two arrays of the same shape"""
|
| 504 | + check_type_supported(dtype_name) |
568 | 505 | lhs = wrapper.randu(shape, dtype_name)
|
569 | 506 | rhs = wrapper.randu(shape, dtype_name)
|
570 | 507 |
|
|
0 commit comments