Skip to content

Commit 6af25a2

Browse files
peterfu0facebook-github-bot
authored andcommitted
add prefetch
Differential Revision: D79404930
1 parent a07bc63 commit 6af25a2

File tree

3 files changed

+99
-6
lines changed

3 files changed

+99
-6
lines changed

torchrec/distributed/train_pipeline/runtime_forwards.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def name(self) -> str:
5858
def args(self) -> CallArgs:
5959
return self._args
6060

61+
@classmethod
62+
def prefetch(cls) -> bool:
63+
return False
64+
6165
def set_context(self, context: TForwardContext) -> None:
6266
self._context = context
6367

@@ -220,6 +224,72 @@ def detach_embeddings(
220224
pass
221225

222226

227+
class PrefetchPipelinedForwardCustomizedOrder(
228+
BaseForward[PrefetchTrainPipelineContext]
229+
):
230+
"""
231+
This pipeline is used in TrainPipelineCustomizedOrderSparseDist
232+
compute_and_output_dist for batch N is called at the end of step N - 1
233+
"""
234+
235+
def __init__(
236+
self,
237+
name: str,
238+
args: CallArgs,
239+
module: ShardedModule,
240+
context: PrefetchTrainPipelineContext,
241+
prefetch_stream: Optional[torch.Stream] = None,
242+
) -> None:
243+
super().__init__(
244+
name=name,
245+
args=args,
246+
module=module,
247+
context=context,
248+
stream=prefetch_stream,
249+
)
250+
self._compute_and_output_dist_awaitable: Optional[
251+
Awaitable[Multistreamable]
252+
] = None
253+
254+
@classmethod
255+
def prefetch(cls) -> bool:
256+
return True
257+
258+
def compute_and_output_dist(self) -> None:
259+
assert (
260+
self._name in self._context.module_input_post_prefetch
261+
), "Invalid PrefetchPipelinedForward usage, please do not directly call model.forward()"
262+
data = self._context.module_input_post_prefetch.pop(self._name)
263+
ctx = self._context.module_contexts_post_prefetch.pop(self._name)
264+
265+
# Make sure that both result of input_dist and context
266+
# are properly transferred to the current stream.
267+
if self._stream is not None:
268+
torch.get_device_module(self._device).current_stream().wait_stream(
269+
self._stream
270+
)
271+
cur_stream = torch.get_device_module(self._device).current_stream()
272+
273+
assert isinstance(
274+
data, (torch.Tensor, Multistreamable)
275+
), f"{type(data)} must implement Multistreamable interface"
276+
data.record_stream(cur_stream)
277+
278+
ctx.record_stream(cur_stream)
279+
280+
self._compute_and_output_dist_awaitable = self._module.compute_and_output_dist(
281+
ctx, data
282+
)
283+
284+
# pyre-ignore [2, 24]
285+
def __call__(self, *input, **kwargs) -> Awaitable:
286+
if not self._compute_and_output_dist_awaitable:
287+
raise Exception(
288+
"compute_and_output_dist must be called before __call__",
289+
)
290+
return self._compute_and_output_dist_awaitable
291+
292+
223293
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
224294
"""
225295
This pipeline is used in PrefetchTrainPipelineSparseDist
@@ -241,6 +311,10 @@ def __init__(
241311
stream=prefetch_stream,
242312
)
243313

314+
@classmethod
315+
def prefetch(cls) -> bool:
316+
return True
317+
244318
# pyre-ignore [2, 24]
245319
def __call__(self, *input, **kwargs) -> Awaitable:
246320
assert (

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
583583
batch, context = self.copy_batch_to_gpu(dataloader_iter)
584584
if batch is None:
585585
return False
586+
586587
self.batches.append(batch)
587588
# pyre-ignore [6]
588589
self.contexts.append(context)
@@ -732,6 +733,7 @@ def _pipeline_model(
732733
batch: Optional[In],
733734
context: TrainPipelineContext,
734735
pipelined_forward: Type[PipelinedForward] = PipelinedForward,
736+
prefetch_stream: Optional[torch.Stream] = None,
735737
) -> None:
736738
(
737739
self._pipelined_modules,
@@ -742,7 +744,9 @@ def _pipeline_model(
742744
) = _rewrite_model(
743745
model=self._model,
744746
context=context,
745-
dist_stream=self._data_dist_stream,
747+
dist_stream=(
748+
self._data_dist_stream if prefetch_stream is None else prefetch_stream
749+
),
746750
default_stream=torch.get_device_module(self._device).current_stream(),
747751
batch=batch,
748752
apply_jit=self._apply_jit,
@@ -845,9 +849,11 @@ def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
845849
"""
846850
with record_function(f"## wait_sparse_data_dist {context.index} ##"):
847851
with self._stream_context(self._data_dist_stream):
852+
# fused_splits_awaitables is empty
848853
for names, awaitable in context.fused_splits_awaitables:
849854
for name, request in zip(names, awaitable.wait()):
850855
context.input_dist_tensors_requests[name] = request
856+
851857
context.input_dist_splits_requests.clear()
852858
context.fused_splits_awaitables.clear()
853859

@@ -1495,6 +1501,10 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
14951501
self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)
14961502
self._start_sparse_data_dist(self._batch_ip1)
14971503

