33import pytest
44
55import arrayfire_wrapper .dtypes as dtypes
6- from arrayfire_wrapper .lib .create_and_modify_array .create_array .identity import identity
7- from arrayfire_wrapper .lib .create_and_modify_array .manage_array import get_dims , get_type
8- from arrayfire_wrapper .lib .create_and_modify_array .manage_device import get_dbl_support
6+ import arrayfire_wrapper .lib as wrapper
97
108
119@pytest .mark .parametrize (
@@ -22,9 +20,9 @@ def test_identity_shape(shape: tuple) -> None:
2220 """Test if identity creates an array with the correct shape"""
2321 dtype = dtypes .s16
2422
25- result = identity (shape , dtype )
23+ result = wrapper . identity (shape , dtype )
2624
27- assert get_dims (result )[0 : len (shape )] == shape
25+ assert wrapper . get_dims (result )[0 : len (shape )] == shape # noqa: E203
2826
2927
3028def test_identity_invalid_shape () -> None :
@@ -39,7 +37,7 @@ def test_identity_invalid_shape() -> None:
3937 )
4038 dtype = dtypes .s16
4139
42- identity (invalid_shape , dtype )
40+ wrapper . identity (invalid_shape , dtype )
4341
4442 assert f"CShape.__init__() takes from 1 to 5 positional arguments but { len (invalid_shape ) + 1 } were given" in str (
4543 excinfo .value
@@ -50,9 +48,9 @@ def test_identity_nonsquare_shape() -> None:
5048 dtype = dtypes .s16
5149 shape = (5 , 6 )
5250
53- result = identity (shape , dtype )
51+ result = wrapper . identity (shape , dtype )
5452
55- assert get_dims (result )[0 : len (shape )] == shape
53+ assert wrapper . get_dims (result )[0 : len (shape )] == shape # noqa: E203
5654
5755
5856@pytest .mark .parametrize (
@@ -61,12 +59,12 @@ def test_identity_nonsquare_shape() -> None:
6159)
6260def test_identity_dtype (dtype_index : int ) -> None :
6361 """Test if identity creates an array with the correct dtype"""
64- if dtype_index in [2 , 3 ] and not get_dbl_support ():
62+ if dtype_index in [2 , 3 ] and not wrapper . get_dbl_support ():
6563 pytest .skip ()
6664
6765 shape = (5 , 5 )
6866 dtype = dtypes .c_api_value_to_dtype (dtype_index )
6967
70- result = identity (shape , dtype )
68+ result = wrapper . identity (shape , dtype )
7169
72- assert dtypes .c_api_value_to_dtype (get_type (result )) == dtype
70+ assert dtypes .c_api_value_to_dtype (wrapper . get_type (result )) == dtype
0 commit comments