Skip to content

Conversation

@garyzhang99
Copy link
Collaborator

@garyzhang99 garyzhang99 commented Oct 22, 2025

Description

The Truncate Large IS feature addresses training instability issues in PPO caused by computation errors between VLLM (used for old_logprob) and transformer (used for logprob) calculations.

Problem Statement

When probabilities are very small, even minor computation errors between VLLM and transformer can lead to extremely large importance sampling ratios. This is particularly problematic when:

  1. Both old_logprob and logprob are small (representing low probabilities)
  2. The advantage is negative
  3. The ratio becomes very large (e.g., tens or hundreds) due to computation discrepancies

In the standard PPO implementation with one-sided clipping (multiply advantage first, then clip and take max), this can cause:

  • Large negative gradients
  • Training instability
  • Issues like repetition in generated text

Solution

The Truncate Large IS feature adds a pre-clipping step that truncates the importance sampling ratio to a configurable range before it's used in loss computation. This is similar to CISPO's approach to handling importance sampling ratios.

Configuration

Parameters

  • truncate_large_is (bool, default: False): Enable/disable the truncate large IS feature
  • truncate_is_range_low (float, default: 0.0): Lower bound for IS ratio truncation
  • truncate_is_range_high (float, default: 2.0): Upper bound for IS ratio truncation

Usage in YAML Config

algorithm:
  algorithm_type: grpo
  ...
  policy_loss_fn_config:
    clip_range: 0.2
    loss_agg_mode: token-mean
    # Enable truncate large IS
    truncate_large_is: true
    truncate_is_range_low: 0.0
    truncate_is_range_high: 2.0

Recommended Settings

For general use with potential VLLM/transformer computation errors:

truncate_large_is: true
truncate_is_range_low: 0.0
truncate_is_range_high: 2.0

Implementation Details

The truncation is applied as follows:

# Compute ratio
ratio = torch.exp(logprob - old_logprob)

# Apply truncation if enabled
if truncate_large_is:
    ratio = torch.clamp(ratio, truncate_is_range_low, truncate_is_range_high)

# Continue with standard PPO loss computation
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - clip_range_low, 1.0 + clip_range_high)
pg_loss = torch.max(pg_losses, pg_losses2)

This ensures that extremely large (or small) ratios are bounded before they interact with the advantage values.

When to Use

Consider enabling this feature when:

  1. Training exhibits instability (loss spikes, large grad norms)
  2. Generated text shows repetition or degeneration
  3. You observe extremely large IS ratios in training

Checklist

Please check the following items before code is ready to be reviewed.

  • Code has passed all tests
  • Docstrings have been added/updated in Google Style
  • Documentation has been updated
  • Code is ready for review

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @garyzhang99, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial feature to address training instability in Proximal Policy Optimization (PPO) algorithms, particularly when dealing with discrepancies in probability calculations between different model components like VLLM and transformer. By implementing a pre-clipping mechanism for the importance sampling ratio, the feature prevents extremely large or small ratios from causing large negative gradients and subsequent training issues such as repetition in generated text. This configurable truncation provides a robust way to stabilize PPO training, especially in scenarios where minor computational errors can have significant impacts.

Highlights

  • PPO Policy Loss Stabilization: Implemented a 'Truncate Large IS' feature for PPO policy loss to enhance training stability by addressing computation errors between VLLM and transformer calculations.
  • Configurable IS Ratio Truncation: Introduced new configurable parameters: 'truncate_large_is' (boolean), 'truncate_is_range_low' (float), and 'truncate_is_range_high' (float) to control the importance sampling ratio truncation bounds.
  • IS Truncation Metric: Added a new metric, 'is_truncate_frac', to monitor the frequency and impact of importance sampling ratio truncation during training.
  • New Test Case: Extended the PPO policy loss test suite with a new test case, 'test_ppo_policy_loss_with_truncate_is', to validate the functionality of the IS truncation feature.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a 'Truncate Large IS' feature to the PPO policy loss function, aimed at improving training stability. The implementation of the feature itself is sound, adding the necessary configuration and logic for clamping the importance sampling ratio. However, a critical issue is that the unit test for this new feature is incomplete, with the core assertions commented out. This means the correctness of the feature is not being verified. I have also provided a couple of suggestions to improve code robustness and clarity in the implementation. The incomplete test must be addressed before this PR can be merged.

@garyzhang99
Copy link
Collaborator Author

/unittest-module-algorithm

@github-actions
Copy link

Summary

Tests 📝 Passed ✅ Failed ❌ Skipped ⏭️ Other ❓ Flaky 🍂 Duration ⏱️
14 14 0 0 0 0 4ms

Tests

Test Name Status Flaky Duration
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_batch_level_std_grpo 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_batch_level_step_wise_grpo_advantage 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_duplicate_grpo 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_advantage 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_correct_bias 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_reward_std 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_step_wise_grpo_advantage 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_dpo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_gspo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_mix_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_opmd_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_ppo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_ppo_policy_loss_with_truncate_is 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_sft_policy_loss 1ms

Github Test Reporter by CTRF 💚

@garyzhang99
Copy link
Collaborator Author

/unittest-module-algorithm

@github-actions
Copy link

Summary

Tests 📝 Passed ✅ Failed ❌ Skipped ⏭️ Other ❓ Flaky 🍂 Duration ⏱️
14 14 0 0 0 0 4ms

Tests

Test Name Status Flaky Duration
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_batch_level_std_grpo 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_batch_level_step_wise_grpo_advantage 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_duplicate_grpo 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_advantage 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_correct_bias 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_reward_std 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_step_wise_grpo_advantage 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_dpo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_gspo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_mix_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_opmd_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_ppo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_ppo_policy_loss_with_truncate_is 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_sft_policy_loss 1ms

Github Test Reporter by CTRF 💚

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants