Skip to content

Commit d01fc1a

Browse files
Chaluvadiroaffix
Chaluvadi
authored andcommitted
Fixed import formatting and automatic checks
1 parent 8711de9 commit d01fc1a

File tree

1 file changed

+12
-25
lines changed

1 file changed

+12
-25
lines changed

tests/test_iota.py

+12-25
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import pytest
55

66
import arrayfire_wrapper.dtypes as dtypes
7-
from arrayfire_wrapper.lib.create_and_modify_array.create_array.iota import iota
8-
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_dims, get_type
9-
from arrayfire_wrapper.lib.create_and_modify_array.manage_device import get_dbl_support
7+
import arrayfire_wrapper.lib as wrapper
108

119

1210
@pytest.mark.parametrize(
@@ -24,9 +22,9 @@ def test_iota_shape(shape: tuple) -> None:
2422
dtype = dtypes.s16
2523
t_shape = (1, 1)
2624

27-
result = iota(shape, t_shape, dtype)
25+
result = wrapper.iota(shape, t_shape, dtype)
2826

29-
assert get_dims(result)[0:len(shape)] == shape
27+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
3028

3129

3230
def test_iota_invalid_shape() -> None:
@@ -42,7 +40,7 @@ def test_iota_invalid_shape() -> None:
4240
dtype = dtypes.s16
4341
t_shape = ()
4442

45-
iota(invalid_shape, t_shape, dtype)
43+
wrapper.iota(invalid_shape, t_shape, dtype)
4644

4745
assert f"CShape.__init__() takes from 1 to 5 positional arguments but {len(invalid_shape) + 1} were given" in str(
4846
excinfo.value
@@ -69,11 +67,11 @@ def test_iota_tshape(t_shape: tuple) -> None:
6967

7068
result_shape = shape * t_shape
7169

72-
result = iota(tuple(shape), t_shape, dtype)
70+
result = wrapper.iota(tuple(shape), t_shape, dtype)
7371

74-
result_dims = tuple(int(value) for value in get_dims(result))
72+
result_dims = tuple(int(value) for value in wrapper.get_dims(result))
7573

76-
assert (result_dims[0:len(result_shape)] == result_shape).all()
74+
assert (result_dims[0 : len(result_shape)] == result_shape).all() # noqa: E203
7775

7876

7977
@pytest.mark.parametrize(
@@ -90,18 +88,7 @@ def test_iota_tshape_zero(t_shape: tuple) -> None:
9088

9189
dtype = dtypes.s16
9290

93-
iota(shape, t_shape, dtype)
94-
95-
96-
def test_iota_tshape_float() -> None:
97-
"""Test it iota properly handles float t_shapes"""
98-
with pytest.raises(TypeError):
99-
shape = (2, 2)
100-
t_shape = (1.5, 1.5)
101-
102-
dtype = dtypes.s16
103-
104-
iota(shape, t_shape, dtype)
91+
wrapper.iota(shape, t_shape, dtype)
10592

10693

10794
def test_iota_tshape_invalid() -> None:
@@ -117,7 +104,7 @@ def test_iota_tshape_invalid() -> None:
117104
)
118105
dtype = dtypes.s16
119106

120-
iota(shape, invalid_tshape, dtype)
107+
wrapper.iota(shape, invalid_tshape, dtype)
121108

122109

123110
@pytest.mark.parametrize(
@@ -126,13 +113,13 @@ def test_iota_tshape_invalid() -> None:
126113
)
127114
def test_iota_dtype(dtype_index: int) -> None:
128115
"""Test if iota creates an array with the correct dtype"""
129-
if (dtype_index in [1, 4]) or (dtype_index in [2, 3] and not get_dbl_support()):
116+
if (dtype_index in [1, 4]) or (dtype_index in [2, 3] and not wrapper.get_dbl_support()):
130117
pytest.skip()
131118

132119
shape = (5, 5)
133120
t_shape = (2, 2)
134121
dtype = dtypes.c_api_value_to_dtype(dtype_index)
135122

136-
result = iota(shape, t_shape, dtype)
123+
result = wrapper.iota(shape, t_shape, dtype)
137124

138-
assert dtypes.c_api_value_to_dtype(get_type(result)) == dtype
125+
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype

0 commit comments

Comments
 (0)