Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ fn main() {
.input("tests/avg_pool1d_ceil_mode/avg_pool1d_ceil_mode.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/avg_pool2d_ceil_mode/avg_pool2d_ceil_mode.onnx")
.input("tests/avg_pool/avg_pool1d_asymmetric_padding.onnx")
.input("tests/avg_pool/avg_pool2d_asymmetric_padding.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/bitshift/bitshift_left.onnx")
.input("tests/bitshift/bitshift_left_scalar.onnx")
Expand Down Expand Up @@ -101,6 +103,8 @@ fn main() {
.input("tests/conv1d/conv1d.onnx")
.input("tests/conv2d/conv2d.onnx")
.input("tests/conv3d/conv3d.onnx")
.input("tests/conv/conv1d_asymmetric_padding.onnx")
.input("tests/conv/conv2d_asymmetric_padding.onnx")
.input("tests/conv_transpose1d/conv_transpose1d.onnx")
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/conv_transpose3d/conv_transpose3d.onnx")
Expand Down Expand Up @@ -237,6 +241,8 @@ fn main() {
.input("tests/maxpool1d_ceil_mode/maxpool1d_ceil_mode.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/maxpool2d_ceil_mode/maxpool2d_ceil_mode.onnx")
.input("tests/maxpool/maxpool1d_asymmetric_padding.onnx")
.input("tests/maxpool/maxpool2d_asymmetric_padding.onnx")
.input("tests/min/min.onnx")
.input("tests/mean/mean.onnx")
.input("tests/mul/mul.onnx")
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies = [
"torch>=2.3.1",
"onnx>=1.16.1",
"onnxruntime>=1.18.0",
"onnxscript>=0.1.0",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

# used to generate model: avg_pool1d_asymmetric_padding.onnx

import numpy as np
import onnx
from onnx import helper, TensorProto
from onnx.reference import ReferenceEvaluator


def main():
# Input: [batch=2, channels=4, width=10]
# Asymmetric padding: left=1, right=2
# kernel=3, stride=1
# Output width = (10 + 1 + 2 - 3) / 1 + 1 = 11

X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4, 10])
Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 4, 11])

# Create AveragePool node with asymmetric padding (left=1, right=2)
# ONNX pads format for 1D: [start, end] = [left, right]
avg_pool_node = helper.make_node(
"AveragePool",
inputs=["x"],
outputs=["y"],
kernel_shape=[3],
strides=[1],
pads=[1, 2], # [left, right] asymmetric padding
count_include_pad=1, # Include padding in average calculation
)

graph = helper.make_graph(
[avg_pool_node],
"avg_pool1d_asymmetric_padding",
[X],
[Y],
)

model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
model.ir_version = 8

onnx.checker.check_model(model)
file_name = "avg_pool1d_asymmetric_padding.onnx"
onnx.save(model, file_name)

print("Finished exporting model to {}".format(file_name))
print("Ops in graph: {}".format([n.op_type for n in model.graph.node]))

# Verify with ReferenceEvaluator
test_input = np.ones((2, 4, 10), dtype=np.float32)
ref = ReferenceEvaluator(file_name)
ref_output = ref.run(None, {"x": test_input})[0]

print("Test input shape: {}".format(test_input.shape))
print("Test output shape: {}".format(ref_output.shape))
print("ReferenceEvaluator output sum: {}".format(ref_output.sum()))


if __name__ == "__main__":
main()
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3

# used to generate model: avg_pool2d_asymmetric_padding.onnx

import numpy as np
import onnx
from onnx import helper, TensorProto
from onnx.reference import ReferenceEvaluator


def main():
# Input: [batch=2, channels=4, height=10, width=15]
# Asymmetric padding: top=1, left=1, bottom=2, right=2
# kernel=[3,3], stride=[1,1]
# Output height = (10 + 1 + 2 - 3) / 1 + 1 = 11
# Output width = (15 + 1 + 2 - 3) / 1 + 1 = 16

X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4, 10, 15])
Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 4, 11, 16])

# Create AveragePool node with asymmetric padding
# ONNX pads format for 2D: [top, left, bottom, right]
avg_pool_node = helper.make_node(
"AveragePool",
inputs=["x"],
outputs=["y"],
kernel_shape=[3, 3],
strides=[1, 1],
pads=[1, 1, 2, 2], # [top, left, bottom, right] asymmetric padding
count_include_pad=1, # Include padding in average calculation
)

graph = helper.make_graph(
[avg_pool_node],
"avg_pool2d_asymmetric_padding",
[X],
[Y],
)

model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
model.ir_version = 8

onnx.checker.check_model(model)
file_name = "avg_pool2d_asymmetric_padding.onnx"
onnx.save(model, file_name)

print("Finished exporting model to {}".format(file_name))
print("Ops in graph: {}".format([n.op_type for n in model.graph.node]))

# Verify with ReferenceEvaluator
test_input = np.ones((2, 4, 10, 15), dtype=np.float32)
ref = ReferenceEvaluator(file_name)
ref_output = ref.run(None, {"x": test_input})[0]

print("Test input shape: {}".format(test_input.shape))
print("Test output shape: {}".format(ref_output.shape))
print("ReferenceEvaluator output sum: {}".format(ref_output.sum()))


