Skip to content

Commit 6e58189

Browse files
AzeezIshsyurkevi
AzeezIsh
authored andcommitted
Added all necessary dtype coverage.
1 parent 0bc17be commit 6e58189

File tree

1 file changed

+27
-90
lines changed

1 file changed

+27
-90
lines changed

tests/test_logical.py

+27-90
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,7 @@
44

55
import arrayfire_wrapper.dtypes as dtype
66
import arrayfire_wrapper.lib as wrapper
7-
8-
dtype_map = {
9-
"int16": dtype.s16,
10-
"int32": dtype.s32,
11-
"int64": dtype.s64,
12-
"uint8": dtype.u8,
13-
"uint16": dtype.u16,
14-
"uint32": dtype.u32,
15-
"uint64": dtype.u64,
16-
"float16": dtype.f16,
17-
"float32": dtype.f32,
18-
# 'float64': dtype.f64,
19-
# 'complex64': dtype.c64,
20-
"complex32": dtype.c32,
21-
"bool": dtype.b8,
22-
"s16": dtype.s16,
23-
"s32": dtype.s32,
24-
"s64": dtype.s64,
25-
"u8": dtype.u8,
26-
"u16": dtype.u16,
27-
"u32": dtype.u32,
28-
"u64": dtype.u64,
29-
"f16": dtype.f16,
30-
"f32": dtype.f32,
31-
# 'f64': dtype.f64,
32-
"c32": dtype.c32,
33-
# 'c64': dtype.c64,
34-
"b8": dtype.b8,
35-
}
7+
from tests.utility_functions import check_type_supported, get_all_types
368

379

3810
@pytest.mark.parametrize(
@@ -45,9 +17,10 @@
4517
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
4618
],
4719
)
48-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
20+
@pytest.mark.parametrize("dtype_name", get_all_types())
4921
def test_and_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
5022
"""Test and_ operation between two arrays of the same shape"""
23+
check_type_supported(dtype_name)
5124
lhs = wrapper.randu(shape, dtype_name)
5225
rhs = wrapper.randu(shape, dtype_name)
5326

@@ -89,17 +62,10 @@ def test_and_shapes_invalid(invdtypes: dtype.Dtype) -> None:
8962
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
9063
],
9164
)
92-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
65+
@pytest.mark.parametrize("dtype_name", get_all_types())
9366
def test_bitand_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
9467
"""Test bitand operation between two arrays of the same shape"""
95-
if (
96-
dtype_name == dtype.c32
97-
or dtype_name == dtype.c64
98-
or dtype_name == dtype.f32
99-
or dtype_name == dtype.f64
100-
or dtype_name == dtype.f16
101-
):
102-
pytest.skip()
68+
check_type_supported(dtype_name)
10369
lhs = wrapper.randu(shape, dtype_name)
10470
rhs = wrapper.randu(shape, dtype_name)
10571

@@ -136,17 +102,10 @@ def test_bitand_shapes_invalid(invdtypes: dtype.Dtype) -> None:
136102
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
137103
],
138104
)
139-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
105+
@pytest.mark.parametrize("dtype_name", get_all_types())
140106
def test_bitnot_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
141107
"""Test bitnot operation between two arrays of the same shape"""
142-
if (
143-
dtype_name == dtype.c32
144-
or dtype_name == dtype.c64
145-
or dtype_name == dtype.f32
146-
or dtype_name == dtype.f64
147-
or dtype_name == dtype.f16
148-
):
149-
pytest.skip()
108+
check_type_supported(dtype_name)
150109
out = wrapper.randu(shape, dtype_name)
151110

152111
result = wrapper.bitnot(out)
@@ -181,17 +140,10 @@ def test_bitnot_shapes_invalid(invdtypes: dtype.Dtype) -> None:
181140
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
182141
],
183142
)
184-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
143+
@pytest.mark.parametrize("dtype_name", get_all_types())
185144
def test_bitor_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
186145
"""Test bitor operation between two arrays of the same shape"""
187-
if (
188-
dtype_name == dtype.c32
189-
or dtype_name == dtype.c64
190-
or dtype_name == dtype.f32
191-
or dtype_name == dtype.f64
192-
or dtype_name == dtype.f16
193-
):
194-
pytest.skip()
146+
check_type_supported(dtype_name)
195147
lhs = wrapper.randu(shape, dtype_name)
196148
rhs = wrapper.randu(shape, dtype_name)
197149

@@ -228,17 +180,10 @@ def test_bitor_shapes_invalid(invdtypes: dtype.Dtype) -> None:
228180
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
229181
],
230182
)
231-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
183+
@pytest.mark.parametrize("dtype_name", get_all_types())
232184
def test_bitxor_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
233185
"""Test bitxor operation between two arrays of the same shape"""
234-
if (
235-
dtype_name == dtype.c32
236-
or dtype_name == dtype.c64
237-
or dtype_name == dtype.f32
238-
or dtype_name == dtype.f64
239-
or dtype_name == dtype.f16
240-
):
241-
pytest.skip()
186+
check_type_supported(dtype_name)
242187
lhs = wrapper.randu(shape, dtype_name)
243188
rhs = wrapper.randu(shape, dtype_name)
244189

