Skip to content

[RFC] PP and EP overlap in DualPipeV #1682

@H-Huang

Description

@H-Huang

Background

In torch.distributed.pipelining, each schedule is represented by a table of schedule operations, for example:

from torch.distributed.pipelining._schedule_visualizer import get_schedule_ops, visualize_schedule
schedule_ops = get_schedule_ops(schedule="GPipe", pp_degree=4, num_microbatches=4)
for row in schedule_ops:
   print(row)
visualize_schedule(schedule_ops, "GPipe.png")
# output
[0F0, 0F1, 0F2, 0F3, None, None, None, None, None, None, None, None, None, 0B0, 0B1, 0B2, 0B3]
[None, 1F0, 1F1, 1F2, 1F3, None, None, None, None, None, None, 1B0, 1B1, 1B2, 1B3]
[None, None, 2F0, 2F1, 2F2, 2F3, None, None, None, 2B0, 2B1, 2B2, 2B3]
[None, None, None, 3F0, 3F1, 3F2, 3F3, 3B0, 3B1, 3B2, 3B3]

Each row represents the operation performed by that rank. We can also see this visually:

Image

Currently, we have the DualPipeV schedule added to torch.distributed.pipelining

Image

This introduces a new operation which is a combined forward and backward of two stages, for example

(2F5;5B1)OVERLAP_F_B

represents Stage 2 Forward microbatch 5, and stage 5 Backward microbatch 1. However, we have a gap that we currently do not overlap the stage operations together as they do in the original deepseek paper.

Image

As the default, we just currently run the forward, then backward, without overlapping the all2alls required in dispatch and combine. This RFC details the steps to optimize this.

Execution Plan

Step 1

We need to update the schedule execution runtime to allow us to call custom functions.

WIP PR (pytorch/pytorch#162016)

This will allow us to override the current naive implementation of OVERLAP_F_B with a custom implementation. We will implement this new method in torchtitan deepseek_v3 model implementation.

Step 2

Implementing the OVERLAP_F_B custom function.

Each stage in PP consists of one or more TransformerBlocks. A TransformerBlock in DeepSeek looks like this*:

Image

*exceptions are for stage0 and stageN-1 which hold the embedding and output layers respectively

To effectively overlap the A2A communication, we have two options of updating the TransformerBlock implementation:

2a. Use current forward/backward methods.

One way of doing this is to add locks to the current forward methods. This allows us to essentially “pause” the middle of the forward computation of stage 1 to allow the A2A and run the computation of the stage 2. Stage 2 then needs to signal to stage 1 when it is performing the A2A, so stage 1 can resume running, and so on.

Pro: uses existing forward, backward pass should work with slight modification
Con: Not compilable, also muddles non-PP logic.

2b. Rewrite forward/backward for this case.

What I described above can be logically implemented as a separate set of methods. We can just call these methods as part of our custom callback logic implemented in step 1.

Pro: More flexible to make arbitrary changes.
Con: Additional implementation specific only for PP cases.

I am working on option 1 to more quickly prototype and get perf numbers, but if I become restricted I will do option 2.

Step 3

We hope to have a working version of the overlap_f_b within the next few weeks at which point we need to do validation on:

  • Numerics
  • Performance
  • Memory

In particular, we need to ensure the comm and compute are overlapped, use DeepSeekV3 profile as a reference. Also in PP, we need to be very careful about holding memory for longer than needed.

cc @fegin @tianyu-l @wwwjn @vwxyzjn @rakkit @wconstab @vishal9-team

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions