3
3
import arrayfire_wrapper .dtypes as dtype
4
4
import arrayfire_wrapper .lib as wrapper
5
5
from arrayfire_wrapper .lib .create_and_modify_array .helper_functions import array_to_string
6
+ from tests .utility_functions import check_type_supported , get_all_types , get_float_types , get_real_types
6
7
7
- dtype_map = {
8
- "int16" : dtype .s16 ,
9
- "int32" : dtype .s32 ,
10
- "int64" : dtype .s64 ,
11
- "uint8" : dtype .u8 ,
12
- "uint16" : dtype .u16 ,
13
- "uint32" : dtype .u32 ,
14
- "uint64" : dtype .u64 ,
15
- # 'float16': dtype.f16,
16
- # 'float32': dtype.f32,
17
- # 'float64': dtype.f64,
18
- # 'complex64': dtype.c64,
19
- # 'complex32': dtype.c32,
20
- "bool" : dtype .b8 ,
21
- "s16" : dtype .s16 ,
22
- "s32" : dtype .s32 ,
23
- "s64" : dtype .s64 ,
24
- "u8" : dtype .u8 ,
25
- "u16" : dtype .u16 ,
26
- "u32" : dtype .u32 ,
27
- "u64" : dtype .u64 ,
28
- # 'f16': dtype.f16,
29
- # 'f32': dtype.f32,
30
- # 'f64': dtype.f64,
31
- # 'c32': dtype.c32,
32
- # 'c64': dtype.c64,
33
- "b8" : dtype .b8 ,
34
- }
35
-
36
-
37
- @pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
8
+ @pytest .mark .parametrize ("dtype_name" , get_real_types ())
38
9
def test_bitshiftl_dtypes (dtype_name : dtype .Dtype ) -> None :
39
10
"""Test bit shift operation across all supported data types."""
11
+ check_type_supported (dtype_name )
12
+ if dtype_name == dtype .f16 or dtype_name == dtype .f32 :
13
+ pytest .skip ()
40
14
shape = (5 , 5 )
41
15
values = wrapper .randu (shape , dtype_name )
42
16
bits_to_shift = wrapper .constant (1 , shape , dtype_name )
@@ -49,7 +23,7 @@ def test_bitshiftl_dtypes(dtype_name: dtype.Dtype) -> None:
49
23
@pytest .mark .parametrize (
50
24
"invdtypes" ,
51
25
[
52
- dtype .c64 ,
26
+ dtype .c32 ,
53
27
dtype .f64 ,
54
28
],
55
29
)
@@ -139,9 +113,12 @@ def test_bitshift_right_varying_shift_amount(shift_amount: int) -> None:
139
113
assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
140
114
141
115
142
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
116
+ @pytest .mark .parametrize ("dtype_name" , get_real_types ())
143
117
def test_bitshiftr_dtypes (dtype_name : dtype .Dtype ) -> None :
144
118
"""Test bit shift operation across all supported data types."""
119
+ check_type_supported (dtype_name )
120
+ if dtype_name == dtype .f16 or dtype_name == dtype .f32 :
121
+ pytest .skip ()
145
122
shape = (5 , 5 )
146
123
values = wrapper .randu (shape , dtype_name )
147
124
bits_to_shift = wrapper .constant (1 , shape , dtype_name )
@@ -154,7 +131,7 @@ def test_bitshiftr_dtypes(dtype_name: dtype.Dtype) -> None:
154
131
@pytest .mark .parametrize (
155
132
"invdtypes" ,
156
133
[
157
- dtype .c64 ,
134
+ dtype .c32 ,
158
135
dtype .f64 ,
159
136
],
160
137
)
0 commit comments