Skip to content

Commit a04e7e5

Browse files
author
emcastillo
authored
Merge pull request #666 from take-cheeze/pfto_grad
[onnx] Parameterize onnx grad test with pfto
2 parents bfbc9cb + 0aff084 commit a04e7e5

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

pytorch_pfn_extras/onnx/pfto_exporter/export.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import onnx.shape_inference
1313
import pytorch_pfn_extras
1414
import pytorch_pfn_extras.onnx._constants
15+
from pytorch_pfn_extras.onnx import _grad as grad
1516
from pytorch_pfn_extras.onnx._globals import GLOBALS
1617
from pytorch_pfn_extras.torchscript import run_jit_pass
1718
import torch
@@ -318,26 +319,30 @@ def _restore_state(self) -> None:
318319
if torch.cuda.is_available():
319320
torch.cuda.set_rng_state_all(self.cuda_rng_state)
320321

322+
# TODO(twata): Use `self.traced` instead or use traced result outputs
323+
def _get_original_outputs(self) -> None:
324+
self._restore_state()
325+
with _force_tracing(), grad.init_grad_state():
326+
self.original_outputs = self.original_model(*self.inputs)
327+
self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0])
328+
321329
def _run_trace(self) -> None:
322330
# TODO(twata): Use `torch._C._craete_graph_by_tracing` instead.
323331
# So that we don't need to run heavy models multiple times
324-
self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore
325-
self.original_model,
326-
self.inputs,
327-
check_trace=self.check_trace,
328-
strict=self.strict_trace,
329-
_force_outplace=self.force_outplace_trace,
330-
)
332+
self._restore_state()
333+
with grad.init_grad_state():
334+
self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore
335+
self.original_model,
336+
self.inputs,
337+
check_trace=self.check_trace,
338+
strict=self.strict_trace,
339+
_force_outplace=self.force_outplace_trace,
340+
)
331341

332342
self.graph_doc_string = f"""
333343
# Model: {self.traced.original_name}
334344
"""
335345

336-
# TODO(twata): Use `self.traced` instead or use traced result outputs
337-
self._restore_state()
338-
with _force_tracing():
339-
self.original_outputs = self.original_model(*self.inputs)
340-
self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0])
341346
self.g: torch._C.Graph = self.traced.inlined_graph
342347
"""
343348
`self.trace` ignores the override of `state_dict` method in `self.original_model`.
@@ -1079,6 +1084,7 @@ def _convert(self) -> None:
10791084
sym_hel._set_onnx_shape_inference( # type: ignore[no-untyped-call]
10801085
False # TODO(twata): Use `self.onnx_shape_inference`
10811086
)
1087+
self._get_original_outputs()
10821088
self._run_trace()
10831089
self.model: onnx.ModelProto = self.generate_onnx()
10841090
finally:

tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ def forward(self, x):
6161
assert y.shape == (1, 1, 32, 20)
6262

6363

64+
@pytest.mark.parametrize("use_pfto", [False, True])
6465
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
66+
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
6567
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
66-
def test_grad():
68+
def test_grad(use_pfto: bool):
6769
if not pytorch_pfn_extras.requires('1.8.0'):
6870
pytest.skip('skip for PyTorch 1.7 or earlier')
6971

@@ -96,12 +98,13 @@ def forward(self, x):
9698
x,
9799
'grad',
98100
enable_onnx_checker=False,
99-
use_pfto=False,
101+
use_pfto=use_pfto,
102+
output_names=["h"],
100103
)
101104

102105
actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
103106
named_nodes = {n.name: n for n in actual_onnx.graph.node}
104-
if pytorch_pfn_extras.requires("1.13"):
107+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
105108
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
106109
assert '/_ppe_as_out_module/Gradient' in named_nodes
107110
assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
@@ -111,20 +114,22 @@ def forward(self, x):
111114
assert 'MatMul_6' in named_nodes
112115

113116
assert list([v.name for v in actual_onnx.graph.output]) == [
114-
"v10_MatMul", "Gradient_y_0", "Gradient_x_0_0"
117+
"h", "Gradient_y_0", "Gradient_x_0_0"
115118
]
116119
y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0")
117-
if pytorch_pfn_extras.requires("1.13"):
120+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
118121
assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0"
119122
assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in
120123
else:
121124
assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0"
122125
assert named_nodes["Conv_2"].output[0] == y_in
123126

124127

128+
@pytest.mark.parametrize("use_pfto", [False, True])
125129
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
130+
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
126131
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
127-
def test_grad_multiple_times():
132+
def test_grad_multiple_times(use_pfto: bool):
128133
if not pytorch_pfn_extras.requires("1.8.0"):
129134
pytest.skip('skip for PyTorch 1.7 or earlier')
130135

@@ -166,12 +171,13 @@ def forward(self, x):
166171
x,
167172
'grad',
168173
enable_onnx_checker=False,
169-
use_pfto=False,
174+
use_pfto=use_pfto,
175+
output_names=["h"],
170176
)
171177

172178
actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
173179
named_nodes = {n.name: n for n in actual_onnx.graph.node}
174-
if pytorch_pfn_extras.requires("1.13"):
180+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
175181
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
176182
assert '/_ppe_as_out_module/conv_1/Conv' in named_nodes
177183
assert '/_ppe_as_out_module/Gradient' in named_nodes
@@ -185,11 +191,11 @@ def forward(self, x):
185191
assert 'MatMul_12' in named_nodes
186192

187193
assert list([v.name for v in actual_onnx.graph.output]) == [
188-
"v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1"
194+
"h", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1"
189195
]
190196
y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0")
191197
y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1")
192-
if pytorch_pfn_extras.requires("1.13"):
198+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
193199
assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0"
194200
assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in
195201
assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1"
@@ -201,9 +207,11 @@ def forward(self, x):
201207
assert named_nodes["Conv_7"].output[0] == y1_in
202208

203209

210+
@pytest.mark.parametrize("use_pfto", [False, True])
204211
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
212+
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
205213
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
206-
def test_grad_with_multiple_inputs():
214+
def test_grad_with_multiple_inputs(use_pfto: bool):
207215
if not pytorch_pfn_extras.requires("1.8.0"):
208216
pytest.skip('skip for PyTorch 1.7 or earlier')
209217

@@ -238,12 +246,13 @@ def forward(self, x):
238246
x,
239247
'grad',
240248
enable_onnx_checker=False,
241-
use_pfto=False,
249+
use_pfto=use_pfto,
250+
output_names=["h"],
242251
)
243252

244253
actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
245254
named_nodes = {n.name: n for n in actual_onnx.graph.node}
246-
if pytorch_pfn_extras.requires("1.13"):
255+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
247256
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
248257
assert '/_ppe_as_out_module/Gradient' in named_nodes
249258
assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
@@ -253,10 +262,10 @@ def forward(self, x):
253262
assert 'MatMul_9' in named_nodes
254263

255264
assert list([v.name for v in actual_onnx.graph.output]) == [
256-
"v14_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0"
265+
"h", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0"
257266
]
258267
y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0")
259-
if pytorch_pfn_extras.requires("1.13"):
268+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
260269
assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0"
261270
assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0"
262271
assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in

0 commit comments

Comments
 (0)