Skip to content

Commit ce6cef2

Browse files
AzeezIshsyurkevi
AzeezIsh
authored andcommitted
Adhered to checkstyle and fixed dtype checking.
1 parent 3785b2a commit ce6cef2

File tree

1 file changed

+33
-83
lines changed

1 file changed

+33
-83
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,10 @@
11
import random
22

3-
# import numpy as np
43
import pytest
54

65
import arrayfire_wrapper.dtypes as dtype
76
import arrayfire_wrapper.lib as wrapper
8-
9-
dtype_map = {
10-
# "int16": dtype.s16,
11-
# "int32": dtype.s32,
12-
# "int64": dtype.s64,
13-
# "uint8": dtype.u8,
14-
# "uint16": dtype.u16,
15-
# "uint32": dtype.u32,
16-
# "uint64": dtype.u64,
17-
# "float16": dtype.f16,
18-
"float32": dtype.f32,
19-
# 'float64': dtype.f64,
20-
# 'complex64': dtype.c64,
21-
# "complex32": dtype.c32,
22-
# "bool": dtype.b8,
23-
# "s16": dtype.s16,
24-
# "s32": dtype.s32,
25-
# "s64": dtype.s64,
26-
# "u8": dtype.u8,
27-
# "u16": dtype.u16,
28-
# "u32": dtype.u32,
29-
# "u64": dtype.u64,
30-
# "f16": dtype.f16,
31-
"f32": dtype.f32,
32-
# 'f64': dtype.f64,
33-
# "c32": dtype.c32,
34-
# 'c64': dtype.c64,
35-
# "b8": dtype.b8,
36-
}
7+
from tests.utility_functions import check_type_supported, get_all_types, get_float_types, get_real_types
378

389

3910
@pytest.mark.parametrize(
@@ -46,9 +17,12 @@
4617
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
4718
],
4819
)
49-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
20+
@pytest.mark.parametrize("dtype_name", get_float_types())
5021
def test_complex_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
5122
"""Test complex operation across all supported data types."""
23+
check_type_supported(dtype_name)
24+
if dtype_name == dtype.f16:
25+
pytest.skip()
5226
tester = wrapper.randu(shape, dtype_name)
5327
result = wrapper.cplx(tester)
5428
assert wrapper.is_complex(result), f"Failed for dtype: {dtype_name}"
@@ -57,8 +31,8 @@ def test_complex_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None
5731
@pytest.mark.parametrize(
5832
"invdtypes",
5933
[
60-
dtype.c64,
61-
dtype.f64,
34+
dtype.int32,
35+
dtype.complex32,
6236
],
6337
)
6438
def test_complex_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
@@ -79,9 +53,10 @@ def test_complex_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
7953
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
8054
],
8155
)
82-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
56+
@pytest.mark.parametrize("dtype_name", get_real_types())
8357
def test_complex2_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
8458
"""Test complex2 operation across all supported data types."""
59+
check_type_supported(dtype_name)
8560
lhs = wrapper.randu(shape, dtype_name)
8661
rhs = wrapper.randu(shape, dtype_name)
8762
result = wrapper.cplx2(lhs, rhs)
@@ -91,8 +66,7 @@ def test_complex2_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> Non
9166
@pytest.mark.parametrize(
9267
"invdtypes",
9368
[
94-
dtype.c64,
95-
dtype.f64,
69+
dtype.c32,
9670
],
9771
)
9872
def test_complex2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
@@ -114,26 +88,13 @@ def test_complex2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
11488
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
11589
],
11690
)
117-
def test_conj_supported_dtypes(shape: tuple) -> None:
91+
@pytest.mark.parametrize("dtypes", get_all_types())
92+
def test_conj_supported_dtypes(shape: tuple, dtypes: dtype.Dtype) -> None:
11893
"""Test conjugate operation for supported data types."""
119-
arr = wrapper.constant(7, shape, dtype.c32)
94+
check_type_supported(dtypes)
95+
arr = wrapper.constant(7, shape, dtypes)
12096
result = wrapper.conjg(arr)
121-
assert wrapper.is_complex(result), f"Failed for shape: {shape}"
122-
123-
124-
@pytest.mark.parametrize(
125-
"invdtypes",
126-
[
127-
dtype.c64,
128-
dtype.f64,
129-
],
130-
)
131-
def test_conj_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
132-
"""Test conjugate operation for unsupported data types."""
133-
with pytest.raises(RuntimeError):
134-
shape = (5, 5)
135-
arr = wrapper.randu(shape, invdtypes)
136-
wrapper.conjg(arr)
97+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"Failed for shape: {shape}, and dtype: {dtypes}" # noqa
13798

13899

139100
@pytest.mark.parametrize(
@@ -146,40 +107,29 @@ def test_conj_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
146107
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
147108
],
148109
)
149-
def test_imag_real_supported_dtypes(shape: tuple) -> None:
110+
@pytest.mark.parametrize("dtypes", get_all_types())
111+
def test_imag_supported_dtypes(shape: tuple, dtypes: dtype.Dtype) -> None:
150112
"""Test imaginary and real operations for supported data types."""
151-
arr = wrapper.randu(shape, dtype.c32)
152-
imaginary = wrapper.imag(arr)
113+
check_type_supported(dtypes)
114+
arr = wrapper.randu(shape, dtypes)
153115
real = wrapper.real(arr)
154-
assert not wrapper.is_empty(imaginary), f"Failed for shape: {shape}"
155-
assert not wrapper.is_empty(real), f"Failed for shape: {shape}"
156-
157-
158-
@pytest.mark.parametrize(
159-
"invdtypes",
160-
[
161-
dtype.c64,
162-
dtype.f64,
163-
],
164-
)
165-
def test_imag_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
166-
"""Test conjugate operation for unsupported data types."""
167-
with pytest.raises(RuntimeError):
168-
shape = (5, 5)
169-
arr = wrapper.randu(shape, invdtypes)
170-
wrapper.imag(arr)
116+
assert wrapper.is_real(real), f"Failed for shape: {shape}"
171117

172118

173119
@pytest.mark.parametrize(
174-
"invdtypes",
120+
"shape",
175121
[
176-
dtype.c64,
177-
dtype.f64,
122+
(),
123+
(random.randint(1, 10),),
124+
(random.randint(1, 10), random.randint(1, 10)),
125+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
126+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
178127
],
179128
)
180-
def test_real_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
181-
"""Test real operation for unsupported data types."""
182-
with pytest.raises(RuntimeError):
183-
shape = (5, 5)
184-
arr = wrapper.randu(shape, invdtypes)
185-
wrapper.real(arr)
129+
@pytest.mark.parametrize("dtypes", get_all_types())
130+
def test_real_supported_dtypes(shape: tuple, dtypes: dtype.Dtype) -> None:
131+
"""Test imaginary and real operations for supported data types."""
132+
check_type_supported(dtypes)
133+
arr = wrapper.randu(shape, dtypes)
134+
real = wrapper.real(arr)
135+
assert wrapper.is_real(real), f"Failed for shape: {shape}"

0 commit comments

Comments
 (0)