File tree 1 file changed +45
-0
lines changed
1 file changed +45
-0
lines changed Original file line number Diff line number Diff line change
1
+ import pytest
2
+
3
+ import arrayfire_wrapper .dtypes as dtypes
4
+ import arrayfire_wrapper .lib as wrapper
5
+
6
+
7
+ @pytest .mark .parametrize (
8
+ "shape" ,
9
+ [
10
+ (3 , 3 ),
11
+ (3 , 3 , 3 ),
12
+ (3 , 3 , 3 , 3 ),
13
+ ],
14
+ )
15
+ def test_diag_is_unit (shape : tuple ) -> None :
16
+ """Test if when is_unit_diag in lower returns an array with a unit diagonal"""
17
+ dtype = dtypes .s64
18
+ constant_array = wrapper .constant (3 , shape , dtype )
19
+
20
+ lower_array = wrapper .upper (constant_array , True )
21
+ diagonal = wrapper .diag_extract (lower_array , 0 )
22
+ diagonal_value = wrapper .get_scalar (diagonal , dtype )
23
+
24
+ assert diagonal_value == 1
25
+
26
+
27
+ @pytest .mark .parametrize (
28
+ "shape" ,
29
+ [
30
+ (3 , 3 ),
31
+ (3 , 3 , 3 ),
32
+ (3 , 3 , 3 , 3 ),
33
+ ],
34
+ )
35
+ def test_is_original (shape : tuple ) -> None :
36
+ """Test if is_original keeps the diagonal the same as the original array"""
37
+ dtype = dtypes .s64
38
+ constant_array = wrapper .constant (3 , shape , dtype )
39
+ original_value = wrapper .get_scalar (constant_array , dtype )
40
+
41
+ lower_array = wrapper .upper (constant_array , False )
42
+ diagonal = wrapper .diag_extract (lower_array , 0 )
43
+ diagonal_value = wrapper .get_scalar (diagonal , dtype )
44
+
45
+ assert original_value == diagonal_value
You can’t perform that action at this time.
0 commit comments