Skip to content

Commit 20ea893

Browse files
sakchalsyurkevi
authored andcommitted
added unit tests for matrix operations
1 parent d7fb4dd commit 20ea893

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed

tests/test_matrix_operations.py

+304
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.lib as wrapper
6+
from arrayfire_wrapper.defines import AFArray
7+
from arrayfire_wrapper.dtypes import Dtype, b8, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64
8+
9+
from .utility_functions import check_type_supported
10+
11+
12+
# det tests
13+
@pytest.mark.parametrize(
14+
"shape",
15+
[(1, 1), (10, 10), (100, 100), (1000, 1000), (10000, 10000)],
16+
)
17+
def test_det_type(shape: tuple) -> None:
18+
"""Test if det returns a complex number"""
19+
arr = wrapper.randn(shape, f32)
20+
determinant = wrapper.det(arr)
21+
22+
assert isinstance(determinant, complex)
23+
24+
25+
@pytest.mark.parametrize(
26+
"shape",
27+
[
28+
(15, 17),
29+
(105, 325),
30+
(567, 803),
31+
(5324, 7865),
32+
(10, 10, 10),
33+
(100, 100, 100),
34+
(1000, 1000, 1000),
35+
(10000, 10000, 10000),
36+
(10, 10, 10, 10),
37+
(100, 100, 100, 100),
38+
(1000, 1000, 1000, 1000),
39+
(10000, 10000, 10000, 10000),
40+
],
41+
)
42+
def test_det_invalid_shape(shape: tuple) -> None:
43+
"""Test if det can properly handle invalid shapes"""
44+
with pytest.raises(RuntimeError):
45+
arr = wrapper.randn(shape, f32)
46+
wrapper.det(arr)
47+
48+
49+
@pytest.mark.parametrize("dtype", [s16, s32, s64, u8, u16, u32, u64, f16, b8])
50+
def test_det_invalid_dtype(dtype: Dtype) -> None:
51+
"""Test if det can properly handle invalid dtypes"""
52+
with pytest.raises(RuntimeError):
53+
shape = (10, 10)
54+
55+
arr = wrapper.identity(shape, dtype)
56+
wrapper.det(arr)
57+
58+
59+
@pytest.mark.parametrize("dtype", [f32, f64, c32, c64])
60+
def test_det_valid_dtype(dtype: Dtype) -> None:
61+
"""Test if det can properly handle all valid dtypes"""
62+
check_type_supported(dtype)
63+
shape = (10, 10)
64+
65+
arr = wrapper.identity(shape, dtype)
66+
determinant = wrapper.det(arr)
67+
68+
assert isinstance(determinant, complex)
69+
70+
71+
# inverse tests
72+
@pytest.mark.parametrize(
73+
"shape",
74+
[(1, 1), (10, 10), (100, 100), (1000, 1000), (10000, 10000)],
75+
)
76+
def test_inverse_type(shape: tuple) -> None:
77+
"""Test if inverse returns an AFArray"""
78+
arr = wrapper.randn(shape, f32)
79+
inv = wrapper.inverse(arr, wrapper.MatProp(0))
80+
81+
assert isinstance(inv, AFArray)
82+
83+
84+
@pytest.mark.parametrize(
85+
"shape",
86+
[
87+
(15, 17),
88+
(105, 325),
89+
(567, 803),
90+
(5324, 7865),
91+
(10, 10, 10),
92+
(100, 100, 100),
93+
(1000, 1000, 1000),
94+
(10000, 10000, 10000),
95+
(10, 10, 10, 10),
96+
(100, 100, 100, 100),
97+
(1000, 1000, 1000, 1000),
98+
(10000, 10000, 10000, 10000),
99+
],
100+
)
101+
def test_inverse_invalid_shape(shape: tuple) -> None:
102+
"""Test if inverse can properly handle invalid shapes"""
103+
with pytest.raises(RuntimeError):
104+
arr = wrapper.randn(shape, f32)
105+
wrapper.inverse(arr, wrapper.MatProp(0))
106+
107+
108+
@pytest.mark.parametrize("dtype", [s16, s32, s64, u8, u16, u32, u64, f16, b8])
109+
def test_inverse_invalid_dtype(dtype: Dtype) -> None:
110+
"""Test if inverse can properly handle invalid dtypes"""
111+
with pytest.raises(RuntimeError):
112+
shape = (10, 10)
113+
114+
arr = wrapper.identity(shape, dtype)
115+
wrapper.inverse(arr, wrapper.MatProp(0))
116+
117+
118+
@pytest.mark.parametrize("dtype", [f32, f64, c32, c64])
119+
def test_inverse_valid_dtype(dtype: Dtype) -> None:
120+
"""Test if inverse can properly handle all valid dtypes"""
121+
check_type_supported(dtype)
122+
shape = (10, 10)
123+
124+
arr = wrapper.identity(shape, dtype)
125+
wrapper.inverse(arr, wrapper.MatProp(0))
126+
127+
128+
# norm tests
129+
@pytest.mark.parametrize(
130+
"shape",
131+
[(1, 1), (10, 10), (100, 100), (1000, 1000), (10000, 10000)],
132+
)
133+
def test_norm_output_type(shape: tuple) -> None:
134+
"""Test if norm returns a float"""
135+
arr = wrapper.randn(shape, f32)
136+
nor = wrapper.norm(arr, wrapper.Norm(2), 1, 1)
137+
138+
assert isinstance(nor, float)
139+
140+
141+
@pytest.mark.parametrize(
142+
"norm",
143+
[0, 1, 2, 3, 4, 5, 7], # VECTOR_1 # VECTOR_INF # VECTOR_2 # VECTOR_3 # MATRIX_1 # MATRIX_INF # MATRIX_L_PQ
144+
)
145+
def test_norm_types(norm: wrapper.Norm) -> None:
146+
"""Test if norm can handle all valid norm types"""
147+
shape = (3, 1)
148+
arr = wrapper.randn(shape, f32)
149+
nor = wrapper.norm(arr, wrapper.Norm(norm), 1, 2)
150+
151+
assert isinstance(nor, float)
152+
153+
154+
@pytest.mark.parametrize(
155+
"shape",
156+
[
157+
(10, 10, 10),
158+
(100, 100, 100),
159+
(1000, 1000, 1000),
160+
(10000, 10000, 10000),
161+
(10, 10, 10, 10),
162+
(100, 100, 100, 100),
163+
(1000, 1000, 1000, 1000),
164+
(10000, 10000, 10000, 10000),
165+
],
166+
)
167+
def test_norm_invalid_shape(shape: tuple) -> None:
168+
"""Test if norm can properly handle invalid shapes"""
169+
with pytest.raises(RuntimeError):
170+
arr = wrapper.randn(shape, f32)
171+
wrapper.norm(arr, wrapper.Norm(0), 1, 1)
172+
173+
174+
@pytest.mark.parametrize("dtype", [s16, s32, s64, u8, u16, u32, u64, f16, b8])
175+
def test_norm_invalid_dtype(dtype: Dtype) -> None:
176+
"""Test if norm can properly handle invalid dtypes"""
177+
with pytest.raises(RuntimeError):
178+
shape = (10, 10)
179+
180+
arr = wrapper.identity(shape, dtype)
181+
wrapper.norm(arr, wrapper.Norm(0), 1, 1)
182+
183+
184+
@pytest.mark.parametrize("dtype", [f32, f64, c32, c64])
185+
def test_norm_valid_dtype(dtype: Dtype) -> None:
186+
"""Test if norm can properly handle all valid dtypes"""
187+
check_type_supported(dtype)
188+
shape = (10, 10)
189+
190+
arr = wrapper.identity(shape, dtype)
191+
wrapper.norm(arr, wrapper.Norm(0), 1, 1)
192+
193+
194+
# pinverse tests
195+
@pytest.mark.parametrize(
196+
"shape",
197+
[
198+
(1, 1),
199+
(10, 10),
200+
(100, 100),
201+
(1000, 1000),
202+
(1, 1, 1),
203+
(10, 10, 10),
204+
(100, 100, 100),
205+
(1, 1, 1, 1),
206+
(10, 10, 10, 10),
207+
(100, 100, 100, 100),
208+
],
209+
)
210+
def test_pinverse_output_type(shape: tuple) -> None:
211+
"""Test if pinverse returns an AFArray"""
212+
arr = wrapper.randn(shape, f32)
213+
pin = wrapper.pinverse(arr, 1e-6, wrapper.MatProp(0))
214+
215+
assert isinstance(pin, AFArray)
216+
217+
218+
@pytest.mark.parametrize("dtype", [s16, s32, s64, u8, u16, u32, u64, f16, b8])
219+
def test_pinverse_invalid_dtype(dtype: Dtype) -> None:
220+
"""Test if pinverse can properly handle invalid dtypes"""
221+
with pytest.raises(RuntimeError):
222+
shape = (10, 10)
223+
224+
arr = wrapper.identity(shape, dtype)
225+
wrapper.pinverse(arr, 1e-6, wrapper.MatProp(0))
226+
227+
228+
@pytest.mark.parametrize("dtype", [f32, f64, c32, c64])
229+
def test_pinverse_valid_dtype(dtype: Dtype) -> None:
230+
"""Test if pinverse can properly handle all valid dtypes"""
231+
check_type_supported(dtype)
232+
shape = (10, 10)
233+
234+
arr = wrapper.identity(shape, dtype)
235+
pin = wrapper.pinverse(arr, 1e-6, wrapper.MatProp(0))
236+
237+
assert isinstance(pin, AFArray)
238+
239+
240+
@pytest.mark.parametrize("tolerance", [-0.0001, -1, -10, -100, -1000])
241+
def test_pinverse_invalid_tol(tolerance: int) -> None:
242+
"""Test if pinverse can properly handle invalid tolerance values"""
243+
with pytest.raises(RuntimeError):
244+
shape = (10, 10)
245+
246+
arr = wrapper.identity(shape, f32)
247+
wrapper.pinverse(arr, tolerance, wrapper.MatProp(0))
248+
249+
250+
# rank tests
251+
@pytest.mark.parametrize(
252+
"shape",
253+
[(1, 1), (10, 10), (100, 100), (1000, 1000), (random.randint(1, 1000), random.randint(1, 1000))],
254+
)
255+
def test_rank_output_type(shape: tuple) -> None:
256+
"""Test if rank returns an AFArray"""
257+
arr = wrapper.randn(shape, f32)
258+
rk = wrapper.rank(arr, 1e-6)
259+
260+
assert isinstance(rk, int)
261+
262+
263+
@pytest.mark.parametrize(
264+
"shape",
265+
[
266+
(10, 10, 10),
267+
(100, 100, 100),
268+
(1000, 1000, 1000),
269+
(10000, 10000, 10000),
270+
(random.randint(1, 1000), random.randint(1, 1000), random.randint(1, 1000)),
271+
(10, 10, 10, 10),
272+
(100, 100, 100, 100),
273+
(1000, 1000, 1000, 1000),
274+
(10000, 10000, 10000, 10000),
275+
(random.randint(1, 1000), random.randint(1, 1000), random.randint(1, 1000), random.randint(1, 10000)),
276+
],
277+
)
278+
def test_rank_invalid_shape(shape: tuple) -> None:
279+
"""Test if rank can properly handle invalid shapes"""
280+
with pytest.raises(RuntimeError):
281+
arr = wrapper.randn(shape, f32)
282+
wrapper.rank(arr, 1e-6)
283+
284+
285+
@pytest.mark.parametrize("dtype", [s16, s32, s64, u8, u16, u32, u64, f16, b8])
286+
def test_rank_invalid_dtype(dtype: Dtype) -> None:
287+
"""Test if rank can properly handle invalid dtypes"""
288+
with pytest.raises(RuntimeError):
289+
shape = (10, 10)
290+
291+
arr = wrapper.identity(shape, dtype)
292+
wrapper.rank(arr, 1e-6)
293+
294+
295+
@pytest.mark.parametrize("dtype", [f32, f64, c32, c64])
296+
def test_rank_valid_dtype(dtype: Dtype) -> None:
297+
"""Test if rank can properly handle all valid dtypes"""
298+
check_type_supported(dtype)
299+
shape = (10, 10)
300+
301+
arr = wrapper.identity(shape, dtype)
302+
rk = wrapper.rank(arr, 1e-6)
303+
304+
assert isinstance(rk, int)

0 commit comments

Comments
 (0)