Skip to content

Commit b51518e

Browse files
AzeezIshsyurkevi
AzeezIsh
authored andcommitted
Removed type map, added all necessary coverage.
1 parent 2b5352c commit b51518e

File tree

1 file changed

+13
-47
lines changed

1 file changed

+13
-47
lines changed

tests/test_hyperbolic.py

+13-47
Original file line numberDiff line numberDiff line change
@@ -4,52 +4,23 @@
44

55
import arrayfire_wrapper.dtypes as dtype
66
import arrayfire_wrapper.lib as wrapper
7-
8-
dtype_map = {
9-
"int16": dtype.s16,
10-
"int32": dtype.s32,
11-
"int64": dtype.s64,
12-
"uint8": dtype.u8,
13-
"uint16": dtype.u16,
14-
"uint32": dtype.u32,
15-
"uint64": dtype.u64,
16-
"float16": dtype.f16,
17-
"float32": dtype.f32,
18-
# 'float64': dtype.f64,
19-
# 'complex64': dtype.c64,
20-
# 'complex32': dtype.c32,
21-
"bool": dtype.b8,
22-
"s16": dtype.s16,
23-
"s32": dtype.s32,
24-
"s64": dtype.s64,
25-
"u8": dtype.u8,
26-
"u16": dtype.u16,
27-
"u32": dtype.u32,
28-
"u64": dtype.u64,
29-
"f16": dtype.f16,
30-
"f32": dtype.f32,
31-
# 'f64': dtype.f64,
32-
# 'c32': dtype.c32,
33-
# 'c64': dtype.c64,
34-
"b8": dtype.b8,
35-
}
7+
from tests.utility_functions import check_type_supported, get_all_types, get_float_types
368

379

3810
@pytest.mark.parametrize(
3911
"shape",
4012
[
4113
(),
4214
(random.randint(1, 10),),
43-
(random.randint(1, 10),),
44-
(random.randint(1, 10),),
4515
(random.randint(1, 10), random.randint(1, 10)),
4616
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
4717
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
4818
],
4919
)
50-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
20+
@pytest.mark.parametrize("dtype_name", get_all_types())
5121
def test_asinh_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
5222
"""Test inverse hyperbolic sine operation across all supported data types."""
23+
check_type_supported(dtype_name)
5324
values = wrapper.randu(shape, dtype_name)
5425
result = wrapper.asinh(values)
5526
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -73,16 +44,15 @@ def test_asinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
7344
[
7445
(),
7546
(random.randint(1, 10),),
76-
(random.randint(1, 10),),
77-
(random.randint(1, 10),),
7847
(random.randint(1, 10), random.randint(1, 10)),
7948
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
8049
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
8150
],
8251
)
83-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
52+
@pytest.mark.parametrize("dtype_name", get_all_types())
8453
def test_acosh_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
8554
"""Test inverse hyperbolic cosine operation across all supported data types."""
55+
check_type_supported(dtype_name)
8656
values = wrapper.randu(shape, dtype_name)
8757
result = wrapper.acosh(values)
8858
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -106,16 +76,15 @@ def test_acosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
10676
[
10777
(),
10878
(random.randint(1, 10),),
109-
(random.randint(1, 10),),
110-
(random.randint(1, 10),),
11179
(random.randint(1, 10), random.randint(1, 10)),
11280
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
11381
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
11482
],
11583
)
116-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
84+
@pytest.mark.parametrize("dtype_name", get_all_types())
11785
def test_atanh_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
11886
"""Test inverse hyperbolic tan operation across all supported data types."""
87+
check_type_supported(dtype_name)
11988
values = wrapper.randu(shape, dtype_name)
12089
result = wrapper.atanh(values)
12190
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -139,16 +108,15 @@ def test_atanh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
139108
[
140109
(),
141110
(random.randint(1, 10),),
142-
(random.randint(1, 10),),
143-
(random.randint(1, 10),),
144111
(random.randint(1, 10), random.randint(1, 10)),
145112
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
146113
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
147114
],
148115
)
149-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
116+
@pytest.mark.parametrize("dtype_name", get_all_types())
150117
def test_cosh_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
151118
"""Test hyperbolic cosine operation across all supported data types."""
119+
check_type_supported(dtype_name)
152120
values = wrapper.randu(shape, dtype_name)
153121
result = wrapper.cosh(values)
154122
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -172,16 +140,15 @@ def test_cosh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
172140
[
173141
(),
174142
(random.randint(1, 10),),
175-
(random.randint(1, 10),),
176-
(random.randint(1, 10),),
177143
(random.randint(1, 10), random.randint(1, 10)),
178144
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
179145
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
180146
],
181147
)
182-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
148+
@pytest.mark.parametrize("dtype_name", get_all_types())
183149
def test_sinh_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
184150
"""Test hyberbolic sin operation across all supported data types."""
151+
check_type_supported(dtype_name)
185152
values = wrapper.randu(shape, dtype_name)
186153
result = wrapper.sinh(values)
187154
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -205,16 +172,15 @@ def test_sinh_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
205172
[
206173
(),
207174
(random.randint(1, 10),),
208-
(random.randint(1, 10),),
209-
(random.randint(1, 10),),
210175
(random.randint(1, 10), random.randint(1, 10)),
211176
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
212177
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
213178
],
214179
)
215-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
180+
@pytest.mark.parametrize("dtype_name", get_all_types())
216181
def test_tanh_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
217182
"""Test hyberbolic tan operation across all supported data types."""
183+
check_type_supported(dtype_name)
218184
values = wrapper.randu(shape, dtype_name)
219185
result = wrapper.tanh(values)
220186
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa

0 commit comments

Comments
 (0)