-
Notifications
You must be signed in to change notification settings - Fork 3.3k
adding CFPO to verl #5252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
adding CFPO to verl #5252
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1249,6 +1249,121 @@ def compute_policy_loss_vanilla( | |
| } | ||
| return pg_loss, pg_metrics | ||
|
|
||
| @register_policy_loss("cfpo") | ||
| def compute_policy_loss_cfpo( | ||
| old_log_prob: torch.Tensor, | ||
| log_prob: torch.Tensor, | ||
| advantages: torch.Tensor, | ||
| response_mask: torch.Tensor, | ||
| loss_agg_mode: str = "token-mean", | ||
| config: Optional[ActorConfig] = None, | ||
| rollout_is_weights: torch.Tensor | None = None, | ||
| ) -> tuple[torch.Tensor, dict[str, Any]]: | ||
|
|
||
| """ | ||
|
|
||
| CFPO uses a quadratic penalty instead of clipping: | ||
| L = -advantage * ratio + |advantage| * (ratio - 1)^2 / (2 * epsilon) | ||
|
|
||
| Args: | ||
| old_log_prob (torch.Tensor): | ||
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). | ||
| log_prob (torch.Tensor): | ||
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). | ||
| advantages (torch.Tensor): | ||
| Advantage estimates for each action, shape (batch_size, response_length). | ||
| response_mask (torch.Tensor): | ||
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). | ||
| loss_agg_mode (str, optional): | ||
| Aggregation mode for `agg_loss`. Defaults to "token-mean". | ||
| config (Optional[AlgoConfig]): | ||
| Algorithm configuration object. | ||
|
|
||
| Returns: | ||
| pg_loss (torch.Tensor): The policy gradient loss. | ||
| pg_clipfrac (torch.Tensor): Fraction of tokens where ratio exceeded upper bound. | ||
| ppo_kl (torch.Tensor): Approximate KL divergence. | ||
| pg_clipfrac_lower (torch.Tensor): Fraction of tokens where ratio exceeded lower bound. | ||
| """ | ||
| assert config is not None | ||
| assert not isinstance(config, AlgoConfig) | ||
| clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. | ||
| clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio | ||
| clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio | ||
| clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. | ||
| "clip_ratio_c", 3.0 | ||
| ) | ||
|
|
||
| cliprange = clip_ratio | ||
| cliprange_low = clip_ratio_low | ||
| cliprange_high = clip_ratio_high | ||
|
|
||
| assert clip_ratio_c > 1.0, ( | ||
| "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," | ||
| + f" but get the value: {clip_ratio_c}." | ||
| ) | ||
|
|
||
|
|
||
|
|
||
| negative_approx_kl = log_prob - old_log_prob | ||
| negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) | ||
| ratio = torch.exp(negative_approx_kl) | ||
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) | ||
|
|
||
| ### This code is for logging purposes | ||
| pg_losses1 = -advantages * ratio | ||
| if cliprange_low is None: | ||
| cliprange_low = cliprange | ||
| if cliprange_high is None: | ||
| cliprange_high = cliprange | ||
| pg_losses2 = -advantages * torch.clamp( | ||
| ratio, 1 - cliprange_low, 1 + cliprange_high | ||
| ) # - clip(ratio, 1-cliprange, 1+cliprange) * A | ||
| clip_pg_losses1 = torch.maximum( | ||
| pg_losses1, pg_losses2 | ||
| ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) | ||
| pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) | ||
|
|
||
| pg_losses3 = -advantages * clip_ratio_c | ||
|
|
||
| pg_clipfrac_lower = verl_F.masked_mean( | ||
| torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask | ||
| ) | ||
| ### This code is for logging purposes | ||
|
||
|
|
||
| # CFPO loss: -advantage * ratio + |advantage| * (ratio - 1)^2 / (2 * epsilon) | ||
| pg_losses = -( | ||
| advantages * ratio - | ||
| torch.abs(advantages) * torch.pow(ratio - 1, 2) / (2 * cliprange) | ||
| ) | ||
|
|
||
| if rollout_is_weights is not None: | ||
| pg_losses = pg_losses * rollout_is_weights | ||
|
|
||
| pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) | ||
|
|
||
| # Compute clip fraction metrics for monitoring | ||
| is_ratio_high = ratio > 1 + cliprange_high | ||
| is_ratio_low = ratio < 1 - cliprange_low | ||
|
|
||
| pg_clipfrac = verl_F.masked_mean( | ||
| ((is_ratio_high & (advantages > 0)) | (is_ratio_low & (advantages < 0))).float(), | ||
| response_mask | ||
| ) | ||
|
|
||
| pg_clipfrac_lower = verl_F.masked_mean( | ||
| (is_ratio_low & (advantages < 0)).float(), | ||
| response_mask | ||
| ) | ||
| pg_metrics = { | ||
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), | ||
| "actor/ppo_kl": ppo_kl.detach().item(), | ||
| "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), | ||
| } | ||
|
|
||
| return pg_loss, pg_metrics | ||
|
|
||
|
|
||
|
|
||
| @register_policy_loss("gspo") | ||
| def compute_policy_loss_gspo( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
clip_ratio_cparameter and its associated assertion are specific to dual-clip PPO and are not used in the CFPO loss calculation or its metrics. This code should be removed to avoid confusion and keep the implementation clean.