Skip to content

Commit c72fb39

Browse files
peterfu0facebook-github-bot
authored andcommitted
add prefetch (#3349)
Summary: Pull Request resolved: #3349 Differential Revision: D79404930
1 parent 85ec396 commit c72fb39

File tree

3 files changed

+192
-1
lines changed

3 files changed

+192
-1
lines changed

torchrec/distributed/train_pipeline/runtime_forwards.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def name(self) -> str:
5858
def args(self) -> CallArgs:
5959
return self._args
6060

61+
@classmethod
62+
def prefetch(cls) -> bool:
63+
return False
64+
6165
def set_context(self, context: TForwardContext) -> None:
6266
self._context = context
6367

@@ -220,9 +224,75 @@ def detach_embeddings(
220224
pass
221225

222226

227+
class PrefetchEmbeddingPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
228+
"""
229+
This pipeline is used in TrainPipelineCustomizedOrderSparseDist when
230+
prefetch is enabled and pipelined_sprase_lookup_fwd is enabled
231+
compute_and_output_dist for batch N is called at the end of step N - 1
232+
"""
233+
234+
def __init__(
235+
self,
236+
name: str,
237+
args: CallArgs,
238+
module: ShardedModule,
239+
context: PrefetchTrainPipelineContext,
240+
prefetch_stream: Optional[torch.Stream] = None,
241+
) -> None:
242+
super().__init__(
243+
name=name,
244+
args=args,
245+
module=module,
246+
context=context,
247+
stream=prefetch_stream,
248+
)
249+
self._compute_and_output_dist_awaitable: Optional[
250+
Awaitable[Multistreamable]
251+
] = None
252+
253+
@classmethod
254+
def prefetch(cls) -> bool:
255+
return True
256+
257+
def compute_and_output_dist(self) -> None:
258+
assert (
259+
self._name in self._context.module_input_post_prefetch
260+
), "Invalid PrefetchPipelinedForward usage, please do not directly call model.forward()"
261+
data = self._context.module_input_post_prefetch.pop(self._name)
262+
ctx = self._context.module_contexts_post_prefetch.pop(self._name)
263+
264+
# Make sure that both result of input_dist and context
265+
# are properly transferred to the current stream.
266+
if self._stream is not None:
267+
torch.get_device_module(self._device).current_stream().wait_stream(
268+
self._stream
269+
)
270+
cur_stream = torch.get_device_module(self._device).current_stream()
271+
272+
assert isinstance(
273+
data, (torch.Tensor, Multistreamable)
274+
), f"{type(data)} must implement Multistreamable interface"
275+
data.record_stream(cur_stream)
276+
277+
ctx.record_stream(cur_stream)
278+
279+
self._compute_and_output_dist_awaitable = self._module.compute_and_output_dist(
280+
ctx, data
281+
)
282+
283+
# pyre-ignore [2, 24]
284+
def __call__(self, *input, **kwargs) -> Awaitable:
285+
if not self._compute_and_output_dist_awaitable:
286+
raise Exception(
287+
"compute_and_output_dist must be called before __call__",
288+
)
289+
return self._compute_and_output_dist_awaitable
290+
291+
223292
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
224293
"""
225294
This pipeline is used in PrefetchTrainPipelineSparseDist
295+
OR in TrainPipelineCustomizedOrderSparseDist, when prefetch is enabled but pipeline_embedding_lookup_fwd is disabled
226296
"""
227297

228298
def __init__(
@@ -241,6 +311,10 @@ def __init__(
241311
stream=prefetch_stream,
242312
)
243313

314+
@classmethod
315+
def prefetch(cls) -> bool:
316+
return True
317+
244318
# pyre-ignore [2, 24]
245319
def __call__(self, *input, **kwargs) -> Awaitable:
246320
assert (
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from unittest.mock import MagicMock
12+
13+
from torchrec.distributed.train_pipeline.pipeline_context import (
14+
PrefetchTrainPipelineContext,
15+
)
16+
from torchrec.distributed.train_pipeline.runtime_forwards import (
17+
PrefetchEmbeddingPipelinedForward,
18+
)
19+
from torchrec.distributed.train_pipeline.types import CallArgs
20+
from torchrec.distributed.train_pipeline.utils import _prefetch_enabled
21+
22+
23+
class TestPrefetchEmbeddingPipelinedForward(unittest.TestCase):
24+
"""Test PrefetchEmbeddingPipelinedForward key functionality"""
25+
26+
def setUp(self) -> None:
27+
"""Set up test fixtures."""
28+
self.mock_module = MagicMock()
29+
self.prefetch_context = PrefetchTrainPipelineContext()
30+
self.mock_args = CallArgs(args=[], kwargs={})
31+
32+
def test_prefetch_returns_true(self) -> None:
33+
"""Test that prefetch() returns True."""
34+
forward = PrefetchEmbeddingPipelinedForward(
35+
name="test_prefetch",
36+
args=self.mock_args,
37+
module=self.mock_module,
38+
context=self.prefetch_context,
39+
)
40+
41+
# Test that prefetch returns True
42+
self.assertTrue(forward.prefetch())
43+
44+
def test_prefetch_enabled_returns_true(self) -> None:
45+
"""Test that _prefetch_enabled returns True for PrefetchEmbeddingPipelinedForward."""
46+
forward = PrefetchEmbeddingPipelinedForward(
47+
name="test_prefetch_enabled",
48+
args=self.mock_args,
49+
module=self.mock_module,
50+
context=self.prefetch_context,
51+
)
52+
53+
# Test that _prefetch_enabled returns True
54+
result = _prefetch_enabled(forward)
55+
self.assertTrue(result)
56+
57+
def test_call_fails_without_compute_and_output_dist(self) -> None:
58+
"""Test that __call__ fails if compute_and_output_dist is not called first."""
59+
forward = PrefetchEmbeddingPipelinedForward(
60+
name="test_call_error",
61+
args=self.mock_args,
62+
module=self.mock_module,
63+
context=self.prefetch_context,
64+
)
65+
66+
# Should raise exception when called without compute_and_output_dist
67+
with self.assertRaises(Exception) as context:
68+
forward()
69+
70+
self.assertIn(
71+
"compute_and_output_dist must be called before __call__",
72+
str(context.exception),
73+
)
74+
75+
def test_call_succeeds_after_compute_and_output_dist(self) -> None:
76+
"""Test that __call__ succeeds when compute_and_output_dist is called first."""
77+
forward = PrefetchEmbeddingPipelinedForward(
78+
name="test_call_success",
79+
args=self.mock_args,
80+
module=self.mock_module,
81+
context=self.prefetch_context,
82+
)
83+
84+
# Set up mock data in context
85+
test_data = MagicMock()
86+
test_ctx = MagicMock()
87+
self.prefetch_context.module_input_post_prefetch = {
88+
"test_call_success": test_data
89+
}
90+
self.prefetch_context.module_contexts_post_prefetch = {
91+
"test_call_success": test_ctx
92+
}
93+
94+
# Mock the module's compute_and_output_dist method
95+
mock_awaitable = MagicMock()
96+
self.mock_module.compute_and_output_dist.return_value = mock_awaitable
97+
98+
# Call compute_and_output_dist first
99+
forward.compute_and_output_dist()
100+
101+
# Now __call__ should succeed and return the awaitable
102+
result = forward()
103+
self.assertEqual(result, mock_awaitable)
104+
105+
106+
if __name__ == "__main__":
107+
unittest.main()

torchrec/distributed/train_pipeline/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
InSyncEmbeddingPipelinedForward,
5353
KJTAllToAllForward,
5454
PipelinedForward,
55+
PrefetchEmbeddingPipelinedForward,
5556
PrefetchPipelinedForward,
5657
TForwardContext,
5758
)
@@ -138,6 +139,7 @@ def _start_data_dist(
138139
PrefetchPipelinedForward,
139140
EmbeddingPipelinedForward,
140141
InSyncEmbeddingPipelinedForward,
142+
PrefetchEmbeddingPipelinedForward,
141143
),
142144
)
143145

@@ -539,6 +541,10 @@ def get_next_batch(self, none_throws: bool = False) -> Optional[In]:
539541
return batch
540542

541543

544+
def _prefetch_enabled(forward: BaseForward[TForwardContext]) -> bool:
545+
return isinstance(forward, BaseForward) and forward.prefetch()
546+
547+
542548
def _prefetch_embeddings(
543549
batch: In,
544550
context: PrefetchTrainPipelineContext,
@@ -551,7 +557,11 @@ def _prefetch_embeddings(
551557
data_per_sharded_module = {}
552558
for sharded_module in pipelined_modules:
553559
forward = sharded_module.forward
554-
assert isinstance(forward, PrefetchPipelinedForward)
560+
# for backward compatibility, consider it valid if it is PrefetchPipelinedForward
561+
# because the class might not have prefetch method
562+
assert isinstance(forward, PrefetchPipelinedForward) or _prefetch_enabled(
563+
forward
564+
)
555565
assert forward._name in context.input_dist_tensors_requests
556566
request = context.input_dist_tensors_requests.pop(forward._name)
557567
assert isinstance(request, Awaitable)

0 commit comments

Comments
 (0)