4
4
5
5
import arrayfire_wrapper .dtypes as dtype
6
6
import arrayfire_wrapper .lib as wrapper
7
-
8
- dtype_map = {
9
- "int16" : dtype .s16 ,
10
- "int32" : dtype .s32 ,
11
- "int64" : dtype .s64 ,
12
- "uint8" : dtype .u8 ,
13
- "uint16" : dtype .u16 ,
14
- "uint32" : dtype .u32 ,
15
- "uint64" : dtype .u64 ,
16
- "float16" : dtype .f16 ,
17
- "float32" : dtype .f32 ,
18
- # 'float64': dtype.f64,
19
- # 'complex64': dtype.c64,
20
- # 'complex32': dtype.c32,
21
- "bool" : dtype .b8 ,
22
- "s16" : dtype .s16 ,
23
- "s32" : dtype .s32 ,
24
- "s64" : dtype .s64 ,
25
- "u8" : dtype .u8 ,
26
- "u16" : dtype .u16 ,
27
- "u32" : dtype .u32 ,
28
- "u64" : dtype .u64 ,
29
- "f16" : dtype .f16 ,
30
- "f32" : dtype .f32 ,
31
- # 'f64': dtype.f64,
32
- # 'c32': dtype.c32,
33
- # 'c64': dtype.c64,
34
- "b8" : dtype .b8 ,
35
- }
7
+ from tests .utility_functions import check_type_supported , get_all_types , get_float_types
36
8
37
9
38
10
@pytest .mark .parametrize (
39
11
"shape" ,
40
12
[
41
13
(),
42
14
(random .randint (1 , 10 ),),
43
- (random .randint (1 , 10 ),),
44
15
(random .randint (1 , 10 ), random .randint (1 , 10 )),
45
16
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
46
17
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
47
18
],
48
19
)
49
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
20
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
50
21
def test_cbrt_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
51
22
"""Test cube root operation across all supported data types."""
23
+ check_type_supported (dtype_name )
52
24
values = wrapper .randu (shape , dtype_name )
53
25
result = wrapper .cbrt (values )
54
26
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -72,15 +44,15 @@ def test_cbrt_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
72
44
[
73
45
(),
74
46
(random .randint (1 , 10 ),),
75
- (random .randint (1 , 10 ),),
76
47
(random .randint (1 , 10 ), random .randint (1 , 10 )),
77
48
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
78
49
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
79
50
],
80
51
)
81
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
52
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
82
53
def test_erf_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
83
54
"""Test gaussian error operation across all supported data types."""
55
+ check_type_supported (dtype_name )
84
56
values = wrapper .randu (shape , dtype_name )
85
57
result = wrapper .erf (values )
86
58
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -104,15 +76,15 @@ def test_erf_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
104
76
[
105
77
(),
106
78
(random .randint (1 , 10 ),),
107
- (random .randint (1 , 10 ),),
108
79
(random .randint (1 , 10 ), random .randint (1 , 10 )),
109
80
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
110
81
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
111
82
],
112
83
)
113
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
84
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
114
85
def test_erfc_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
115
86
"""Test gaussian error complement operation across all supported data types."""
87
+ check_type_supported (dtype_name )
116
88
values = wrapper .randu (shape , dtype_name )
117
89
result = wrapper .erfc (values )
118
90
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -136,15 +108,15 @@ def test_erfc_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
136
108
[
137
109
(),
138
110
(random .randint (1 , 10 ),),
139
- (random .randint (1 , 10 ),),
140
111
(random .randint (1 , 10 ), random .randint (1 , 10 )),
141
112
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
142
113
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
143
114
],
144
115
)
145
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
116
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
146
117
def test_exp_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
147
118
"""Test exponent operation across all supported data types."""
119
+ check_type_supported (dtype_name )
148
120
values = wrapper .randu (shape , dtype_name )
149
121
result = wrapper .exp (values )
150
122
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -168,15 +140,15 @@ def test_exp_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
168
140
[
169
141
(),
170
142
(random .randint (1 , 10 ),),
171
- (random .randint (1 , 10 ),),
172
143
(random .randint (1 , 10 ), random .randint (1 , 10 )),
173
144
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
174
145
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
175
146
],
176
147
)
177
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
148
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
178
149
def test_exp1_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
179
150
"""Test exponent - 1 operation across all supported data types."""
151
+ check_type_supported (dtype_name )
180
152
values = wrapper .randu (shape , dtype_name )
181
153
result = wrapper .expm1 (values )
182
154
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -200,15 +172,15 @@ def test_expm1_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
200
172
[
201
173
(),
202
174
(random .randint (1 , 10 ),),
203
- (random .randint (1 , 10 ),),
204
175
(random .randint (1 , 10 ), random .randint (1 , 10 )),
205
176
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
206
177
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
207
178
],
208
179
)
209
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
180
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
210
181
def test_fac_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
211
182
"""Test exponent operation across all supported data types."""
183
+ check_type_supported (dtype_name )
212
184
values = wrapper .randu (shape , dtype_name )
213
185
result = wrapper .factorial (values )
214
186
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -238,9 +210,10 @@ def test_fac_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
238
210
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
239
211
],
240
212
)
241
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
213
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
242
214
def test_lgamma_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
243
215
"""Test lgamma operation across all supported data types."""
216
+ check_type_supported (dtype_name )
244
217
values = wrapper .randu (shape , dtype_name )
245
218
result = wrapper .lgamma (values )
246
219
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -264,15 +237,15 @@ def test_lgamma_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
264
237
[
265
238
(),
266
239
(random .randint (1 , 10 ),),
267
- (random .randint (1 , 10 ),),
268
240
(random .randint (1 , 10 ), random .randint (1 , 10 )),
269
241
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
270
242
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
271
243
],
272
244
)
273
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
245
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
274
246
def test_log_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
275
247
"""Test log operation across all supported data types."""
248
+ check_type_supported (dtype_name )
276
249
values = wrapper .randu (shape , dtype_name )
277
250
result = wrapper .log (values )
278
251
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -296,15 +269,15 @@ def test_log_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
296
269
[
297
270
(),
298
271
(random .randint (1 , 10 ),),
299
- (random .randint (1 , 10 ),),
300
272
(random .randint (1 , 10 ), random .randint (1 , 10 )),
301
273
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
302
274
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
303
275
],
304
276
)
305
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
277
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
306
278
def test_log10_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
307
279
"""Test log10 operation across all supported data types."""
280
+ check_type_supported (dtype_name )
308
281
values = wrapper .randu (shape , dtype_name )
309
282
result = wrapper .log10 (values )
310
283
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -328,15 +301,15 @@ def test_log10_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
328
301
[
329
302
(),
330
303
(random .randint (1 , 10 ),),
331
- (random .randint (1 , 10 ),),
332
304
(random .randint (1 , 10 ), random .randint (1 , 10 )),
333
305
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
334
306
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
335
307
],
336
308
)
337
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
309
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
338
310
def test_log1p_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
339
311
"""Test natural logarithm of 1 + input operation across all supported data types."""
312
+ check_type_supported (dtype_name )
340
313
values = wrapper .randu (shape , dtype_name )
341
314
result = wrapper .log1p (values )
342
315
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -360,15 +333,15 @@ def test_log1p_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
360
333
[
361
334
(),
362
335
(random .randint (1 , 10 ),),
363
- (random .randint (1 , 10 ),),
364
336
(random .randint (1 , 10 ), random .randint (1 , 10 )),
365
337
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
366
338
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
367
339
],
368
340
)
369
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
341
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
370
342
def test_log2_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
371
343
"""Test log2 operation across all supported data types."""
344
+ check_type_supported (dtype_name )
372
345
values = wrapper .randu (shape , dtype_name )
373
346
result = wrapper .log2 (values )
374
347
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -392,15 +365,15 @@ def test_log2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
392
365
[
393
366
(),
394
367
(random .randint (1 , 10 ),),
395
- (random .randint (1 , 10 ),),
396
368
(random .randint (1 , 10 ), random .randint (1 , 10 )),
397
369
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
398
370
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
399
371
],
400
372
)
401
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
373
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
402
374
def test_pow_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
403
375
"""Test power operation across all supported data types."""
376
+ check_type_supported (dtype_name )
404
377
lhs = wrapper .randu (shape , dtype_name )
405
378
rhs = wrapper .randu (shape , dtype_name )
406
379
result = wrapper .pow (lhs , rhs )
@@ -425,15 +398,15 @@ def test_pow_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
425
398
[
426
399
(),
427
400
(random .randint (1 , 10 ),),
428
- (random .randint (1 , 10 ),),
429
401
(random .randint (1 , 10 ), random .randint (1 , 10 )),
430
402
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
431
403
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
432
404
],
433
405
)
434
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
406
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
435
407
def test_root_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
436
408
"""Test root operation across all supported data types."""
409
+ check_type_supported (dtype_name )
437
410
lhs = wrapper .randu (shape , dtype_name )
438
411
rhs = wrapper .randu (shape , dtype_name )
439
412
result = wrapper .root (lhs , rhs )
@@ -464,9 +437,10 @@ def test_root_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
464
437
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
465
438
],
466
439
)
467
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
440
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
468
441
def test_pow2_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
469
442
"""Test 2 to power operation across all supported data types."""
443
+ check_type_supported (dtype_name )
470
444
values = wrapper .randu (shape , dtype_name )
471
445
result = wrapper .pow2 (values )
472
446
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -490,15 +464,15 @@ def test_pow2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
490
464
[
491
465
(),
492
466
(random .randint (1 , 10 ),),
493
- (random .randint (1 , 10 ),),
494
467
(random .randint (1 , 10 ), random .randint (1 , 10 )),
495
468
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
496
469
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
497
470
],
498
471
)
499
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
472
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
500
473
def test_rsqrt_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
501
474
"""Test reciprocal square root operation across all supported data types."""
475
+ check_type_supported (dtype_name )
502
476
values = wrapper .randu (shape , dtype_name )
503
477
result = wrapper .rsqrt (values )
504
478
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -522,15 +496,15 @@ def test_rsqrt_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
522
496
[
523
497
(),
524
498
(random .randint (1 , 10 ),),
525
- (random .randint (1 , 10 ),),
526
499
(random .randint (1 , 10 ), random .randint (1 , 10 )),
527
500
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
528
501
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
529
502
],
530
503
)
531
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
504
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
532
505
def test_sqrt_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
533
506
"""Test square root operation across all supported data types."""
507
+ check_type_supported (dtype_name )
534
508
values = wrapper .randu (shape , dtype_name )
535
509
result = wrapper .sqrt (values )
536
510
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -554,15 +528,15 @@ def test_sqrt_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
554
528
[
555
529
(),
556
530
(random .randint (1 , 10 ),),
557
- (random .randint (1 , 10 ),),
558
531
(random .randint (1 , 10 ), random .randint (1 , 10 )),
559
532
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
560
533
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
561
534
],
562
535
)
563
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
536
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
564
537
def test_tgamma_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
565
538
"""Test gamma operation across all supported data types."""
539
+ check_type_supported (dtype_name )
566
540
values = wrapper .randu (shape , dtype_name )
567
541
result = wrapper .tgamma (values )
568
542
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
@@ -586,15 +560,15 @@ def test_tgamma_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
586
560
[
587
561
(),
588
562
(random .randint (1 , 10 ),),
589
- (random .randint (1 , 10 ),),
590
563
(random .randint (1 , 10 ), random .randint (1 , 10 )),
591
564
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
592
565
(random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
593
566
],
594
567
)
595
- @pytest .mark .parametrize ("dtype_name" , dtype_map . values ())
568
+ @pytest .mark .parametrize ("dtype_name" , get_all_types ())
596
569
def test_sigmoid_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
597
570
"""Test sigmoid operation across all supported data types."""
571
+ check_type_supported (dtype_name )
598
572
values = wrapper .randu (shape , dtype_name )
599
573
result = wrapper .sigmoid (values )
600
574
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
0 commit comments