1
1
import random
2
2
3
- # import numpy as np
4
3
import pytest
5
4
6
5
import arrayfire_wrapper .dtypes as dtype
7
6
import arrayfire_wrapper .lib as wrapper
8
-
9
- dtype_map = {
10
- # "int16": dtype.s16,
11
- # "int32": dtype.s32,
12
- # "int64": dtype.s64,
13
- # "uint8": dtype.u8,
14
- # "uint16": dtype.u16,
15
- # "uint32": dtype.u32,
16
- # "uint64": dtype.u64,
17
- # "float16": dtype.f16,
18
- "float32" : dtype .f32 ,
19
- # 'float64': dtype.f64,
20
- # 'complex64': dtype.c64,
21
- # "complex32": dtype.c32,
22
- # "bool": dtype.b8,
23
- # "s16": dtype.s16,
24
- # "s32": dtype.s32,
25
- # "s64": dtype.s64,
26
- # "u8": dtype.u8,
27
- # "u16": dtype.u16,
28
- # "u32": dtype.u32,
29
- # "u64": dtype.u64,
30
- # "f16": dtype.f16,
31
- "f32" : dtype .f32 ,
32
- # 'f64': dtype.f64,
33
- # "c32": dtype.c32,
34
- # 'c64': dtype.c64,
35
- # "b8": dtype.b8,
36
- }
7
+ from tests .utility_functions import check_type_supported , get_all_types , get_float_types , get_real_types
37
8
38
9
39
10
@pytest .mark .parametrize (
46
17
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
47
18
],
48
19
)
49
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
20
+ @pytest .mark .parametrize ("dtype_name" , get_float_types ())
50
21
def test_complex_supported_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
51
22
"""Test complex operation across all supported data types."""
23
+ check_type_supported (dtype_name )
24
+ if dtype_name == dtype .f16 :
25
+ pytest .skip ()
52
26
tester = wrapper .randu (shape , dtype_name )
53
27
result = wrapper .cplx (tester )
54
28
assert wrapper .is_complex (result ), f"Failed for dtype: { dtype_name } "
@@ -57,8 +31,8 @@ def test_complex_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None
57
31
@pytest .mark .parametrize (
58
32
"invdtypes" ,
59
33
[
60
- dtype .c64 ,
61
- dtype .f64 ,
34
+ dtype .int32 ,
35
+ dtype .complex32 ,
62
36
],
63
37
)
64
38
def test_complex_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
@@ -79,9 +53,10 @@ def test_complex_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
79
53
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
80
54
],
81
55
)
82
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
56
+ @pytest .mark .parametrize ("dtype_name" , get_real_types ())
83
57
def test_complex2_supported_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
84
58
"""Test complex2 operation across all supported data types."""
59
+ check_type_supported (dtype_name )
85
60
lhs = wrapper .randu (shape , dtype_name )
86
61
rhs = wrapper .randu (shape , dtype_name )
87
62
result = wrapper .cplx2 (lhs , rhs )
@@ -91,8 +66,7 @@ def test_complex2_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> Non
91
66
@pytest .mark .parametrize (
92
67
"invdtypes" ,
93
68
[
94
- dtype .c64 ,
95
- dtype .f64 ,
69
+ dtype .c32 ,
96
70
],
97
71
)
98
72
def test_complex2_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
@@ -114,26 +88,13 @@ def test_complex2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
114
88
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
115
89
],
116
90
)
117
- def test_conj_supported_dtypes (shape : tuple ) -> None :
91
+ @pytest .mark .parametrize ("dtypes" , get_all_types ())
92
+ def test_conj_supported_dtypes (shape : tuple , dtypes : dtype .Dtype ) -> None :
118
93
"""Test conjugate operation for supported data types."""
119
- arr = wrapper .constant (7 , shape , dtype .c32 )
94
+ check_type_supported (dtypes )
95
+ arr = wrapper .constant (7 , shape , dtypes )
120
96
result = wrapper .conjg (arr )
121
- assert wrapper .is_complex (result ), f"Failed for shape: { shape } "
122
-
123
-
124
- @pytest .mark .parametrize (
125
- "invdtypes" ,
126
- [
127
- dtype .c64 ,
128
- dtype .f64 ,
129
- ],
130
- )
131
- def test_conj_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
132
- """Test conjugate operation for unsupported data types."""
133
- with pytest .raises (RuntimeError ):
134
- shape = (5 , 5 )
135
- arr = wrapper .randu (shape , invdtypes )
136
- wrapper .conjg (arr )
97
+ assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"Failed for shape: { shape } , and dtype: { dtypes } " # noqa
137
98
138
99
139
100
@pytest .mark .parametrize (
@@ -146,40 +107,29 @@ def test_conj_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
146
107
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
147
108
],
148
109
)
149
- def test_imag_real_supported_dtypes (shape : tuple ) -> None :
110
+ @pytest .mark .parametrize ("dtypes" , get_all_types ())
111
+ def test_imag_supported_dtypes (shape : tuple , dtypes : dtype .Dtype ) -> None :
150
112
"""Test imaginary and real operations for supported data types."""
151
- arr = wrapper . randu ( shape , dtype . c32 )
152
- imaginary = wrapper .imag ( arr )
113
+ check_type_supported ( dtypes )
114
+ arr = wrapper .randu ( shape , dtypes )
153
115
real = wrapper .real (arr )
154
- assert not wrapper .is_empty (imaginary ), f"Failed for shape: { shape } "
155
- assert not wrapper .is_empty (real ), f"Failed for shape: { shape } "
156
-
157
-
158
- @pytest .mark .parametrize (
159
- "invdtypes" ,
160
- [
161
- dtype .c64 ,
162
- dtype .f64 ,
163
- ],
164
- )
165
- def test_imag_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
166
- """Test conjugate operation for unsupported data types."""
167
- with pytest .raises (RuntimeError ):
168
- shape = (5 , 5 )
169
- arr = wrapper .randu (shape , invdtypes )
170
- wrapper .imag (arr )
116
+ assert wrapper .is_real (real ), f"Failed for shape: { shape } "
171
117
172
118
173
119
@pytest .mark .parametrize (
174
- "invdtypes " ,
120
+ "shape " ,
175
121
[
176
- dtype .c64 ,
177
- dtype .f64 ,
122
+ (),
123
+ (random .randint (1 , 10 ),),
124
+ (random .randint (1 , 10 ), random .randint (1 , 10 )),
125
+ (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
126
+ (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
178
127
],
179
128
)
180
- def test_real_unsupported_dtypes (invdtypes : dtype .Dtype ) -> None :
181
- """Test real operation for unsupported data types."""
182
- with pytest .raises (RuntimeError ):
183
- shape = (5 , 5 )
184
- arr = wrapper .randu (shape , invdtypes )
185
- wrapper .real (arr )
129
+ @pytest .mark .parametrize ("dtypes" , get_all_types ())
130
+ def test_real_supported_dtypes (shape : tuple , dtypes : dtype .Dtype ) -> None :
131
+ """Test imaginary and real operations for supported data types."""
132
+ check_type_supported (dtypes )
133
+ arr = wrapper .randu (shape , dtypes )
134
+ real = wrapper .real (arr )
135
+ assert wrapper .is_real (real ), f"Failed for shape: { shape } "
0 commit comments