Skip to content

Commit 034af6c

Browse files
author
AzeezIsh
committed
Added all dtype coverage, removed dmap.
1 parent 9134686 commit 034af6c

File tree

1 file changed

+37
-63
lines changed

1 file changed

+37
-63
lines changed

tests/test_exp_log.py

+37-63
Original file line numberDiff line numberDiff line change
@@ -4,51 +4,23 @@
44

55
import arrayfire_wrapper.dtypes as dtype
66
import arrayfire_wrapper.lib as wrapper
7-
8-
dtype_map = {
9-
"int16": dtype.s16,
10-
"int32": dtype.s32,
11-
"int64": dtype.s64,
12-
"uint8": dtype.u8,
13-
"uint16": dtype.u16,
14-
"uint32": dtype.u32,
15-
"uint64": dtype.u64,
16-
"float16": dtype.f16,
17-
"float32": dtype.f32,
18-
# 'float64': dtype.f64,
19-
# 'complex64': dtype.c64,
20-
# 'complex32': dtype.c32,
21-
"bool": dtype.b8,
22-
"s16": dtype.s16,
23-
"s32": dtype.s32,
24-
"s64": dtype.s64,
25-
"u8": dtype.u8,
26-
"u16": dtype.u16,
27-
"u32": dtype.u32,
28-
"u64": dtype.u64,
29-
"f16": dtype.f16,
30-
"f32": dtype.f32,
31-
# 'f64': dtype.f64,
32-
# 'c32': dtype.c32,
33-
# 'c64': dtype.c64,
34-
"b8": dtype.b8,
35-
}
7+
from tests.utility_functions import check_type_supported, get_all_types, get_float_types
368

379

