You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
refactor PostProc tracing and debug message (#3379)
Summary:
Pull Request resolved: #3379
# context
* There's quite some limitations on the postproc support for TorchRec's train pipeline
* add better warning message for debugging
## symptoms
* unable to run input_dist in "-1" batch with the `SparseDistTrainPipeline`, AKA, SDD (Sparse Data Dist) pipeline
* warning in log: `Module '{node.target}' will NOT be pipelined, due to input modifications`
## typical issues
* root cause: input KJT is modified or passed through some module/function potentially modifies the KJT
* pipeline_postproc is not enabled
* check the error message for `fx node {child_node.name, child_node.op, child_node.target} can't be handled correctly for postproc module`
* postproc module has trainable weights (sorry we don't support this)
* a postproc function modifies the input KJT
* two postproc modules have certain execution order
## workaround
* make the postproc function a nn.Module
* put order-dependent functions/modules under the same nn.Module to preserve the order.
Reviewed By: spmex
Differential Revision: D82591429
fbshipit-source-id: 86ca706a508e250a42e3c37aa622582372d229f3
Copy file name to clipboardExpand all lines: torchrec/distributed/train_pipeline/tracing.py
+13-6Lines changed: 13 additions & 6 deletions
Original file line number
Diff line number
Diff line change
@@ -302,7 +302,11 @@ def _handle_module(
302
302
303
303
ifnotself._pipeline_postproc:
304
304
logger.warning(
305
-
f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`"
305
+
f"Found module {postproc_module} that potentially modifies input KJT. "
306
+
"Train pipeline initialized with `pipeline_postproc=False` (default), "
307
+
"so we assume KJT input modification. "
308
+
"To allow torchrec to check if this module can be safely pipelined, "
309
+
"please set `pipeline_postproc=True`"
306
310
)
307
311
returnNone
308
312
@@ -341,11 +345,10 @@ def _handle_module(
341
345
)
342
346
ifnum_found_safe_postproc_args==total_num_args:
343
347
logger.info(
344
-
f"""Module {postproc_module} is a valid postproc module (no
345
-
trainable params and inputs can be derived from train batch input
346
-
via a series of either valid postproc modules or non-modifying
347
-
transformations) and will be applied during sparse data dist
348
-
stage"""
348
+
f"Module {postproc_module} is a valid postproc module (no "
349
+
"trainable params and inputs can be derived from train batch input "
350
+
"via a series of either valid postproc modules or non-modifying "
351
+
"transformations) and will be applied during sparse data dist stage"
0 commit comments