Skip to content

Commit 8a0d0b2

Browse files
author
AzeezIsh
committed
Incorperated Utility Functions, added checkstyle
Checked for all dtype compatibility issues where needed.
1 parent 6402f56 commit 8a0d0b2

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed

tests/test_trig.py

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.dtypes as dtype
6+
import arrayfire_wrapper.lib as wrapper
7+
from . import utility_functions as util
8+
9+
10+
@pytest.mark.parametrize(
11+
"shape",
12+
[
13+
(),
14+
(random.randint(1, 10),),
15+
(random.randint(1, 10), random.randint(1, 10)),
16+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
17+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
18+
],
19+
)
20+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
21+
def test_asin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
22+
"""Test inverse sine operation across all supported data types."""
23+
util.check_type_supported(dtype_name)
24+
values = wrapper.randu(shape, dtype_name)
25+
result = wrapper.asin(values)
26+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
27+
28+
29+
@pytest.mark.parametrize(
30+
"shape",
31+
[
32+
(),
33+
(random.randint(1, 10),),
34+
(random.randint(1, 10), random.randint(1, 10)),
35+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
36+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
37+
],
38+
)
39+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
40+
def test_acos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
41+
"""Test inverse cosine operation across all supported data types."""
42+
util.check_type_supported(dtype_name)
43+
values = wrapper.randu(shape, dtype_name)
44+
result = wrapper.acos(values)
45+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
46+
47+
48+
@pytest.mark.parametrize(
49+
"shape",
50+
[
51+
(),
52+
(random.randint(1, 10),),
53+
(random.randint(1, 10), random.randint(1, 10)),
54+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
55+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
56+
],
57+
)
58+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
59+
def test_atan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
60+
"""Test inverse tan operation across all supported data types."""
61+
util.check_type_supported(dtype_name)
62+
values = wrapper.randu(shape, dtype_name)
63+
result = wrapper.atan(values)
64+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
65+
66+
67+
@pytest.mark.parametrize(
68+
"shape",
69+
[
70+
(),
71+
(random.randint(1, 10),),
72+
(random.randint(1, 10), random.randint(1, 10)),
73+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
74+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
75+
],
76+
)
77+
@pytest.mark.parametrize("dtype_name", util.get_float_types())
78+
def test_atan2_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
79+
"""Test inverse tan operation across all supported data types."""
80+
util.check_type_supported(dtype_name)
81+
if dtype_name == dtype.f16:
82+
pytest.skip()
83+
lhs = wrapper.randu(shape, dtype_name)
84+
rhs = wrapper.randu(shape, dtype_name)
85+
result = wrapper.atan2(lhs, rhs)
86+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
87+
88+
89+
@pytest.mark.parametrize(
90+
"invdtypes",
91+
[
92+
dtype.int16,
93+
dtype.bool,
94+
],
95+
)
96+
def test_atan2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
97+
"""Test inverse tan operation for unsupported data types."""
98+
with pytest.raises(RuntimeError):
99+
wrapper.atan2(wrapper.randu((10, 10), invdtypes), wrapper.randu((10, 10), invdtypes))
100+
101+
@pytest.mark.parametrize(
102+
"shape",
103+
[
104+
(),
105+
(random.randint(1, 10),),
106+
(random.randint(1, 10), random.randint(1, 10)),
107+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
108+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
109+
],
110+
)
111+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
112+
def test_cos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
113+
"""Test cosine operation across all supported data types."""
114+
util.check_type_supported(dtype_name)
115+
values = wrapper.randu(shape, dtype_name)
116+
result = wrapper.cos(values)
117+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
118+
119+
120+
@pytest.mark.parametrize(
121+
"shape",
122+
[
123+
(),
124+
(random.randint(1, 10),),
125+
(random.randint(1, 10), random.randint(1, 10)),
126+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
127+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
128+
],
129+
)
130+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
131+
def test_sin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
132+
"""Test sin operation across all supported data types."""
133+
util.check_type_supported(dtype_name)
134+
values = wrapper.randu(shape, dtype_name)
135+
result = wrapper.sin(values)
136+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
137+
138+
139+
@pytest.mark.parametrize(
140+
"shape",
141+
[
142+
(),
143+
(random.randint(1, 10),),
144+
(random.randint(1, 10), random.randint(1, 10)),
145+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
146+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
147+
],
148+
)
149+
@pytest.mark.parametrize("dtype_name", util.get_all_types())
150+
def test_tan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
151+
"""Test tan operation across all supported data types."""
152+
util.check_type_supported(dtype_name)
153+
values = wrapper.randu(shape, dtype_name)
154+
result = wrapper.tan(values)
155+
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa

0 commit comments

Comments
 (0)