Skip to content

Commit 960581a

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
refactor PostProc tracing and debug message (#3379)
Summary: Pull Request resolved: #3379 # context * There's quite some limitations on the postproc support for TorchRec's train pipeline * add better warning message for debugging ## symptoms * unable to run input_dist in "-1" batch with the `SparseDistTrainPipeline`, AKA, SDD (Sparse Data Dist) pipeline * warning in log: `Module '{node.target}' will NOT be pipelined, due to input modifications` ## typical issues * root cause: input KJT is modified or passed through some module/function potentially modifies the KJT * pipeline_postproc is not enabled * check the error message for `fx node {child_node.name, child_node.op, child_node.target} can't be handled correctly for postproc module` * postproc module has trainable weights (sorry we don't support this) * a postproc function modifies the input KJT * two postproc modules have certain execution order ## workaround * make the postproc function a nn.Module * put order-dependent functions/modules under the same nn.Module to preserve the order. Reviewed By: spmex Differential Revision: D82591429 fbshipit-source-id: 86ca706a508e250a42e3c37aa622582372d229f3
1 parent 9d6ce01 commit 960581a

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

torchrec/distributed/train_pipeline/tracing.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,11 @@ def _handle_module(
302302

303303
if not self._pipeline_postproc:
304304
logger.warning(
305-
f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`"
305+
f"Found module {postproc_module} that potentially modifies input KJT. "
306+
"Train pipeline initialized with `pipeline_postproc=False` (default), "
307+
"so we assume KJT input modification. "
308+
"To allow torchrec to check if this module can be safely pipelined, "
309+
"please set `pipeline_postproc=True`"
306310
)
307311
return None
308312

@@ -341,11 +345,10 @@ def _handle_module(
341345
)
342346
if num_found_safe_postproc_args == total_num_args:
343347
logger.info(
344-
f"""Module {postproc_module} is a valid postproc module (no
345-
trainable params and inputs can be derived from train batch input
346-
via a series of either valid postproc modules or non-modifying
347-
transformations) and will be applied during sparse data dist
348-
stage"""
348+
f"Module {postproc_module} is a valid postproc module (no "
349+
"trainable params and inputs can be derived from train batch input "
350+
"via a series of either valid postproc modules or non-modifying "
351+
"transformations) and will be applied during sparse data dist stage"
349352
)
350353

351354
pipelined_postproc_module = PipelinedPostproc(
@@ -449,6 +452,10 @@ def _get_node_args_helper_inner(
449452
arg_info.add_step(ArgInfoStepFactory.get_item(child_node.args[1]))
450453
arg = child_node.args[0]
451454
else:
455+
logger.warning(
456+
f"fx node {child_node.name, child_node.op, child_node.target} "
457+
"can't be handled correctly for postproc module"
458+
)
452459
break
453460

454461
# if we couldn't hit one of the "decisive" outcomes (constant, placeholder or module), return "not found"

0 commit comments

Comments
 (0)