Skip to content

Commit 3785b2a

Browse files
AzeezIshsyurkevi
AzeezIsh
authored andcommitted
Adhered to all checkstyle and checked for dtypes
and shapes for all complex math functions.
1 parent dd6d4fc commit 3785b2a

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed

tests/complex_testing.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import random
2+
3+
# import numpy as np
4+
import pytest
5+
6+
import arrayfire_wrapper.dtypes as dtype
7+
import arrayfire_wrapper.lib as wrapper
8+
9+
dtype_map = {
10+
# "int16": dtype.s16,
11+
# "int32": dtype.s32,
12+
# "int64": dtype.s64,
13+
# "uint8": dtype.u8,
14+
# "uint16": dtype.u16,
15+
# "uint32": dtype.u32,
16+
# "uint64": dtype.u64,
17+
# "float16": dtype.f16,
18+
"float32": dtype.f32,
19+
# 'float64': dtype.f64,
20+
# 'complex64': dtype.c64,
21+
# "complex32": dtype.c32,
22+
# "bool": dtype.b8,
23+
# "s16": dtype.s16,
24+
# "s32": dtype.s32,
25+
# "s64": dtype.s64,
26+
# "u8": dtype.u8,
27+
# "u16": dtype.u16,
28+
# "u32": dtype.u32,
29+
# "u64": dtype.u64,
30+
# "f16": dtype.f16,
31+
"f32": dtype.f32,
32+
# 'f64': dtype.f64,
33+
# "c32": dtype.c32,
34+
# 'c64': dtype.c64,
35+
# "b8": dtype.b8,
36+
}
37+
38+
39+
@pytest.mark.parametrize(
40+
"shape",
41+
[
42+
(),
43+
(random.randint(1, 10),),
44+
(random.randint(1, 10), random.randint(1, 10)),
45+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
46+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
47+
],
48+
)
49+
@pytest.mark.parametrize("dtype_name", dtype_map.values())
50+
def test_complex_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
51+
"""Test complex operation across all supported data types."""
52+
tester = wrapper.randu(shape, dtype_name)
53+
result = wrapper.cplx(tester)
54+
assert wrapper.is_complex(result), f"Failed for dtype: {dtype_name}"
55+
56+
57+
@pytest.mark.parametrize(
58+
"invdtypes",
59+
[
60+
dtype.c64,
61+
dtype.f64,
62+
],
63+
)
64+
def test_complex_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
65+
"""Test complex operation for unsupported data types."""
66+
with pytest.raises(RuntimeError):
67+
shape = (5, 5)
68+
out = wrapper.randu(shape, invdtypes)
69+
wrapper.cplx(out)
70+
71+
72+
@pytest.mark.parametrize(
73+
"shape",
74+
[
75+
(),
76+
(random.randint(1, 10),),
77+
(random.randint(1, 10), random.randint(1, 10)),
78+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
79+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
80+
],
81+
)
82+
@pytest.mark.parametrize("dtype_name", dtype_map.values())
83+
def test_complex2_supported_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
84+
"""Test complex2 operation across all supported data types."""
85+
lhs = wrapper.randu(shape, dtype_name)
86+
rhs = wrapper.randu(shape, dtype_name)
87+
result = wrapper.cplx2(lhs, rhs)
88+
assert wrapper.is_complex(result), f"Failed for dtype: {dtype_name}"
89+
90+
91+
@pytest.mark.parametrize(
92+
"invdtypes",
93+
[
94+
dtype.c64,
95+
dtype.f64,
96+
],
97+
)
98+
def test_complex2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
99+
"""Test complex2 operation for unsupported data types."""
100+
with pytest.raises(RuntimeError):
101+
shape = (5, 5)
102+
lhs = wrapper.randu(shape, invdtypes)
103+
rhs = wrapper.randu(shape, invdtypes)
104+
wrapper.cplx2(lhs, rhs)
105+
106+
107+
@pytest.mark.parametrize(
108+
"shape",
109+
[
110+
(),
111+
(random.randint(1, 10),),
112+
(random.randint(1, 10), random.randint(1, 10)),
113+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
114+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
115+
],
116+
)
117+
def test_conj_supported_dtypes(shape: tuple) -> None:
118+
"""Test conjugate operation for supported data types."""
119+
arr = wrapper.constant(7, shape, dtype.c32)
120+
result = wrapper.conjg(arr)
121+
assert wrapper.is_complex(result), f"Failed for shape: {shape}"
122+
123+
124+
@pytest.mark.parametrize(
125+
"invdtypes",
126+
[
127+
dtype.c64,
128+
dtype.f64,
129+
],
130+
)
131+
def test_conj_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
132+
"""Test conjugate operation for unsupported data types."""
133+
with pytest.raises(RuntimeError):
134+
shape = (5, 5)
135+
arr = wrapper.randu(shape, invdtypes)
136+
wrapper.conjg(arr)
137+
138+
139+
@pytest.mark.parametrize(
140+
"shape",
141+
[
142+
(),
143+
(random.randint(1, 10),),
144+
(random.randint(1, 10), random.randint(1, 10)),
145+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
146+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
147+
],
148+
)
149+
def test_imag_real_supported_dtypes(shape: tuple) -> None:
150+
"""Test imaginary and real operations for supported data types."""
151+
arr = wrapper.randu(shape, dtype.c32)
152+
imaginary = wrapper.imag(arr)
153+
real = wrapper.real(arr)
154+
assert not wrapper.is_empty(imaginary), f"Failed for shape: {shape}"
155+
assert not wrapper.is_empty(real), f"Failed for shape: {shape}"
156+
157+
158+
@pytest.mark.parametrize(
159+
"invdtypes",
160+
[
161+
dtype.c64,
162+
dtype.f64,
163+
],
164+
)
165+
def test_imag_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
166+
"""Test conjugate operation for unsupported data types."""
167+
with pytest.raises(RuntimeError):
168+
shape = (5, 5)
169+
arr = wrapper.randu(shape, invdtypes)
170+
wrapper.imag(arr)
171+
172+
173+
@pytest.mark.parametrize(
174+
"invdtypes",
175+
[
176+
dtype.c64,
177+
dtype.f64,
178+
],
179+
)
180+
def test_real_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
181+
"""Test real operation for unsupported data types."""
182+
with pytest.raises(RuntimeError):
183+
shape = (5, 5)
184+
arr = wrapper.randu(shape, invdtypes)
185+
wrapper.real(arr)

0 commit comments

Comments
 (0)