diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 2039fe56f62..458734b3207 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -1194,18 +1194,13 @@ def compute_policy_loss_vanilla( clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio - clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. - "clip_ratio_c", 3.0 - ) + cliprange = clip_ratio cliprange_low = clip_ratio_low cliprange_high = clip_ratio_high - assert clip_ratio_c > 1.0, ( - "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," - + f" but get the value: {clip_ratio_c}." - ) + negative_approx_kl = log_prob - old_log_prob # Clamp negative_approx_kl for stability @@ -1249,6 +1244,92 @@ def compute_policy_loss_vanilla( } return pg_loss, pg_metrics +@register_policy_loss("cfpo") +def compute_policy_loss_cfpo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + + """ + + CFPO uses a quadratic penalty instead of clipping: + L = -advantage * ratio + |advantage| * (ratio - 1)^2 / (2 * epsilon) + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + config (Optional[AlgoConfig]): + Algorithm configuration object. + + Returns: + pg_loss (torch.Tensor): The policy gradient loss. + pg_clipfrac (torch.Tensor): Fraction of tokens where ratio exceeded upper bound. + ppo_kl (torch.Tensor): Approximate KL divergence. + pg_clipfrac_lower (torch.Tensor): Fraction of tokens where ratio exceeded lower bound. + """ + assert config is not None + assert not isinstance(config, AlgoConfig) + clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio + + cliprange = clip_ratio + cliprange_low = clip_ratio_low + cliprange_high = clip_ratio_high + + + negative_approx_kl = log_prob - old_log_prob + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + + # CFPO loss: -advantage * ratio + |advantage| * (ratio - 1)^2 / (2 * epsilon) + pg_losses = -( + advantages * ratio - + torch.abs(advantages) * torch.pow(ratio - 1, 2) / (2 * cliprange) + ) + + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + # Compute clip fraction metrics for monitoring + is_ratio_high = ratio > 1 + cliprange_high + is_ratio_low = ratio < 1 - cliprange_low + + pg_clipfrac = verl_F.masked_mean( + ((is_ratio_high & (advantages > 0)) | (is_ratio_low & (advantages < 0))).float(), + response_mask + ) + + pg_clipfrac_lower = verl_F.masked_mean( + (is_ratio_low & (advantages < 0)).float(), + response_mask + ) + pg_metrics = { + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + + return pg_loss, pg_metrics + + @register_policy_loss("gspo") def compute_policy_loss_gspo(