Skip to content

Conversation

@conver334
Copy link
Contributor

@conver334 conver334 commented Sep 25, 2025

What does this PR do?

Mitigate workload imbalance in DP.

As shown in the figure below, all ranks must synchronize after mini batch in DP. Stragglers with longer sequences delay all workers.

Workload_balance_for_skewed_data length_in_RL_training-SimiaoZhang-9_page-0001

Workload_balance_for_skewed_data length_in_RL_training-SimiaoZhang-10_page-0001

Workload_balance_for_skewed_data length_in_RL_training-SimiaoZhang-12_page-0001

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: [model] feat: polish megatron engine #3401
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

The line with the suffix Balance in the figure below can get better MFU in Qwen2.5-Math-7 GRPO.
W B Chart 2025_9_24 16_52_24

API and Usage Example

split Data to n workload balanced chunks

_balance_data_proto(DataProto_obj, chunks)

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

As shown in the figure, the leftmost side shows the unsplit data with a global batch size of 16.

When DP = 2, existing methods directly split the batch into two ranks sequentially. You can see that in this case, rank 0 receives more tokens than rank 1.

The rightmost side shows our design. We model the workload generated by each data entry and use the Karmarkar-Karp algorithm to split the batch into two equal parts, ensuring that the total workload of each part is as close as possible.

The workload can be calculated using the FLOPS formula in verl. Here, we roughly estimate and hardcode the FLOPs by seqlens**2 + seqlens * 24576 (Attention+MLP of 7B model).

Workload_balance_for_skewed_data length_in_RL_training-SimiaoZhang-21_page-0001

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


root seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces workload balancing for Data Parallelism to mitigate the impact of stragglers with longer sequences. The core change is the _balance_data_proto function, which reorders data within a batch using the Karmarkar-Karp algorithm to equalize workload across DP ranks before splitting. My review has identified a critical bug that breaks data splitting when auto-padding is enabled, and a high-severity issue related to a hardcoded value that limits the feature's general applicability. Addressing these points will improve the correctness and maintainability of the implementation.

Comment on lines 81 to 82
# approximate workload of transformer block
workloads = seqlens**2 + seqlens * 33024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The workload calculation uses a hardcoded magic number 33024. According to the PR description, this is specific to a 7B model's MLP layer. This limits the general applicability of this workload balancing feature to other model architectures.

The function signature includes an unused model_config parameter, which should be used to pass model-specific configuration, such as the MLP size, to make the workload calculation more flexible and accurate for different models.

Consider refactoring this to use the model_config parameter. For example:

        # approximate workload of transformer block
        mlp_factor = model_config.get("mlp_workload_factor", 33024) if model_config else 33024
        workloads = seqlens**2 + seqlens * mlp_factor

@ISEEKYAN ISEEKYAN merged commit 62221fa into volcengine:main Oct 23, 2025
66 of 67 checks passed
vermouth1992 added a commit that referenced this pull request Oct 23, 2025
sunnweiwei pushed a commit to sunnweiwei/verl that referenced this pull request Oct 23, 2025
### What does this PR do?

Mitigate workload imbalance in DP.