@@ -275,17 +220,10 @@ def test_bitxor_shapes_invalid(invdtypes: dtype.Dtype) -> None:
275220
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
276221
],
277222
)
278-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
223+
@pytest.mark.parametrize("dtype_name", get_all_types())
279224
def test_eq_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
280225
"""Test eq operation between two arrays of the same shape"""
281-
if (
282-
dtype_name == dtype.c32
283-
or dtype_name == dtype.c64
284-
or dtype_name == dtype.f32
285-
or dtype_name == dtype.f64
286-
or dtype_name == dtype.f16
287-
):
288-
pytest.skip()
226+
check_type_supported(dtype_name)
289227
lhs = wrapper.randu(shape, dtype_name)
290228
rhs = wrapper.randu(shape, dtype_name)
291229

@@ -322,9 +260,10 @@ def test_eq_shapes_invalid(invdtypes: dtype.Dtype) -> None:
322260
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
323261
],
324262
)
325-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
263+
@pytest.mark.parametrize("dtype_name", get_all_types())
326264
def test_ge_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
327265
"""Test >= operation between two arrays of the same shape"""
266+
check_type_supported(dtype_name)
328267
lhs = wrapper.randu(shape, dtype_name)
329268
rhs = wrapper.randu(shape, dtype_name)
330269

@@ -361,9 +300,10 @@ def test_ge_shapes_invalid(invdtypes: dtype.Dtype) -> None:
361300
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
362301
],
363302
)
364-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
303+
@pytest.mark.parametrize("dtype_name", get_all_types())
365304
def test_gt_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
366305
"""Test > operation between two arrays of the same shape"""
306+
check_type_supported(dtype_name)
367307
lhs = wrapper.randu(shape, dtype_name)
368308
rhs = wrapper.randu(shape, dtype_name)
369309

@@ -400,9 +340,10 @@ def test_gt_shapes_invalid(invdtypes: dtype.Dtype) -> None:
400340
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
401341
],
402342
)
403-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
343+
@pytest.mark.parametrize("dtype_name", get_all_types())
404344
def test_le_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
405345
"""Test <= operation between two arrays of the same shape"""
346+
check_type_supported(dtype_name)
406347
lhs = wrapper.randu(shape, dtype_name)
407348
rhs = wrapper.randu(shape, dtype_name)
408349

@@ -439,9 +380,10 @@ def test_le_shapes_invalid(invdtypes: dtype.Dtype) -> None:
439380
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
440381
],
441382
)
442-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
383+
@pytest.mark.parametrize("dtype_name", get_all_types())
443384
def test_lt_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
444385
"""Test < operation between two arrays of the same shape"""
386+
check_type_supported(dtype_name)
445387
lhs = wrapper.randu(shape, dtype_name)
446388
rhs = wrapper.randu(shape, dtype_name)
447389

@@ -478,9 +420,10 @@ def test_lt_shapes_invalid(invdtypes: dtype.Dtype) -> None:
478420
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
479421
],
480422
)
481-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
423+
@pytest.mark.parametrize("dtype_name", get_all_types())
482424
def test_neq_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
483425
"""Test not equal operation between two arrays of the same shape"""
426+
check_type_supported(dtype_name)
484427
lhs = wrapper.randu(shape, dtype_name)
485428
rhs = wrapper.randu(shape, dtype_name)
486429

@@ -517,17 +460,10 @@ def test_neq_shapes_invalid(invdtypes: dtype.Dtype) -> None:
517460
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
518461
],
519462
)
520-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
463+
@pytest.mark.parametrize("dtype_name", get_all_types())
521464
def test_not_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
522465
"""Test not operation between two arrays of the same shape"""
523-
if (
524-
dtype_name == dtype.c32
525-
or dtype_name == dtype.c64
526-
or dtype_name == dtype.f32
527-
or dtype_name == dtype.f64
528-
or dtype_name == dtype.f16
529-
):
530-
pytest.skip()
466+
check_type_supported(dtype_name)
531467
out = wrapper.randu(shape, dtype_name)
532468

533469
result = wrapper.not_(out)
@@ -562,9 +498,10 @@ def test_not_shapes_invalid(invdtypes: dtype.Dtype) -> None:
562498
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
563499
],
564500
)
565-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
501+
@pytest.mark.parametrize("dtype_name", get_all_types())
566502
def test_or_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
567503
"""Test or operation between two arrays of the same shape"""
504+
check_type_supported(dtype_name)
568505
lhs = wrapper.randu(shape, dtype_name)
569506
rhs = wrapper.randu(shape, dtype_name)
570507

0 commit comments

Comments
 (0)