5
5
from arrayfire_wrapper .lib .create_and_modify_array .helper_functions import array_to_string
6
6
7
7
dtype_map = {
8
- ' int16' : dtype .s16 ,
9
- ' int32' : dtype .s32 ,
10
- ' int64' : dtype .s64 ,
11
- ' uint8' : dtype .u8 ,
12
- ' uint16' : dtype .u16 ,
13
- ' uint32' : dtype .u32 ,
14
- ' uint64' : dtype .u64 ,
8
+ " int16" : dtype .s16 ,
9
+ " int32" : dtype .s32 ,
10
+ " int64" : dtype .s64 ,
11
+ " uint8" : dtype .u8 ,
12
+ " uint16" : dtype .u16 ,
13
+ " uint32" : dtype .u32 ,
14
+ " uint64" : dtype .u64 ,
15
15
# 'float16': dtype.f16,
16
16
# 'float32': dtype.f32,
17
17
# 'float64': dtype.f64,
18
18
# 'complex64': dtype.c64,
19
19
# 'complex32': dtype.c32,
20
- ' bool' : dtype .b8 ,
21
- ' s16' : dtype .s16 ,
22
- ' s32' : dtype .s32 ,
23
- ' s64' : dtype .s64 ,
24
- 'u8' : dtype .u8 ,
25
- ' u16' : dtype .u16 ,
26
- ' u32' : dtype .u32 ,
27
- ' u64' : dtype .u64 ,
20
+ " bool" : dtype .b8 ,
21
+ " s16" : dtype .s16 ,
22
+ " s32" : dtype .s32 ,
23
+ " s64" : dtype .s64 ,
24
+ "u8" : dtype .u8 ,
25
+ " u16" : dtype .u16 ,
26
+ " u32" : dtype .u32 ,
27
+ " u64" : dtype .u64 ,
28
28
# 'f16': dtype.f16,
29
29
# 'f32': dtype.f32,
30
30
# 'f64': dtype.f64,
31
31
# 'c32': dtype.c32,
32
32
# 'c64': dtype.c64,
33
- 'b8' : dtype .b8 ,
33
+ "b8" : dtype .b8 ,
34
34
}
35
+
36
+
35
37
@pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
36
38
def test_bitshiftl_dtypes (dtype_name : dtype .Dtype ) -> None :
37
39
"""Test bit shift operation across all supported data types."""
38
40
shape = (5 , 5 )
39
41
values = wrapper .randu (shape , dtype_name )
40
42
bits_to_shift = wrapper .constant (1 , shape , dtype_name )
41
-
43
+
42
44
result = wrapper .bitshiftl (values , bits_to_shift )
43
45
44
46
assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == dtype_name , f"Failed for dtype: { dtype_name } "
47
+
48
+
45
49
@pytest .mark .parametrize (
46
50
"invdtypes" ,
47
51
[
@@ -54,25 +58,28 @@ def test_bitshiftl_supported_dtypes(invdtypes: dtype.Dtype) -> None:
54
58
shape = (5 , 5 )
55
59
with pytest .raises (RuntimeError ):
56
60
value = wrapper .randu (shape , invdtypes )
57
- shift_amount = 1
61
+ bits_to_shift = wrapper . constant ( 1 , shape , invdtypes )
58
62
59
- result = wrapper .bitshiftl (value , shift_amount )
63
+ result = wrapper .bitshiftl (value , bits_to_shift )
60
64
assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == invdtypes , f"Failed for dtype: { invdtypes } "
65
+
66
+
61
67
@pytest .mark .parametrize ("input_size" , [8 , 10 , 12 ])
62
- def test_bitshiftl_varying_input_size (input_size ) :
68
+ def test_bitshiftl_varying_input_size (input_size : int ) -> None :
63
69
"""Test bitshift left operation with varying input sizes"""
64
70
shape = (input_size , input_size )
65
71
value = wrapper .randu (shape , dtype .int16 )
66
72
shift_amount = wrapper .constant (1 , shape , dtype .int16 ) # Fixed shift amount for simplicity
67
73
68
74
result = wrapper .bitshiftl (value , shift_amount )
69
75
70
- assert wrapper .get_dims (result )[0 : len (shape )] == shape
76
+ assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
77
+
71
78
72
79
@pytest .mark .parametrize (
73
80
"shape" ,
74
81
[
75
- (10 , ),
82
+ (10 ,),
76
83
(5 , 5 ),
77
84
(2 , 3 , 4 ),
78
85
],
@@ -81,21 +88,23 @@ def test_bitshiftl_varying_shapes(shape: tuple) -> None:
81
88
"""Test left bit shifting with arrays of varying shapes."""
82
89
values = wrapper .randu (shape , dtype .int16 )
83
90
bits_to_shift = wrapper .constant (1 , shape , dtype .int16 )
84
-
91
+
85
92
result = wrapper .bitshiftl (values , bits_to_shift )
86
93
87
- assert wrapper .get_dims (result )[0 : len (shape )] == shape
94
+ assert wrapper .get_dims (result )[0 : len (shape )] == shape # noqa
95
+
88
96
89
97
@pytest .mark .parametrize ("shift_amount" , [- 1 , 0 , 2 , 30 ])
90
- def test_bitshift_left_varying_shift_amount (shift_amount ) :
98
+ def test_bitshift_left_varying_shift_amount (shift_amount : int ) -> None :
91
99
"""Test bitshift left operation with varying shift amounts."""
92
100
shape = (5 , 5 )
93
101
value = wrapper .randu (shape , dtype .int16 )
94
102
shift_amount_arr = wrapper .constant (shift_amount , shape , dtype .int16 )
95
103
96
104
result = wrapper .bitshiftl (value , shift_amount_arr )
97
105
98
- assert wrapper .get_dims (result )[0 : len (shape )] == shape
106
+ assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
107
+
99
108
100
109
@pytest .mark .parametrize (
101
110
"shape_a, shape_b" ,
@@ -113,29 +122,35 @@ def test_bitshiftl_different_shapes(shape_a: tuple, shape_b: tuple) -> None:
113
122
bits_to_shift = wrapper .constant (1 , shape_b , dtype .int16 )
114
123
result = wrapper .bitshiftl (values , bits_to_shift )
115
124
print (array_to_string ("" , result , 3 , False ))
116
- assert wrapper .get_dims (result )[0 : len (shape_a )] == shape_a , f"Failed for shapes { shape_a } and { shape_b } "
125
+ assert (
126
+ wrapper .get_dims (result )[0 : len (shape_a )] == shape_a # noqa
127
+ ), f"Failed for shapes { shape_a } and { shape_b } "
128
+
117
129
118
130
@pytest .mark .parametrize ("shift_amount" , [- 1 , 0 , 2 , 30 ])
119
- def test_bitshift_right_varying_shift_amount (shift_amount ) :
131
+ def test_bitshift_right_varying_shift_amount (shift_amount : int ) -> None :
120
132
"""Test bitshift right operation with varying shift amounts."""
121
133
shape = (5 , 5 )
122
134
value = wrapper .randu (shape , dtype .int16 )
123
135
shift_amount_arr = wrapper .constant (shift_amount , shape , dtype .int16 )
124
136
125
137
result = wrapper .bitshiftr (value , shift_amount_arr )
126
138
127
- assert wrapper .get_dims (result )[0 : len (shape )] == shape
139
+ assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
140
+
128
141
129
142
@pytest .mark .parametrize ("dtype_name" , dtype_map .values ())
130
143
def test_bitshiftr_dtypes (dtype_name : dtype .Dtype ) -> None :
131
144
"""Test bit shift operation across all supported data types."""
132
145
shape = (5 , 5 )
133
146
values = wrapper .randu (shape , dtype_name )
134
147
bits_to_shift = wrapper .constant (1 , shape , dtype_name )
135
-
148
+
136
149
result = wrapper .bitshiftr (values , bits_to_shift )
137
150
138
151
assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == dtype_name , f"Failed for dtype: { dtype_name } "
152
+
153
+
139
154
@pytest .mark .parametrize (
140
155
"invdtypes" ,
141
156
[
@@ -148,25 +163,28 @@ def test_bitshiftr_supported_dtypes(invdtypes: dtype.Dtype) -> None:
148
163
shape = (5 , 5 )
149
164
with pytest .raises (RuntimeError ):
150
165
value = wrapper .randu (shape , invdtypes )
151
- shift_amount = 1
166
+ shift_amount = wrapper . constant ( 1 , shape , invdtypes )
152
167
153
168
result = wrapper .bitshiftr (value , shift_amount )
154
169
assert dtype .c_api_value_to_dtype (wrapper .get_type (result )) == invdtypes , f"Failed for dtype: { invdtypes } "
155
170
171
+
156
172
@pytest .mark .parametrize ("input_size" , [8 , 10 , 12 ])
157
- def test_bitshift_right_varying_input_size (input_size ) :
173
+ def test_bitshift_right_varying_input_size (input_size : int ) -> None :
158
174
"""Test bitshift right operation with varying input sizes"""
159
175
shape = (input_size , input_size )
160
176
value = wrapper .randu (shape , dtype .int16 )
161
177
shift_amount = wrapper .constant (1 , shape , dtype .int16 ) # Fixed shift amount for simplicity
162
178
163
179
result = wrapper .bitshiftr (value , shift_amount )
164
180
165
- assert wrapper .get_dims (result )[0 : len (shape )] == shape
181
+ assert (wrapper .get_dims (result )[0 ], wrapper .get_dims (result )[1 ]) == shape
182
+
183
+
166
184
@pytest .mark .parametrize (
167
185
"shape" ,
168
186
[
169
- (10 , ),
187
+ (10 ,),
170
188
(5 , 5 ),
171
189
(2 , 3 , 4 ),
172
190
],
@@ -175,11 +193,10 @@ def test_bitshiftr_varying_shapes(shape: tuple) -> None:
175
193
"""Test right bit shifting with arrays of varying shapes."""
176
194
values = wrapper .randu (shape , dtype .int16 )
177
195
bits_to_shift = wrapper .constant (1 , shape , dtype .int16 )
178
-
179
- result = wrapper .bitshiftr (values , bits_to_shift )
180
196
181
- assert wrapper .get_dims ( result )[ 0 : len ( shape )] == shape
197
+ result = wrapper .bitshiftr ( values , bits_to_shift )
182
198
199
+ assert wrapper .get_dims (result )[0 : len (shape )] == shape # noqa
183
200
184
201
185
202
@pytest .mark .parametrize (
@@ -198,4 +215,6 @@ def test_bitshiftr_different_shapes(shape_a: tuple, shape_b: tuple) -> None:
198
215
bits_to_shift = wrapper .constant (1 , shape_b , dtype .int16 )
199
216
result = wrapper .bitshiftr (values , bits_to_shift )
200
217
print (array_to_string ("" , result , 3 , False ))
201
- assert wrapper .get_dims (result )[0 : len (shape_a )] == shape_a , f"Failed for shapes { shape_a } and { shape_b } "
218
+ assert (
219
+ wrapper .get_dims (result )[0 : len (shape_a )] == shape_a # noqa
220
+ ), f"Failed for shapes { shape_a } and { shape_b } "
0 commit comments