Skip to content

Commit 504927c

Browse files
Chaluvadiroaffix
Chaluvadi
authored andcommitted
Added unit tests for the pad function and fixed pad function return value
1 parent e56731e commit 504927c

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

arrayfire_wrapper/lib/create_and_modify_array/create_array/pad.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ def pad(arr: AFArray, begin_shape: tuple[int, ...], end_shape: tuple[int, ...],
2222
end_c_shape.c_array,
2323
border_type.value,
2424
)
25-
return NotImplemented
25+
return out

tests/test_pad.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import random
2+
3+
import numpy as np
4+
import pytest
5+
6+
import arrayfire_wrapper.dtypes as dtypes
7+
from arrayfire_wrapper.lib._constants import Pad
8+
from arrayfire_wrapper.lib.create_and_modify_array.create_array.constant import constant
9+
from arrayfire_wrapper.lib.create_and_modify_array.create_array.pad import pad
10+
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_dims
11+
12+
13+
@pytest.mark.parametrize(
14+
"original_shape",
15+
[
16+
(random.randint(1, 100),),
17+
(random.randint(1, 100), random.randint(1, 100)),
18+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
19+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
20+
],
21+
)
22+
def test_zero_padding(original_shape: tuple) -> None:
23+
"""Test if pad creates an array with no padding if no padding is given"""
24+
original_array = constant(2, original_shape, dtypes.s64)
25+
padding = Pad(0)
26+
27+
zero_shape = tuple(0 for _ in range(len(original_shape)))
28+
result = pad(original_array, zero_shape, zero_shape, padding)
29+
30+
assert get_dims(result)[0:len(original_shape)] == original_shape
31+
32+
33+
@pytest.mark.parametrize(
34+
"original_shape",
35+
[
36+
(random.randint(1, 100),),
37+
(random.randint(1, 100), random.randint(1, 100)),
38+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
39+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
40+
],
41+
)
42+
def test_negative_padding(original_shape: tuple) -> None:
43+
"""Test if pad can properly handle if negative padding is given"""
44+
with pytest.raises(RuntimeError):
45+
original_array = constant(2, original_shape, dtypes.s64)
46+
padding = Pad(0)
47+
48+
neg_shape = tuple(-1 for _ in range(len(original_shape)))
49+
result = pad(original_array, neg_shape, neg_shape, padding)
50+
51+
assert get_dims(result)[0:len(original_shape)] == original_shape
52+
53+
54+
@pytest.mark.parametrize(
55+
"original_shape",
56+
[
57+
(random.randint(1, 100),),
58+
(random.randint(1, 100), random.randint(1, 100)),
59+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
60+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
61+
],
62+
)
63+
def test_padding_shape(original_shape: tuple) -> None:
64+
"""Test if pad outputs the correct shape when a padding is adding to the original array"""
65+
original_array = constant(2, original_shape, dtypes.s64)
66+
padding = Pad(0)
67+
68+
beg_shape = tuple(random.randint(1, 10) for _ in range(len(original_shape)))
69+
end_shape = tuple(random.randint(1, 10) for _ in range(len(original_shape)))
70+
71+
result = pad(original_array, beg_shape, end_shape, padding)
72+
new_shape = np.array(beg_shape) + np.array(end_shape) + np.array(original_shape)
73+
74+
assert get_dims(result)[0:len(original_shape)] == tuple(new_shape)

0 commit comments

Comments
 (0)