Skip to content

Commit 71bc4f8

Browse files
authored
Merge branch 'master' into grad_domain
2 parents 8995e70 + a04e7e5 commit 71bc4f8

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

pytorch_pfn_extras/onnx/pfto_exporter/export.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import onnx.shape_inference
1313
import pytorch_pfn_extras
1414
import pytorch_pfn_extras.onnx._constants
15+
from pytorch_pfn_extras.onnx import _grad as grad
1516
from pytorch_pfn_extras.onnx._globals import GLOBALS
1617
from pytorch_pfn_extras.torchscript import run_jit_pass
1718
import torch
@@ -318,26 +319,30 @@ def _restore_state(self) -> None:
318319
if torch.cuda.is_available():
319320
torch.cuda.set_rng_state_all(self.cuda_rng_state)
320321

322+
# TODO(twata): Use `self.traced` instead or use traced result outputs
323+
def _get_original_outputs(self) -> None:
324+
self._restore_state()
325+
with _force_tracing(), grad.init_grad_state():
326+
self.original_outputs = self.original_model(*self.inputs)
327+
self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0])
328+
321329
def _run_trace(self) -> None:
322330
# TODO(twata): Use `torch._C._craete_graph_by_tracing` instead.
323331
# So that we don't need to run heavy models multiple times
324-
self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore
325-
self.original_model,
326-
self.inputs,
327-
check_trace=self.check_trace,
328-
strict=self.strict_trace,
329-
_force_outplace=self.force_outplace_trace,
330-
)
332+
self._restore_state()
333+
with grad.init_grad_state():
334+
self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore
335+
self.original_model,
336+
self.inputs,
337+
check_trace=self.check_trace,
338+
strict=self.strict_trace,
339+
_force_outplace=self.force_outplace_trace,
340+
)
331341

332342
self.graph_doc_string = f"""
333343
# Model: {self.traced.original_name}
334344
"""
335345

336-
# TODO(twata): Use `self.traced` instead or use traced result outputs
337-
self._restore_state()
338-
with _force_tracing():
339-
self.original_outputs = self.original_model(*self.inputs)
340-
self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0])
341346
self.g: torch._C.Graph = self.traced.inlined_graph
342347
"""
343348
`self.trace` ignores the override of `state_dict` method in `self.original_model`.
@@ -1079,6 +1084,7 @@ def _convert(self) -> None:
10791084
sym_hel._set_onnx_shape_inference( # type: ignore[no-untyped-call]
10801085
False # TODO(twata): Use `self.onnx_shape_inference`
10811086
)
1087+
self._get_original_outputs()
10821088
self._run_trace()
10831089
self.model: onnx.ModelProto = self.generate_onnx()
10841090
finally:

tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ def forward(self, x):
6161
assert y.shape == (1, 1, 32, 20)
6262

6363

64+
@pytest.mark.parametrize("use_pfto", [False, True])
6465
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
66+
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
6567
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
66-
def test_grad():
68+
def test_grad(use_pfto: bool):
6769
if not pytorch_pfn_extras.requires('1.8.0'):
6870
pytest.skip('skip for PyTorch 1.7 or earlier')
6971

@@ -96,13 +98,14 @@ def forward(self, x):
9698
x,
9799
'grad',
98100
enable_onnx_checker=False,
99-
use_pfto=False,
101+
use_pfto=use_pfto,
102+
output_names=["h"],
100103
)
101104

102105
actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
103106
print(actual_onnx)
104107
named_nodes = {n.name: n for n in actual_onnx.graph.node}
105-
if pytorch_pfn_extras.requires("1.13"):
108+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
106109
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
107110
assert '/_ppe_as_out_module/Gradient' in named_nodes
108111
assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
@@ -112,20 +115,22 @@ def forward(self, x):
112115
assert 'MatMul_6' in named_nodes
113116

114117
assert list([v.name for v in actual_onnx.graph.output]) == [
115-
"v10_MatMul", "Gradient_y_0", "Gradient_x_0_0"
118+
"h", "Gradient_y_0", "Gradient_x_0_0"
116119
]
117120
y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0")
118-
if pytorch_pfn_extras.requires("1.13"):
121+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
119122
assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0"
120123
assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in
121124
else:
122125
assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0"
123126
assert named_nodes["Conv_2"].output[0] == y_in
124127

125128

129+
@pytest.mark.parametrize("use_pfto", [False, True])
126130
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
131+
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
127132
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
128-
def test_grad_multiple_times():
133+
def test_grad_multiple_times(use_pfto: bool):
129134
if not pytorch_pfn_extras.requires("1.8.0"):
130135
pytest.skip('skip for PyTorch 1.7 or earlier')
131136

@@ -167,12 +172,13 @@ def forward(self, x):
167172
x,
168173
'grad',
169174
enable_onnx_checker=False,
170-
use_pfto=False,
175+
use_pfto=use_pfto,
176+
output_names=["h"],
171177
)
172178

173179
actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
174180
named_nodes = {n.name: n for n in actual_onnx.graph.node}
175-
if pytorch_pfn_extras.requires("1.13"):
181+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
176182
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
177183
assert '/_ppe_as_out_module/conv_1/Conv' in named_nodes
178184
assert '/_ppe_as_out_module/Gradient' in named_nodes
@@ -186,11 +192,11 @@ def forward(self, x):
186192
assert 'MatMul_12' in named_nodes
187193

188194
assert list([v.name for v in actual_onnx.graph.output]) == [
189-
"v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1"
195+
"h", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1"
190196
]
191197
y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0")
192198
y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1")
193-
if pytorch_pfn_extras.requires("1.13"):
199+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
194200
assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0"
195201
assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in
196202
assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1"
@@ -202,9 +208,11 @@ def forward(self, x):
202208
assert named_nodes["Conv_7"].output[0] == y1_in
203209

204210

211+
@pytest.mark.parametrize("use_pfto", [False, True])
205212
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
213+
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
206214
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
207-
def test_grad_with_multiple_inputs():
215+
def test_grad_with_multiple_inputs(use_pfto: bool):
208216
if not pytorch_pfn_extras.requires("1.8.0"):
209217
pytest.skip('skip for PyTorch 1.7 or earlier')
210218

@@ -239,12 +247,13 @@ def forward(self, x):
239247
x,
240248
'grad',
241249
enable_onnx_checker=False,
242-
use_pfto=False,
250+
use_pfto=use_pfto,
251+
output_names=["h"],
243252
)
244253

245254
actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
246255
named_nodes = {n.name: n for n in actual_onnx.graph.node}
247-
if pytorch_pfn_extras.requires("1.13"):
256+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
248257
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
249258
assert '/_ppe_as_out_module/Gradient' in named_nodes
250259
assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
@@ -254,10 +263,10 @@ def forward(self, x):
254263
assert 'MatMul_9' in named_nodes
255264

256265
assert list([v.name for v in actual_onnx.graph.output]) == [
257-
"v14_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0"
266+
"h", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0"
258267
]
259268
y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0")
260-
if pytorch_pfn_extras.requires("1.13"):
269+
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
261270
assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0"
262271
assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0"
263272
assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in

0 commit comments

Comments
 (0)