Skip to content

Conversation

terrykong
Copy link
Contributor

@terrykong terrykong commented Sep 30, 2025

closes #1227

########################################################
# QWEN
########################################################
# TP=1
uv run --env-file .env ./examples/run_rm.py \
    cluster.gpus_per_node=8 \
    +policy.dtensor_cfg._v2=true \
    policy.dtensor_cfg.tensor_parallel_size=1 \
    rm.val_at_start=true \
    rm.max_num_steps=0 \
    checkpointing.enabled=false \
    policy.model_name=Skywork/Skywork-Reward-V2-Qwen3-0.6B
# Validation accuracy: 0.6331


########################################################

# TP=4
uv run --env-file .env ./examples/run_rm.py \
    cluster.gpus_per_node=8 \
    +policy.dtensor_cfg._v2=true \
    policy.dtensor_cfg.tensor_parallel_size=4 \
    rm.val_at_start=true \
    rm.max_num_steps=0 \
    checkpointing.enabled=false \
    policy.model_name=Skywork/Skywork-Reward-V2-Qwen3-0.6B
# Validation accuracy: 0.6777


# TP=4 (custom parallel plan)
uv run --env-file .env ./examples/run_rm.py \
    cluster.gpus_per_node=8 \
    +policy.dtensor_cfg._v2=true \
    policy.dtensor_cfg.tensor_parallel_size=4 \
    rm.val_at_start=true \
    rm.max_num_steps=0 \
    checkpointing.enabled=false \
    policy.model_name=Skywork/Skywork-Reward-V2-Qwen3-0.6B \
    policy.dtensor_cfg.custom_parallel_plan=examples.custom_parallel.qwen_model_tp_plan_stable
• Validation accuracy: 0.6331

Summary by CodeRabbit

  • New Features

    • Introduced an optional, numerically stable parallel plan configuration (qwen_model_tp_plan_stable) for model execution. Enable via policy.dtensor_cfg.custom_parallel_plan.
  • Documentation

    • Added inline guidance on when and how to enable the new plan, including notes on layout settings and usage considerations.

@terrykong terrykong requested a review from joyang-nv September 30, 2025 06:37
@terrykong terrykong requested a review from a team as a code owner September 30, 2025 06:37
Signed-off-by: Terry Kong <[email protected]>
Copy link
Contributor

coderabbitai bot commented Sep 30, 2025

📝 Walkthrough

Walkthrough

Introduces a new public configuration variable qwen_model_tp_plan_stable in examples/custom_parallel.py that defines a numerically stable tensor-parallel plan with adjusted per-layer layout settings. Existing custom_parallel_plan and imports remain unchanged. A note explains default plan instability and how to enable the new plan via policy.dtensor_cfg.custom_parallel_plan.

Changes

Cohort / File(s) Summary
Config: Stable TP plan addition
`examples/custom_parallel.py`
Added public variable `qwen_model_tp_plan_stable` defining a numerically stable tensor-parallel plan (adjusted input/output layouts, use_local_output). Added note on instability of default plans and how to enable via `policy.dtensor_cfg.custom_parallel_plan`. No changes to existing `custom_parallel_plan` or imports.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Pre-merge checks and finishing touches

✅ Passed checks (4 passed)
Check name Status Explanation
Title Check ✅ Passed The title succinctly describes the primary change by indicating that a more numerically stable custom plan for the Qwen model has been introduced, follows the ‘feat:’ prefix convention, and is clear and specific enough for teammates to understand the update at a glance.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
Test Results For Major Changes ✅ Passed This PR introduces a major change affecting numerical stability in tensor-parallel configurations for the Qwen model. The PR description includes comprehensive test results showing validation accuracy across three configurations: TP=1 baseline (0.6331), TP=4 with default plan (0.6777 - showing divergence), and TP=4 with the new stable plan (0.6331 - matching baseline). The testing information includes exact commands, model specifications, and complete configuration flags, clearly demonstrating that the new parallel plan prevents numerical regression when scaling across multiple GPUs.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch tk/numerically-stable-qwen-plan

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🧪 Early access (Sonnet 4.5): enabled

We are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience.

Note:

  • Public repositories are always opted into early access features.
  • You can enable or disable early access features from the CodeRabbit UI or by updating the CodeRabbit configuration file.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
examples/custom_parallel.py (1)

30-39: Clarify what "increases" with larger TP.

The phrase "numerically unstable and tend to increase with larger TP" is ambiguous—does instability increase, or errors/variance increase?

Consider rephrasing for clarity:

-- Default plans that keep attention output proj and mlp downproj RowwiseParallel are numerically 
-  unstable and tend to increase with larger TP (e.g., TP >= 4).
+- Default plans that keep attention output proj and mlp downproj RowwiseParallel exhibit numerical 
+  instability that worsens with larger TP (e.g., TP >= 4).
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c2b36f2 and 8528150.

📒 Files selected for processing (1)
  • examples/custom_parallel.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • examples/custom_parallel.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post automodel integration comment / Comment on PR
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (3)
examples/custom_parallel.py (3)

1-13: LGTM!

Copyright header follows NVIDIA standards with correct year.


29-29: LGTM!

Good visual separation between the two configuration blocks.


40-60: Approve tensor-parallel plan configuration.

The change to ColwiseParallel for o_proj and down_proj addresses numerical stability and restores TP=4 accuracy to the TP=1 baseline. Using snake_case for qwen_model_tp_plan_stable aligns with existing examples (custom_parallel_plan) and requires no renaming.

Signed-off-by: Terry Kong <[email protected]>
Copy link
Member

@joyang-nv joyang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified locally. Looks good to me!

@terrykong terrykong added the CI:docs Run doctest label Oct 3, 2025
@terrykong terrykong enabled auto-merge (squash) October 3, 2025 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI:docs Run doctest
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inconsistent DTensorPolicyWorker logits with TP>1 and bf16
2 participants