Skip to content

Commit b45dcbc

Browse files
AzeezIshsyurkevi
AzeezIsh
authored andcommitted
Removed dtype map, added further dtype coverage.
1 parent ce6cef2 commit b45dcbc

File tree

1 file changed

+11
-34
lines changed

1 file changed

+11
-34
lines changed

tests/test_bitshift.py

+11-34
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,14 @@
33
import arrayfire_wrapper.dtypes as dtype
44
import arrayfire_wrapper.lib as wrapper
55
from arrayfire_wrapper.lib.create_and_modify_array.helper_functions import array_to_string
6+
from tests.utility_functions import check_type_supported, get_all_types, get_float_types, get_real_types
67

7-
dtype_map = {
8-
"int16": dtype.s16,
9-
"int32": dtype.s32,
10-
"int64": dtype.s64,
11-
"uint8": dtype.u8,
12-
"uint16": dtype.u16,
13-
"uint32": dtype.u32,
14-
"uint64": dtype.u64,
15-
# 'float16': dtype.f16,
16-
# 'float32': dtype.f32,
17-
# 'float64': dtype.f64,
18-
# 'complex64': dtype.c64,
19-
# 'complex32': dtype.c32,
20-
"bool": dtype.b8,
21-
"s16": dtype.s16,
22-
"s32": dtype.s32,
23-
"s64": dtype.s64,
24-
"u8": dtype.u8,
25-
"u16": dtype.u16,
26-
"u32": dtype.u32,
27-
"u64": dtype.u64,
28-
# 'f16': dtype.f16,
29-
# 'f32': dtype.f32,
30-
# 'f64': dtype.f64,
31-
# 'c32': dtype.c32,
32-
# 'c64': dtype.c64,
33-
"b8": dtype.b8,
34-
}
35-
36-
37-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
8+
@pytest.mark.parametrize("dtype_name", get_real_types())
389
def test_bitshiftl_dtypes(dtype_name: dtype.Dtype) -> None:
3910
"""Test bit shift operation across all supported data types."""
11+
check_type_supported(dtype_name)
12+
if dtype_name == dtype.f16 or dtype_name == dtype.f32:
13+
pytest.skip()
4014
shape = (5, 5)
4115
values = wrapper.randu(shape, dtype_name)
4216
bits_to_shift = wrapper.constant(1, shape, dtype_name)
@@ -49,7 +23,7 @@ def test_bitshiftl_dtypes(dtype_name: dtype.Dtype) -> None:
4923
@pytest.mark.parametrize(
5024
"invdtypes",
5125
[
52-
dtype.c64,
26+
dtype.c32,
5327
dtype.f64,
5428
],
5529
)
@@ -139,9 +113,12 @@ def test_bitshift_right_varying_shift_amount(shift_amount: int) -> None:
139113
assert (wrapper.get_dims(result)[0], wrapper.get_dims(result)[1]) == shape
140114

141115

142-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
116+
@pytest.mark.parametrize("dtype_name", get_real_types())
143117
def test_bitshiftr_dtypes(dtype_name: dtype.Dtype) -> None:
144118
"""Test bit shift operation across all supported data types."""
119+
check_type_supported(dtype_name)
120+
if dtype_name == dtype.f16 or dtype_name == dtype.f32:
121+
pytest.skip()
145122
shape = (5, 5)
146123
values = wrapper.randu(shape, dtype_name)
147124
bits_to_shift = wrapper.constant(1, shape, dtype_name)
@@ -154,7 +131,7 @@ def test_bitshiftr_dtypes(dtype_name: dtype.Dtype) -> None:
154131
@pytest.mark.parametrize(
155132
"invdtypes",
156133
[
157-
dtype.c64,
134+
dtype.c32,
158135
dtype.f64,
159136
],
160137
)

0 commit comments

Comments
 (0)