if __name__ == "__main__":
main()
56 changes: 56 additions & 0 deletions crates/burn-import/onnx-tests/tests/avg_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
use crate::include_models;
include_models!(
avg_pool1d,
avg_pool1d_asymmetric_padding,
avg_pool1d_ceil_mode,
avg_pool2d,
avg_pool2d_asymmetric_padding,
avg_pool2d_ceil_mode
);

Expand Down Expand Up @@ -181,4 +183,58 @@ mod tests {
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
}

#[test]
fn avg_pool1d_asymmetric_padding() {
// Test asymmetric padding (left=1, right=2) for AvgPool1d
let device = Default::default();
let model: avg_pool1d_asymmetric_padding::Model<TestBackend> =
avg_pool1d_asymmetric_padding::Model::new(&device);

// Run the model with ones as input for easier testing
let input = Tensor::<TestBackend, 3>::ones([2, 4, 10], &device);
let output = model.forward(input);

// With asymmetric padding (1, 2), input length 10 becomes 10+1+2=13
// After pool with kernel 3, stride 1, output length is 13-3+1=11
let expected_shape = Shape::from([2, 4, 11]);
assert_eq!(output.shape(), expected_shape);

// Verify the sum matches PyTorch output
let output_sum = output.sum().into_scalar();
let expected_sum = 77.333_33; // from pytorch
assert!(
(output_sum - expected_sum).abs() < 0.1,
"Expected sum ~{}, got {}",
expected_sum,
output_sum
);
}

#[test]
fn avg_pool2d_asymmetric_padding() {
// Test asymmetric padding (left=1, right=2, top=1, bottom=2) for AvgPool2d
let device = Default::default();
let model: avg_pool2d_asymmetric_padding::Model<TestBackend> =
avg_pool2d_asymmetric_padding::Model::new(&device);

// Run the model with ones as input for easier testing
let input = Tensor::<TestBackend, 4>::ones([2, 4, 10, 15], &device);
let output = model.forward(input);

// With asymmetric padding (1, 1, 2, 2), input (10, 15) becomes (13, 18)
// After pool with kernel (3, 3), stride (1, 1), output is (11, 16)
let expected_shape = Shape::from([2, 4, 11, 16]);
assert_eq!(output.shape(), expected_shape);

// Verify the sum matches ReferenceEvaluator output
let output_sum = output.sum().into_scalar();
let expected_sum = 1134.222; // from ReferenceEvaluator
assert!(
(output_sum - expected_sum).abs() < 1.0,
"Expected sum ~{}, got {}",
expected_sum,
output_sum
);
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python3

# used to generate model: conv1d_asymmetric_padding.onnx

import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
from onnx.reference import ReferenceEvaluator

# must set for testing against crate
torch.manual_seed(0)


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# Create a Conv1d without padding - we'll apply asymmetric padding manually
self.conv1 = nn.Conv1d(4, 6, kernel_size=3, stride=1, padding=0)

def forward(self, x):
# Apply asymmetric padding: (left=1, right=2)
# PyTorch F.pad takes (left, right) for 1D
x = F.pad(x, (1, 2), mode='constant', value=0)
x = self.conv1(x)
return x


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "conv1d_asymmetric_padding.onnx"
test_input = torch.ones(2, 4, 10, device=device)

# Export with dynamo exporter (opset 18)
torch.onnx.export(model, test_input, file_name, verbose=False, opset_version=18)

# Load model and convert external data to embedded
onnx_model = onnx.load(file_name, load_external_data=True)
# Save with all data embedded
onnx.save(onnx_model, file_name, save_as_external_data=False)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

# Verify with ONNX ReferenceEvaluator
ref = ReferenceEvaluator(file_name)
ref_output = ref.run(None, {"x": test_input.numpy()})[0]

output_sum = output.sum().item()
ref_sum = ref_output.sum()

print("PyTorch output sum: {}".format(output_sum))
print("ReferenceEvaluator output sum: {}".format(ref_sum))


if __name__ == "__main__":
main()
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python3

# used to generate model: conv2d_asymmetric_padding.onnx

import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
from onnx.reference import ReferenceEvaluator

# must set for testing against crate
torch.manual_seed(0)


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# Create a Conv2d without padding - we'll apply asymmetric padding manually
self.conv1 = nn.Conv2d(4, 6, kernel_size=3, stride=1, padding=0)

def forward(self, x):
# Apply asymmetric padding: (left=1, right=2, top=1, bottom=3)
# PyTorch F.pad takes (left, right, top, bottom) for 2D
x = F.pad(x, (1, 2, 1, 3), mode='constant', value=0)
x = self.conv1(x)
return x


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "conv2d_asymmetric_padding.onnx"
test_input = torch.ones(2, 4, 10, 15, device=device)

# Export with dynamo exporter (opset 18)
torch.onnx.export(model, test_input, file_name, verbose=False, opset_version=18)

# Load model and convert external data to embedded
onnx_model = onnx.load(file_name, load_external_data=True)
# Save with all data embedded
onnx.save(onnx_model, file_name, save_as_external_data=False)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

# Verify with ONNX ReferenceEvaluator
ref = ReferenceEvaluator(file_name)
ref_output = ref.run(None, {"x": test_input.numpy()})[0]

output_sum = output.sum().item()
ref_sum = ref_output.sum()

print("PyTorch output sum: {}".format(output_sum))
print("ReferenceEvaluator output sum: {}".format(ref_sum))


if __name__ == "__main__":
main()
Loading
Loading