Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ fn main() {
.input("tests/avg_pool1d_ceil_mode/avg_pool1d_ceil_mode.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/avg_pool2d_ceil_mode/avg_pool2d_ceil_mode.onnx")
.input("tests/avg_pool/avg_pool1d_asymmetric_padding.onnx")
.input("tests/avg_pool/avg_pool2d_asymmetric_padding.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/bitshift/bitshift_left.onnx")
.input("tests/bitshift/bitshift_left_scalar.onnx")
Expand Down Expand Up @@ -101,6 +103,8 @@ fn main() {
.input("tests/conv1d/conv1d.onnx")
.input("tests/conv2d/conv2d.onnx")
.input("tests/conv3d/conv3d.onnx")
.input("tests/conv/conv1d_asymmetric_padding.onnx")
.input("tests/conv/conv2d_asymmetric_padding.onnx")
.input("tests/conv_transpose1d/conv_transpose1d.onnx")
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/conv_transpose3d/conv_transpose3d.onnx")
Expand Down Expand Up @@ -237,6 +241,8 @@ fn main() {
.input("tests/maxpool1d_ceil_mode/maxpool1d_ceil_mode.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/maxpool2d_ceil_mode/maxpool2d_ceil_mode.onnx")
.input("tests/maxpool/maxpool1d_asymmetric_padding.onnx")
.input("tests/maxpool/maxpool2d_asymmetric_padding.onnx")
.input("tests/min/min.onnx")
.input("tests/mean/mean.onnx")
.input("tests/mul/mul.onnx")
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies = [
"torch>=2.3.1",
"onnx>=1.16.1",
"onnxruntime>=1.18.0",
"onnxscript>=0.1.0",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

# used to generate model: avg_pool1d_asymmetric_padding.onnx

import numpy as np
import onnx
from onnx import helper, TensorProto
from onnx.reference import ReferenceEvaluator


def main():
# Input: [batch=2, channels=4, width=10]
# Asymmetric padding: left=1, right=2
# kernel=3, stride=1
# Output width = (10 + 1 + 2 - 3) / 1 + 1 = 11

X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4, 10])
Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 4, 11])

# Create AveragePool node with asymmetric padding (left=1, right=2)
# ONNX pads format for 1D: [start, end] = [left, right]
avg_pool_node = helper.make_node(
"AveragePool",
inputs=["x"],
outputs=["y"],
kernel_shape=[3],
strides=[1],
pads=[1, 2], # [left, right] asymmetric padding
count_include_pad=1, # Include padding in average calculation
)

graph = helper.make_graph(
[avg_pool_node],
"avg_pool1d_asymmetric_padding",
[X],
[Y],
)

model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
model.ir_version = 8

onnx.checker.check_model(model)
file_name = "avg_pool1d_asymmetric_padding.onnx"
onnx.save(model, file_name)

print("Finished exporting model to {}".format(file_name))
print("Ops in graph: {}".format([n.op_type for n in model.graph.node]))

# Verify with ReferenceEvaluator
test_input = np.ones((2, 4, 10), dtype=np.float32)
ref = ReferenceEvaluator(file_name)
ref_output = ref.run(None, {"x": test_input})[0]

print("Test input shape: {}".format(test_input.shape))
print("Test output shape: {}".format(ref_output.shape))
print("ReferenceEvaluator output sum: {}".format(ref_output.sum()))


if __name__ == "__main__":
main()
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3

# used to generate model: avg_pool2d_asymmetric_padding.onnx

import numpy as np
import onnx
from onnx import helper, TensorProto
from onnx.reference import ReferenceEvaluator


def main():
# Input: [batch=2, channels=4, height=10, width=15]
# Asymmetric padding: top=1, left=1, bottom=2, right=2
# kernel=[3,3], stride=[1,1]
# Output height = (10 + 1 + 2 - 3) / 1 + 1 = 11
# Output width = (15 + 1 + 2 - 3) / 1 + 1 = 16

X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4, 10, 15])
Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 4, 11, 16])

# Create AveragePool node with asymmetric padding
# ONNX pads format for 2D: [top, left, bottom, right]
avg_pool_node = helper.make_node(
"AveragePool",
inputs=["x"],
outputs=["y"],
kernel_shape=[3, 3],
strides=[1, 1],
pads=[1, 1, 2, 2], # [top, left, bottom, right] asymmetric padding
count_include_pad=1, # Include padding in average calculation
)

graph = helper.make_graph(
[avg_pool_node],
"avg_pool2d_asymmetric_padding",
[X],
[Y],
)

model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
model.ir_version = 8

onnx.checker.check_model(model)
file_name = "avg_pool2d_asymmetric_padding.onnx"
onnx.save(model, file_name)

print("Finished exporting model to {}".format(file_name))
print("Ops in graph: {}".format([n.op_type for n in model.graph.node]))

# Verify with ReferenceEvaluator
test_input = np.ones((2, 4, 10, 15), dtype=np.float32)
ref = ReferenceEvaluator(file_name)
ref_output = ref.run(None, {"x": test_input})[0]

print("Test input shape: {}".format(test_input.shape))
print("Test output shape: {}".format(ref_output.shape))
print("ReferenceEvaluator output sum: {}".format(ref_output.sum()))


