-
Notifications
You must be signed in to change notification settings - Fork 502
Description
Background
In torch.distributed.pipelining, each schedule is represented by a table of schedule operations, for example:
from torch.distributed.pipelining._schedule_visualizer import get_schedule_ops, visualize_schedule
schedule_ops = get_schedule_ops(schedule="GPipe", pp_degree=4, num_microbatches=4)
for row in schedule_ops:
print(row)
visualize_schedule(schedule_ops, "GPipe.png")
# output
[0F0, 0F1, 0F2, 0F3, None, None, None, None, None, None, None, None, None, 0B0, 0B1, 0B2, 0B3]
[None, 1F0, 1F1, 1F2, 1F3, None, None, None, None, None, None, 1B0, 1B1, 1B2, 1B3]
[None, None, 2F0, 2F1, 2F2, 2F3, None, None, None, 2B0, 2B1, 2B2, 2B3]
[None, None, None, 3F0, 3F1, 3F2, 3F3, 3B0, 3B1, 3B2, 3B3]
Each row represents the operation performed by that rank. We can also see this visually:

Currently, we have the DualPipeV schedule added to torch.distributed.pipelining

This introduces a new operation which is a combined forward and backward of two stages, for example
(2F5;5B1)OVERLAP_F_B
represents Stage 2 Forward microbatch 5, and stage 5 Backward microbatch 1. However, we have a gap that we currently do not overlap the stage operations together as they do in the original deepseek paper.

As the default, we just currently run the forward, then backward, without overlapping the all2alls required in dispatch and combine. This RFC details the steps to optimize this.
Execution Plan
Step 1
We need to update the schedule execution runtime to allow us to call custom functions.
WIP PR (pytorch/pytorch#162016)
This will allow us to override the current naive implementation of OVERLAP_F_B with a custom implementation. We will implement this new method in torchtitan deepseek_v3 model implementation.
Step 2
Implementing the OVERLAP_F_B
custom function.
Each stage in PP consists of one or more TransformerBlocks. A TransformerBlock in DeepSeek looks like this*:

*exceptions are for stage0 and stageN-1 which hold the embedding and output layers respectively
To effectively overlap the A2A communication, we have two options of updating the TransformerBlock implementation:
2a. Use current forward/backward methods.
One way of doing this is to add locks to the current forward methods. This allows us to essentially “pause” the middle of the forward computation of stage 1 to allow the A2A and run the computation of the stage 2. Stage 2 then needs to signal to stage 1 when it is performing the A2A, so stage 1 can resume running, and so on.
Pro: uses existing forward, backward pass should work with slight modification
Con: Not compilable, also muddles non-PP logic.
2b. Rewrite forward/backward for this case.
What I described above can be logically implemented as a separate set of methods. We can just call these methods as part of our custom callback logic implemented in step 1.
Pro: More flexible to make arbitrary changes.
Con: Additional implementation specific only for PP cases.
I am working on option 1 to more quickly prototype and get perf numbers, but if I become restricted I will do option 2.
Step 3
We hope to have a working version of the overlap_f_b within the next few weeks at which point we need to do validation on:
- Numerics
- Performance
- Memory
In particular, we need to ensure the comm and compute are overlapped, use DeepSeekV3 profile as a reference. Also in PP, we need to be very careful about holding memory for longer than needed.
cc @fegin @tianyu-l @wwwjn @vwxyzjn @rakkit @wconstab @vishal9-team