Skip to content

Commit ef33c80

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add postproc module name in trace (#3384)
Summary: Pull Request resolved: #3384 # context * simple BE work to add postproc module name to the trace Reviewed By: spmex Differential Revision: D82700705 fbshipit-source-id: 572b120d008bfe85eb25b7488c6f46830e5e382f
1 parent ced0adf commit ef33c80

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

torchrec/distributed/train_pipeline/postproc.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,13 @@ def __init__(
7979
self._dist_stream = dist_stream
8080
if not default_stream:
8181
logger.warning(
82-
f"Postproc module {fqn} has no default stream. This may cause race conditions and NaNs during training!"
82+
f"Postproc module {fqn} has no default stream. "
83+
"This may cause race conditions and NaNs during training!"
8384
)
8485
if not dist_stream:
8586
logger.warning(
86-
f"Postproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!"
87+
f"Postproc module {fqn} has no dist stream. "
88+
"This may cause race conditions and NaNs during training!"
8789
)
8890

8991
if self._dist_stream:
@@ -139,7 +141,9 @@ def forward(self, *input, **kwargs) -> Any:
139141
# Use input[0] as _start_data_dist only passes 1 arg
140142
args, kwargs = self._args.build_args_kwargs(input[0])
141143

142-
with record_function(f"## sdd_input_postproc {self._context.index} ##"):
144+
with record_function(
145+
f"## input_postproc {type(self.postproc_module)} {self._context.index} ##"
146+
):
143147
# should be no-op as we call this in dist stream
144148
with self._stream_context(self._dist_stream):
145149
res = self._postproc_module(*args, **kwargs)
@@ -160,7 +164,11 @@ def forward(self, *input, **kwargs) -> Any:
160164
PipelinedPostproc.recursive_record_stream(res, self._default_stream)
161165
elif self._context.index == 0:
162166
logger.warning(
163-
f"Result of postproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!"
167+
f"Result of postproc module {self._fqn} is of type {type(res)}. "
168+
"We currently expect it to be a Tensor, Pipelineable, Iterable, "
169+
"or Dict to handle memory safety. If your output is not of this "
170+
"type, please add support for it above. Otherwise you might run "
171+
"into NaNs or CUDA Illegal Memory issues during training!"
164172
)
165173

166174
with self._stream_context(self._default_stream):

0 commit comments

Comments
 (0)