Skip to content

Commit c77bb54

Browse files
conver334sunnweiwei
authored andcommitted
[perf, data] feat: DP workload balance (volcengine#3605)
### What does this PR do? Mitigate workload imbalance in DP. As shown in the figure below, all ranks must synchronize after mini batch in DP. Stragglers with longer sequences delay all workers. ![Workload_balance_for_skewed_data length_in_RL_training-SimiaoZhang-9_page-0001](https://github.com/user-attachments/assets/f5bffd63-cb00-40df-96e0-5042e81400b8) ![Workload_balance_for_skewed_data length_in_RL_training-SimiaoZhang-10_page-0001](https://github.com/user-attachments/assets/165b8cc1-ec1d-4c6c-9151-674d53172bc4) ![Workload_balance_for_skewed_data length_in_RL_training-SimiaoZhang-12_page-0001](https://github.com/user-attachments/assets/3f79b371-c102-4596-b5a4-fb8348eb75e3) > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: volcengine#3401 - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. The line with the suffix `Balance` in the figure below can get better MFU in Qwen2.5-Math-7 GRPO. <img width="5056" height="2656" alt="W B Chart 2025_9_24 16_52_24" src="https://github.com/user-attachments/assets/b83bd7a2-3c74-4a09-8212-2f9b754c4ef1" /> ### API and Usage Example split Data to n workload balanced chunks ```python _balance_data_proto(DataProto_obj, chunks) ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. As shown in the figure, the leftmost side shows the unsplit data with a global batch size of 16. When DP = 2, existing methods directly split the batch into two ranks sequentially. You can see that in this case, rank 0 receives more tokens than rank 1. The rightmost side shows our design. We model the workload generated by each data entry and use the Karmarkar-Karp algorithm to split the batch into two equal parts, ensuring that the total workload of each part is as close as possible. The workload can be calculated using the FLOPS formula in verl. Here, we roughly estimate and hardcode the FLOPs by `seqlens**2 + seqlens * 24576` (Attention+MLP of 7B model). ![Workload_balance_for_skewed_data length_in_RL_training-SimiaoZhang-21_page-0001](https://github.com/user-attachments/assets/30d3376c-7970-4d62-947c-f25c6d6224d4) ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 0640ff5 commit c77bb54

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

verl/trainer/ppo/ray_trainer.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from verl.utils.debug import marked_timer
5858
from verl.utils.metric import reduce_metrics
5959
from verl.utils.rollout_skip import RolloutSkip
60-
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
60+
from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance
6161
from verl.utils.torch_functional import masked_mean
6262
from verl.utils.tracking import ValidationGenerationsLogger
6363

@@ -914,15 +914,35 @@ def _stop_profiling(self, do_profile: bool) -> None:
914914
if self.use_rm:
915915
self.rm_wg.stop_profile()
916916

917-
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
917+
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False):
918918
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
919919
attention_mask = batch.batch["attention_mask"]
920920
batch_size = attention_mask.shape[0]
921-
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
921+
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,)
922+
global_seqlen_lst = calculate_workload(global_seqlen_lst)
922923
world_size = self.actor_rollout_wg.world_size
923-
global_partition_lst = get_seqlen_balanced_partitions(
924-
global_seqlen_lst, k_partitions=world_size, equal_size=True
925-
)
924+
if keep_minibatch:
925+
# Decouple the DP balancing and mini-batching.
926+
minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size")
927+
minibatch_num = len(global_seqlen_lst) // minibatch_size
928+
global_partition_lst = [[] for _ in range(world_size)]
929+
for i in range(minibatch_num):
930+
rearrange_minibatch_lst = get_seqlen_balanced_partitions(
931+
global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size],
932+
k_partitions=world_size,
933+
equal_size=True,
934+
)
935+
for j, part in enumerate(rearrange_minibatch_lst):
936+
global_partition_lst[j].extend([x + minibatch_size * i for x in part])
937+
else:
938+
global_partition_lst = get_seqlen_balanced_partitions(
939+
global_seqlen_lst, k_partitions=world_size, equal_size=True
940+
)
941+
# Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
942+
for idx, partition in enumerate(global_partition_lst):
943+
partition.sort(key=lambda x: (global_seqlen_lst[x], x))
944+
ordered_partition = partition[::2] + partition[1::2][::-1]
945+
global_partition_lst[idx] = ordered_partition
926946
# reorder based on index. The data will be automatically equally partitioned by dispatch function
927947
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
928948
batch.reorder(global_idx)
@@ -1103,7 +1123,6 @@ def fit(self):
11031123
# NOTE: This usually changes the order of data in the `batch`,
11041124
# which won't affect the advantage calculation (since it's based on uid),
11051125
# but might affect the loss calculation (due to the change of mini-batching).
1106-
# TODO: Decouple the DP balancing and mini-batching.
11071126
if self.config.trainer.balance_batch:
11081127
self._balance_batch(batch, metrics=metrics)
11091128

verl/utils/seqlen_balancing.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
from verl.utils.device import get_device_name
2525

2626

27+
def calculate_workload(seqlen_list: list[int]):
28+
"""
29+
Calculate the workload for a dense transformer block based on sequence length.
30+
FLOPs = 12 * hidden_size^2 * seqlen + 2 * hidden_size * seqlen^2
31+
Hardcodes the constants by a 7B model (hidden_size=4096),
32+
so the FLOPs are propotional to (6 * 4096 * seqlen + seqlen^2).
33+
"""
34+
return 24576 * seqlen_list + seqlen_list**2
35+
36+
2737
def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool):
2838
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
2939
class Set:
@@ -298,20 +308,22 @@ def rearrange_micro_batches(
298308
if num_batches_divided_by is not None:
299309
num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by)
300310

301-
seq_len_effective = seq_len_effective.tolist()
302311
assert num_micro_batches <= len(seq_len_effective)
303312

304-
micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)
313+
workloads = calculate_workload(seq_len_effective)
314+
micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False)
305315

306316
if use_dynamic_bsz_balance:
307317
# Use the sum of squared sequence lengths to approximate attention computation workload
308318
micro_bsz_idx.sort(
309319
key=lambda partition: (
310-
sum(seq_len_effective[idx] ** 2 for idx in partition),
311-
min(partition) if partition else 0,
320+
sum(workloads[idx] for idx in partition),
321+
partition[0] if partition else 0,
312322
),
313323
reverse=True,
314324
)
325+
# Place smaller micro-batches at both ends to reduce the bubbles exposed during the warm-up and cool-down.
326+
micro_bsz_idx = micro_bsz_idx[::2][::-1] + micro_bsz_idx[1::2]
315327

316328
micro_batches = []
317329

0 commit comments

Comments
 (0)