As shown in the figure below, all ranks must synchronize after mini
batch in DP. Stragglers with longer sequences delay all workers.

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-9_page-0001](https://github.com/user-attachments/assets/f5bffd63-cb00-40df-96e0-5042e81400b8)

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-10_page-0001](https://github.com/user-attachments/assets/165b8cc1-ec1d-4c6c-9151-674d53172bc4)

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-12_page-0001](https://github.com/user-attachments/assets/3f79b371-c102-4596-b5a4-fb8348eb75e3)



> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here:
volcengine#3401
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

The line with the suffix `Balance` in the figure below can get better
MFU in Qwen2.5-Math-7 GRPO.
<img width="5056" height="2656" alt="W B Chart 2025_9_24 16_52_24"
src="https://github.com/user-attachments/assets/b83bd7a2-3c74-4a09-8212-2f9b754c4ef1"
/>


### API and Usage Example

split Data to n workload balanced chunks 
```python
_balance_data_proto(DataProto_obj, chunks)
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

As shown in the figure, the leftmost side shows the unsplit data with a
global batch size of 16.

When DP = 2, existing methods directly split the batch into two ranks
sequentially. You can see that in this case, rank 0 receives more tokens
than rank 1.

The rightmost side shows our design. We model the workload generated by
each data entry and use the Karmarkar-Karp algorithm to split the batch
into two equal parts, ensuring that the total workload of each part is
as close as possible.

The workload can be calculated using the FLOPS formula in verl. Here, we
roughly estimate and hardcode the FLOPs by `seqlens**2 + seqlens *
24576` (Attention+MLP of 7B model).

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-21_page-0001](https://github.com/user-attachments/assets/30d3376c-7970-4d62-947c-f25c6d6224d4)

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
wangboxiong320 pushed a commit to wangboxiong320/verl that referenced this pull request Nov 1, 2025
### What does this PR do?

Mitigate workload imbalance in DP.

As shown in the figure below, all ranks must synchronize after mini
batch in DP. Stragglers with longer sequences delay all workers.

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-9_page-0001](https://github.com/user-attachments/assets/f5bffd63-cb00-40df-96e0-5042e81400b8)

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-10_page-0001](https://github.com/user-attachments/assets/165b8cc1-ec1d-4c6c-9151-674d53172bc4)

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-12_page-0001](https://github.com/user-attachments/assets/3f79b371-c102-4596-b5a4-fb8348eb75e3)



> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here:
volcengine#3401
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

The line with the suffix `Balance` in the figure below can get better
MFU in Qwen2.5-Math-7 GRPO.
<img width="5056" height="2656" alt="W B Chart 2025_9_24 16_52_24"
src="https://github.com/user-attachments/assets/b83bd7a2-3c74-4a09-8212-2f9b754c4ef1"
/>


### API and Usage Example

split Data to n workload balanced chunks 
```python
_balance_data_proto(DataProto_obj, chunks)
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

As shown in the figure, the leftmost side shows the unsplit data with a
global batch size of 16.

When DP = 2, existing methods directly split the batch into two ranks
sequentially. You can see that in this case, rank 0 receives more tokens
than rank 1.

The rightmost side shows our design. We model the workload generated by
each data entry and use the Karmarkar-Karp algorithm to split the batch
into two equal parts, ensuring that the total workload of each part is
as close as possible.

The workload can be calculated using the FLOPS formula in verl. Here, we
roughly estimate and hardcode the FLOPs by `seqlens**2 + seqlens *
24576` (Attention+MLP of 7B model).

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-21_page-0001](https://github.com/user-attachments/assets/30d3376c-7970-4d62-947c-f25c6d6224d4)

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
NenoL2001 pushed a commit to NenoL2001/verl that referenced this pull request Nov 3, 2025
### What does this PR do?

Mitigate workload imbalance in DP.

As shown in the figure below, all ranks must synchronize after mini
batch in DP. Stragglers with longer sequences delay all workers.

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-9_page-0001](https://github.com/user-attachments/assets/f5bffd63-cb00-40df-96e0-5042e81400b8)

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-10_page-0001](https://github.com/user-attachments/assets/165b8cc1-ec1d-4c6c-9151-674d53172bc4)

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-12_page-0001](https://github.com/user-attachments/assets/3f79b371-c102-4596-b5a4-fb8348eb75e3)



> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here:
volcengine#3401
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

The line with the suffix `Balance` in the figure below can get better
MFU in Qwen2.5-Math-7 GRPO.
<img width="5056" height="2656" alt="W B Chart 2025_9_24 16_52_24"
src="https://github.com/user-attachments/assets/b83bd7a2-3c74-4a09-8212-2f9b754c4ef1"
/>


### API and Usage Example

split Data to n workload balanced chunks 
```python
_balance_data_proto(DataProto_obj, chunks)
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

As shown in the figure, the leftmost side shows the unsplit data with a
global batch size of 16.

When DP = 2, existing methods directly split the batch into two ranks
sequentially. You can see that in this case, rank 0 receives more tokens
than rank 1.

The rightmost side shows our design. We model the workload generated by
each data entry and use the Karmarkar-Karp algorithm to split the batch
into two equal parts, ensuring that the total workload of each part is
as close as possible.

The workload can be calculated using the FLOPS formula in verl. Here, we
roughly estimate and hardcode the FLOPs by `seqlens**2 + seqlens *
24576` (Attention+MLP of 7B model).

![Workload_balance_for_skewed_data
length_in_RL_training-SimiaoZhang-21_page-0001](https://github.com/user-attachments/assets/30d3376c-7970-4d62-947c-f25c6d6224d4)

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants