Skip to content

Commit fab31a7

Browse files
sakchalChaluvadi
and
Chaluvadi
authored
Added unit tests for the lower function (#20)
* Added unit tests for the lower function * Fixed import formatting, black and flake8 automatic checks --------- Co-authored-by: Chaluvadi <[email protected]>
1 parent cdc9098 commit fab31a7

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

tests/test_lower.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pytest
2+
3+
import arrayfire_wrapper.dtypes as dtypes
4+
import arrayfire_wrapper.lib as wrapper
5+
6+
7+
@pytest.mark.parametrize(
8+
"shape",
9+
[
10+
(3, 3),
11+
(3, 3, 3),
12+
(3, 3, 3, 3),
13+
],
14+
)
15+
def test_diag_is_unit(shape: tuple) -> None:
16+
"""Test if when is_unit_diag in lower returns an array with a unit diagonal"""
17+
dtype = dtypes.s64
18+
constant_array = wrapper.constant(3, shape, dtype)
19+
20+
lower_array = wrapper.lower(constant_array, True)
21+
diagonal = wrapper.diag_extract(lower_array, 0)
22+
diagonal_value = wrapper.get_scalar(diagonal, dtype)
23+
24+
assert diagonal_value == 1
25+
26+
27+
@pytest.mark.parametrize(
28+
"shape",
29+
[
30+
(3, 3),
31+
(3, 3, 3),
32+
(3, 3, 3, 3),
33+
],
34+
)
35+
def test_is_original(shape: tuple) -> None:
36+
"""Test if is_original keeps the diagonal the same as the original array"""
37+
dtype = dtypes.s64
38+
constant_array = wrapper.constant(3, shape, dtype)
39+
original_value = wrapper.get_scalar(constant_array, dtype)
40+
41+
lower_array = wrapper.lower(constant_array, False)
42+
diagonal = wrapper.diag_extract(lower_array, 0)
43+
diagonal_value = wrapper.get_scalar(diagonal, dtype)
44+
45+
assert original_value == diagonal_value

0 commit comments

Comments
 (0)