Skip to content

Commit 61b5c20

Browse files
peterfu0facebook-github-bot
authored andcommitted
add prefetch (pytorch#3349)
Summary: Pull Request resolved: pytorch#3349 Differential Revision: D79404930
1 parent 60f7f87 commit 61b5c20

File tree

3 files changed

+193
-3
lines changed

3 files changed

+193
-3
lines changed

torchrec/distributed/train_pipeline/runtime_forwards.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def name(self) -> str:
5858
def args(self) -> CallArgs:
5959
return self._args
6060

61+
@classmethod
62+
def prefetch(cls) -> bool:
63+
return False
64+
6165
def set_context(self, context: TForwardContext) -> None:
6266
self._context = context
6367

@@ -220,6 +224,72 @@ def detach_embeddings(
220224
pass
221225

222226

227+
class PrefetchPipelinedForwardCustomizedOrder(
228+
BaseForward[PrefetchTrainPipelineContext]
229+
):
230+
"""
231+
This pipeline is used in TrainPipelineCustomizedOrderSparseDist
232+
compute_and_output_dist for batch N is called at the end of step N - 1
233+
"""
234+
235+
def __init__(
236+
self,
237+
name: str,
238+
args: CallArgs,
239+
module: ShardedModule,
240+
context: PrefetchTrainPipelineContext,
241+
prefetch_stream: Optional[torch.Stream] = None,
242+
) -> None:
243+
super().__init__(
244+
name=name,
245+
args=args,
246+
module=module,
247+
context=context,
248+
stream=prefetch_stream,
249+
)
250+
self._compute_and_output_dist_awaitable: Optional[
251+
Awaitable[Multistreamable]
252+
] = None
253+
254+
@classmethod
255+
def prefetch(cls) -> bool:
256+
return True
257+
258+
def compute_and_output_dist(self) -> None:
259+
assert (
260+
self._name in self._context.module_input_post_prefetch
261+
), "Invalid PrefetchPipelinedForward usage, please do not directly call model.forward()"
262+
data = self._context.module_input_post_prefetch.pop(self._name)
263+
ctx = self._context.module_contexts_post_prefetch.pop(self._name)
264+
265+
# Make sure that both result of input_dist and context
266+
# are properly transferred to the current stream.
267+
if self._stream is not None:
268+
torch.get_device_module(self._device).current_stream().wait_stream(
269+
self._stream
270+
)
271+
cur_stream = torch.get_device_module(self._device).current_stream()
272+
273+
assert isinstance(
274+
data, (torch.Tensor, Multistreamable)
275+
), f"{type(data)} must implement Multistreamable interface"
276+
data.record_stream(cur_stream)
277+
278+
ctx.record_stream(cur_stream)
279+
280+
self._compute_and_output_dist_awaitable = self._module.compute_and_output_dist(
281+
ctx, data
282+
)
283+
284+
# pyre-ignore [2, 24]
285+
def __call__(self, *input, **kwargs) -> Awaitable:
286+
if not self._compute_and_output_dist_awaitable:
287+
raise Exception(
288+
"compute_and_output_dist must be called before __call__",
289+
)
290+
return self._compute_and_output_dist_awaitable
291+
292+
223293
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
224294
"""
225295
This pipeline is used in PrefetchTrainPipelineSparseDist
@@ -241,6 +311,10 @@ def __init__(
241311
stream=prefetch_stream,
242312
)
243313

314+
@classmethod
315+
def prefetch(cls) -> bool:
316+
return True
317+
244318
# pyre-ignore [2, 24]
245319
def __call__(self, *input, **kwargs) -> Awaitable:
246320
assert (
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from unittest.mock import MagicMock
12+
13+
from torchrec.distributed.train_pipeline.pipeline_context import (
14+
PrefetchTrainPipelineContext,
15+
)
16+
from torchrec.distributed.train_pipeline.runtime_forwards import (
17+
PrefetchPipelinedForwardCustomizedOrder,
18+
)
19+
from torchrec.distributed.train_pipeline.types import CallArgs
20+
from torchrec.distributed.train_pipeline.utils import _prefetch_enabled
21+
22+
23+
class TestPrefetchPipelinedForwardCustomizedOrder(unittest.TestCase):
24+
"""Test PrefetchPipelinedForwardCustomizedOrder key functionality"""
25+
26+
def setUp(self) -> None:
27+
"""Set up test fixtures."""
28+
self.mock_module = MagicMock()
29+
self.prefetch_context = PrefetchTrainPipelineContext()
30+
self.mock_args = CallArgs(args=[], kwargs={})
31+
32+
def test_prefetch_returns_true(self) -> None:
33+
"""Test that prefetch() returns True."""
34+
forward = PrefetchPipelinedForwardCustomizedOrder(
35+
name="test_prefetch",
36+
args=self.mock_args,
37+
module=self.mock_module,
38+
context=self.prefetch_context,
39+
)
40+
41+
# Test that prefetch returns True
42+
self.assertTrue(forward.prefetch())
43+
44+
def test_prefetch_enabled_returns_true(self) -> None:
45+
"""Test that _prefetch_enabled returns True for PrefetchPipelinedForwardCustomizedOrder."""
46+
forward = PrefetchPipelinedForwardCustomizedOrder(
47+
name="test_prefetch_enabled",
48+
args=self.mock_args,
49+
module=self.mock_module,
50+
context=self.prefetch_context,
51+
)
52+
53+
# Test that _prefetch_enabled returns True
54+
result = _prefetch_enabled(forward)
55+
self.assertTrue(result)
56+
57+
def test_call_fails_without_compute_and_output_dist(self) -> None:
58+
"""Test that __call__ fails if compute_and_output_dist is not called first."""
59+
forward = PrefetchPipelinedForwardCustomizedOrder(
60+
name="test_call_error",
61+
args=self.mock_args,
62+
module=self.mock_module,
63+
context=self.prefetch_context,
64+
)
65+
66+
# Should raise exception when called without compute_and_output_dist
67+
with self.assertRaises(Exception) as context:
68+
forward()
69+
70+
self.assertIn(
71+
"compute_and_output_dist must be called before __call__",
72+
str(context.exception),
73+
)
74+
75+
def test_call_succeeds_after_compute_and_output_dist(self) -> None:
76+
"""Test that __call__ succeeds when compute_and_output_dist is called first."""
77+
forward = PrefetchPipelinedForwardCustomizedOrder(
78+
name="test_call_success",
79+
args=self.mock_args,
80+
module=self.mock_module,
81+
context=self.prefetch_context,
82+
)
83+
84+
# Set up mock data in context
85+
test_data = MagicMock()
86+
test_ctx = MagicMock()
87+
self.prefetch_context.module_input_post_prefetch = {
88+
"test_call_success": test_data
89+
}
90+
self.prefetch_context.module_contexts_post_prefetch = {
91+
"test_call_success": test_ctx
92+
}
93+
94+
# Mock the module's compute_and_output_dist method
95+
mock_awaitable = MagicMock()
96+
self.mock_module.compute_and_output_dist.return_value = mock_awaitable
97+
98+
# Call compute_and_output_dist first
99+
forward.compute_and_output_dist()
100+
101+
# Now __call__ should succeed and return the awaitable
102+
result = forward()
103+
self.assertEqual(result, mock_awaitable)
104+
105+
106+
if __name__ == "__main__":
107+
unittest.main()

torchrec/distributed/train_pipeline/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import torch
3131
from torch.profiler import record_function
32-
3332
from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable
3433
from torchrec.distributed.embedding_sharding import (
3534
FusedKJTListSplitsAwaitable,
@@ -53,6 +52,7 @@
5352
KJTAllToAllForward,
5453
PipelinedForward,
5554
PrefetchPipelinedForward,
55+
PrefetchPipelinedForwardCustomizedOrder,
5656
TForwardContext,
5757
)
5858
from torchrec.distributed.train_pipeline.tracing import (
@@ -61,7 +61,7 @@
6161
Tracer,
6262
)
6363
from torchrec.distributed.train_pipeline.types import CallArgs # noqa
64-
from torchrec.distributed.types import Awaitable
64+
from torchrec.distributed.types import Awaitable, LazyAwaitable
6565
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
6666
from torchrec.streamable import Multistreamable, Pipelineable
6767

@@ -138,6 +138,7 @@ def _start_data_dist(
138138
PrefetchPipelinedForward,
139139
EmbeddingPipelinedForward,
140140
InSyncEmbeddingPipelinedForward,
141+
PrefetchPipelinedForwardCustomizedOrder,
141142
),
142143
)
143144

@@ -539,6 +540,10 @@ def get_next_batch(self, none_throws: bool = False) -> Optional[In]:
539540
return batch
540541

541542

543+
def _prefetch_enabled(forward: LazyAwaitable[Out]) -> bool:
544+
return isinstance(forward, BaseForward) and forward.prefetch()
545+
546+
542547
def _prefetch_embeddings(
543548
batch: In,
544549
context: PrefetchTrainPipelineContext,
@@ -551,7 +556,11 @@ def _prefetch_embeddings(
551556
data_per_sharded_module = {}
552557
for sharded_module in pipelined_modules:
553558
forward = sharded_module.forward
554-
assert isinstance(forward, PrefetchPipelinedForward)
559+
# for backward compatibility, consider it valid if it is PrefetchPipelinedForward
560+
# because the class might not have prefetch method
561+
assert isinstance(forward, PrefetchPipelinedForward) or _prefetch_enabled(
562+
forward
563+
)
555564
assert forward._name in context.input_dist_tensors_requests
556565
request = context.input_dist_tensors_requests.pop(forward._name)
557566
assert isinstance(request, Awaitable)

0 commit comments

Comments
 (0)