Skip to content

Commit 9ea2293

Browse files
sidt-metapobin6
authored andcommitted
[IG] Avoid generation of empty merge cpu submodule by splitter v2 (pytorch#140794)
Summary: Customize splitter behavior to mark `get_attr` nodes as acc supported. Currently these nodes are excluded by `FxNetAccNodesFinder` which marks all nodes with op not in `CALLABLE_NODE_OPS` ("call_module", "call_function", "call_method") as unsupported. Before this change, merge-net is split into an almost empty cpu submodule with a single empty output node: ``` INFO:caffe2.torch.fb.model_transform.experimental.prepare_fx_model:###### debug_print nodes for _run_on_cpu_0 INFO:caffe2.torch.fb.model_transform.experimental.prepare_fx_model:Found output node: n.name='output', n.target='output', n.args=((),), n.kwargs={}, n.meta={} INFO:caffe2.torch.fb.model_transform.experimental.prepare_fx_model:return () INFO:caffe2.torch.fb.model_transform.experimental.prepare_fx_model: _run_on_cpu_0 stats for merge: [output] output: 1 ``` full log: P1678727348 (generated using same command as below) Test Plan: Tested by lowering `ig_organic_feed_cn_v2_mtml` using cmd: ``` buck run mode/opt-split-dwarf //tgif/cli:cli -- --model-name=ig_organic_feed_cn_v2_mtml --model-type ig_organic_feed_cn_v2_mtml --world-size=1 --storage-mode 1 --inference-dtype=FP16 --meta-transform=False --use-random-weights=True --accelerator-arch=3 --enable-input-dist=True --embedding-tables-dtype=FP16 --mtia-use-torch-export=True embedding-quantization-pass torchrec-sharding-pass tgif-split-pass gen-app-graph-pass tgif-mtia-lowering-pass dense-quantization-pass save-torch-package-pass generate-model-package-pass pack-weights-and-save-pass 2>&1 | tee /tmp/publish_ig_organic_feed_cn_v2_mtml_mtia_export_20241114_splitter_2.log ``` Output shows only 1 acc submodule is generated for merge: ``` INFO 18:33:15.951 1735650 utils.py:235: [TGIF] num of acc submodules: 1 INFO 18:33:15.952 1735650 utils.py:236: [TGIF] num of cpu submodules: 0 INFO 18:33:16.534 1735650 logging_utils.py:53: [TGIF] _run_on_acc_0 graph module debug info: https://www.internalfb.com/intern/everpaste/?color=0&handle=GK4VKhWsDKF9VdsDAKxhR6KAlhJ0br0LAAAz INFO 18:33:16.534 1735650 utils.py:257: [TGIF] Start MTIA lowering _run_on_acc_0 in merge, device ordinal: -1 ``` full log: P1679596796 Differential Revision: D65983916 Pull Request resolved: pytorch#140794 Approved by: https://github.com/ezyang
1 parent b55d923 commit 9ea2293

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

torch/fx/passes/splitter_base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def __init__(
331331
settings: _SplitterSettingBase,
332332
non_acc_submodule_name: str = "_run_on_cpu_",
333333
return_tuple: bool = False,
334+
nodes_finder: Optional[FxNetAccNodesFinder] = None,
334335
):
335336
"""
336337
Preprocesses graph before splitting:
@@ -348,9 +349,11 @@ def __init__(
348349
self.settings = settings
349350
self.operator_support = operator_support
350351
self.sample_input = sample_input
351-
self.acc_nodes = FxNetAccNodesFinder(
352-
self.module, self.operator_support, self.settings.allow_non_tensor
353-
)()
352+
if nodes_finder is None:
353+
nodes_finder = FxNetAccNodesFinder(
354+
self.module, self.operator_support, self.settings.allow_non_tensor
355+
)
356+
self.acc_nodes = nodes_finder()
354357

355358
if self.settings.skip_fusion:
356359
self.fusions = {}

0 commit comments

Comments
 (0)