Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/algorithm/policy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,29 @@ def test_mix_policy_loss(self):
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))

def test_ppo_policy_loss_with_truncate_is(self):
"""Test PPO policy loss with truncate large IS enabled."""
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
# Enable truncate large IS with custom bounds
policy_loss_fn_args["truncate_large_is"] = True
policy_loss_fn_args["truncate_is_range_low"] = 0.0
policy_loss_fn_args["truncate_is_range_high"] = 2.0
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)

# Expected values with IS truncation enabled
# Need calculations for these values
# ppo_loss_truncated = torch.tensor(0.27213451266288757)
# pg_clipfrac_truncated = torch.tensor(0.36458331346511841)
# ppo_kl_truncated = torch.tensor(-0.21663446724414825)

# self.assertTrue(torch.allclose(loss, ppo_loss_truncated))
# self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_truncated))
# self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_truncated))
# self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated))
# Check that IS truncation metric is present
self.assertIn("is_truncate_frac", metrics)
self.assertGreaterEqual(metrics["is_truncate_frac"], 0.0)
self.assertLessEqual(metrics["is_truncate_frac"], 1.0)
54 changes: 54 additions & 0 deletions trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,24 @@ def __init__(
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
loss_agg_mode: Optional[str] = "token-mean",
truncate_large_is: bool = False,
truncate_is_range_low: Optional[float] = None,
truncate_is_range_high: Optional[float] = None,
) -> None:
"""
Initialize PPO policy loss function.

Args:
backend: Backend framework (default: "verl")
clip_range: Symmetric clipping range for PPO
clip_range_low: Lower bound for clipping (1.0 - clip_range_low)
clip_range_high: Upper bound for clipping (1.0 + clip_range_high)
loss_agg_mode: Loss aggregation mode (default: "token-mean")
truncate_large_is: Whether to truncate large importance sampling ratios
to handle computation errors between VLLM and transformer calculations
truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0)
truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0)
"""
super().__init__(backend=backend)
if clip_range_low is None:
self.clip_range_low = clip_range
Expand All @@ -34,6 +51,23 @@ def __init__(
assert self.clip_range_high is not None, "clip_range_high must be specified."
self.loss_agg_mode = loss_agg_mode

# Truncate large IS configuration
self.truncate_large_is = truncate_large_is
if truncate_large_is:
self.truncate_is_range_low = (
truncate_is_range_low if truncate_is_range_low is not None else 0.0
)
self.truncate_is_range_high = (
truncate_is_range_high if truncate_is_range_high is not None else 2.0
)
assert self.truncate_is_range_low >= 0.0, "truncate_is_range_low must be non-negative."
assert (
self.truncate_is_range_high > self.truncate_is_range_low
), "truncate_is_range_high must be greater than truncate_is_range_low."
else:
self.truncate_is_range_low = None
self.truncate_is_range_high = None

def __call__( # type: ignore
self,
logprob: torch.Tensor,
Expand All @@ -46,6 +80,18 @@ def __call__( # type: ignore
ratio = torch.exp(negative_approx_kl)
ppo_kl = masked_mean(-negative_approx_kl, action_mask)

# Truncate large IS ratios if enabled
# This helps stabilize training when there are computation errors between
# VLLM and transformer calculations, especially for small probabilities
if self.truncate_large_is:
ratio_before_truncate = ratio.clone()
ratio = torch.clamp(ratio, self.truncate_is_range_low, self.truncate_is_range_high)
# Track how often truncation occurs
is_truncated = torch.ne(ratio, ratio_before_truncate).float()
is_truncate_frac = masked_mean(is_truncated, action_mask)
else:
is_truncate_frac = torch.tensor(0.0)

pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
Expand All @@ -60,11 +106,19 @@ def __call__( # type: ignore
"ppo_kl": ppo_kl.detach().item(),
"pg_loss": pg_loss.detach().item(),
}

# Add IS truncation metrics if enabled
if self.truncate_large_is:
metrics["is_truncate_frac"] = is_truncate_frac.detach().item()

return pg_loss, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
"loss_agg_mode": "token-mean",
"truncate_large_is": False,
"truncate_is_range_low": 0.0,
"truncate_is_range_high": 2.0,
}