@@ -28,6 +28,7 @@ def onnx_export(
28
28
dynamic_size : bool = False ,
29
29
aten_fallback : bool = False ,
30
30
keep_initializers : Optional [bool ] = None ,
31
+ use_dynamo : bool = False ,
31
32
input_names : List [str ] = None ,
32
33
output_names : List [str ] = None ,
33
34
):
@@ -53,7 +54,8 @@ def onnx_export(
53
54
# Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
54
55
# issues in the tracing of the dynamic padding or errors attempting to export the model after jit
55
56
# scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
56
- original_out = model (example_input )
57
+ with torch .no_grad ():
58
+ original_out = model (example_input )
57
59
58
60
input_names = input_names or ["input0" ]
59
61
output_names = output_names or ["output0" ]
@@ -68,28 +70,40 @@ def onnx_export(
68
70
else :
69
71
export_type = torch .onnx .OperatorExportTypes .ONNX
70
72
71
- torch_out = torch .onnx ._export (
72
- model ,
73
- example_input ,
74
- output_file ,
75
- training = training_mode ,
76
- export_params = True ,
77
- verbose = verbose ,
78
- input_names = input_names ,
79
- output_names = output_names ,
80
- keep_initializers_as_inputs = keep_initializers ,
81
- dynamic_axes = dynamic_axes ,
82
- opset_version = opset ,
83
- operator_export_type = export_type
84
- )
73
+ if use_dynamo :
74
+ export_options = torch .onnx .ExportOptions (dynamic_shapes = dynamic_size )
75
+ export_output = torch .onnx .dynamo_export (
76
+ model ,
77
+ example_input ,
78
+ export_options = export_options ,
79
+ )
80
+ export_output .save (output_file )
81
+ torch_out = None
82
+ else :
83
+ torch_out = torch .onnx ._export (
84
+ model ,
85
+ example_input ,
86
+ output_file ,
87
+ training = training_mode ,
88
+ export_params = True ,
89
+ verbose = verbose ,
90
+ input_names = input_names ,
91
+ output_names = output_names ,
92
+ keep_initializers_as_inputs = keep_initializers ,
93
+ dynamic_axes = dynamic_axes ,
94
+ opset_version = opset ,
95
+ operator_export_type = export_type
96
+ )
85
97
86
98
if check :
87
99
onnx_model = onnx .load (output_file )
88
100
onnx .checker .check_model (onnx_model , full_check = True ) # assuming throw on error
89
101
if check_forward and not training :
90
102
import numpy as np
91
103
onnx_out = onnx_forward (output_file , example_input )
92
- np .testing .assert_almost_equal (torch_out .data .numpy (), onnx_out , decimal = 3 )
93
- np .testing .assert_almost_equal (original_out .data .numpy (), torch_out .data .numpy (), decimal = 5 )
94
-
104
+ if torch_out is not None :
105
+ np .testing .assert_almost_equal (torch_out .numpy (), onnx_out , decimal = 3 )
106
+ np .testing .assert_almost_equal (original_out .numpy (), torch_out .numpy (), decimal = 5 )
107
+ else :
108
+ np .testing .assert_almost_equal (original_out .numpy (), onnx_out , decimal = 3 )
95
109
0 commit comments