Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,18 +1194,13 @@ def compute_policy_loss_vanilla(
clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
"clip_ratio_c", 3.0
)


cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high

assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)


negative_approx_kl = log_prob - old_log_prob
# Clamp negative_approx_kl for stability
Expand Down Expand Up @@ -1249,6 +1244,92 @@ def compute_policy_loss_vanilla(
}
return pg_loss, pg_metrics

@register_policy_loss("cfpo")
def compute_policy_loss_cfpo(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[ActorConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:

"""

CFPO uses a quadratic penalty instead of clipping:
L = -advantage * ratio + |advantage| * (ratio - 1)^2 / (2 * epsilon)

Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
config (Optional[AlgoConfig]):
Algorithm configuration object.

Returns:
pg_loss (torch.Tensor): The policy gradient loss.
pg_clipfrac (torch.Tensor): Fraction of tokens where ratio exceeded upper bound.
ppo_kl (torch.Tensor): Approximate KL divergence.
pg_clipfrac_lower (torch.Tensor): Fraction of tokens where ratio exceeded lower bound.
"""
assert config is not None
assert not isinstance(config, AlgoConfig)
clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio

cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high


negative_approx_kl = log_prob - old_log_prob
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)


# CFPO loss: -advantage * ratio + |advantage| * (ratio - 1)^2 / (2 * epsilon)
pg_losses = -(
advantages * ratio -
torch.abs(advantages) * torch.pow(ratio - 1, 2) / (2 * cliprange)
)

if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights

pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

# Compute clip fraction metrics for monitoring
is_ratio_high = ratio > 1 + cliprange_high
is_ratio_low = ratio < 1 - cliprange_low

pg_clipfrac = verl_F.masked_mean(
((is_ratio_high & (advantages > 0)) | (is_ratio_low & (advantages < 0))).float(),
response_mask
)

pg_clipfrac_lower = verl_F.masked_mean(
(is_ratio_low & (advantages < 0)).float(),
response_mask
)
pg_metrics = {
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
}

return pg_loss, pg_metrics



@register_policy_loss("gspo")
def compute_policy_loss_gspo(
Expand Down