@@ -61,9 +61,11 @@ def forward(self, x):
6161 assert y .shape == (1 , 1 , 32 , 20 )
6262
6363
64+ @pytest .mark .parametrize ("use_pfto" , [False , True ])
6465@pytest .mark .filterwarnings ("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning" )
66+ @pytest .mark .filterwarnings ("ignore:Specified output_names .*:UserWarning" )
6567@pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
66- def test_grad ():
68+ def test_grad (use_pfto : bool ):
6769 if not pytorch_pfn_extras .requires ('1.8.0' ):
6870 pytest .skip ('skip for PyTorch 1.7 or earlier' )
6971
@@ -96,13 +98,14 @@ def forward(self, x):
9698 x ,
9799 'grad' ,
98100 enable_onnx_checker = False ,
99- use_pfto = False ,
101+ use_pfto = use_pfto ,
102+ output_names = ["h" ],
100103 )
101104
102105 actual_onnx = onnx .load (os .path .join (output_dir , 'model.onnx' ))
103106 print (actual_onnx )
104107 named_nodes = {n .name : n for n in actual_onnx .graph .node }
105- if pytorch_pfn_extras .requires ("1.13" ):
108+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
106109 assert '/_ppe_as_out_module/conv/Conv' in named_nodes
107110 assert '/_ppe_as_out_module/Gradient' in named_nodes
108111 assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
@@ -112,20 +115,22 @@ def forward(self, x):
112115 assert 'MatMul_6' in named_nodes
113116
114117 assert list ([v .name for v in actual_onnx .graph .output ]) == [
115- "v10_MatMul " , "Gradient_y_0" , "Gradient_x_0_0"
118+ "h " , "Gradient_y_0" , "Gradient_x_0_0"
116119 ]
117120 y_in , _ = _get_name (actual_onnx .graph , "Gradient_y_0" )
118- if pytorch_pfn_extras .requires ("1.13" ):
121+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
119122 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].input [0 ] == "Gradient_x_0_0"
120123 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].output [0 ] == y_in
121124 else :
122125 assert named_nodes ["Conv_2" ].input [0 ] == "Gradient_x_0_0"
123126 assert named_nodes ["Conv_2" ].output [0 ] == y_in
124127
125128
129+ @pytest .mark .parametrize ("use_pfto" , [False , True ])
126130@pytest .mark .filterwarnings ("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning" )
131+ @pytest .mark .filterwarnings ("ignore:Specified output_names .*:UserWarning" )
127132@pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
128- def test_grad_multiple_times ():
133+ def test_grad_multiple_times (use_pfto : bool ):
129134 if not pytorch_pfn_extras .requires ("1.8.0" ):
130135 pytest .skip ('skip for PyTorch 1.7 or earlier' )
131136
@@ -167,12 +172,13 @@ def forward(self, x):
167172 x ,
168173 'grad' ,
169174 enable_onnx_checker = False ,
170- use_pfto = False ,
175+ use_pfto = use_pfto ,
176+ output_names = ["h" ],
171177 )
172178
173179 actual_onnx = onnx .load (os .path .join (output_dir , 'model.onnx' ))
174180 named_nodes = {n .name : n for n in actual_onnx .graph .node }
175- if pytorch_pfn_extras .requires ("1.13" ):
181+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
176182 assert '/_ppe_as_out_module/conv/Conv' in named_nodes
177183 assert '/_ppe_as_out_module/conv_1/Conv' in named_nodes
178184 assert '/_ppe_as_out_module/Gradient' in named_nodes
@@ -186,11 +192,11 @@ def forward(self, x):
186192 assert 'MatMul_12' in named_nodes
187193
188194 assert list ([v .name for v in actual_onnx .graph .output ]) == [
189- "v16_MatMul " , "Gradient_y_0" , "Gradient_x_0_0" , "Gradient_y_1" , "Gradient_x_0_1"
195+ "h " , "Gradient_y_0" , "Gradient_x_0_0" , "Gradient_y_1" , "Gradient_x_0_1"
190196 ]
191197 y0_in , _ = _get_name (actual_onnx .graph , "Gradient_y_0" )
192198 y1_in , _ = _get_name (actual_onnx .graph , "Gradient_y_1" )
193- if pytorch_pfn_extras .requires ("1.13" ):
199+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
194200 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].input [0 ] == "Gradient_x_0_0"
195201 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].output [0 ] == y0_in
196202 assert named_nodes ["/_ppe_as_out_module/conv_1/Conv" ].input [0 ] == "Gradient_x_0_1"
@@ -202,9 +208,11 @@ def forward(self, x):
202208 assert named_nodes ["Conv_7" ].output [0 ] == y1_in
203209
204210
211+ @pytest .mark .parametrize ("use_pfto" , [False , True ])
205212@pytest .mark .filterwarnings ("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning" )
213+ @pytest .mark .filterwarnings ("ignore:Specified output_names .*:UserWarning" )
206214@pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
207- def test_grad_with_multiple_inputs ():
215+ def test_grad_with_multiple_inputs (use_pfto : bool ):
208216 if not pytorch_pfn_extras .requires ("1.8.0" ):
209217 pytest .skip ('skip for PyTorch 1.7 or earlier' )
210218
@@ -239,12 +247,13 @@ def forward(self, x):
239247 x ,
240248 'grad' ,
241249 enable_onnx_checker = False ,
242- use_pfto = False ,
250+ use_pfto = use_pfto ,
251+ output_names = ["h" ],
243252 )
244253
245254 actual_onnx = onnx .load (os .path .join (output_dir , 'model.onnx' ))
246255 named_nodes = {n .name : n for n in actual_onnx .graph .node }
247- if pytorch_pfn_extras .requires ("1.13" ):
256+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
248257 assert '/_ppe_as_out_module/conv/Conv' in named_nodes
249258 assert '/_ppe_as_out_module/Gradient' in named_nodes
250259 assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
@@ -254,10 +263,10 @@ def forward(self, x):
254263 assert 'MatMul_9' in named_nodes
255264
256265 assert list ([v .name for v in actual_onnx .graph .output ]) == [
257- "v14_MatMul " , "Gradient_y_0" , "Gradient_x_0_0" , "Gradient_x_1_0"
266+ "h " , "Gradient_y_0" , "Gradient_x_0_0" , "Gradient_x_1_0"
258267 ]
259268 y_in , _ = _get_name (actual_onnx .graph , "Gradient_y_0" )
260- if pytorch_pfn_extras .requires ("1.13" ):
269+ if pytorch_pfn_extras .requires ("1.13" ) and not use_pfto :
261270 assert named_nodes ["/_ppe_as_out_module/Concat" ].input [0 ] == "Gradient_x_0_0"
262271 assert named_nodes ["/_ppe_as_out_module/Concat" ].input [1 ] == "Gradient_x_1_0"
263272 assert named_nodes ["/_ppe_as_out_module/conv/Conv" ].output [0 ] == y_in
0 commit comments