Skip to content

Commit 21f72a2

Browse files
Chaluvadiroaffix
Chaluvadi
authored andcommitted
fixed import formatting, black and flake8 checks
1 parent 886f411 commit 21f72a2

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

tests/test_upper.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import pytest
22

33
import arrayfire_wrapper.dtypes as dtypes
4-
from arrayfire_wrapper.lib.create_and_modify_array.create_array.constant import constant
5-
from arrayfire_wrapper.lib.create_and_modify_array.create_array.diag import diag_extract
6-
from arrayfire_wrapper.lib.create_and_modify_array.create_array.upper import upper
7-
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_scalar
4+
import arrayfire_wrapper.lib as wrapper
85

96

107
@pytest.mark.parametrize(
@@ -18,11 +15,11 @@
1815
def test_diag_is_unit(shape: tuple) -> None:
1916
"""Test if when is_unit_diag in lower returns an array with a unit diagonal"""
2017
dtype = dtypes.s64
21-
constant_array = constant(3, shape, dtype)
18+
constant_array = wrapper.constant(3, shape, dtype)
2219

23-
lower_array = upper(constant_array, True)
24-
diagonal = diag_extract(lower_array, 0)
25-
diagonal_value = get_scalar(diagonal, dtype)
20+
lower_array = wrapper.upper(constant_array, True)
21+
diagonal = wrapper.diag_extract(lower_array, 0)
22+
diagonal_value = wrapper.get_scalar(diagonal, dtype)
2623

2724
assert diagonal_value == 1
2825

@@ -38,11 +35,11 @@ def test_diag_is_unit(shape: tuple) -> None:
3835
def test_is_original(shape: tuple) -> None:
3936
"""Test if is_original keeps the diagonal the same as the original array"""
4037
dtype = dtypes.s64
41-
constant_array = constant(3, shape, dtype)
42-
original_value = get_scalar(constant_array, dtype)
38+
constant_array = wrapper.constant(3, shape, dtype)
39+
original_value = wrapper.get_scalar(constant_array, dtype)
4340

44-
lower_array = upper(constant_array, False)
45-
diagonal = diag_extract(lower_array, 0)
46-
diagonal_value = get_scalar(diagonal, dtype)
41+
lower_array = wrapper.upper(constant_array, False)
42+
diagonal = wrapper.diag_extract(lower_array, 0)
43+
diagonal_value = wrapper.get_scalar(diagonal, dtype)
4744

4845
assert original_value == diagonal_value

0 commit comments

Comments
 (0)