Skip to content

Conversation

@puneeshkhanna
Copy link
Contributor

@puneeshkhanna puneeshkhanna commented Oct 7, 2025

What does this PR do?

Fixes data balance dp tokens logic when sequence parallelism is enabled. valid token this rank is all reduced over all ranks (including SP ranks as an example considering same tensor values twice when SP is 2) so we need to scale by full world size else loss plot will be incorrectly half of the values for SP-2 in comparison to SP-1 loss plot when data balance dp token is set to True.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

Verified loss plot of SP-1 vs SP-2 with data balance DP tokens and they match.
Without the fix, SP-2 loss will be half of SP-1 loss.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@puneeshkhanna
Copy link
Contributor Author

@vermouth1992 - please review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug in the loss calculation when data_balance_dp_token and sequence parallelism are enabled. The original logic used an incorrect scaling factor for the loss, causing it to be smaller when sequence parallelism was active. The fix simplifies the code by consistently using the total world size for scaling, which ensures that the calculated loss and resulting gradients are correct and consistent regardless of the sequence parallelism size. My analysis confirms that with this change, both the gradient updates and the logged loss values will be consistent across different sequence parallelism configurations, resolving the issue described.

@puneeshkhanna
Copy link
Contributor Author

@vermouth1992 - Did you get a chance to check this ? I think very important for loss plot to be correct with data balance DP tokens when SP is enabled.

@puneeshkhanna
Copy link
Contributor Author

@vermouth1992 @eric-haibin-lin - Did you get a chance to review this one ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant