1
1
import pytest
2
2
3
3
import arrayfire_wrapper .lib as wrapper
4
- from arrayfire_wrapper .dtypes import Dtype , c32 , c64 , f16 , f32 , f64 , s16 , s32 , s64 , u8 , u16 , u32 , u64
4
+ from arrayfire_wrapper .dtypes import Dtype , b8 , c32 , c64 , f16 , f32 , f64 , s16 , s32 , s64 , u8 , u16 , u32 , u64
5
5
6
6
7
7
def check_type_supported (dtype : Dtype ) -> None :
8
8
"""Checks to see if the specified type is supported by the current system"""
9
9
if dtype in [f64 , c64 ] and not wrapper .get_dbl_support ():
10
10
pytest .skip ("Device does not support double types" )
11
-
12
11
if dtype == f16 and not wrapper .get_half_support ():
13
12
pytest .skip ("Device does not support half types." )
14
13
@@ -25,4 +24,8 @@ def get_real_types() -> list:
25
24
26
25
def get_all_types () -> list :
27
26
"""Returns all types"""
28
- return [s16 , s32 , s64 , u8 , u16 , u32 , u64 , f16 , f32 , f64 , c32 , c64 ]
27
+ return [b8 , s16 , s32 , s64 , u8 , u16 , u32 , u64 , f16 , f32 , f64 , c32 , c64 ]
28
+
29
+ def get_float_types () -> list :
30
+ """Returns all types"""
31
+ return [f16 , f32 , f64 ]
0 commit comments