Skip to content

Commit ba641e0

Browse files
committed
Add support for dynamo based onnx export
1 parent 2ec2f1a commit ba641e0

File tree

2 files changed

+35
-18
lines changed

2 files changed

+35
-18
lines changed

onnx_export.py

+3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
help='Export in training mode (default is eval)')
5858
parser.add_argument('--verbose', default=False, action='store_true',
5959
help='Extra stdout output')
60+
parser.add_argument('--dynamo', default=False, action='store_true',
61+
help='Use torch dynamo export.')
6062

6163
def main():
6264
args = parser.parse_args()
@@ -90,6 +92,7 @@ def main():
9092
check_forward=args.check_forward,
9193
training=args.training,
9294
verbose=args.verbose,
95+
use_dynamo=args.dynamo,
9396
input_size=(3, args.img_size, args.img_size),
9497
batch_size=args.batch_size,
9598
)

timm/utils/onnx.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def onnx_export(
2828
dynamic_size: bool = False,
2929
aten_fallback: bool = False,
3030
keep_initializers: Optional[bool] = None,
31+
use_dynamo: bool = False,
3132
input_names: List[str] = None,
3233
output_names: List[str] = None,
3334
):
@@ -53,7 +54,8 @@ def onnx_export(
5354
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
5455
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
5556
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
56-
original_out = model(example_input)
57+
with torch.no_grad():
58+
original_out = model(example_input)
5759

5860
input_names = input_names or ["input0"]
5961
output_names = output_names or ["output0"]
@@ -68,28 +70,40 @@ def onnx_export(
6870
else:
6971
export_type = torch.onnx.OperatorExportTypes.ONNX
7072

71-
torch_out = torch.onnx._export(
72-
model,
73-
example_input,
74-
output_file,
75-
training=training_mode,
76-
export_params=True,
77-
verbose=verbose,
78-
input_names=input_names,
79-
output_names=output_names,
80-
keep_initializers_as_inputs=keep_initializers,
81-
dynamic_axes=dynamic_axes,
82-
opset_version=opset,
83-
operator_export_type=export_type
84-
)
73+
if use_dynamo:
74+
export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size)
75+
export_output = torch.onnx.dynamo_export(
76+
model,
77+
example_input,
78+
export_options=export_options,
79+
)
80+
export_output.save(output_file)
81+
torch_out = None
82+
else:
83+
torch_out = torch.onnx._export(
84+
model,
85+
example_input,
86+
output_file,
87+
training=training_mode,
88+
export_params=True,
89+
verbose=verbose,
90+
input_names=input_names,
91+
output_names=output_names,
92+
keep_initializers_as_inputs=keep_initializers,
93+
dynamic_axes=dynamic_axes,
94+
opset_version=opset,
95+
operator_export_type=export_type
96+
)
8597

8698
if check:
8799
onnx_model = onnx.load(output_file)
88100
onnx.checker.check_model(onnx_model, full_check=True) # assuming throw on error
89101
if check_forward and not training:
90102
import numpy as np
91103
onnx_out = onnx_forward(output_file, example_input)
92-
np.testing.assert_almost_equal(torch_out.data.numpy(), onnx_out, decimal=3)
93-
np.testing.assert_almost_equal(original_out.data.numpy(), torch_out.data.numpy(), decimal=5)
94-
104+
if torch_out is not None:
105+
np.testing.assert_almost_equal(torch_out.numpy(), onnx_out, decimal=3)
106+
np.testing.assert_almost_equal(original_out.numpy(), torch_out.numpy(), decimal=5)
107+
else:
108+
np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)
95109

0 commit comments

Comments
 (0)