4
4
import pytest
5
5
6
6
import arrayfire_wrapper .dtypes as dtypes
7
- from arrayfire_wrapper .lib .create_and_modify_array .create_array .iota import iota
8
- from arrayfire_wrapper .lib .create_and_modify_array .manage_array import get_dims , get_type
9
- from arrayfire_wrapper .lib .create_and_modify_array .manage_device import get_dbl_support
7
+ import arrayfire_wrapper .lib as wrapper
10
8
11
9
12
10
@pytest .mark .parametrize (
@@ -24,9 +22,9 @@ def test_iota_shape(shape: tuple) -> None:
24
22
dtype = dtypes .s16
25
23
t_shape = (1 , 1 )
26
24
27
- result = iota (shape , t_shape , dtype )
25
+ result = wrapper . iota (shape , t_shape , dtype )
28
26
29
- assert get_dims (result )[0 : len (shape )] == shape
27
+ assert wrapper . get_dims (result )[0 : len (shape )] == shape # noqa: E203
30
28
31
29
32
30
def test_iota_invalid_shape () -> None :
@@ -42,7 +40,7 @@ def test_iota_invalid_shape() -> None:
42
40
dtype = dtypes .s16
43
41
t_shape = ()
44
42
45
- iota (invalid_shape , t_shape , dtype )
43
+ wrapper . iota (invalid_shape , t_shape , dtype )
46
44
47
45
assert f"CShape.__init__() takes from 1 to 5 positional arguments but { len (invalid_shape ) + 1 } were given" in str (
48
46
excinfo .value
@@ -69,11 +67,11 @@ def test_iota_tshape(t_shape: tuple) -> None:
69
67
70
68
result_shape = shape * t_shape
71
69
72
- result = iota (tuple (shape ), t_shape , dtype )
70
+ result = wrapper . iota (tuple (shape ), t_shape , dtype )
73
71
74
- result_dims = tuple (int (value ) for value in get_dims (result ))
72
+ result_dims = tuple (int (value ) for value in wrapper . get_dims (result ))
75
73
76
- assert (result_dims [0 : len (result_shape )] == result_shape ).all ()
74
+ assert (result_dims [0 : len (result_shape )] == result_shape ).all () # noqa: E203
77
75
78
76
79
77
@pytest .mark .parametrize (
@@ -90,18 +88,7 @@ def test_iota_tshape_zero(t_shape: tuple) -> None:
90
88
91
89
dtype = dtypes .s16
92
90
93
- iota (shape , t_shape , dtype )
94
-
95
-
96
- def test_iota_tshape_float () -> None :
97
- """Test it iota properly handles float t_shapes"""
98
- with pytest .raises (TypeError ):
99
- shape = (2 , 2 )
100
- t_shape = (1.5 , 1.5 )
101
-
102
- dtype = dtypes .s16
103
-
104
- iota (shape , t_shape , dtype )
91
+ wrapper .iota (shape , t_shape , dtype )
105
92
106
93
107
94
def test_iota_tshape_invalid () -> None :
@@ -117,7 +104,7 @@ def test_iota_tshape_invalid() -> None:
117
104
)
118
105
dtype = dtypes .s16
119
106
120
- iota (shape , invalid_tshape , dtype )
107
+ wrapper . iota (shape , invalid_tshape , dtype )
121
108
122
109
123
110
@pytest .mark .parametrize (
@@ -126,13 +113,13 @@ def test_iota_tshape_invalid() -> None:
126
113
)
127
114
def test_iota_dtype (dtype_index : int ) -> None :
128
115
"""Test if iota creates an array with the correct dtype"""
129
- if (dtype_index in [1 , 4 ]) or (dtype_index in [2 , 3 ] and not get_dbl_support ()):
116
+ if (dtype_index in [1 , 4 ]) or (dtype_index in [2 , 3 ] and not wrapper . get_dbl_support ()):
130
117
pytest .skip ()
131
118
132
119
shape = (5 , 5 )
133
120
t_shape = (2 , 2 )
134
121
dtype = dtypes .c_api_value_to_dtype (dtype_index )
135
122
136
- result = iota (shape , t_shape , dtype )
123
+ result = wrapper . iota (shape , t_shape , dtype )
137
124
138
- assert dtypes .c_api_value_to_dtype (get_type (result )) == dtype
125
+ assert dtypes .c_api_value_to_dtype (wrapper . get_type (result )) == dtype
0 commit comments