1
1
import pytest
2
2
3
3
import arrayfire_wrapper .dtypes as dtypes
4
- from arrayfire_wrapper .lib .create_and_modify_array .create_array .constant import constant
5
- from arrayfire_wrapper .lib .create_and_modify_array .create_array .diag import diag_extract
6
- from arrayfire_wrapper .lib .create_and_modify_array .create_array .upper import upper
7
- from arrayfire_wrapper .lib .create_and_modify_array .manage_array import get_scalar
4
+ import arrayfire_wrapper .lib as wrapper
8
5
9
6
10
7
@pytest .mark .parametrize (
18
15
def test_diag_is_unit (shape : tuple ) -> None :
19
16
"""Test if when is_unit_diag in lower returns an array with a unit diagonal"""
20
17
dtype = dtypes .s64
21
- constant_array = constant (3 , shape , dtype )
18
+ constant_array = wrapper . constant (3 , shape , dtype )
22
19
23
- lower_array = upper (constant_array , True )
24
- diagonal = diag_extract (lower_array , 0 )
25
- diagonal_value = get_scalar (diagonal , dtype )
20
+ lower_array = wrapper . upper (constant_array , True )
21
+ diagonal = wrapper . diag_extract (lower_array , 0 )
22
+ diagonal_value = wrapper . get_scalar (diagonal , dtype )
26
23
27
24
assert diagonal_value == 1
28
25
@@ -38,11 +35,11 @@ def test_diag_is_unit(shape: tuple) -> None:
38
35
def test_is_original (shape : tuple ) -> None :
39
36
"""Test if is_original keeps the diagonal the same as the original array"""
40
37
dtype = dtypes .s64
41
- constant_array = constant (3 , shape , dtype )
42
- original_value = get_scalar (constant_array , dtype )
38
+ constant_array = wrapper . constant (3 , shape , dtype )
39
+ original_value = wrapper . get_scalar (constant_array , dtype )
43
40
44
- lower_array = upper (constant_array , False )
45
- diagonal = diag_extract (lower_array , 0 )
46
- diagonal_value = get_scalar (diagonal , dtype )
41
+ lower_array = wrapper . upper (constant_array , False )
42
+ diagonal = wrapper . diag_extract (lower_array , 0 )
43
+ diagonal_value = wrapper . get_scalar (diagonal , dtype )
47
44
48
45
assert original_value == diagonal_value
0 commit comments