Skip to content

Commit 9d6ce01

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix corner case where train_pipeline PostProc can't trace through a fx wrap module (#3377)
Summary: Pull Request resolved: #3377 # context * when working on UDD LSR MC5, we observed 5% QPS regression due to disabled SDD pipeline ([mast log](https://fburl.com/mlhub/rvuibmcp)) * the code change is mainly due to the `torch.fx.wrap` of a [postproc function](https://fburl.com/code/spvnz047) * although the `pipeline_postproc` flag is `True`, the issue is that it's actually a fx-wrapped function. * currently postproc trace only works for a torch.nn.Module, not a function. * added test case to verify this, and future work is needed to support better tracing for the fx-wrapped function. Reviewed By: spmex Differential Revision: D82501416 fbshipit-source-id: 9cadbafaaee6536c95fe3050b60cb0a2ad7a91b1
1 parent fcfc8ec commit 9d6ce01

File tree

2 files changed

+131
-9
lines changed

2 files changed

+131
-9
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

Lines changed: 128 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,23 @@
1010
import copy
1111
import enum
1212
import unittest
13+
from typing import Tuple, Union
1314
from unittest.mock import MagicMock
1415

1516
import torch
1617

1718
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
18-
from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule
19+
from torchrec.distributed.test_utils.test_model import (
20+
ModelInput,
21+
TestNegSamplingModule,
22+
TestSparseNN,
23+
)
1924
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
2025
from torchrec.distributed.train_pipeline.runtime_forwards import PipelinedForward
21-
2226
from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
2327
TrainPipelineSparseDistTestBase,
2428
)
25-
from torchrec.distributed.train_pipeline.tracing import (
26-
ArgInfo,
27-
ArgInfoStepFactory,
28-
CallArgs,
29-
NodeArgsHelper,
30-
PipelinedPostproc,
31-
)
29+
from torchrec.distributed.train_pipeline.tracing import CallArgs, PipelinedPostproc
3230
from torchrec.distributed.train_pipeline.utils import _rewrite_model
3331
from torchrec.distributed.types import ShardingType
3432
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -40,6 +38,15 @@ class ModelType(enum.Enum):
4038
PIPELINED = "pipelined"
4139

4240

41+
@torch.fx.wrap
42+
def enrich_hstu_features(
43+
kjt: KeyedJaggedTensor, hstu_factor: float
44+
) -> KeyedJaggedTensor:
45+
if kjt._weights is not None:
46+
kjt._weights *= hstu_factor
47+
return kjt
48+
49+
4350
class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase):
4451
# pyre-fixme[56]: Pyre was not able to infer the type of argument
4552
@unittest.skipIf(
@@ -257,3 +264,115 @@ def test_restore_from_snapshot(self) -> None:
257264
]
258265
for source_model_type, recipient_model_type in variants:
259266
self._test_restore_from_snapshot(source_model_type, recipient_model_type)
267+
268+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
269+
@unittest.skipIf(
270+
not torch.cuda.is_available(),
271+
"Not enough GPUs, this test requires at least one GPU",
272+
)
273+
def test_rewrite_model_with_fx_wrap(self) -> None:
274+
sharding_type = ShardingType.TABLE_WISE.value
275+
kernel_type = EmbeddingComputeKernel.FUSED.value
276+
fused_params = {}
277+
278+
class TestPostProcModule(torch.nn.Module):
279+
def __init__(self, f: float):
280+
super().__init__()
281+
self.f = f
282+
283+
def forward(self, x: KeyedJaggedTensor) -> KeyedJaggedTensor:
284+
return enrich_hstu_features(x, self.f)
285+
286+
postproc_module = TestPostProcModule(0.3)
287+
288+
class TestModel(TestSparseNN):
289+
use_postproc_module: bool = False
290+
291+
def forward(
292+
self,
293+
input: ModelInput,
294+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
295+
if (type(self)).use_postproc_module:
296+
input = self.postproc_module(input)
297+
else:
298+
input = enrich_hstu_features(input, 0.3)
299+
return self.dense_forward(input, self.sparse_forward(input))
300+
301+
model = TestModel(
302+
tables=self.tables,
303+
weighted_tables=self.weighted_tables,
304+
dense_device=self.device,
305+
sparse_device=torch.device("meta"),
306+
postproc_module=postproc_module,
307+
)
308+
309+
sharded_model, optim = self._generate_sharded_model_and_optimizer(
310+
model, sharding_type, kernel_type, fused_params
311+
)
312+
313+
# Try to rewrite model using a function for postproc
314+
# EBC forwards not overwritten to PipelinedForward due to KJT modification
315+
self.assertFalse(model.use_postproc_module)
316+
_rewrite_model(
317+
model=sharded_model,
318+
batch=None,
319+
context=TrainPipelineContext(),
320+
dist_stream=None,
321+
pipeline_postproc=True,
322+
)
323+
self.assertNotIsInstance(
324+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
325+
# `sparse`.
326+
sharded_model.module.sparse.ebc.forward,
327+
PipelinedForward,
328+
)
329+
self.assertNotIsInstance(
330+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
331+
# `sparse`.
332+
sharded_model.module.sparse.weighted_ebc.forward,
333+
PipelinedForward,
334+
)
335+
336+
# Now use postproc module
337+
TestModel.use_postproc_module = True
338+
self.assertTrue(model.use_postproc_module)
339+
_rewrite_model(
340+
model=sharded_model,
341+
batch=None,
342+
context=TrainPipelineContext(),
343+
dist_stream=None,
344+
pipeline_postproc=True,
345+
)
346+
347+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`.
348+
self.assertIsInstance(sharded_model.module.sparse.ebc.forward, PipelinedForward)
349+
self.assertIsInstance(
350+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
351+
# `sparse`.
352+
sharded_model.module.sparse.weighted_ebc.forward,
353+
PipelinedForward,
354+
)
355+
self.assertEqual(
356+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
357+
# `sparse`.
358+
sharded_model.module.sparse.ebc.forward._args.args[0]
359+
.steps[0]
360+
.postproc_module,
361+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
362+
# `postproc_module`.
363+
sharded_model.module.postproc_module,
364+
)
365+
self.assertEqual(
366+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
367+
# `sparse`.
368+
sharded_model.module.sparse.weighted_ebc.forward._args.args[0]
369+
.steps[0]
370+
.postproc_module,
371+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
372+
# `postproc_module`.
373+
sharded_model.module.postproc_module,
374+
)
375+
state_dict = sharded_model.state_dict()
376+
missing_keys, unexpected_keys = sharded_model.load_state_dict(state_dict)
377+
self.assertEqual(missing_keys, [])
378+
self.assertEqual(unexpected_keys, [])

torchrec/distributed/train_pipeline/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,9 @@ def _rewrite_model( # noqa C901
361361
model, context, pipeline_postproc, default_stream, dist_stream
362362
)
363363

364+
logger.info(
365+
f"pipeline_postproc is {'enabled' if pipeline_postproc else 'disabled'}"
366+
)
364367
for node in graph.nodes:
365368
# only work on the call_module node which is also a sharded module
366369
if node.op != "call_module" or node.target not in sharded_modules:

0 commit comments

Comments
 (0)