3810
@pytest.mark.parametrize(
3911
"shape",
4012
[
4113
(),
4214
(random.randint(1, 10),),
43-
(random.randint(1, 10),),
4415
(random.randint(1, 10), random.randint(1, 10)),
4516
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
4617
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
4718
],
4819
)
49-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
20+
@pytest.mark.parametrize("dtype_name", get_all_types())
5021
def test_cbrt_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
5122
"""Test cube root operation across all supported data types."""
23+
check_type_supported(dtype_name)
5224
values = wrapper.randu(shape, dtype_name)
5325
result = wrapper.cbrt(values)
5426
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -72,15 +44,15 @@ def test_cbrt_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
7244
[
7345
(),
7446
(random.randint(1, 10),),
75-
(random.randint(1, 10),),
7647
(random.randint(1, 10), random.randint(1, 10)),
7748
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
7849
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
7950
],
8051
)
81-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
52+
@pytest.mark.parametrize("dtype_name", get_all_types())
8253
def test_erf_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
8354
"""Test gaussian error operation across all supported data types."""
55+
check_type_supported(dtype_name)
8456
values = wrapper.randu(shape, dtype_name)
8557
result = wrapper.erf(values)
8658
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -104,15 +76,15 @@ def test_erf_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
10476
[
10577
(),
10678
(random.randint(1, 10),),
107-
(random.randint(1, 10),),
10879
(random.randint(1, 10), random.randint(1, 10)),
10980
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
11081
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
11182
],
11283
)
113-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
84+
@pytest.mark.parametrize("dtype_name", get_all_types())
11485
def test_erfc_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
11586
"""Test gaussian error complement operation across all supported data types."""
87+
check_type_supported(dtype_name)
11688
values = wrapper.randu(shape, dtype_name)
11789
result = wrapper.erfc(values)
11890
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -136,15 +108,15 @@ def test_erfc_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
136108
[
137109
(),
138110
(random.randint(1, 10),),
139-
(random.randint(1, 10),),
140111
(random.randint(1, 10), random.randint(1, 10)),
141112
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
142113
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
143114
],
144115
)
145-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
116+
@pytest.mark.parametrize("dtype_name", get_all_types())
146117
def test_exp_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
147118
"""Test exponent operation across all supported data types."""
119+
check_type_supported(dtype_name)
148120
values = wrapper.randu(shape, dtype_name)
149121
result = wrapper.exp(values)
150122
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -168,15 +140,15 @@ def test_exp_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
168140
[
169141
(),
170142
(random.randint(1, 10),),
171-
(random.randint(1, 10),),
172143
(random.randint(1, 10), random.randint(1, 10)),
173144
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
174145
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
175146
],
176147
)
177-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
148+
@pytest.mark.parametrize("dtype_name", get_all_types())
178149
def test_exp1_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
179150
"""Test exponent - 1 operation across all supported data types."""
151+
check_type_supported(dtype_name)
180152
values = wrapper.randu(shape, dtype_name)
181153
result = wrapper.expm1(values)
182154
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -200,15 +172,15 @@ def test_expm1_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
200172
[
201173
(),
202174
(random.randint(1, 10),),
203-
(random.randint(1, 10),),
204175
(random.randint(1, 10), random.randint(1, 10)),
205176
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
206177
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
207178
],
208179
)
209-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
180+
@pytest.mark.parametrize("dtype_name", get_all_types())
210181
def test_fac_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
211182
"""Test exponent operation across all supported data types."""
183+
check_type_supported(dtype_name)
212184
values = wrapper.randu(shape, dtype_name)
213185
result = wrapper.factorial(values)
214186
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -238,9 +210,10 @@ def test_fac_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
238210
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
239211
],
240212
)
241-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
213+
@pytest.mark.parametrize("dtype_name", get_all_types())
242214
def test_lgamma_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
243215
"""Test lgamma operation across all supported data types."""
216+
check_type_supported(dtype_name)
244217
values = wrapper.randu(shape, dtype_name)
245218
result = wrapper.lgamma(values)
246219
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -264,15 +237,15 @@ def test_lgamma_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
264237
[
265238
(),
266239
(random.randint(1, 10),),
267-
(random.randint(1, 10),),
268240
(random.randint(1, 10), random.randint(1, 10)),
269241
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
270242
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
271243
],
272244
)
273-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
245+
@pytest.mark.parametrize("dtype_name", get_all_types())
274246
def test_log_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
275247
"""Test log operation across all supported data types."""
248+
check_type_supported(dtype_name)
276249
values = wrapper.randu(shape, dtype_name)
277250
result = wrapper.log(values)
278251
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -296,15 +269,15 @@ def test_log_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
296269
[
297270
(),
298271
(random.randint(1, 10),),
299-
(random.randint(1, 10),),
300272
(random.randint(1, 10), random.randint(1, 10)),
301273
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
302274
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
303275
],
304276
)
305-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
277+
@pytest.mark.parametrize("dtype_name", get_all_types())
306278
def test_log10_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
307279
"""Test log10 operation across all supported data types."""
280+
check_type_supported(dtype_name)
308281
values = wrapper.randu(shape, dtype_name)
309282
result = wrapper.log10(values)
310283
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -328,15 +301,15 @@ def test_log10_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
328301
[
329302
(),
330303
(random.randint(1, 10),),
331-
(random.randint(1, 10),),
332304
(random.randint(1, 10), random.randint(1, 10)),
333305
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
334306
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
335307
],
336308
)
337-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
309+
@pytest.mark.parametrize("dtype_name", get_all_types())
338310
def test_log1p_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
339311
"""Test natural logarithm of 1 + input operation across all supported data types."""
312+
check_type_supported(dtype_name)
340313
values = wrapper.randu(shape, dtype_name)
341314
result = wrapper.log1p(values)
342315
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -360,15 +333,15 @@ def test_log1p_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
360333
[
361334
(),
362335
(random.randint(1, 10),),
363-
(random.randint(1, 10),),
364336
(random.randint(1, 10), random.randint(1, 10)),
365337
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
366338
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
367339
],
368340
)
369-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
341+
@pytest.mark.parametrize("dtype_name", get_all_types())
370342
def test_log2_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
371343
"""Test log2 operation across all supported data types."""
344+
check_type_supported(dtype_name)
372345
values = wrapper.randu(shape, dtype_name)
373346
result = wrapper.log2(values)
374347
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -392,15 +365,15 @@ def test_log2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
392365
[
393366
(),
394367
(random.randint(1, 10),),
395-
(random.randint(1, 10),),
396368
(random.randint(1, 10), random.randint(1, 10)),
397369
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
398370
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
399371
],
400372
)
401-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
373+
@pytest.mark.parametrize("dtype_name", get_all_types())
402374
def test_pow_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
403375
"""Test power operation across all supported data types."""
376+
check_type_supported(dtype_name)
404377
lhs = wrapper.randu(shape, dtype_name)
405378
rhs = wrapper.randu(shape, dtype_name)
406379
result = wrapper.pow(lhs, rhs)
@@ -425,15 +398,15 @@ def test_pow_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
425398
[
426399
(),
427400
(random.randint(1, 10),),
428-
(random.randint(1, 10),),
429401
(random.randint(1, 10), random.randint(1, 10)),
430402
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
431403
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
432404
],
433405
)
434-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
406+
@pytest.mark.parametrize("dtype_name", get_all_types())
435407
def test_root_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
436408
"""Test root operation across all supported data types."""
409+
check_type_supported(dtype_name)
437410
lhs = wrapper.randu(shape, dtype_name)
438411
rhs = wrapper.randu(shape, dtype_name)
439412
result = wrapper.root(lhs, rhs)
@@ -464,9 +437,10 @@ def test_root_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
464437
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
465438
],
466439
)
467-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
440+
@pytest.mark.parametrize("dtype_name", get_all_types())
468441
def test_pow2_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
469442
"""Test 2 to power operation across all supported data types."""
443+
check_type_supported(dtype_name)
470444
values = wrapper.randu(shape, dtype_name)
471445
result = wrapper.pow2(values)
472446
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -490,15 +464,15 @@ def test_pow2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
490464
[
491465
(),
492466
(random.randint(1, 10),),
493-
(random.randint(1, 10),),
494467
(random.randint(1, 10), random.randint(1, 10)),
495468
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
496469
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
497470
],
498471
)
499-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
472+
@pytest.mark.parametrize("dtype_name", get_all_types())
500473
def test_rsqrt_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
501474
"""Test reciprocal square root operation across all supported data types."""
475+
check_type_supported(dtype_name)
502476
values = wrapper.randu(shape, dtype_name)
503477
result = wrapper.rsqrt(values)
504478
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -522,15 +496,15 @@ def test_rsqrt_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
522496
[
523497
(),
524498
(random.randint(1, 10),),
525-
(random.randint(1, 10),),
526499
(random.randint(1, 10), random.randint(1, 10)),
527500
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
528501
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
529502
],
530503
)
531-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
504+
@pytest.mark.parametrize("dtype_name", get_all_types())
532505
def test_sqrt_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
533506
"""Test square root operation across all supported data types."""
507+
check_type_supported(dtype_name)
534508
values = wrapper.randu(shape, dtype_name)
535509
result = wrapper.sqrt(values)
536510
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -554,15 +528,15 @@ def test_sqrt_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
554528
[
555529
(),
556530
(random.randint(1, 10),),
557-
(random.randint(1, 10),),
558531
(random.randint(1, 10), random.randint(1, 10)),
559532
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
560533
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
561534
],
562535
)
563-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
536+
@pytest.mark.parametrize("dtype_name", get_all_types())
564537
def test_tgamma_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
565538
"""Test gamma operation across all supported data types."""
539+
check_type_supported(dtype_name)
566540
values = wrapper.randu(shape, dtype_name)
567541
result = wrapper.tgamma(values)
568542
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
@@ -586,15 +560,15 @@ def test_tgamma_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
586560
[
587561
(),
588562
(random.randint(1, 10),),
589-
(random.randint(1, 10),),
590563
(random.randint(1, 10), random.randint(1, 10)),
591564
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
592565
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
593566
],
594567
)
595-
@pytest.mark.parametrize("dtype_name", dtype_map.values())
568+
@pytest.mark.parametrize("dtype_name", get_all_types())
596569
def test_sigmoid_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
597570
"""Test sigmoid operation across all supported data types."""
571+
check_type_supported(dtype_name)
598572
values = wrapper.randu(shape, dtype_name)
599573
result = wrapper.sigmoid(values)
600574
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa

0 commit comments

Comments
 (0)