Skip to content

Commit 2f30ed1

Browse files
author
AzeezIsh
committed
Applied all checkstyle changes, fixed conv_grad.
1 parent 1c07686 commit 2f30ed1

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

tests/test_conv_f.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import arrayfire_wrapper.lib as wrapper
55
import arrayfire_wrapper.lib.signal_processing.convolutions as convolutions
66
from arrayfire_wrapper.lib._constants import ConvDomain, ConvMode
7-
87
from tests.utility_functions import check_type_supported, get_all_types, get_float_types
98

9+
1010
# Parameterization for input shapes
1111
@pytest.mark.parametrize(
1212
"inputShape",
@@ -219,6 +219,7 @@ def test_convolve1_conv_domain(conv_domain: int) -> None:
219219

220220
assert wrapper.get_dims(result)[0] == input_size, f"Failed for conv_domain: {ConvDomain(conv_domain)}"
221221

222+
222223
@pytest.mark.parametrize("dtypes", get_all_types())
223224
def test_convolve1_valid(dtypes: dtype.Dtype) -> None:
224225
"""Test convolve1 with valid dtypes."""

tests/test_conv_grad.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import arrayfire_wrapper.dtypes as dtype
44
import arrayfire_wrapper.lib as wrapper
55
from arrayfire_wrapper.lib._constants import ConvGradient
6+
from tests.utility_functions import get_float_types
67

78

8-
# First parameterization for grad_types
99
@pytest.mark.parametrize(
1010
"grad_type",
1111
[
@@ -15,15 +15,7 @@
1515
3, # ConvGradient.BIAS
1616
],
1717
)
18-
# Second parameterization for dtypes
19-
@pytest.mark.parametrize(
20-
"dtypes",
21-
[
22-
dtype.float16, # Floating point 16-bit
23-
dtype.float32, # Floating point 32-bit
24-
dtype.float64, # Floating point 64-bit
25-
],
26-
)
18+
@pytest.mark.parametrize("dtypes", get_float_types())
2719
def test_convolve2_gradient_data(grad_type: int, dtypes: dtype.Dtype) -> None:
2820
"""Test if convolve gradient returns the correct shape with varying data type and grad type."""
2921
incoming_gradient = wrapper.randu((8, 8), dtypes)
@@ -50,7 +42,6 @@ def test_convolve2_gradient_data(grad_type: int, dtypes: dtype.Dtype) -> None:
5042
assert wrapper.get_dims(result) == expected_shape, f"Failed for grad_type: {grad_type_enum}, dtype: {dtypes}"
5143

5244

53-
# Third parameterization for dtypes
5445
@pytest.mark.parametrize(
5546
"invdtypes",
5647
[

0 commit comments

Comments
 (0)