@@ -61,9 +61,11 @@ def forward(self, x):
6161 assert y .shape == (1 , 1 , 32 , 20 )
6262
6363
64+ @pytest .mark .parametrize ("use_pfto" , [False , True ])
6465@pytest .mark .filterwarnings ("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning" )
66+ @pytest .mark .filterwarnings ("ignore:Specified output_names .*:UserWarning" )
6567@pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
66- def test_grad ():
68+ def test_grad (use_pfto : bool ):
6769 if not pytorch_pfn_extras .requires ('1.8.0' ):
6870 pytest .skip ('skip for PyTorch 1.7 or earlier' )
6971
@@ -96,12 +98,13 @@ def forward(self, x):
9698 x ,
9799 'grad' ,
98100 enable_onnx_checker = False ,
99- use_pfto = False ,
101+ use_pfto = use_pfto ,
102+ output_names = ["h" ],
100103 )
101104
102105 actual_onnx = onnx .load (os .path .join (output_dir , 'model.onnx' ))
103106 named_nodes = {n .name : n for n in actual_onnx .graph .node }
104- if pytorch_pfn_extras .requires ("1.13" ):
107+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
105108 assert '/_ppe_as_out_module/conv/Conv' in named_nodes
106109 assert '/_ppe_as_out_module/Gradient' in named_nodes
107110 assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
@@ -111,20 +114,22 @@ def forward(self, x):
111114 assert 'MatMul_6' in named_nodes
112115
113116 assert list ([v .name for v in actual_onnx .graph .output ]) == [
114- "v10_MatMul " , "Gradient_y_0" , "Gradient_x_0_0"
117+ "h " , "Gradient_y_0" , "Gradient_x_0_0"
115118 ]
116119 y_in , _ = _get_name (actual_onnx .graph , "Gradient_y_0" )
117- if pytorch_pfn_extras .requires ("1.13" ):
120+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
118121 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].input [0 ] == "Gradient_x_0_0"
119122 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].output [0 ] == y_in
120123 else :
121124 assert named_nodes ["Conv_2" ].input [0 ] == "Gradient_x_0_0"
122125 assert named_nodes ["Conv_2" ].output [0 ] == y_in
123126
124127
128+ @pytest .mark .parametrize ("use_pfto" , [False , True ])
125129@pytest .mark .filterwarnings ("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning" )
130+ @pytest .mark .filterwarnings ("ignore:Specified output_names .*:UserWarning" )
126131@pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
127- def test_grad_multiple_times ():
132+ def test_grad_multiple_times (use_pfto : bool ):
128133 if not pytorch_pfn_extras .requires ("1.8.0" ):
129134 pytest .skip ('skip for PyTorch 1.7 or earlier' )
130135
@@ -166,12 +171,13 @@ def forward(self, x):
166171 x ,
167172 'grad' ,
168173 enable_onnx_checker = False ,
169- use_pfto = False ,
174+ use_pfto = use_pfto ,
175+ output_names = ["h" ],
170176 )
171177
172178 actual_onnx = onnx .load (os .path .join (output_dir , 'model.onnx' ))
173179 named_nodes = {n .name : n for n in actual_onnx .graph .node }
174- if pytorch_pfn_extras .requires ("1.13" ):
180+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
175181 assert '/_ppe_as_out_module/conv/Conv' in named_nodes
176182 assert '/_ppe_as_out_module/conv_1/Conv' in named_nodes
177183 assert '/_ppe_as_out_module/Gradient' in named_nodes
@@ -185,11 +191,11 @@ def forward(self, x):
185191 assert 'MatMul_12' in named_nodes
186192
187193 assert list ([v .name for v in actual_onnx .graph .output ]) == [
188- "v16_MatMul " , "Gradient_y_0" , "Gradient_x_0_0" , "Gradient_y_1" , "Gradient_x_0_1"
194+ "h " , "Gradient_y_0" , "Gradient_x_0_0" , "Gradient_y_1" , "Gradient_x_0_1"
189195 ]
190196 y0_in , _ = _get_name (actual_onnx .graph , "Gradient_y_0" )
191197 y1_in , _ = _get_name (actual_onnx .graph , "Gradient_y_1" )
192- if pytorch_pfn_extras .requires ("1.13" ):
198+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
193199 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].input [0 ] == "Gradient_x_0_0"
194200 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].output [0 ] == y0_in
195201 assert named_nodes ["/_ppe_as_out_module/conv_1/Conv" ].input [0 ] == "Gradient_x_0_1"
@@ -201,9 +207,11 @@ def forward(self, x):
201207 assert named_nodes ["Conv_7" ].output [0 ] == y1_in
202208
203209
210+ @pytest .mark .parametrize ("use_pfto" , [False , True ])
204211@pytest .mark .filterwarnings ("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning" )
212+ @pytest .mark .filterwarnings ("ignore:Specified output_names .*:UserWarning" )
205213@pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
206- def test_grad_with_multiple_inputs ():
214+ def test_grad_with_multiple_inputs (use_pfto : bool ):
207215 if not pytorch_pfn_extras .requires ("1.8.0" ):
208216 pytest .skip ('skip for PyTorch 1.7 or earlier' )
209217
@@ -238,12 +246,13 @@ def forward(self, x):
238246 x ,
239247 'grad' ,
240248 enable_onnx_checker = False ,
241- use_pfto = False ,
249+ use_pfto = use_pfto ,
250+ output_names = ["h" ],
242251 )
243252
244253 actual_onnx = onnx .load (os .path .join (output_dir , 'model.onnx' ))
245254 named_nodes = {n .name : n for n in actual_onnx .graph .node }
246- if pytorch_pfn_extras .requires ("1.13" ):
255+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
247256 assert '/_ppe_as_out_module/conv/Conv' in named_nodes
248257 assert '/_ppe_as_out_module/Gradient' in named_nodes
249258 assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
@@ -253,10 +262,10 @@ def forward(self, x):
253262 assert 'MatMul_9' in named_nodes
254263
255264 assert list ([v .name for v in actual_onnx .graph .output ]) == [
256- "v14_MatMul " , "Gradient_y_0" , "Gradient_x_0_0" , "Gradient_x_1_0"
265+ "h " , "Gradient_y_0" , "Gradient_x_0_0" , "Gradient_x_1_0"
257266 ]
258267 y_in , _ = _get_name (actual_onnx .graph , "Gradient_y_0" )
259- if pytorch_pfn_extras .requires ("1.13" ):
268+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
260269 assert named_nodes ["/_ppe_as_out_module/Concat" ].input [0 ] == "Gradient_x_0_0"
261270 assert named_nodes ["/_ppe_as_out_module/Concat" ].input [1 ] == "Gradient_x_1_0"
262271 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].output [0 ] == y_in
0 commit comments