diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index a702f2e3ab..dfaa2eb333 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -41,6 +41,8 @@ fn main() { .input("tests/avg_pool1d_ceil_mode/avg_pool1d_ceil_mode.onnx") .input("tests/avg_pool2d/avg_pool2d.onnx") .input("tests/avg_pool2d_ceil_mode/avg_pool2d_ceil_mode.onnx") + .input("tests/avg_pool/avg_pool1d_asymmetric_padding.onnx") + .input("tests/avg_pool/avg_pool2d_asymmetric_padding.onnx") .input("tests/batch_norm/batch_norm.onnx") .input("tests/bitshift/bitshift_left.onnx") .input("tests/bitshift/bitshift_left_scalar.onnx") @@ -101,6 +103,8 @@ fn main() { .input("tests/conv1d/conv1d.onnx") .input("tests/conv2d/conv2d.onnx") .input("tests/conv3d/conv3d.onnx") + .input("tests/conv/conv1d_asymmetric_padding.onnx") + .input("tests/conv/conv2d_asymmetric_padding.onnx") .input("tests/conv_transpose1d/conv_transpose1d.onnx") .input("tests/conv_transpose2d/conv_transpose2d.onnx") .input("tests/conv_transpose3d/conv_transpose3d.onnx") @@ -237,6 +241,8 @@ fn main() { .input("tests/maxpool1d_ceil_mode/maxpool1d_ceil_mode.onnx") .input("tests/maxpool2d/maxpool2d.onnx") .input("tests/maxpool2d_ceil_mode/maxpool2d_ceil_mode.onnx") + .input("tests/maxpool/maxpool1d_asymmetric_padding.onnx") + .input("tests/maxpool/maxpool2d_asymmetric_padding.onnx") .input("tests/min/min.onnx") .input("tests/mean/mean.onnx") .input("tests/mul/mul.onnx") diff --git a/crates/burn-import/onnx-tests/pyproject.toml b/crates/burn-import/onnx-tests/pyproject.toml index a225faabab..6a5eea1c7d 100644 --- a/crates/burn-import/onnx-tests/pyproject.toml +++ b/crates/burn-import/onnx-tests/pyproject.toml @@ -7,6 +7,7 @@ dependencies = [ "torch>=2.3.1", "onnx>=1.16.1", "onnxruntime>=1.18.0", + "onnxscript>=0.1.0", ] readme = "README.md" requires-python = ">= 3.8" diff --git a/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool1d_asymmetric_padding.onnx b/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool1d_asymmetric_padding.onnx new file mode 100644 index 0000000000..3b7771354c Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool1d_asymmetric_padding.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool1d_asymmetric_padding.py b/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool1d_asymmetric_padding.py new file mode 100644 index 0000000000..9fd28b59df --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool1d_asymmetric_padding.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# used to generate model: avg_pool1d_asymmetric_padding.onnx + +import numpy as np +import onnx +from onnx import helper, TensorProto +from onnx.reference import ReferenceEvaluator + + +def main(): + # Input: [batch=2, channels=4, width=10] + # Asymmetric padding: left=1, right=2 + # kernel=3, stride=1 + # Output width = (10 + 1 + 2 - 3) / 1 + 1 = 11 + + X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4, 10]) + Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 4, 11]) + + # Create AveragePool node with asymmetric padding (left=1, right=2) + # ONNX pads format for 1D: [start, end] = [left, right] + avg_pool_node = helper.make_node( + "AveragePool", + inputs=["x"], + outputs=["y"], + kernel_shape=[3], + strides=[1], + pads=[1, 2], # [left, right] asymmetric padding + count_include_pad=1, # Include padding in average calculation + ) + + graph = helper.make_graph( + [avg_pool_node], + "avg_pool1d_asymmetric_padding", + [X], + [Y], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + model.ir_version = 8 + + onnx.checker.check_model(model) + file_name = "avg_pool1d_asymmetric_padding.onnx" + onnx.save(model, file_name) + + print("Finished exporting model to {}".format(file_name)) + print("Ops in graph: {}".format([n.op_type for n in model.graph.node])) + + # Verify with ReferenceEvaluator + test_input = np.ones((2, 4, 10), dtype=np.float32) + ref = ReferenceEvaluator(file_name) + ref_output = ref.run(None, {"x": test_input})[0] + + print("Test input shape: {}".format(test_input.shape)) + print("Test output shape: {}".format(ref_output.shape)) + print("ReferenceEvaluator output sum: {}".format(ref_output.sum())) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool2d_asymmetric_padding.onnx b/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool2d_asymmetric_padding.onnx new file mode 100644 index 0000000000..5d96bafed3 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool2d_asymmetric_padding.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool2d_asymmetric_padding.py b/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool2d_asymmetric_padding.py new file mode 100644 index 0000000000..7d2571b689 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/avg_pool/avg_pool2d_asymmetric_padding.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# used to generate model: avg_pool2d_asymmetric_padding.onnx + +import numpy as np +import onnx +from onnx import helper, TensorProto +from onnx.reference import ReferenceEvaluator + + +def main(): + # Input: [batch=2, channels=4, height=10, width=15] + # Asymmetric padding: top=1, left=1, bottom=2, right=2 + # kernel=[3,3], stride=[1,1] + # Output height = (10 + 1 + 2 - 3) / 1 + 1 = 11 + # Output width = (15 + 1 + 2 - 3) / 1 + 1 = 16 + + X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4, 10, 15]) + Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 4, 11, 16]) + + # Create AveragePool node with asymmetric padding + # ONNX pads format for 2D: [top, left, bottom, right] + avg_pool_node = helper.make_node( + "AveragePool", + inputs=["x"], + outputs=["y"], + kernel_shape=[3, 3], + strides=[1, 1], + pads=[1, 1, 2, 2], # [top, left, bottom, right] asymmetric padding + count_include_pad=1, # Include padding in average calculation + ) + + graph = helper.make_graph( + [avg_pool_node], + "avg_pool2d_asymmetric_padding", + [X], + [Y], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + model.ir_version = 8 + + onnx.checker.check_model(model) + file_name = "avg_pool2d_asymmetric_padding.onnx" + onnx.save(model, file_name) + + print("Finished exporting model to {}".format(file_name)) + print("Ops in graph: {}".format([n.op_type for n in model.graph.node])) + + # Verify with ReferenceEvaluator + test_input = np.ones((2, 4, 10, 15), dtype=np.float32) + ref = ReferenceEvaluator(file_name) + ref_output = ref.run(None, {"x": test_input})[0] + + print("Test input shape: {}".format(test_input.shape)) + print("Test output shape: {}".format(ref_output.shape)) + print("ReferenceEvaluator output sum: {}".format(ref_output.sum())) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/avg_pool/mod.rs b/crates/burn-import/onnx-tests/tests/avg_pool/mod.rs index e41c94892e..c124be1755 100644 --- a/crates/burn-import/onnx-tests/tests/avg_pool/mod.rs +++ b/crates/burn-import/onnx-tests/tests/avg_pool/mod.rs @@ -2,8 +2,10 @@ use crate::include_models; include_models!( avg_pool1d, + avg_pool1d_asymmetric_padding, avg_pool1d_ceil_mode, avg_pool2d, + avg_pool2d_asymmetric_padding, avg_pool2d_ceil_mode ); @@ -181,4 +183,58 @@ mod tests { .to_data() .assert_approx_eq::(&expected, tolerance); } + + #[test] + fn avg_pool1d_asymmetric_padding() { + // Test asymmetric padding (left=1, right=2) for AvgPool1d + let device = Default::default(); + let model: avg_pool1d_asymmetric_padding::Model = + avg_pool1d_asymmetric_padding::Model::new(&device); + + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10], &device); + let output = model.forward(input); + + // With asymmetric padding (1, 2), input length 10 becomes 10+1+2=13 + // After pool with kernel 3, stride 1, output length is 13-3+1=11 + let expected_shape = Shape::from([2, 4, 11]); + assert_eq!(output.shape(), expected_shape); + + // Verify the sum matches PyTorch output + let output_sum = output.sum().into_scalar(); + let expected_sum = 77.333_33; // from pytorch + assert!( + (output_sum - expected_sum).abs() < 0.1, + "Expected sum ~{}, got {}", + expected_sum, + output_sum + ); + } + + #[test] + fn avg_pool2d_asymmetric_padding() { + // Test asymmetric padding (left=1, right=2, top=1, bottom=2) for AvgPool2d + let device = Default::default(); + let model: avg_pool2d_asymmetric_padding::Model = + avg_pool2d_asymmetric_padding::Model::new(&device); + + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15], &device); + let output = model.forward(input); + + // With asymmetric padding (1, 1, 2, 2), input (10, 15) becomes (13, 18) + // After pool with kernel (3, 3), stride (1, 1), output is (11, 16) + let expected_shape = Shape::from([2, 4, 11, 16]); + assert_eq!(output.shape(), expected_shape); + + // Verify the sum matches ReferenceEvaluator output + let output_sum = output.sum().into_scalar(); + let expected_sum = 1134.222; // from ReferenceEvaluator + assert!( + (output_sum - expected_sum).abs() < 1.0, + "Expected sum ~{}, got {}", + expected_sum, + output_sum + ); + } } diff --git a/crates/burn-import/onnx-tests/tests/conv/conv1d_asymmetric_padding.onnx b/crates/burn-import/onnx-tests/tests/conv/conv1d_asymmetric_padding.onnx new file mode 100644 index 0000000000..976c635cf2 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/conv/conv1d_asymmetric_padding.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/conv/conv1d_asymmetric_padding.py b/crates/burn-import/onnx-tests/tests/conv/conv1d_asymmetric_padding.py new file mode 100644 index 0000000000..fef29d0471 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/conv/conv1d_asymmetric_padding.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +# used to generate model: conv1d_asymmetric_padding.onnx + +import torch +import torch.nn as nn +import torch.nn.functional as F +import onnx +from onnx.reference import ReferenceEvaluator + +# must set for testing against crate +torch.manual_seed(0) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + # Create a Conv1d without padding - we'll apply asymmetric padding manually + self.conv1 = nn.Conv1d(4, 6, kernel_size=3, stride=1, padding=0) + + def forward(self, x): + # Apply asymmetric padding: (left=1, right=2) + # PyTorch F.pad takes (left, right) for 1D + x = F.pad(x, (1, 2), mode='constant', value=0) + x = self.conv1(x) + return x + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + + file_name = "conv1d_asymmetric_padding.onnx" + test_input = torch.ones(2, 4, 10, device=device) + + # Export with dynamo exporter (opset 18) + torch.onnx.export(model, test_input, file_name, verbose=False, opset_version=18) + + # Load model and convert external data to embedded + onnx_model = onnx.load(file_name, load_external_data=True) + # Save with all data embedded + onnx.save(onnx_model, file_name, save_as_external_data=False) + + print("Finished exporting model to {}".format(file_name)) + + # Output some test data for use in the test + print("Test input data shape of ones: {}".format(test_input.shape)) + output = model.forward(test_input) + print("Test output data shape: {}".format(output.shape)) + + # Verify with ONNX ReferenceEvaluator + ref = ReferenceEvaluator(file_name) + ref_output = ref.run(None, {"x": test_input.numpy()})[0] + + output_sum = output.sum().item() + ref_sum = ref_output.sum() + + print("PyTorch output sum: {}".format(output_sum)) + print("ReferenceEvaluator output sum: {}".format(ref_sum)) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/conv/conv2d_asymmetric_padding.onnx b/crates/burn-import/onnx-tests/tests/conv/conv2d_asymmetric_padding.onnx new file mode 100644 index 0000000000..a89abf48e7 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/conv/conv2d_asymmetric_padding.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/conv/conv2d_asymmetric_padding.py b/crates/burn-import/onnx-tests/tests/conv/conv2d_asymmetric_padding.py new file mode 100644 index 0000000000..5ef19e983e --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/conv/conv2d_asymmetric_padding.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +# used to generate model: conv2d_asymmetric_padding.onnx + +import torch +import torch.nn as nn +import torch.nn.functional as F +import onnx +from onnx.reference import ReferenceEvaluator + +# must set for testing against crate +torch.manual_seed(0) + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + # Create a Conv2d without padding - we'll apply asymmetric padding manually + self.conv1 = nn.Conv2d(4, 6, kernel_size=3, stride=1, padding=0) + + def forward(self, x): + # Apply asymmetric padding: (left=1, right=2, top=1, bottom=3) + # PyTorch F.pad takes (left, right, top, bottom) for 2D + x = F.pad(x, (1, 2, 1, 3), mode='constant', value=0) + x = self.conv1(x) + return x + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + + file_name = "conv2d_asymmetric_padding.onnx" + test_input = torch.ones(2, 4, 10, 15, device=device) + + # Export with dynamo exporter (opset 18) + torch.onnx.export(model, test_input, file_name, verbose=False, opset_version=18) + + # Load model and convert external data to embedded + onnx_model = onnx.load(file_name, load_external_data=True) + # Save with all data embedded + onnx.save(onnx_model, file_name, save_as_external_data=False) + + print("Finished exporting model to {}".format(file_name)) + + # Output some test data for use in the test + print("Test input data shape of ones: {}".format(test_input.shape)) + output = model.forward(test_input) + print("Test output data shape: {}".format(output.shape)) + + # Verify with ONNX ReferenceEvaluator + ref = ReferenceEvaluator(file_name) + ref_output = ref.run(None, {"x": test_input.numpy()})[0] + + output_sum = output.sum().item() + ref_sum = ref_output.sum() + + print("PyTorch output sum: {}".format(output_sum)) + print("ReferenceEvaluator output sum: {}".format(ref_sum)) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/conv/mod.rs b/crates/burn-import/onnx-tests/tests/conv/mod.rs index 0eac075e3d..dbf07c5269 100644 --- a/crates/burn-import/onnx-tests/tests/conv/mod.rs +++ b/crates/burn-import/onnx-tests/tests/conv/mod.rs @@ -1,6 +1,12 @@ // Import the shared macro use crate::include_models; -include_models!(conv1d, conv2d, conv3d); +include_models!( + conv1d, + conv1d_asymmetric_padding, + conv2d, + conv2d_asymmetric_padding, + conv3d +); #[cfg(test)] mod tests { @@ -75,4 +81,55 @@ mod tests { assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); } + + #[test] + fn conv1d_asymmetric_padding() { + // Initialize the model with weights (loaded from the exported file) + // This model tests asymmetric padding: (left=1, right=2) + let model: conv1d_asymmetric_padding::Model = + conv1d_asymmetric_padding::Model::default(); + + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10], &Default::default()); + + let output = model.forward(input); + + // With asymmetric padding (1, 2), input length 10 becomes 10+1+2=13 + // After conv with kernel 3, stride 1, output length is 13-3+1=11 + let expected_shape = Shape::from([2, 6, 11]); + assert_eq!(output.shape(), expected_shape); + + // We are using the sum of the output tensor to test the correctness + let output_sum = output.sum().into_scalar(); + let expected_sum = -0.386_136; // from pytorch + + assert!(expected_sum.approx_eq(output_sum, (1.0e-3, 2))); + } + + #[test] + fn conv2d_asymmetric_padding() { + // Initialize the model with weights (loaded from the exported file) + // This model tests asymmetric padding: (left=1, right=2, top=1, bottom=3) + let model: conv2d_asymmetric_padding::Model = + conv2d_asymmetric_padding::Model::default(); + + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15], &Default::default()); + + let output = model.forward(input); + + // With asymmetric padding (1, 2, 1, 3), input (10, 15) becomes (10+1+3, 15+1+2) = (14, 18) + // After conv with kernel (3, 3), stride (1, 1), output is (12, 16) + let expected_shape = Shape::from([2, 6, 12, 16]); + assert_eq!(output.shape(), expected_shape); + + // We are using the sum of the output tensor to test the correctness + // because the output tensor is too large to compare with the expected tensor. + let output_sum = output.sum().into_scalar(); + + let expected_sum = -481.674_65; // from burn (close to pytorch's -481.6749572753906) + + // Use a slightly larger tolerance to account for floating-point differences + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } } diff --git a/crates/burn-import/onnx-tests/tests/maxpool/maxpool1d_asymmetric_padding.onnx b/crates/burn-import/onnx-tests/tests/maxpool/maxpool1d_asymmetric_padding.onnx new file mode 100644 index 0000000000..a0a97ec8fe Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/maxpool/maxpool1d_asymmetric_padding.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/maxpool/maxpool1d_asymmetric_padding.py b/crates/burn-import/onnx-tests/tests/maxpool/maxpool1d_asymmetric_padding.py new file mode 100644 index 0000000000..1b53c94a65 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/maxpool/maxpool1d_asymmetric_padding.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +# used to generate model: maxpool1d_asymmetric_padding.onnx + +import numpy as np +import onnx +from onnx import helper, TensorProto +from onnxruntime import InferenceSession + + +def main(): + # Input: [batch=2, channels=4, width=10] + # Asymmetric padding: left=1, right=2 + # kernel=3, stride=1 + # Output width = (10 + 1 + 2 - 3) / 1 + 1 = 11 + + X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4, 10]) + Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 4, 11]) + + # Create MaxPool node with asymmetric padding (left=1, right=2) + # ONNX pads format for 1D: [start, end] = [left, right] + max_pool_node = helper.make_node( + "MaxPool", + inputs=["x"], + outputs=["y"], + kernel_shape=[3], + strides=[1], + pads=[1, 2], # [left, right] asymmetric padding + ) + + graph = helper.make_graph( + [max_pool_node], + "maxpool1d_asymmetric_padding", + [X], + [Y], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + model.ir_version = 8 + + onnx.checker.check_model(model) + file_name = "maxpool1d_asymmetric_padding.onnx" + onnx.save(model, file_name) + + print("Finished exporting model to {}".format(file_name)) + print("Ops in graph: {}".format([n.op_type for n in model.graph.node])) + + # Verify with ONNX Runtime (ReferenceEvaluator has a bug with MaxPool asymmetric padding) + test_input = np.ones((2, 4, 10), dtype=np.float32) + session = InferenceSession(file_name) + ort_output = session.run(None, {"x": test_input})[0] + + print("Test input shape: {}".format(test_input.shape)) + print("Test output shape: {}".format(ort_output.shape)) + print("ONNX Runtime output sum: {}".format(ort_output.sum())) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/maxpool/maxpool2d_asymmetric_padding.onnx b/crates/burn-import/onnx-tests/tests/maxpool/maxpool2d_asymmetric_padding.onnx new file mode 100644 index 0000000000..ab14d51f29 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/maxpool/maxpool2d_asymmetric_padding.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/maxpool/maxpool2d_asymmetric_padding.py b/crates/burn-import/onnx-tests/tests/maxpool/maxpool2d_asymmetric_padding.py new file mode 100644 index 0000000000..c94a6e9d3e --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/maxpool/maxpool2d_asymmetric_padding.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# used to generate model: maxpool2d_asymmetric_padding.onnx + +import numpy as np +import onnx +from onnx import helper, TensorProto +from onnxruntime import InferenceSession + + +def main(): + # Input: [batch=2, channels=4, height=10, width=15] + # Asymmetric padding: top=1, left=1, bottom=2, right=2 (pad must be < kernel for ONNX Runtime) + # kernel=[3,3], stride=[1,1] + # Output height = (10 + 1 + 2 - 3) / 1 + 1 = 11 + # Output width = (15 + 1 + 2 - 3) / 1 + 1 = 16 + + X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4, 10, 15]) + Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 4, 11, 16]) + + # Create MaxPool node with asymmetric padding + # ONNX pads format for 2D: [top, left, bottom, right] + max_pool_node = helper.make_node( + "MaxPool", + inputs=["x"], + outputs=["y"], + kernel_shape=[3, 3], + strides=[1, 1], + pads=[1, 1, 2, 2], # [top, left, bottom, right] asymmetric padding + ) + + graph = helper.make_graph( + [max_pool_node], + "maxpool2d_asymmetric_padding", + [X], + [Y], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + model.ir_version = 8 + + onnx.checker.check_model(model) + file_name = "maxpool2d_asymmetric_padding.onnx" + onnx.save(model, file_name) + + print("Finished exporting model to {}".format(file_name)) + print("Ops in graph: {}".format([n.op_type for n in model.graph.node])) + + # Verify with ONNX Runtime (ReferenceEvaluator has a bug with MaxPool asymmetric padding) + test_input = np.ones((2, 4, 10, 15), dtype=np.float32) + session = InferenceSession(file_name) + ort_output = session.run(None, {"x": test_input})[0] + + print("Test input shape: {}".format(test_input.shape)) + print("Test output shape: {}".format(ort_output.shape)) + print("ONNX Runtime output sum: {}".format(ort_output.sum())) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/onnx-tests/tests/maxpool/mod.rs b/crates/burn-import/onnx-tests/tests/maxpool/mod.rs index dc6a571605..4d87db2da1 100644 --- a/crates/burn-import/onnx-tests/tests/maxpool/mod.rs +++ b/crates/burn-import/onnx-tests/tests/maxpool/mod.rs @@ -2,15 +2,17 @@ use crate::include_models; include_models!( maxpool1d, + maxpool1d_asymmetric_padding, maxpool1d_ceil_mode, maxpool2d, + maxpool2d_asymmetric_padding, maxpool2d_ceil_mode ); #[cfg(test)] mod tests { use super::*; - use burn::tensor::{Tensor, TensorData}; + use burn::tensor::{Shape, Tensor, TensorData}; use crate::backend::TestBackend; @@ -126,4 +128,58 @@ mod tests { ]]]); output.to_data().assert_eq(&expected, true); } + + #[test] + fn maxpool1d_asymmetric_padding() { + // Test asymmetric padding (left=1, right=2) for MaxPool1d + let device = Default::default(); + let model: maxpool1d_asymmetric_padding::Model = + maxpool1d_asymmetric_padding::Model::new(&device); + + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10], &device); + let output = model.forward(input); + + // With asymmetric padding (1, 2), input length 10 becomes 10+1+2=13 + // After pool with kernel 3, stride 1, output length is 13-3+1=11 + let expected_shape = Shape::from([2, 4, 11]); + assert_eq!(output.shape(), expected_shape); + + // Verify the sum matches PyTorch output (all 1.0 values, so max = 1.0 for all positions that see valid input) + let output_sum = output.sum().into_scalar(); + let expected_sum = 88.0; // from pytorch + assert!( + (output_sum - expected_sum).abs() < 0.1, + "Expected sum ~{}, got {}", + expected_sum, + output_sum + ); + } + + #[test] + fn maxpool2d_asymmetric_padding() { + // Test asymmetric padding (left=1, right=2, top=1, bottom=2) for MaxPool2d + let device = Default::default(); + let model: maxpool2d_asymmetric_padding::Model = + maxpool2d_asymmetric_padding::Model::new(&device); + + // Run the model with ones as input for easier testing + let input = Tensor::::ones([2, 4, 10, 15], &device); + let output = model.forward(input); + + // With asymmetric padding (1, 1, 2, 2), input (10, 15) becomes (13, 18) + // After pool with kernel (3, 3), stride (1, 1), output is (11, 16) + let expected_shape = Shape::from([2, 4, 11, 16]); + assert_eq!(output.shape(), expected_shape); + + // Verify the sum matches ONNX Runtime output + let output_sum = output.sum().into_scalar(); + let expected_sum = 1408.0; // from ONNX Runtime + assert!( + (output_sum - expected_sum).abs() < 1.0, + "Expected sum ~{}, got {}", + expected_sum, + output_sum + ); + } } diff --git a/crates/burn-import/src/burn/codegen.rs b/crates/burn-import/src/burn/codegen.rs index 63b4e27996..f4bee27b3e 100644 --- a/crates/burn-import/src/burn/codegen.rs +++ b/crates/burn-import/src/burn/codegen.rs @@ -75,43 +75,54 @@ impl ToTokens for f32 { } } -/// Padding configuration +/// Padding configuration for 1D operations. +/// +/// Converts PaddingConfig1d to Rust code tokens. +/// Format: Explicit(left, right) impl ToTokens for PaddingConfig1d { fn to_tokens(&self) -> TokenStream { match self { Self::Valid => quote! { PaddingConfig1d::Valid }, - Self::Explicit(padding) => { - let padding = padding.to_tokens(); - quote! { PaddingConfig1d::Explicit(#padding) } + Self::Explicit(left, right) => { + let left = left.to_tokens(); + let right = right.to_tokens(); + quote! { PaddingConfig1d::Explicit(#left, #right) } } } } } -/// Padding configuration +/// Converts PaddingConfig2d to Rust code tokens. +/// Format: Explicit(top, left, bottom, right) impl ToTokens for PaddingConfig2d { fn to_tokens(&self) -> TokenStream { match self { Self::Valid => quote! { PaddingConfig2d::Valid }, - Self::Explicit(padding1, padding2) => { - let padding1 = padding1.to_tokens(); - let padding2 = padding2.to_tokens(); - quote! { PaddingConfig2d::Explicit(#padding1, #padding2) } + Self::Explicit(top, left, bottom, right) => { + let top = top.to_tokens(); + let left = left.to_tokens(); + let bottom = bottom.to_tokens(); + let right = right.to_tokens(); + quote! { PaddingConfig2d::Explicit(#top, #left, #bottom, #right) } } } } } -/// Padding configuration +/// Converts PaddingConfig3d to Rust code tokens. +/// Format: Explicit(front, top, left, back, bottom, right) impl ToTokens for PaddingConfig3d { fn to_tokens(&self) -> TokenStream { match self { Self::Valid => quote! { PaddingConfig3d::Valid }, - Self::Explicit(padding1, padding2, padding3) => { - let padding1 = padding1.to_tokens(); - let padding2 = padding2.to_tokens(); - let padding3 = padding3.to_tokens(); - quote! { PaddingConfig3d::Explicit(#padding1, #padding2, #padding3) } + Self::Explicit(front, top, left, back, bottom, right) => { + let front = front.to_tokens(); + let top = top.to_tokens(); + let left = left.to_tokens(); + let back = back.to_tokens(); + let bottom = bottom.to_tokens(); + let right = right.to_tokens(); + quote! { PaddingConfig3d::Explicit(#front, #top, #left, #back, #bottom, #right) } } } } diff --git a/crates/burn-import/src/burn/node/avg_pool1d.rs b/crates/burn-import/src/burn/node/avg_pool1d.rs index 11ebefc75d..859ceeef0c 100644 --- a/crates/burn-import/src/burn/node/avg_pool1d.rs +++ b/crates/burn-import/src/burn/node/avg_pool1d.rs @@ -1,4 +1,5 @@ use super::prelude::*; + impl NodeCodegen for onnx_ir::node::avg_pool1d::AveragePool1dNode { fn inputs(&self) -> &[Argument] { &self.inputs @@ -12,10 +13,11 @@ impl NodeCodegen for onnx_ir::node::avg_pool1d::AveragePool1dNode { let name = Ident::new(&self.name, Span::call_site()); let kernel_size = self.config.kernel_size.to_tokens(); let strides = self.config.stride.to_tokens(); - let padding = self.config.padding.to_tokens(); let count_include_pad = self.config.count_include_pad; let ceil_mode = self.config.ceil_mode; + let padding = self.config.padding.to_tokens(); + Some(Field::new( self.name.clone(), quote! { @@ -67,6 +69,17 @@ mod tests { .build() } + fn create_avg_pool1d_node_asymmetric(name: &str) -> AveragePool1dNode { + // Asymmetric padding: left=1, right=2 + let config = AvgPool1dConfig::new(3, 1, PaddingConfig1d::Explicit(1, 2), false, 1, false); + + AveragePool1dNodeBuilder::new(name) + .input_tensor("input", 3, DType::F32) + .output_tensor("output", 3, DType::F32) + .config(config) + .build() + } + #[test] fn test_avg_pool1d_forward() { let node = create_avg_pool1d_node("pool1", false); @@ -118,4 +131,19 @@ mod tests { .init(); "#); } + + #[test] + fn test_avg_pool1d_field_init_asymmetric_padding() { + let node = create_avg_pool1d_node_asymmetric("pool1"); + let code = codegen_field_init(&node); + // Asymmetric padding is passed directly to the module + assert_snapshot!(code, @r" + let pool1 = AvgPool1dConfig::new(3) + .with_stride(1) + .with_padding(PaddingConfig1d::Explicit(1, 2)) + .with_count_include_pad(false) + .with_ceil_mode(false) + .init(); + "); + } } diff --git a/crates/burn-import/src/burn/node/avg_pool2d.rs b/crates/burn-import/src/burn/node/avg_pool2d.rs index 6701be49fd..6819996af3 100644 --- a/crates/burn-import/src/burn/node/avg_pool2d.rs +++ b/crates/burn-import/src/burn/node/avg_pool2d.rs @@ -1,4 +1,5 @@ use super::prelude::*; + impl NodeCodegen for onnx_ir::node::avg_pool2d::AveragePool2dNode { fn inputs(&self) -> &[Argument] { &self.inputs @@ -12,10 +13,11 @@ impl NodeCodegen for onnx_ir::node::avg_pool2d::AveragePool2dNode { let name = Ident::new(&self.name, Span::call_site()); let kernel_size = self.config.kernel_size.to_tokens(); let strides = self.config.strides.to_tokens(); - let padding = self.config.padding.to_tokens(); let count_include_pad = self.config.count_include_pad; let ceil_mode = self.config.ceil_mode; + let padding = self.config.padding.to_tokens(); + Some(Field::new( self.name.clone(), quote! { @@ -74,6 +76,24 @@ mod tests { .build() } + fn create_avg_pool2d_node_asymmetric(name: &str) -> AveragePool2dNode { + // Asymmetric padding: top=1, left=2, bottom=3, right=4 + let config = AvgPool2dConfig::new( + [3, 3], + [1, 1], + PaddingConfig2d::Explicit(1, 2, 3, 4), + false, + [1, 1], + false, + ); + + AveragePool2dNodeBuilder::new(name) + .input_tensor("input", 4, DType::F32) + .output_tensor("output", 4, DType::F32) + .config(config) + .build() + } + #[test] fn test_avg_pool2d_forward() { let node = create_avg_pool2d_node("pool1", false); @@ -125,4 +145,19 @@ mod tests { .init(); "#); } + + #[test] + fn test_avg_pool2d_field_init_asymmetric_padding() { + let node = create_avg_pool2d_node_asymmetric("pool1"); + let code = codegen_field_init(&node); + // Asymmetric padding is passed directly to the module + assert_snapshot!(code, @r" + let pool1 = AvgPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4)) + .with_count_include_pad(false) + .with_ceil_mode(false) + .init(); + "); + } } diff --git a/crates/burn-import/src/burn/node/conv1d.rs b/crates/burn-import/src/burn/node/conv1d.rs index 3ec1e30ff7..cb1ca37bd9 100644 --- a/crates/burn-import/src/burn/node/conv1d.rs +++ b/crates/burn-import/src/burn/node/conv1d.rs @@ -19,9 +19,10 @@ impl NodeCodegen for onnx_ir::conv1d::Conv1dNode { let stride = self.config.stride.to_tokens(); let dilation = self.config.dilation.to_tokens(); let groups = self.config.groups.to_tokens(); - let padding = self.config.padding.to_tokens(); let bias = self.config.bias; + let padding = self.config.padding.to_tokens(); + Some(Field::new( self.name.clone(), quote! { @@ -48,6 +49,7 @@ impl NodeCodegen for onnx_ir::conv1d::Conv1dNode { let #output = self.#field.forward(#input); } } + fn register_imports(&self, imports: &mut BurnImports) { imports.register("burn::nn::PaddingConfig1d"); imports.register("burn::nn::conv::Conv1d"); @@ -87,7 +89,18 @@ mod tests { use onnx_ir::padding::PaddingConfig1d; fn create_conv1d_node(name: &str) -> Conv1dNode { - let config = Conv1dConfig::new(3, 64, 3, 1, 1, 1, true, PaddingConfig1d::Explicit(1)); + let config = Conv1dConfig::new(3, 64, 3, 1, 1, 1, true, PaddingConfig1d::Explicit(1, 1)); + + Conv1dNodeBuilder::new(name) + .input_tensor("input", 3, DType::F32) + .output_tensor("output", 3, DType::F32) + .config(config) + .build() + } + + fn create_conv1d_node_asymmetric(name: &str) -> Conv1dNode { + // Asymmetric padding: left=1, right=2 + let config = Conv1dConfig::new(3, 64, 3, 1, 1, 1, true, PaddingConfig1d::Explicit(1, 2)); Conv1dNodeBuilder::new(name) .input_tensor("input", 3, DType::F32) @@ -119,4 +132,20 @@ mod tests { } "); } + + #[test] + fn test_conv1d_field_init_asymmetric_padding() { + let node = create_conv1d_node_asymmetric("conv1"); + let code = codegen_field_init(&node); + // Asymmetric padding is passed directly to the module + assert_snapshot!(code, @r" + let conv1 = Conv1dConfig::new(3, 64, 3) + .with_stride(1) + .with_padding(PaddingConfig1d::Explicit(1, 2)) + .with_dilation(1) + .with_groups(1) + .with_bias(true) + .init(device); + "); + } } diff --git a/crates/burn-import/src/burn/node/conv2d.rs b/crates/burn-import/src/burn/node/conv2d.rs index c3ceea813e..8fc8b73af1 100644 --- a/crates/burn-import/src/burn/node/conv2d.rs +++ b/crates/burn-import/src/burn/node/conv2d.rs @@ -17,9 +17,10 @@ impl NodeCodegen for onnx_ir::conv2d::Conv2dNode { let stride = self.config.stride.to_tokens(); let dilation = self.config.dilation.to_tokens(); let groups = self.config.groups.to_tokens(); - let padding = self.config.padding.to_tokens(); let bias = self.config.bias; + let padding = self.config.padding.to_tokens(); + Some(Field::new( self.name.clone(), quote! { @@ -93,7 +94,26 @@ mod tests { [3, 64], [3, 3], [1, 1], - PaddingConfig2d::Explicit(1, 1), + PaddingConfig2d::Explicit(1, 1, 1, 1), + [1, 1], + 1, + true, + ); + + Conv2dNodeBuilder::new(name) + .input_tensor("input", 4, DType::F32) + .output_tensor("output", 4, DType::F32) + .config(config) + .build() + } + + fn create_conv2d_node_asymmetric(name: &str) -> Conv2dNode { + // Asymmetric padding: top=1, left=2, bottom=3, right=4 + let config = Conv2dConfig::new( + [3, 64], + [3, 3], + [1, 1], + PaddingConfig2d::Explicit(1, 2, 3, 4), [1, 1], 1, true, @@ -129,4 +149,20 @@ mod tests { } "); } + + #[test] + fn test_conv2d_field_init_asymmetric_padding() { + let node = create_conv2d_node_asymmetric("conv1"); + let code = codegen_field_init(&node); + // Asymmetric padding is passed directly to the module + assert_snapshot!(code, @r" + let conv1 = Conv2dConfig::new([3, 64], [3, 3]) + .with_stride([1, 1]) + .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4)) + .with_dilation([1, 1]) + .with_groups(1) + .with_bias(true) + .init(device); + "); + } } diff --git a/crates/burn-import/src/burn/node/conv3d.rs b/crates/burn-import/src/burn/node/conv3d.rs index ad46274a48..15e7dbc99d 100644 --- a/crates/burn-import/src/burn/node/conv3d.rs +++ b/crates/burn-import/src/burn/node/conv3d.rs @@ -18,9 +18,11 @@ impl NodeCodegen for onnx_ir::conv3d::Conv3dNode { let stride = self.config.stride.to_tokens(); let dilation = self.config.dilation.to_tokens(); let groups = self.config.groups.to_tokens(); - let padding = self.config.padding.to_tokens(); let bias = self.config.bias; + // Asymmetric 3D padding is handled by the burn-nn module (will panic if attempted) + let padding = self.config.padding.to_tokens(); + Some(Field::new( self.name.clone(), quote! { @@ -43,10 +45,12 @@ impl NodeCodegen for onnx_ir::conv3d::Conv3dNode { let output = arg_to_ident(self.outputs.first().unwrap()); let field = Ident::new(&self.name, Span::call_site()); + // Asymmetric 3D padding will panic at runtime in the burn-nn module quote! { let #output = self.#field.forward(#input); } } + fn register_imports(&self, imports: &mut BurnImports) { imports.register("burn::nn::PaddingConfig3d"); imports.register("burn::nn::conv::Conv3d"); @@ -93,7 +97,26 @@ mod tests { [1, 1, 1], 1, true, - PaddingConfig3d::Explicit(1, 1, 1), + PaddingConfig3d::Explicit(1, 1, 1, 1, 1, 1), + ); + + Conv3dNodeBuilder::new(name) + .input_tensor("input", 5, DType::F32) + .output_tensor("output", 5, DType::F32) + .config(config) + .build() + } + + fn create_conv3d_node_asymmetric(name: &str) -> Conv3dNode { + // Asymmetric padding: front=1, top=2, left=3, back=4, bottom=5, right=6 + let config = Conv3dConfig::new( + [3, 64], + [3, 3, 3], + [1, 1, 1], + [1, 1, 1], + 1, + true, + PaddingConfig3d::Explicit(1, 2, 3, 4, 5, 6), ); Conv3dNodeBuilder::new(name) @@ -126,4 +149,20 @@ mod tests { } "); } + + #[test] + fn test_conv3d_field_init_asymmetric_padding() { + let node = create_conv3d_node_asymmetric("conv1"); + let code = codegen_field_init(&node); + // Asymmetric padding is passed directly to the module (will panic at runtime) + assert_snapshot!(code, @r" + let conv1 = Conv3dConfig::new([3, 64], [3, 3, 3]) + .with_stride([1, 1, 1]) + .with_padding(PaddingConfig3d::Explicit(1, 2, 3, 4, 5, 6)) + .with_dilation([1, 1, 1]) + .with_groups(1) + .with_bias(true) + .init(device); + "); + } } diff --git a/crates/burn-import/src/burn/node/max_pool1d.rs b/crates/burn-import/src/burn/node/max_pool1d.rs index e7f4269525..510fac3e00 100644 --- a/crates/burn-import/src/burn/node/max_pool1d.rs +++ b/crates/burn-import/src/burn/node/max_pool1d.rs @@ -1,4 +1,5 @@ use super::prelude::*; + impl NodeCodegen for onnx_ir::max_pool1d::MaxPool1dNode { fn inputs(&self) -> &[Argument] { &self.inputs @@ -12,10 +13,11 @@ impl NodeCodegen for onnx_ir::max_pool1d::MaxPool1dNode { let name = Ident::new(&self.name, Span::call_site()); let kernel_size = self.config.kernel_size.to_tokens(); let strides = self.config.stride.to_tokens(); - let padding = self.config.padding.to_tokens(); let dilation = self.config.dilation.to_tokens(); let ceil_mode = self.config.ceil_mode; + let padding = self.config.padding.to_tokens(); + Some(Field::new( self.name.clone(), quote! { @@ -67,6 +69,17 @@ mod tests { .build() } + fn create_max_pool1d_node_asymmetric(name: &str) -> MaxPool1dNode { + // Asymmetric padding: left=1, right=2 + let config = MaxPool1dConfig::new(3, 1, 1, PaddingConfig1d::Explicit(1, 2), false); + + MaxPool1dNodeBuilder::new(name) + .input_tensor("input", 3, DType::F32) + .output_tensor("output", 3, DType::F32) + .config(config) + .build() + } + #[test] fn test_max_pool1d_forward() { let node = create_max_pool1d_node("pool1", false); @@ -118,4 +131,32 @@ mod tests { .init(); "#); } + + #[test] + fn test_max_pool1d_forward_asymmetric_padding() { + let node = create_max_pool1d_node_asymmetric("pool1"); + let code = codegen_forward_default(&node); + // Asymmetric padding is now handled by the burn-nn module + assert_snapshot!(code, @r" + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.pool1.forward(input); + output + } + "); + } + + #[test] + fn test_max_pool1d_field_init_asymmetric_padding() { + let node = create_max_pool1d_node_asymmetric("pool1"); + let code = codegen_field_init(&node); + // Asymmetric padding is passed directly to the module + assert_snapshot!(code, @r" + let pool1 = MaxPool1dConfig::new(3) + .with_stride(1) + .with_padding(PaddingConfig1d::Explicit(1, 2)) + .with_dilation(1) + .with_ceil_mode(false) + .init(); + "); + } } diff --git a/crates/burn-import/src/burn/node/max_pool2d.rs b/crates/burn-import/src/burn/node/max_pool2d.rs index 2b9c783ab4..246a302fd3 100644 --- a/crates/burn-import/src/burn/node/max_pool2d.rs +++ b/crates/burn-import/src/burn/node/max_pool2d.rs @@ -1,4 +1,5 @@ use super::prelude::*; + impl NodeCodegen for onnx_ir::max_pool2d::MaxPool2dNode { fn inputs(&self) -> &[Argument] { &self.inputs @@ -12,10 +13,11 @@ impl NodeCodegen for onnx_ir::max_pool2d::MaxPool2dNode { let name = Ident::new(&self.name, Span::call_site()); let kernel_size = self.config.kernel_size.to_tokens(); let strides = self.config.strides.to_tokens(); - let padding = self.config.padding.to_tokens(); let dilation = self.config.dilation.to_tokens(); let ceil_mode = self.config.ceil_mode; + let padding = self.config.padding.to_tokens(); + Some(Field::new( self.name.clone(), quote! { @@ -68,6 +70,23 @@ mod tests { .build() } + fn create_max_pool2d_node_asymmetric(name: &str) -> MaxPool2dNode { + // Asymmetric padding: top=1, left=2, bottom=3, right=4 + let config = MaxPool2dConfig::new( + [3, 3], + [1, 1], + PaddingConfig2d::Explicit(1, 2, 3, 4), + [1, 1], + false, + ); + + MaxPool2dNodeBuilder::new(name) + .input_tensor("input", 4, DType::F32) + .output_tensor("output", 4, DType::F32) + .config(config) + .build() + } + #[test] fn test_max_pool2d_forward() { let node = create_max_pool2d_node("pool1", false); @@ -119,4 +138,32 @@ mod tests { .init(); "#); } + + #[test] + fn test_max_pool2d_forward_asymmetric_padding() { + let node = create_max_pool2d_node_asymmetric("pool1"); + let code = codegen_forward_default(&node); + // Asymmetric padding is now handled by the burn-nn module + assert_snapshot!(code, @r" + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.pool1.forward(input); + output + } + "); + } + + #[test] + fn test_max_pool2d_field_init_asymmetric_padding() { + let node = create_max_pool2d_node_asymmetric("pool1"); + let code = codegen_field_init(&node); + // Asymmetric padding is passed directly to the module + assert_snapshot!(code, @r" + let pool1 = MaxPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4)) + .with_dilation([1, 1]) + .with_ceil_mode(false) + .init(); + "); + } } diff --git a/crates/burn-nn/src/modules/conv/conv1d.rs b/crates/burn-nn/src/modules/conv/conv1d.rs index 450ae49ea4..9b4222f318 100644 --- a/crates/burn-nn/src/modules/conv/conv1d.rs +++ b/crates/burn-nn/src/modules/conv/conv1d.rs @@ -3,7 +3,12 @@ use alloc::format; use burn_core as burn; use crate::{PaddingConfig1d, conv::checks}; -use burn::tensor::{Tensor, backend::Backend, module::conv1d, ops::ConvOptions}; +use burn::tensor::{ + Tensor, + backend::Backend, + module::conv1d, + ops::{ConvOptions, PadMode}, +}; use burn::{ config::Config, module::{Content, DisplaySettings, Ignored, Initializer, Module, ModuleDisplay, Param}, @@ -149,17 +154,32 @@ impl Conv1d { /// - input: `[batch_size, channels_in, length_in]` /// - output: `[batch_size, channels_out, length_out]` pub fn forward(&self, input: Tensor) -> Tensor { - let length = input.dims()[2]; - let padding = self - .padding - .calculate_padding_1d(length, self.kernel_size, self.stride); - - conv1d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvOptions::new([self.stride], [padding], [self.dilation], self.groups), - ) + // Handle asymmetric padding by applying explicit pad operation first + if self.padding.is_asymmetric() { + let (left, right) = self.padding.as_tuple(); + // Burn's pad takes (left, right, top, bottom) for the last two dimensions + // For 1D (NCL format), we only pad L (last dim), so top/bottom = 0 + let padded = input.pad((left, right, 0, 0), PadMode::Constant(0.0)); + // Use zero padding for the conv operation since we already padded + conv1d( + padded, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvOptions::new([self.stride], [0], [self.dilation], self.groups), + ) + } else { + let length = input.dims()[2]; + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); + + conv1d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvOptions::new([self.stride], [padding], [self.dilation], self.groups), + ) + } } } @@ -227,4 +247,42 @@ mod tests { let input = Tensor::::zeros([1, 4, 10], &Default::default()); let _ = conv.forward(input); } + + #[test] + fn asymmetric_padding_forward() { + let device = Default::default(); + // Create conv with asymmetric padding: left=1, right=2 + let config = Conv1dConfig::new(2, 3, 3) + .with_padding(PaddingConfig1d::Explicit(1, 2)) + .with_initializer(Initializer::Constant { value: 1.0 }) + .with_bias(false); + let conv = config.init::(&device); + + // Input: [batch=1, channels=2, length=4] + let input = Tensor::::ones([1, 2, 4], &device); + let output = conv.forward(input); + + // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7 + // Output length = (7 - 3) / 1 + 1 = 5 + assert_eq!(output.dims(), [1, 3, 5]); + } + + #[test] + fn symmetric_explicit_padding_forward() { + let device = Default::default(); + // Create conv with symmetric explicit padding: left=2, right=2 + let config = Conv1dConfig::new(2, 3, 3) + .with_padding(PaddingConfig1d::Explicit(2, 2)) + .with_initializer(Initializer::Constant { value: 1.0 }) + .with_bias(false); + let conv = config.init::(&device); + + // Input: [batch=1, channels=2, length=4] + let input = Tensor::::ones([1, 2, 4], &device); + let output = conv.forward(input); + + // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8 + // Output length = (8 - 3) / 1 + 1 = 6 + assert_eq!(output.dims(), [1, 3, 6]); + } } diff --git a/crates/burn-nn/src/modules/conv/conv2d.rs b/crates/burn-nn/src/modules/conv/conv2d.rs index ebf35f2d41..cda5cb061f 100644 --- a/crates/burn-nn/src/modules/conv/conv2d.rs +++ b/crates/burn-nn/src/modules/conv/conv2d.rs @@ -9,7 +9,7 @@ use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Par use burn::tensor::Tensor; use burn::tensor::backend::Backend; use burn::tensor::module::conv2d; -use burn::tensor::ops::ConvOptions; +use burn::tensor::ops::{ConvOptions, PadMode}; use crate::conv::checks; @@ -168,16 +168,33 @@ impl Conv2d { /// println!("{:?}", y.dims()); // [1, 8, 26, 26] /// ``` pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels_in, height_in, width_in] = input.dims(); - let padding = - self.padding - .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); - conv2d( - input, - self.weight.val(), - self.bias.as_ref().map(|bias| bias.val()), - ConvOptions::new(self.stride, padding, self.dilation, self.groups), - ) + // Handle asymmetric padding by applying explicit pad operation first + if self.padding.is_asymmetric() { + let (top, left, bottom, right) = self.padding.as_tuple(); + // Burn's pad takes (left, right, top, bottom) for the last two dimensions + let padded = input.pad((left, right, top, bottom), PadMode::Constant(0.0)); + // Use zero padding for the conv operation since we already padded + conv2d( + padded, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvOptions::new(self.stride, [0, 0], self.dilation, self.groups), + ) + } else { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = self.padding.calculate_padding_2d( + height_in, + width_in, + &self.kernel_size, + &self.stride, + ); + conv2d( + input, + self.weight.val(), + self.bias.as_ref().map(|bias| bias.val()), + ConvOptions::new(self.stride, padding, self.dilation, self.groups), + ) + } } } @@ -289,4 +306,42 @@ mod tests { let input = Tensor::::zeros([1, 4, 10, 10], &Default::default()); let _ = conv.forward(input); } + + #[test] + fn asymmetric_padding_forward() { + let device = Default::default(); + // Create conv with asymmetric padding: top=1, left=2, bottom=3, right=4 + let config = Conv2dConfig::new([2, 3], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4)) + .with_initializer(Initializer::Constant { value: 1.0 }) + .with_bias(false); + let conv = config.init::(&device); + + // Input: [batch=1, channels=2, height=4, width=5] + let input = Tensor::::ones([1, 2, 4, 5], &device); + let output = conv.forward(input); + + // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6 + // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9 + assert_eq!(output.dims(), [1, 3, 6, 9]); + } + + #[test] + fn symmetric_explicit_padding_forward() { + let device = Default::default(); + // Create conv with symmetric explicit padding: top=2, left=2, bottom=2, right=2 + let config = Conv2dConfig::new([2, 3], [3, 3]) + .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2)) + .with_initializer(Initializer::Constant { value: 1.0 }) + .with_bias(false); + let conv = config.init::(&device); + + // Input: [batch=1, channels=2, height=4, width=5] + let input = Tensor::::ones([1, 2, 4, 5], &device); + let output = conv.forward(input); + + // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6 + // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7 + assert_eq!(output.dims(), [1, 3, 6, 7]); + } } diff --git a/crates/burn-nn/src/modules/conv/conv3d.rs b/crates/burn-nn/src/modules/conv/conv3d.rs index e882d38af5..bf626f19fc 100644 --- a/crates/burn-nn/src/modules/conv/conv3d.rs +++ b/crates/burn-nn/src/modules/conv/conv3d.rs @@ -154,6 +154,15 @@ impl Conv3d { /// - input: `[batch_size, channels_in, depth_in, height_in, width_in]` /// - output: `[batch_size, channels_out, depth_out, height_out, width_out]` pub fn forward(&self, input: Tensor) -> Tensor { + // Asymmetric 3D padding is not currently supported because burn's pad API + // only supports padding the last 2 dimensions + if self.padding.is_asymmetric() { + panic!( + "Asymmetric 3D padding is not currently supported. \ + burn's pad API only supports 2D padding (last two dimensions)." + ); + } + let [_batch_size, _channels_in, depth_in, height_in, width_in] = input.dims(); let padding = self.padding.calculate_padding_3d( depth_in, diff --git a/crates/burn-nn/src/modules/pool/avg_pool1d.rs b/crates/burn-nn/src/modules/pool/avg_pool1d.rs index 8e51d7de96..1c0e54658d 100644 --- a/crates/burn-nn/src/modules/pool/avg_pool1d.rs +++ b/crates/burn-nn/src/modules/pool/avg_pool1d.rs @@ -7,6 +7,7 @@ use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::module::{Ignored, Module}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; +use burn::tensor::ops::PadMode; use burn::tensor::module::avg_pool1d; @@ -103,25 +104,43 @@ impl AvgPool1d { /// - input: `[batch_size, channels, length_in]` /// - output: `[batch_size, channels, length_out]` pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels, length] = input.dims(); - let padding = self - .padding - .calculate_padding_1d(length, self.kernel_size, self.stride); - - avg_pool1d( - input, - self.kernel_size, - self.stride, - padding, - self.count_include_pad, - self.ceil_mode, - ) + // Handle asymmetric padding by applying explicit pad operation first + if self.padding.is_asymmetric() { + let (left, right) = self.padding.as_tuple(); + // Burn's pad takes (left, right, top, bottom) for the last two dimensions + // For 1D (NCL format), we only pad L (last dim), so top/bottom = 0 + let padded = input.pad((left, right, 0, 0), PadMode::Constant(0.0)); + // Use zero padding for the pool operation since we already padded + avg_pool1d( + padded, + self.kernel_size, + self.stride, + 0, + self.count_include_pad, + self.ceil_mode, + ) + } else { + let [_batch_size, _channels, length] = input.dims(); + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); + + avg_pool1d( + input, + self.kernel_size, + self.stride, + padding, + self.count_include_pad, + self.ceil_mode, + ) + } } } #[cfg(test)] mod tests { use super::*; + use crate::TestBackend; use rstest::rstest; #[test] @@ -154,4 +173,40 @@ mod tests { config.stride, config.kernel_size ); } + + #[test] + fn asymmetric_padding_forward() { + let device = Default::default(); + // Create avg pool with asymmetric padding: left=1, right=2 + let config = AvgPool1dConfig::new(3) + .with_stride(1) + .with_padding(PaddingConfig1d::Explicit(1, 2)); + let pool = config.init(); + + // Input: [batch=1, channels=2, length=4] + let input = Tensor::::ones([1, 2, 4], &device); + let output = pool.forward(input); + + // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7 + // Output length = (7 - 3) / 1 + 1 = 5 + assert_eq!(output.dims(), [1, 2, 5]); + } + + #[test] + fn symmetric_explicit_padding_forward() { + let device = Default::default(); + // Create avg pool with symmetric explicit padding: left=2, right=2 + let config = AvgPool1dConfig::new(3) + .with_stride(1) + .with_padding(PaddingConfig1d::Explicit(2, 2)); + let pool = config.init(); + + // Input: [batch=1, channels=2, length=4] + let input = Tensor::::ones([1, 2, 4], &device); + let output = pool.forward(input); + + // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8 + // Output length = (8 - 3) / 1 + 1 = 6 + assert_eq!(output.dims(), [1, 2, 6]); + } } diff --git a/crates/burn-nn/src/modules/pool/avg_pool2d.rs b/crates/burn-nn/src/modules/pool/avg_pool2d.rs index 97e90e1f60..c708d67d98 100644 --- a/crates/burn-nn/src/modules/pool/avg_pool2d.rs +++ b/crates/burn-nn/src/modules/pool/avg_pool2d.rs @@ -7,6 +7,7 @@ use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::module::{Ignored, Module}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; +use burn::tensor::ops::PadMode; use burn::tensor::module::avg_pool2d; @@ -103,25 +104,45 @@ impl AvgPool2d { /// - input: `[batch_size, channels, height_in, width_in]` /// - output: `[batch_size, channels, height_out, width_out]` pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels_in, height_in, width_in] = input.dims(); - let padding = - self.padding - .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); - - avg_pool2d( - input, - self.kernel_size, - self.stride, - padding, - self.count_include_pad, - self.ceil_mode, - ) + // Handle asymmetric padding by applying explicit pad operation first + if self.padding.is_asymmetric() { + let (top, left, bottom, right) = self.padding.as_tuple(); + // Burn's pad takes (left, right, top, bottom) for the last two dimensions + let padded = input.pad((left, right, top, bottom), PadMode::Constant(0.0)); + // Use zero padding for the pool operation since we already padded + avg_pool2d( + padded, + self.kernel_size, + self.stride, + [0, 0], + self.count_include_pad, + self.ceil_mode, + ) + } else { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = self.padding.calculate_padding_2d( + height_in, + width_in, + &self.kernel_size, + &self.stride, + ); + + avg_pool2d( + input, + self.kernel_size, + self.stride, + padding, + self.count_include_pad, + self.ceil_mode, + ) + } } } #[cfg(test)] mod tests { use super::*; + use crate::TestBackend; use rstest::rstest; #[test] @@ -155,4 +176,40 @@ mod tests { config.strides, config.kernel_size ); } + + #[test] + fn asymmetric_padding_forward() { + let device = Default::default(); + // Create avg pool with asymmetric padding: top=1, left=2, bottom=3, right=4 + let config = AvgPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4)); + let pool = config.init(); + + // Input: [batch=1, channels=2, height=4, width=5] + let input = Tensor::::ones([1, 2, 4, 5], &device); + let output = pool.forward(input); + + // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6 + // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9 + assert_eq!(output.dims(), [1, 2, 6, 9]); + } + + #[test] + fn symmetric_explicit_padding_forward() { + let device = Default::default(); + // Create avg pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2 + let config = AvgPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2)); + let pool = config.init(); + + // Input: [batch=1, channels=2, height=4, width=5] + let input = Tensor::::ones([1, 2, 4, 5], &device); + let output = pool.forward(input); + + // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6 + // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7 + assert_eq!(output.dims(), [1, 2, 6, 7]); + } } diff --git a/crates/burn-nn/src/modules/pool/max_pool1d.rs b/crates/burn-nn/src/modules/pool/max_pool1d.rs index 99524438d0..53a38106dd 100644 --- a/crates/burn-nn/src/modules/pool/max_pool1d.rs +++ b/crates/burn-nn/src/modules/pool/max_pool1d.rs @@ -7,6 +7,7 @@ use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::module::{Ignored, Module}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; +use burn::tensor::ops::PadMode; use burn::tensor::module::max_pool1d; @@ -95,25 +96,44 @@ impl MaxPool1d { /// - input: `[batch_size, channels, length_in]` /// - output: `[batch_size, channels, length_out]` pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels, length] = input.dims(); - let padding = self - .padding - .calculate_padding_1d(length, self.kernel_size, self.stride); - - max_pool1d( - input, - self.kernel_size, - self.stride, - padding, - self.dilation, - self.ceil_mode, - ) + // Handle asymmetric padding by applying explicit pad operation first + if self.padding.is_asymmetric() { + let (left, right) = self.padding.as_tuple(); + // Burn's pad takes (left, right, top, bottom) for the last two dimensions + // For 1D (NCL format), we only pad L (last dim), so top/bottom = 0 + // Use -inf for max pooling so padded values don't affect the max + let padded = input.pad((left, right, 0, 0), PadMode::Constant(f32::NEG_INFINITY)); + // Use zero padding for the pool operation since we already padded + max_pool1d( + padded, + self.kernel_size, + self.stride, + 0, + self.dilation, + self.ceil_mode, + ) + } else { + let [_batch_size, _channels, length] = input.dims(); + let padding = self + .padding + .calculate_padding_1d(length, self.kernel_size, self.stride); + + max_pool1d( + input, + self.kernel_size, + self.stride, + padding, + self.dilation, + self.ceil_mode, + ) + } } } #[cfg(test)] mod tests { use super::*; + use crate::TestBackend; use rstest::rstest; #[test] @@ -147,4 +167,40 @@ mod tests { config.stride, config.kernel_size ); } + + #[test] + fn asymmetric_padding_forward() { + let device = Default::default(); + // Create max pool with asymmetric padding: left=1, right=2 + let config = MaxPool1dConfig::new(3) + .with_stride(1) + .with_padding(PaddingConfig1d::Explicit(1, 2)); + let pool = config.init(); + + // Input: [batch=1, channels=2, length=4] + let input = Tensor::::ones([1, 2, 4], &device); + let output = pool.forward(input); + + // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7 + // Output length = (7 - 3) / 1 + 1 = 5 + assert_eq!(output.dims(), [1, 2, 5]); + } + + #[test] + fn symmetric_explicit_padding_forward() { + let device = Default::default(); + // Create max pool with symmetric explicit padding: left=2, right=2 + let config = MaxPool1dConfig::new(3) + .with_stride(1) + .with_padding(PaddingConfig1d::Explicit(2, 2)); + let pool = config.init(); + + // Input: [batch=1, channels=2, length=4] + let input = Tensor::::ones([1, 2, 4], &device); + let output = pool.forward(input); + + // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8 + // Output length = (8 - 3) / 1 + 1 = 6 + assert_eq!(output.dims(), [1, 2, 6]); + } } diff --git a/crates/burn-nn/src/modules/pool/max_pool2d.rs b/crates/burn-nn/src/modules/pool/max_pool2d.rs index c3b9f489a5..4301899112 100644 --- a/crates/burn-nn/src/modules/pool/max_pool2d.rs +++ b/crates/burn-nn/src/modules/pool/max_pool2d.rs @@ -7,6 +7,7 @@ use burn::module::{Content, DisplaySettings, ModuleDisplay}; use burn::module::{Ignored, Module}; use burn::tensor::Tensor; use burn::tensor::backend::Backend; +use burn::tensor::ops::PadMode; use burn::tensor::module::max_pool2d; @@ -95,25 +96,49 @@ impl MaxPool2d { /// - input: `[batch_size, channels, height_in, width_in]` /// - output: `[batch_size, channels, height_out, width_out]` pub fn forward(&self, input: Tensor) -> Tensor { - let [_batch_size, _channels_in, height_in, width_in] = input.dims(); - let padding = - self.padding - .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); - - max_pool2d( - input, - self.kernel_size, - self.stride, - padding, - self.dilation, - self.ceil_mode, - ) + // Handle asymmetric padding by applying explicit pad operation first + if self.padding.is_asymmetric() { + let (top, left, bottom, right) = self.padding.as_tuple(); + // Burn's pad takes (left, right, top, bottom) for the last two dimensions + // Use -inf for max pooling so padded values don't affect the max + let padded = input.pad( + (left, right, top, bottom), + PadMode::Constant(f32::NEG_INFINITY), + ); + // Use zero padding for the pool operation since we already padded + max_pool2d( + padded, + self.kernel_size, + self.stride, + [0, 0], + self.dilation, + self.ceil_mode, + ) + } else { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = self.padding.calculate_padding_2d( + height_in, + width_in, + &self.kernel_size, + &self.stride, + ); + + max_pool2d( + input, + self.kernel_size, + self.stride, + padding, + self.dilation, + self.ceil_mode, + ) + } } } #[cfg(test)] mod tests { use super::*; + use crate::TestBackend; use rstest::rstest; #[test] @@ -147,4 +172,40 @@ mod tests { config.strides, config.kernel_size ); } + + #[test] + fn asymmetric_padding_forward() { + let device = Default::default(); + // Create max pool with asymmetric padding: top=1, left=2, bottom=3, right=4 + let config = MaxPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4)); + let pool = config.init(); + + // Input: [batch=1, channels=2, height=4, width=5] + let input = Tensor::::ones([1, 2, 4, 5], &device); + let output = pool.forward(input); + + // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6 + // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9 + assert_eq!(output.dims(), [1, 2, 6, 9]); + } + + #[test] + fn symmetric_explicit_padding_forward() { + let device = Default::default(); + // Create max pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2 + let config = MaxPool2dConfig::new([3, 3]) + .with_strides([1, 1]) + .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2)); + let pool = config.init(); + + // Input: [batch=1, channels=2, height=4, width=5] + let input = Tensor::::ones([1, 2, 4, 5], &device); + let output = pool.forward(input); + + // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6 + // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7 + assert_eq!(output.dims(), [1, 2, 6, 7]); + } } diff --git a/crates/burn-nn/src/padding.rs b/crates/burn-nn/src/padding.rs index 21edc61f89..73de37674c 100644 --- a/crates/burn-nn/src/padding.rs +++ b/crates/burn-nn/src/padding.rs @@ -10,11 +10,16 @@ pub enum PaddingConfig1d { Same, /// No padding applied. Valid, - /// Applies a specific amount of padding to all inputs. - Explicit(usize), + /// Applies explicit padding values. + /// Format: (left, right) + /// For symmetric padding, use the same value for both (e.g., `Explicit(1, 1)`). + Explicit(usize, usize), } impl PaddingConfig1d { + /// Calculate symmetric padding for 1D operations. + /// Returns a single padding value (same for both sides). + /// Panics if asymmetric padding is used. pub(crate) fn calculate_padding_1d( &self, length: usize, @@ -25,7 +30,29 @@ impl PaddingConfig1d { match self { Self::Valid => 0, Self::Same => same_padding(), - Self::Explicit(value) => *value, + Self::Explicit(left, right) => { + if left != right { + panic!("Asymmetric padding should be handled separately via is_asymmetric()") + } + *left + } + } + } + + /// Returns true if this padding is asymmetric (left != right). + pub fn is_asymmetric(&self) -> bool { + match self { + Self::Explicit(left, right) => left != right, + _ => false, + } + } + + /// Returns the padding values (left, right). + /// Panics if not Explicit padding. + pub fn as_tuple(&self) -> (usize, usize) { + match self { + Self::Explicit(left, right) => (*left, *right), + _ => panic!("as_tuple() only works with Explicit padding"), } } } @@ -37,11 +64,16 @@ pub enum PaddingConfig2d { Same, /// No padding applied. Valid, - /// Applies specified padding values to height and width dimensions. - Explicit(usize, usize), + /// Applies explicit padding values. + /// Format: (top, left, bottom, right) + /// For symmetric padding, use matching values (e.g., `Explicit(1, 1, 1, 1)`). + Explicit(usize, usize, usize, usize), } impl PaddingConfig2d { + /// Calculate symmetric padding for 2D operations. + /// Returns padding values [height, width] (same for both sides). + /// Panics if asymmetric padding is used. pub(crate) fn calculate_padding_2d( &self, height: usize, @@ -59,7 +91,29 @@ impl PaddingConfig2d { match self { Self::Same => same_padding(), Self::Valid => [0, 0], - Self::Explicit(v1, v2) => [*v1, *v2], + Self::Explicit(top, left, bottom, right) => { + if top != bottom || left != right { + panic!("Asymmetric padding should be handled separately via is_asymmetric()") + } + [*top, *left] + } + } + } + + /// Returns true if this padding is asymmetric (top != bottom or left != right). + pub fn is_asymmetric(&self) -> bool { + match self { + Self::Explicit(top, left, bottom, right) => top != bottom || left != right, + _ => false, + } + } + + /// Returns the padding values (top, left, bottom, right). + /// Panics if not Explicit padding. + pub fn as_tuple(&self) -> (usize, usize, usize, usize) { + match self { + Self::Explicit(top, left, bottom, right) => (*top, *left, *bottom, *right), + _ => panic!("as_tuple() only works with Explicit padding"), } } } @@ -71,11 +125,16 @@ pub enum PaddingConfig3d { Same, /// No padding applied. Valid, - /// Applies specified padding values to depth, height, and width dimensions. - Explicit(usize, usize, usize), + /// Applies explicit padding values. + /// Format: (front, top, left, back, bottom, right) + /// For symmetric padding, use matching values (e.g., `Explicit(1, 1, 1, 1, 1, 1)`). + Explicit(usize, usize, usize, usize, usize, usize), } impl PaddingConfig3d { + /// Calculate symmetric padding for 3D operations. + /// Returns padding values [depth, height, width] (same for both sides). + /// Panics if asymmetric padding is used. pub(crate) fn calculate_padding_3d( &self, depth: usize, @@ -95,7 +154,288 @@ impl PaddingConfig3d { match self { Self::Same => same_padding(), Self::Valid => [0, 0, 0], - Self::Explicit(v1, v2, v3) => [*v1, *v2, *v3], + Self::Explicit(front, top, left, back, bottom, right) => { + if front != back || top != bottom || left != right { + panic!("Asymmetric padding should be handled separately via is_asymmetric()") + } + [*front, *top, *left] + } + } + } + + /// Returns true if this padding is asymmetric. + pub fn is_asymmetric(&self) -> bool { + match self { + Self::Explicit(front, top, left, back, bottom, right) => { + front != back || top != bottom || left != right + } + _ => false, + } + } + + /// Returns the padding values (front, top, left, back, bottom, right). + /// Panics if not Explicit padding. + pub fn as_tuple(&self) -> (usize, usize, usize, usize, usize, usize) { + match self { + Self::Explicit(front, top, left, back, bottom, right) => { + (*front, *top, *left, *back, *bottom, *right) + } + _ => panic!("as_tuple() only works with Explicit padding"), } } } + +#[cfg(test)] +mod tests { + use super::*; + + // ==================== PaddingConfig1d Tests ==================== + + #[test] + fn test_padding_config_1d_is_asymmetric_symmetric() { + // Symmetric padding (left == right) should return false + let padding = PaddingConfig1d::Explicit(2, 2); + assert!(!padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_1d_is_asymmetric_asymmetric() { + // Asymmetric padding (left != right) should return true + let padding = PaddingConfig1d::Explicit(1, 2); + assert!(padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_1d_is_asymmetric_valid() { + // Valid padding should return false + let padding = PaddingConfig1d::Valid; + assert!(!padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_1d_is_asymmetric_same() { + // Same padding should return false + let padding = PaddingConfig1d::Same; + assert!(!padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_1d_as_tuple() { + let padding = PaddingConfig1d::Explicit(1, 2); + assert_eq!(padding.as_tuple(), (1, 2)); + } + + #[test] + #[should_panic(expected = "as_tuple() only works with Explicit padding")] + fn test_padding_config_1d_as_tuple_valid_panics() { + let padding = PaddingConfig1d::Valid; + let _ = padding.as_tuple(); + } + + #[test] + #[should_panic(expected = "as_tuple() only works with Explicit padding")] + fn test_padding_config_1d_as_tuple_same_panics() { + let padding = PaddingConfig1d::Same; + let _ = padding.as_tuple(); + } + + #[test] + fn test_padding_config_1d_calculate_valid() { + let padding = PaddingConfig1d::Valid; + assert_eq!(padding.calculate_padding_1d(10, 3, 1), 0); + } + + #[test] + fn test_padding_config_1d_calculate_explicit_symmetric() { + let padding = PaddingConfig1d::Explicit(2, 2); + assert_eq!(padding.calculate_padding_1d(10, 3, 1), 2); + } + + #[test] + #[should_panic(expected = "Asymmetric padding should be handled separately")] + fn test_padding_config_1d_calculate_explicit_asymmetric_panics() { + let padding = PaddingConfig1d::Explicit(1, 2); + let _ = padding.calculate_padding_1d(10, 3, 1); + } + + // ==================== PaddingConfig2d Tests ==================== + + #[test] + fn test_padding_config_2d_is_asymmetric_symmetric() { + // Symmetric padding should return false + let padding = PaddingConfig2d::Explicit(2, 2, 2, 2); + assert!(!padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_2d_is_asymmetric_top_bottom() { + // top != bottom should return true + let padding = PaddingConfig2d::Explicit(1, 2, 3, 2); + assert!(padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_2d_is_asymmetric_left_right() { + // left != right should return true + let padding = PaddingConfig2d::Explicit(2, 1, 2, 3); + assert!(padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_2d_is_asymmetric_all_different() { + // All different values should return true + let padding = PaddingConfig2d::Explicit(1, 2, 3, 4); + assert!(padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_2d_is_asymmetric_valid() { + let padding = PaddingConfig2d::Valid; + assert!(!padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_2d_is_asymmetric_same() { + let padding = PaddingConfig2d::Same; + assert!(!padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_2d_as_tuple() { + let padding = PaddingConfig2d::Explicit(1, 2, 3, 4); + assert_eq!(padding.as_tuple(), (1, 2, 3, 4)); + } + + #[test] + #[should_panic(expected = "as_tuple() only works with Explicit padding")] + fn test_padding_config_2d_as_tuple_valid_panics() { + let padding = PaddingConfig2d::Valid; + let _ = padding.as_tuple(); + } + + #[test] + #[should_panic(expected = "as_tuple() only works with Explicit padding")] + fn test_padding_config_2d_as_tuple_same_panics() { + let padding = PaddingConfig2d::Same; + let _ = padding.as_tuple(); + } + + #[test] + fn test_padding_config_2d_calculate_valid() { + let padding = PaddingConfig2d::Valid; + assert_eq!( + padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]), + [0, 0] + ); + } + + #[test] + fn test_padding_config_2d_calculate_explicit_symmetric() { + let padding = PaddingConfig2d::Explicit(2, 3, 2, 3); + assert_eq!( + padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]), + [2, 3] + ); + } + + #[test] + #[should_panic(expected = "Asymmetric padding should be handled separately")] + fn test_padding_config_2d_calculate_explicit_asymmetric_panics() { + let padding = PaddingConfig2d::Explicit(1, 2, 3, 4); + let _ = padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]); + } + + // ==================== PaddingConfig3d Tests ==================== + + #[test] + fn test_padding_config_3d_is_asymmetric_symmetric() { + // Symmetric padding should return false + let padding = PaddingConfig3d::Explicit(2, 3, 1, 2, 3, 1); + assert!(!padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_3d_is_asymmetric_front_back() { + // front != back should return true + let padding = PaddingConfig3d::Explicit(1, 3, 1, 2, 3, 1); + assert!(padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_3d_is_asymmetric_top_bottom() { + // top != bottom should return true + let padding = PaddingConfig3d::Explicit(2, 1, 1, 2, 3, 1); + assert!(padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_3d_is_asymmetric_left_right() { + // left != right should return true + let padding = PaddingConfig3d::Explicit(2, 3, 1, 2, 3, 4); + assert!(padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_3d_is_asymmetric_all_different() { + // All different values should return true + let padding = PaddingConfig3d::Explicit(1, 2, 3, 4, 5, 6); + assert!(padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_3d_is_asymmetric_valid() { + let padding = PaddingConfig3d::Valid; + assert!(!padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_3d_is_asymmetric_same() { + let padding = PaddingConfig3d::Same; + assert!(!padding.is_asymmetric()); + } + + #[test] + fn test_padding_config_3d_as_tuple() { + let padding = PaddingConfig3d::Explicit(1, 2, 3, 4, 5, 6); + assert_eq!(padding.as_tuple(), (1, 2, 3, 4, 5, 6)); + } + + #[test] + #[should_panic(expected = "as_tuple() only works with Explicit padding")] + fn test_padding_config_3d_as_tuple_valid_panics() { + let padding = PaddingConfig3d::Valid; + let _ = padding.as_tuple(); + } + + #[test] + #[should_panic(expected = "as_tuple() only works with Explicit padding")] + fn test_padding_config_3d_as_tuple_same_panics() { + let padding = PaddingConfig3d::Same; + let _ = padding.as_tuple(); + } + + #[test] + fn test_padding_config_3d_calculate_valid() { + let padding = PaddingConfig3d::Valid; + assert_eq!( + padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]), + [0, 0, 0] + ); + } + + #[test] + fn test_padding_config_3d_calculate_explicit_symmetric() { + let padding = PaddingConfig3d::Explicit(1, 2, 3, 1, 2, 3); + assert_eq!( + padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]), + [1, 2, 3] + ); + } + + #[test] + #[should_panic(expected = "Asymmetric padding should be handled separately")] + fn test_padding_config_3d_calculate_explicit_asymmetric_panics() { + let padding = PaddingConfig3d::Explicit(1, 2, 3, 4, 5, 6); + let _ = padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]); + } +} diff --git a/crates/burn-store/src/safetensors/tests/multi_layer_verify.rs b/crates/burn-store/src/safetensors/tests/multi_layer_verify.rs index 4a1cb50a93..3e64006523 100644 --- a/crates/burn-store/src/safetensors/tests/multi_layer_verify.rs +++ b/crates/burn-store/src/safetensors/tests/multi_layer_verify.rs @@ -23,7 +23,7 @@ impl Net { pub fn new(device: &B::Device) -> Self { Self { conv1: Conv2dConfig::new([3, 4], [3, 3]) - .with_padding(PaddingConfig2d::Explicit(1, 1)) + .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .init(device), norm1: BatchNormConfig::new(4).init(device), fc1: LinearConfig::new(4 * 8 * 8, 16).init(device), diff --git a/crates/burn-store/src/safetensors/tests/pytorch_import.rs b/crates/burn-store/src/safetensors/tests/pytorch_import.rs index 733381374c..40849b2126 100644 --- a/crates/burn-store/src/safetensors/tests/pytorch_import.rs +++ b/crates/burn-store/src/safetensors/tests/pytorch_import.rs @@ -23,7 +23,7 @@ impl Net { pub fn new(device: &B::Device) -> Self { Self { conv1: Conv2dConfig::new([3, 4], [3, 3]) - .with_padding(PaddingConfig2d::Explicit(1, 1)) + .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1)) .init(device), norm1: BatchNormConfig::new(4).init(device), fc1: LinearConfig::new(4 * 8 * 8, 16).init(device), diff --git a/crates/onnx-ir/src/node/avg_pool1d.rs b/crates/onnx-ir/src/node/avg_pool1d.rs index 98bb639936..dd622f909a 100644 --- a/crates/onnx-ir/src/node/avg_pool1d.rs +++ b/crates/onnx-ir/src/node/avg_pool1d.rs @@ -245,7 +245,7 @@ mod tests { assert_eq!(config.stride, 2); assert_eq!(config.dilation, 1); assert!(!config.count_include_pad); - assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2, 2))); } #[test] @@ -261,7 +261,7 @@ mod tests { assert_eq!(config.stride, 1); assert_eq!(config.dilation, 1); assert!(config.count_include_pad); - assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2, 2))); } #[test] diff --git a/crates/onnx-ir/src/node/avg_pool2d.rs b/crates/onnx-ir/src/node/avg_pool2d.rs index 7d88b8754c..263a5acd19 100644 --- a/crates/onnx-ir/src/node/avg_pool2d.rs +++ b/crates/onnx-ir/src/node/avg_pool2d.rs @@ -243,7 +243,10 @@ mod tests { assert_eq!(config.strides, [2, 2]); assert_eq!(config.dilation, [1, 1]); assert!(!config.count_include_pad); - assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + assert!(matches!( + config.padding, + PaddingConfig2d::Explicit(1, 1, 1, 1) + )); } #[test] @@ -259,7 +262,10 @@ mod tests { assert_eq!(config.strides, [1, 1]); assert_eq!(config.dilation, [1, 1]); assert!(config.count_include_pad); - assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + assert!(matches!( + config.padding, + PaddingConfig2d::Explicit(1, 1, 1, 1) + )); } #[test] diff --git a/crates/onnx-ir/src/node/conv1d.rs b/crates/onnx-ir/src/node/conv1d.rs index 6fb6527a62..df1fda539c 100644 --- a/crates/onnx-ir/src/node/conv1d.rs +++ b/crates/onnx-ir/src/node/conv1d.rs @@ -369,7 +369,7 @@ mod tests { assert_eq!(config.dilation, 1); assert_eq!(config.groups, 1); assert!(config.bias); - assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2, 2))); } #[test] @@ -413,12 +413,14 @@ mod tests { } #[test] - #[should_panic(expected = "Asymmetric padding is not supported")] fn test_conv1d_config_asymmetric_padding() { let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1], 1, false, None) .build_with_graph_data(16); let processor = Conv1dProcessor; - let _ = processor.extract_config(&node, 16); + let config = processor.extract_config(&node, 16).unwrap(); + // Asymmetric padding should now be captured instead of panicking + assert!(matches!(config.padding, PaddingConfig1d::Explicit(1, 2))); + assert!(config.padding.is_asymmetric()); } #[test] diff --git a/crates/onnx-ir/src/node/conv2d.rs b/crates/onnx-ir/src/node/conv2d.rs index ddee9386b7..d94ee544fc 100644 --- a/crates/onnx-ir/src/node/conv2d.rs +++ b/crates/onnx-ir/src/node/conv2d.rs @@ -359,7 +359,10 @@ mod tests { processor.infer_types(&mut node, 16, &prefs).unwrap(); assert_eq!(config.kernel_size, [3, 3]); - assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + assert!(matches!( + config.padding, + PaddingConfig2d::Explicit(1, 1, 1, 1) + )); } #[test] @@ -424,7 +427,10 @@ mod tests { processor.infer_types(&mut node, 16, &prefs).unwrap(); assert_eq!(config.kernel_size, [3, 3]); - assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + assert!(matches!( + config.padding, + PaddingConfig2d::Explicit(1, 1, 1, 1) + )); } #[test] diff --git a/crates/onnx-ir/src/node/conv3d.rs b/crates/onnx-ir/src/node/conv3d.rs index 2ebb53ee57..7630c5afda 100644 --- a/crates/onnx-ir/src/node/conv3d.rs +++ b/crates/onnx-ir/src/node/conv3d.rs @@ -283,7 +283,10 @@ mod tests { processor.infer_types(&mut node, 16, &prefs).unwrap(); assert_eq!(config.kernel_size, [3, 3, 3]); - assert!(matches!(config.padding, PaddingConfig3d::Explicit(1, 1, 1))); + assert!(matches!( + config.padding, + PaddingConfig3d::Explicit(1, 1, 1, 1, 1, 1) + )); } #[test] diff --git a/crates/onnx-ir/src/node/max_pool1d.rs b/crates/onnx-ir/src/node/max_pool1d.rs index 4bdd2075f0..7a80339388 100644 --- a/crates/onnx-ir/src/node/max_pool1d.rs +++ b/crates/onnx-ir/src/node/max_pool1d.rs @@ -260,7 +260,7 @@ mod tests { assert_eq!(config.kernel_size, 4); assert_eq!(config.stride, 2); assert_eq!(config.dilation, 1); - assert!(matches!(config.padding, PaddingConfig1d::Explicit(2))); + assert!(matches!(config.padding, PaddingConfig1d::Explicit(2, 2))); } #[test] @@ -279,11 +279,13 @@ mod tests { } #[test] - #[should_panic(expected = "Asymmetric padding is not supported")] fn test_max_pool1d_config_asymmetric_padding() { let node = create_test_node(vec![4], vec![1], vec![1, 2], vec![1], 0, None); let processor = MaxPool1dProcessor; - let _ = processor.extract_config(&node, 16); + let config = processor.extract_config(&node, 16).unwrap(); + // Asymmetric padding should now be captured instead of panicking + assert!(matches!(config.padding, PaddingConfig1d::Explicit(1, 2))); + assert!(config.padding.is_asymmetric()); } #[test] diff --git a/crates/onnx-ir/src/node/max_pool2d.rs b/crates/onnx-ir/src/node/max_pool2d.rs index 3be95999d2..e3e3bcf2e5 100644 --- a/crates/onnx-ir/src/node/max_pool2d.rs +++ b/crates/onnx-ir/src/node/max_pool2d.rs @@ -271,7 +271,10 @@ mod tests { assert_eq!(config.kernel_size, [2, 2]); assert_eq!(config.strides, [2, 2]); assert_eq!(config.dilation, [1, 1]); - assert!(matches!(config.padding, PaddingConfig2d::Explicit(1, 1))); + assert!(matches!( + config.padding, + PaddingConfig2d::Explicit(1, 1, 1, 1) + )); } #[test] diff --git a/crates/onnx-ir/src/node/padding.rs b/crates/onnx-ir/src/node/padding.rs index 7f32088e79..3748a2d61f 100644 --- a/crates/onnx-ir/src/node/padding.rs +++ b/crates/onnx-ir/src/node/padding.rs @@ -4,8 +4,6 @@ //! //! Provides `PaddingConfig1d`, `PaddingConfig2d`, `PaddingConfig3d` enums and helper //! functions to convert ONNX padding arrays. -//! -//! **Limitations**: Only symmetric, non-negative padding is supported. use std::fmt; @@ -15,15 +13,35 @@ pub enum PaddingConfig1d { /// No padding (valid padding) #[default] Valid, - /// Explicit padding with a specific size - Explicit(usize), + /// Explicit padding with values for left and right sides + /// Format: (left, right) + /// For symmetric padding, use the same value for both (e.g., `Explicit(1, 1)`). + Explicit(usize, usize), } impl fmt::Display for PaddingConfig1d { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { PaddingConfig1d::Valid => write!(f, "Valid"), - PaddingConfig1d::Explicit(size) => write!(f, "Explicit({size})"), + PaddingConfig1d::Explicit(left, right) => write!(f, "Explicit({left}, {right})"), + } + } +} + +impl PaddingConfig1d { + /// Returns true if this padding configuration is asymmetric (left != right) + pub fn is_asymmetric(&self) -> bool { + match self { + PaddingConfig1d::Explicit(left, right) => left != right, + _ => false, + } + } + + /// Returns the padding values as (left, right) tuple + pub fn as_tuple(&self) -> (usize, usize) { + match self { + PaddingConfig1d::Valid => (0, 0), + PaddingConfig1d::Explicit(left, right) => (*left, *right), } } } @@ -32,16 +50,15 @@ impl fmt::Display for PaddingConfig1d { /// /// # Arguments /// -/// * `pads` - The padding values +/// * `pads` - The padding values [left, right] /// /// # Panics /// /// * If the padding is negative -/// * If the padding is not symmetric /// /// # Returns /// -/// * The padding configuration +/// * The padding configuration (Valid or Explicit) /// /// # Remarks /// @@ -52,17 +69,10 @@ pub(crate) fn padding_config_1d(pads: &[i64]) -> PaddingConfig1d { if left < 0 || right < 0 { panic!("Negative pad values are not supported"); - } else if left != right { - panic!("Asymmetric padding is not supported"); } else if left == 0 && right == 0 { - // i.e. [0, 0] PaddingConfig1d::Valid - } else if left == right { - // i.e. [2, 2] - PaddingConfig1d::Explicit(left as usize) } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({pads:?}) not supported"); + PaddingConfig1d::Explicit(left as usize, right as usize) } } @@ -72,38 +82,37 @@ pub enum PaddingConfig2d { /// No padding (valid padding) #[default] Valid, - /// Explicit padding with specific width and height - Explicit(usize, usize), + /// Explicit padding with values for each side + /// Format: (top, left, bottom, right) + /// For symmetric padding, use matching values (e.g., `Explicit(1, 1, 1, 1)`). + Explicit(usize, usize, usize, usize), } impl fmt::Display for PaddingConfig2d { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { PaddingConfig2d::Valid => write!(f, "Valid"), - PaddingConfig2d::Explicit(width, height) => { - write!(f, "Explicit({width}, {height})") + PaddingConfig2d::Explicit(top, left, bottom, right) => { + write!(f, "Explicit({top}, {left}, {bottom}, {right})") } } } } -/// Padding configuration for 3D operations such as convolution -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub enum PaddingConfig3d { - /// No padding (valid padding) - #[default] - Valid, - /// Explicit padding with specific width, height, and depth - Explicit(usize, usize, usize), -} +impl PaddingConfig2d { + /// Returns true if this padding configuration is asymmetric (top != bottom or left != right) + pub fn is_asymmetric(&self) -> bool { + match self { + PaddingConfig2d::Explicit(top, left, bottom, right) => top != bottom || left != right, + _ => false, + } + } -impl fmt::Display for PaddingConfig3d { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + /// Returns the padding values as (top, left, bottom, right) tuple + pub fn as_tuple(&self) -> (usize, usize, usize, usize) { match self { - PaddingConfig3d::Valid => write!(f, "Valid"), - PaddingConfig3d::Explicit(width, height, depth) => { - write!(f, "Explicit({width}, {height}, {depth})") - } + PaddingConfig2d::Valid => (0, 0, 0, 0), + PaddingConfig2d::Explicit(top, left, bottom, right) => (*top, *left, *bottom, *right), } } } @@ -112,16 +121,15 @@ impl fmt::Display for PaddingConfig3d { /// /// # Arguments /// -/// * `pads` - The padding values [left, right, top, bottom] +/// * `pads` - The padding values [top, left, bottom, right] (ONNX format) /// /// # Panics /// /// * If the padding is negative -/// * If the padding is not symmetric /// /// # Returns /// -/// * The padding configuration +/// * The padding configuration (Valid or Explicit) /// /// # Remarks /// @@ -132,15 +140,58 @@ pub(crate) fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { if left < 0 || right < 0 || top < 0 || bottom < 0 { panic!("Negative pad values are not supported"); - } else if left != right || top != bottom { - panic!("Asymmetric padding is not supported"); } else if left == 0 && right == 0 && top == 0 && bottom == 0 { PaddingConfig2d::Valid - } else if left == right && top == bottom { - PaddingConfig2d::Explicit(top as usize, left as usize) } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({pads:?}) not supported"); + PaddingConfig2d::Explicit(top as usize, left as usize, bottom as usize, right as usize) + } +} + +/// Padding configuration for 3D operations such as convolution +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum PaddingConfig3d { + /// No padding (valid padding) + #[default] + Valid, + /// Explicit padding with values for each side + /// Format: (front, top, left, back, bottom, right) + /// For symmetric padding, use matching values (e.g., `Explicit(1, 1, 1, 1, 1, 1)`). + Explicit(usize, usize, usize, usize, usize, usize), +} + +impl fmt::Display for PaddingConfig3d { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PaddingConfig3d::Valid => write!(f, "Valid"), + PaddingConfig3d::Explicit(front, top, left, back, bottom, right) => { + write!( + f, + "Explicit({front}, {top}, {left}, {back}, {bottom}, {right})" + ) + } + } + } +} + +impl PaddingConfig3d { + /// Returns true if this padding configuration is asymmetric + pub fn is_asymmetric(&self) -> bool { + match self { + PaddingConfig3d::Explicit(front, top, left, back, bottom, right) => { + front != back || top != bottom || left != right + } + _ => false, + } + } + + /// Returns the padding values as (front, top, left, back, bottom, right) tuple + pub fn as_tuple(&self) -> (usize, usize, usize, usize, usize, usize) { + match self { + PaddingConfig3d::Valid => (0, 0, 0, 0, 0, 0), + PaddingConfig3d::Explicit(front, top, left, back, bottom, right) => { + (*front, *top, *left, *back, *bottom, *right) + } + } } } @@ -148,16 +199,15 @@ pub(crate) fn padding_config_2d(pads: &[i64]) -> PaddingConfig2d { /// /// # Arguments /// -/// * `pads` - The padding values [left, right, top, bottom, front, back] +/// * `pads` - The padding values [front, top, left, back, bottom, right] (ONNX format) /// /// # Panics /// /// * If the padding is negative -/// * If the padding is not symmetric /// /// # Returns /// -/// * The padding configuration +/// * The padding configuration (Valid or Explicit) /// /// # Remarks /// @@ -169,15 +219,17 @@ pub(crate) fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { if left < 0 || right < 0 || top < 0 || bottom < 0 || front < 0 || back < 0 { panic!("Negative pad values are not supported"); - } else if left != right || top != bottom || front != back { - panic!("Asymmetric padding is not supported"); } else if left == 0 && right == 0 && top == 0 && bottom == 0 && front == 0 && back == 0 { PaddingConfig3d::Valid - } else if left == right && top == bottom && front == back { - PaddingConfig3d::Explicit(front as usize, top as usize, left as usize) } else { - // Unaccounted for padding configuration - panic!("Padding configuration ({pads:?}) not supported"); + PaddingConfig3d::Explicit( + front as usize, + top as usize, + left as usize, + back as usize, + bottom as usize, + right as usize, + ) } } @@ -185,25 +237,83 @@ pub(crate) fn padding_config_3d(pads: &[i64]) -> PaddingConfig3d { mod tests { use super::*; + // 1D padding tests + #[test] + fn test_padding_config_1d_valid() { + let pads = vec![0, 0]; + let config = padding_config_1d(&pads); + assert!(matches!(config, PaddingConfig1d::Valid)); + } + + #[test] + fn test_padding_config_1d_explicit_symmetric() { + let pads = vec![2, 2]; + let config = padding_config_1d(&pads); + assert!(matches!(config, PaddingConfig1d::Explicit(2, 2))); + assert!(!config.is_asymmetric()); + assert_eq!(config.as_tuple(), (2, 2)); + } + + #[test] + fn test_padding_config_1d_explicit_asymmetric() { + let pads = vec![1, 2]; + let config = padding_config_1d(&pads); + assert!(matches!(config, PaddingConfig1d::Explicit(1, 2))); + assert!(config.is_asymmetric()); + assert_eq!(config.as_tuple(), (1, 2)); + } + + #[test] + #[should_panic(expected = "Negative pad values are not supported")] + fn test_padding_config_1d_negative() { + let pads = vec![-1, -1]; + let _ = padding_config_1d(&pads); + } + + // 2D padding tests #[test] fn test_padding_config_2d_valid() { let pads = vec![0, 0, 0, 0]; let config = padding_config_2d(&pads); assert!(matches!(config, PaddingConfig2d::Valid)); + assert!(!config.is_asymmetric()); } #[test] - fn test_padding_config_2d_explicit() { + fn test_padding_config_2d_explicit_symmetric() { let pads = vec![2, 2, 2, 2]; let config = padding_config_2d(&pads); - assert!(matches!(config, PaddingConfig2d::Explicit(2, 2))); + assert!(matches!(config, PaddingConfig2d::Explicit(2, 2, 2, 2))); + assert!(!config.is_asymmetric()); + assert_eq!(config.as_tuple(), (2, 2, 2, 2)); } #[test] - #[should_panic(expected = "Asymmetric padding is not supported")] - fn test_padding_config_2d_asymmetric() { - let pads = vec![2, 3, 2, 2]; - let _ = padding_config_2d(&pads); + fn test_padding_config_2d_explicit_asymmetric() { + // pads = [top, left, bottom, right] + let pads = vec![1, 2, 3, 4]; + let config = padding_config_2d(&pads); + assert!(matches!(config, PaddingConfig2d::Explicit(1, 2, 3, 4))); + assert!(config.is_asymmetric()); + assert_eq!(config.as_tuple(), (1, 2, 3, 4)); + } + + #[test] + fn test_padding_config_2d_explicit_asymmetric_top_bottom() { + // top != bottom but left == right + let pads = vec![1, 2, 3, 2]; + let config = padding_config_2d(&pads); + assert!(matches!(config, PaddingConfig2d::Explicit(1, 2, 3, 2))); + assert!(config.is_asymmetric()); + } + + #[test] + fn test_padding_config_2d_explicit_asymmetric_left_right() { + // left != right but top == bottom + let pads = vec![2, 1, 2, 3]; + let config = padding_config_2d(&pads); + assert!(matches!(config, PaddingConfig2d::Explicit(2, 1, 2, 3))); + assert!(config.is_asymmetric()); } #[test] @@ -213,25 +323,50 @@ mod tests { let _ = padding_config_2d(&pads); } + // 3D padding tests #[test] fn test_padding_config_3d_valid() { let pads = vec![0, 0, 0, 0, 0, 0]; let config = padding_config_3d(&pads); assert!(matches!(config, PaddingConfig3d::Valid)); + assert!(!config.is_asymmetric()); } #[test] - fn test_padding_config_3d_explicit() { + fn test_padding_config_3d_explicit_symmetric() { let pads = vec![2, 3, 1, 2, 3, 1]; let config = padding_config_3d(&pads); - assert!(matches!(config, PaddingConfig3d::Explicit(2, 3, 1))); + assert!(matches!( + config, + PaddingConfig3d::Explicit(2, 3, 1, 2, 3, 1) + )); + assert!(!config.is_asymmetric()); + assert_eq!(config.as_tuple(), (2, 3, 1, 2, 3, 1)); } #[test] - #[should_panic(expected = "Asymmetric padding is not supported")] - fn test_padding_config_3d_asymmetric() { - let pads = vec![2, 3, 1, 3, 3, 1]; - let _ = padding_config_3d(&pads); + fn test_padding_config_3d_explicit_asymmetric() { + // pads = [front, top, left, back, bottom, right] + let pads = vec![1, 2, 3, 4, 5, 6]; + let config = padding_config_3d(&pads); + assert!(matches!( + config, + PaddingConfig3d::Explicit(1, 2, 3, 4, 5, 6) + )); + assert!(config.is_asymmetric()); + assert_eq!(config.as_tuple(), (1, 2, 3, 4, 5, 6)); + } + + #[test] + fn test_padding_config_3d_explicit_asymmetric_partial() { + // Only front != back + let pads = vec![1, 3, 1, 2, 3, 1]; + let config = padding_config_3d(&pads); + assert!(matches!( + config, + PaddingConfig3d::Explicit(1, 3, 1, 2, 3, 1) + )); + assert!(config.is_asymmetric()); } #[test]