4
4
5
5
import arrayfire_wrapper .dtypes as dtype
6
6
import arrayfire_wrapper .lib as wrapper
7
-
8
- dtype_map = {
9
- "int16" : dtype .s16 ,
10
- "int32" : dtype .s32 ,
11
- "int64" : dtype .s64 ,
12
- "uint8" : dtype .u8 ,
13
- "uint16" : dtype .u16 ,
14
- "uint32" : dtype .u32 ,
15
- "uint64" : dtype .u64 ,
16
- "float16" : dtype .f16 ,
17
- "float32" : dtype .f32 ,
18
- # 'float64': dtype.f64,
19
- # 'complex64': dtype.c64,
20
- # 'complex32': dtype.c32,
21
- "bool" : dtype .b8 ,
22
- "s16" : dtype .s16 ,
23
- "s32" : dtype .s32 ,
24
- "s64" : dtype .s64 ,
25
- "u8" : dtype .u8 ,
26
- "u16" : dtype .u16 ,
27
- "u32" : dtype .u32 ,
28
- "u64" : dtype .u64 ,
29
- "f16" : dtype .f16 ,
30
- "f32" : dtype .f32 ,
31
- # 'f64': dtype.f64,
32
- # 'c32': dtype.c32,
33
- # 'c64': dtype.c64,
34
- "b8" : dtype .b8 ,
35
- }
7
+ from tests .utility_functions import check_type_supported , get_all_types , get_float_types
36
8
37
9
38
10
@pytest .mark .parametrize (
39
11
"shape" ,
40
12
[
41
13
(),
42
14
(random .randint (1 , 10 ),),
43
- (random .randint (1 , 10 ),),
44
- (random .randint (1 , 10 ),),
45
15
(random .randint (1 , 10 ), random .randint (1 , 10 )),
46
16
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
47
17
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
48
18
],
49
19
)
50
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
20
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
51
21
def test_asinh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
52
22
"""Test inverse hyperbolic sine operation across all supported data types."""
23
+ check_type_supported (dtype_name )
53
24
values = wrapper .randu (shape , dtype_name )
54
25
result = wrapper .asinh (values )
55
26
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -73,16 +44,15 @@ def test_asinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
73
44
[
74
45
(),
75
46
(random .randint (1 , 10 ),),
76
- (random .randint (1 , 10 ),),
77
- (random .randint (1 , 10 ),),
78
47
(random .randint (1 , 10 ), random .randint (1 , 10 )),
79
48
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
80
49
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
81
50
],
82
51
)
83
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
52
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
84
53
def test_acosh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
85
54
"""Test inverse hyperbolic cosine operation across all supported data types."""
55
+ check_type_supported (dtype_name )
86
56
values = wrapper .randu (shape , dtype_name )
87
57
result = wrapper .acosh (values )
88
58
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -106,16 +76,15 @@ def test_acosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
106
76
[
107
77
(),
108
78
(random .randint (1 , 10 ),),
109
- (random .randint (1 , 10 ),),
110
- (random .randint (1 , 10 ),),
111
79
(random .randint (1 , 10 ), random .randint (1 , 10 )),
112
80
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
113
81
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
114
82
],
115
83
)
116
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
84
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
117
85
def test_atanh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
118
86
"""Test inverse hyperbolic tan operation across all supported data types."""
87
+ check_type_supported (dtype_name )
119
88
values = wrapper .randu (shape , dtype_name )
120
89
result = wrapper .atanh (values )
121
90
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -139,16 +108,15 @@ def test_atanh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
139
108
[
140
109
(),
141
110
(random .randint (1 , 10 ),),
142
- (random .randint (1 , 10 ),),
143
- (random .randint (1 , 10 ),),
144
111
(random .randint (1 , 10 ), random .randint (1 , 10 )),
145
112
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
146
113
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
147
114
],
148
115
)
149
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
116
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
150
117
def test_cosh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
151
118
"""Test hyperbolic cosine operation across all supported data types."""
119
+ check_type_supported (dtype_name )
152
120
values = wrapper .randu (shape , dtype_name )
153
121
result = wrapper .cosh (values )
154
122
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -172,16 +140,15 @@ def test_cosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
172
140
[
173
141
(),
174
142
(random .randint (1 , 10 ),),
175
- (random .randint (1 , 10 ),),
176
- (random .randint (1 , 10 ),),
177
143
(random .randint (1 , 10 ), random .randint (1 , 10 )),
178
144
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
179
145
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
180
146
],
181
147
)
182
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
148
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
183
149
def test_sinh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
184
150
"""Test hyberbolic sin operation across all supported data types."""
151
+ check_type_supported (dtype_name )
185
152
values = wrapper .randu (shape , dtype_name )
186
153
result = wrapper .sinh (values )
187
154
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -205,16 +172,15 @@ def test_sinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
205
172
[
206
173
(),
207
174
(random .randint (1 , 10 ),),
208
- (random .randint (1 , 10 ),),
209
- (random .randint (1 , 10 ),),
210
175
(random .randint (1 , 10 ), random .randint (1 , 10 )),
211
176
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
212
177
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
213
178
],
214
179
)
215
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
180
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
216
181
def test_tanh_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
217
182
"""Test hyberbolic tan operation across all supported data types."""
183
+ check_type_supported (dtype_name )
218
184
values = wrapper .randu (shape , dtype_name )
219
185
result = wrapper .tanh (values )
220
186
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
0 commit comments