if __name__ == "__main__":
main()
56 changes: 56 additions & 0 deletions crates/burn-import/onnx-tests/tests/avg_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
use crate::include_models;
include_models!(
avg_pool1d,
avg_pool1d_asymmetric_padding,
avg_pool1d_ceil_mode,
avg_pool2d,
avg_pool2d_asymmetric_padding,
avg_pool2d_ceil_mode
);

Expand Down Expand Up @@ -181,4 +183,58 @@ mod tests {
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
}

#[test]
fn avg_pool1d_asymmetric_padding() {
// Test asymmetric padding (left=1, right=2) for AvgPool1d
let device = Default::default();
let model: avg_pool1d_asymmetric_padding::Model<TestBackend> =
avg_pool1d_asymmetric_padding::Model::new(&device);

// Run the model with ones as input for easier testing
let input = Tensor::<TestBackend, 3>::ones([2, 4, 10], &device);
let output = model.forward(input);

// With asymmetric padding (1, 2), input length 10 becomes 10+1+2=13
// After pool with kernel 3, stride 1, output length is 13-3+1=11
let expected_shape = Shape::from([2, 4, 11]);
assert_eq!(output.shape(), expected_shape);

// Verify the sum matches PyTorch output
let output_sum = output.sum().into_scalar();
let expected_sum = 77.333_33; // from pytorch
assert!(
(output_sum - expected_sum).abs() < 0.1,
"Expected sum ~{}, got {}",
expected_sum,
output_sum
);
}

#[test]
fn avg_pool2d_asymmetric_padding() {
// Test asymmetric padding (left=1, right=2, top=1, bottom=2) for AvgPool2d
let device = Default::default();
let model: avg_pool2d_asymmetric_padding::Model<TestBackend> =
avg_pool2d_asymmetric_padding::Model::new(&device);

// Run the model with ones as input for easier testing
let input = Tensor::<TestBackend, 4>::ones([2, 4, 10, 15], &device);
let output = model.forward(input);

// With asymmetric padding (1, 1, 2, 2), input (10, 15) becomes (13, 18)
// After pool with kernel (3, 3), stride (1, 1), output is (11, 16)
let expected_shape = Shape::from([2, 4, 11, 16]);
assert_eq!(output.shape(), expected_shape);

// Verify the sum matches ReferenceEvaluator output
let output_sum = output.sum().into_scalar();
let expected_sum = 1134.222; // from ReferenceEvaluator
assert!(
(output_sum - expected_sum).abs() < 1.0,
"Expected sum ~{}, got {}",
expected_sum,
output_sum
);
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python3

# used to generate model: conv1d_asymmetric_padding.onnx

import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
from onnx.reference import ReferenceEvaluator

# must set for testing against crate
torch.manual_seed(0)


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# Create a Conv1d without padding - we'll apply asymmetric padding manually
self.conv1 = nn.Conv1d(4, 6, kernel_size=3, stride=1, padding=0)

def forward(self, x):
# Apply asymmetric padding: (left=1, right=2)
# PyTorch F.pad takes (left, right) for 1D
x = F.pad(x, (1, 2), mode='constant', value=0)
x = self.conv1(x)
return x


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "conv1d_asymmetric_padding.onnx"
test_input = torch.ones(2, 4, 10, device=device)

# Export with dynamo exporter (opset 18)
torch.onnx.export(model, test_input, file_name, verbose=False, opset_version=18)

# Load model and convert external data to embedded
onnx_model = onnx.load(file_name, load_external_data=True)
# Save with all data embedded
onnx.save(onnx_model, file_name, save_as_external_data=False)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

# Verify with ONNX ReferenceEvaluator
ref = ReferenceEvaluator(file_name)
ref_output = ref.run(None, {"x": test_input.numpy()})[0]

output_sum = output.sum().item()
ref_sum = ref_output.sum()

print("PyTorch output sum: {}".format(output_sum))
print("ReferenceEvaluator output sum: {}".format(ref_sum))


if __name__ == "__main__":
main()
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python3

# used to generate model: conv2d_asymmetric_padding.onnx

import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
from onnx.reference import ReferenceEvaluator

# must set for testing against crate
torch.manual_seed(0)


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# Create a Conv2d without padding - we'll apply asymmetric padding manually
self.conv1 = nn.Conv2d(4, 6, kernel_size=3, stride=1, padding=0)

def forward(self, x):
# Apply asymmetric padding: (left=1, right=2, top=1, bottom=3)
# PyTorch F.pad takes (left, right, top, bottom) for 2D
x = F.pad(x, (1, 2, 1, 3), mode='constant', value=0)
x = self.conv1(x)
return x


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "conv2d_asymmetric_padding.onnx"
test_input = torch.ones(2, 4, 10, 15, device=device)

# Export with dynamo exporter (opset 18)
torch.onnx.export(model, test_input, file_name, verbose=False, opset_version=18)

# Load model and convert external data to embedded
onnx_model = onnx.load(file_name, load_external_data=True)
# Save with all data embedded
onnx.save(onnx_model, file_name, save_as_external_data=False)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

# Verify with ONNX ReferenceEvaluator
ref = ReferenceEvaluator(file_name)
ref_output = ref.run(None, {"x": test_input.numpy()})[0]

output_sum = output.sum().item()
ref_sum = ref_output.sum()

print("PyTorch output sum: {}".format(output_sum))
print("ReferenceEvaluator output sum: {}".format(ref_sum))


if __name__ == "__main__":
main()
Loading
Loading