diff --git a/arrayfire_wrapper/lib/create_and_modify_array/create_array/pad.py b/arrayfire_wrapper/lib/create_and_modify_array/create_array/pad.py index 5542d8b..7fd8b21 100644 --- a/arrayfire_wrapper/lib/create_and_modify_array/create_array/pad.py +++ b/arrayfire_wrapper/lib/create_and_modify_array/create_array/pad.py @@ -22,4 +22,4 @@ def pad(arr: AFArray, begin_shape: tuple[int, ...], end_shape: tuple[int, ...], end_c_shape.c_array, border_type.value, ) - return NotImplemented + return out diff --git a/tests/test_pad.py b/tests/test_pad.py new file mode 100644 index 0000000..d31d71a --- /dev/null +++ b/tests/test_pad.py @@ -0,0 +1,71 @@ +import random + +import numpy as np +import pytest + +import arrayfire_wrapper.dtypes as dtypes +import arrayfire_wrapper.lib as wrapper + + +@pytest.mark.parametrize( + "original_shape", + [ + (random.randint(1, 100),), + (random.randint(1, 100), random.randint(1, 100)), + (random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)), + (random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)), + ], +) +def test_zero_padding(original_shape: tuple) -> None: + """Test if pad creates an array with no padding if no padding is given""" + original_array = wrapper.constant(2, original_shape, dtypes.s64) + padding = wrapper.Pad(0) + + zero_shape = tuple(0 for _ in range(len(original_shape))) + result = wrapper.pad(original_array, zero_shape, zero_shape, padding) + + assert wrapper.get_dims(result)[0 : len(original_shape)] == original_shape # noqa: E203 + + +@pytest.mark.parametrize( + "original_shape", + [ + (random.randint(1, 100),), + (random.randint(1, 100), random.randint(1, 100)), + (random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)), + (random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)), + ], +) +def test_negative_padding(original_shape: tuple) -> None: + """Test if pad can properly handle if negative padding is given""" + with pytest.raises(RuntimeError): + original_array = wrapper.constant(2, original_shape, dtypes.s64) + padding = wrapper.Pad(0) + + neg_shape = tuple(-1 for _ in range(len(original_shape))) + result = wrapper.pad(original_array, neg_shape, neg_shape, padding) + + assert wrapper.get_dims(result)[0 : len(original_shape)] == original_shape # noqa: E203 + + +@pytest.mark.parametrize( + "original_shape", + [ + (random.randint(1, 100),), + (random.randint(1, 100), random.randint(1, 100)), + (random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)), + (random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)), + ], +) +def test_padding_shape(original_shape: tuple) -> None: + """Test if pad outputs the correct shape when a padding is adding to the original array""" + original_array = wrapper.constant(2, original_shape, dtypes.s64) + padding = wrapper.Pad(0) + + beg_shape = tuple(random.randint(1, 10) for _ in range(len(original_shape))) + end_shape = tuple(random.randint(1, 10) for _ in range(len(original_shape))) + + result = wrapper.pad(original_array, beg_shape, end_shape, padding) + new_shape = np.array(beg_shape) + np.array(end_shape) + np.array(original_shape) + + assert wrapper.get_dims(result)[0 : len(original_shape)] == tuple(new_shape) # noqa: E203