1504+
# i: prefetch is done
1505+
# ip1: input_dist is done, need to prefetch
1506+
# ip2: not exist, need to copy and then start input_dist
1507+
# how about: ip2': memcpy is done, need to input_dist, ip3': not exist, need to memcpy
14981508
def progress(self, dataloader_iter: Iterator[In]) -> Out:
14991509
self._fill_pipeline(dataloader_iter)
15001510

@@ -1507,12 +1517,12 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
15071517

15081518
self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter)
15091519

1510-
self._wait_sparse_data_dist()
1520+
self._wait_sparse_data_dist() # it waits for both i and ip1, as ip1(ip2 in previous round) started
15111521
# forward
15121522
with record_function("## forward ##"):
15131523
losses, output = self._model_fwd(self._batch_i)
15141524

1515-
self._prefetch(self._batch_ip1)
1525+
self._prefetch(self._batch_ip1) # prefetch 1
15161526

15171527
if self._model.training:
15181528
# backward

torchrec/distributed/train_pipeline/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import torch
3131
from torch.profiler import record_function
32-
3332
from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable
3433
from torchrec.distributed.embedding_sharding import (
3534
FusedKJTListSplitsAwaitable,
@@ -53,6 +52,7 @@
5352
KJTAllToAllForward,
5453
PipelinedForward,
5554
PrefetchPipelinedForward,
55+
PrefetchPipelinedForwardCustomizedOrder,
5656
TForwardContext,
5757
)
5858
from torchrec.distributed.train_pipeline.tracing import (
@@ -61,7 +61,7 @@
6161
Tracer,
6262
)
6363
from torchrec.distributed.train_pipeline.types import CallArgs # noqa
64-
from torchrec.distributed.types import Awaitable
64+
from torchrec.distributed.types import Awaitable, LazyAwaitable
6565
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
6666
from torchrec.streamable import Multistreamable, Pipelineable
6767

@@ -138,6 +138,7 @@ def _start_data_dist(
138138
PrefetchPipelinedForward,
139139
EmbeddingPipelinedForward,
140140
InSyncEmbeddingPipelinedForward,
141+
PrefetchPipelinedForwardCustomizedOrder,
141142
),
142143
)
143144

@@ -539,6 +540,10 @@ def get_next_batch(self, none_throws: bool = False) -> Optional[In]:
539540
return batch
540541

541542

543+
def _prefetch_enabled(forward: LazyAwaitable[Out]) -> bool:
544+
assert isinstance(forward, BaseForward) and forward.prefetch
545+
546+
542547
def _prefetch_embeddings(
543548
batch: In,
544549
context: PrefetchTrainPipelineContext,
@@ -551,7 +556,11 @@ def _prefetch_embeddings(
551556
data_per_sharded_module = {}
552557
for sharded_module in pipelined_modules:
553558
forward = sharded_module.forward
554-
assert isinstance(forward, PrefetchPipelinedForward)
559+
# for backward compatibility, consider it valid if it is PrefetchPipelinedForward
560+
# because the class might not have prefetch method
561+
assert isinstance(forward, PrefetchPipelinedForward) or _prefetch_enabled(
562+
forward
563+
)
555564
assert forward._name in context.input_dist_tensors_requests
556565
request = context.input_dist_tensors_requests.pop(forward._name)
557566
assert isinstance(request, Awaitable)

0 commit comments

Comments
 (0)