Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,121 @@ def compute_policy_loss_vanilla(
}
return pg_loss, pg_metrics

@register_policy_loss("cfpo")
def compute_policy_loss_cfpo(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[ActorConfig] = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:

"""

CFPO uses a quadratic penalty instead of clipping:
L = -advantage * ratio + |advantage| * (ratio - 1)^2 / (2 * epsilon)

Args:
old_log_prob (torch.Tensor):
Log-probabilities of actions under the old policy, shape (batch_size, response_length).
log_prob (torch.Tensor):
Log-probabilities of actions under the current policy, shape (batch_size, response_length).
advantages (torch.Tensor):
Advantage estimates for each action, shape (batch_size, response_length).
response_mask (torch.Tensor):
Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
loss_agg_mode (str, optional):
Aggregation mode for `agg_loss`. Defaults to "token-mean".
config (Optional[AlgoConfig]):
Algorithm configuration object.

Returns:
pg_loss (torch.Tensor): The policy gradient loss.
pg_clipfrac (torch.Tensor): Fraction of tokens where ratio exceeded upper bound.
ppo_kl (torch.Tensor): Approximate KL divergence.
pg_clipfrac_lower (torch.Tensor): Fraction of tokens where ratio exceeded lower bound.
"""
assert config is not None
assert not isinstance(config, AlgoConfig)
clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347.
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
"clip_ratio_c", 3.0
)

cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high

assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The clip_ratio_c parameter and its associated assertion are specific to dual-clip PPO and are not used in the CFPO loss calculation or its metrics. This code should be removed to avoid confusion and keep the implementation clean.

Suggested change
clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729.
"clip_ratio_c", 3.0
)
cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high
assert clip_ratio_c > 1.0, (
"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ f" but get the value: {clip_ratio_c}."
)
cliprange = clip_ratio
cliprange_low = clip_ratio_low
cliprange_high = clip_ratio_high




negative_approx_kl = log_prob - old_log_prob
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)

### This code is for logging purposes
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
pg_losses2 = -advantages * torch.clamp(
ratio, 1 - cliprange_low, 1 + cliprange_high
) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(
pg_losses1, pg_losses2
) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)

pg_losses3 = -advantages * clip_ratio_c

pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask
)
### This code is for logging purposes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code, marked for "logging purposes", appears to be dead code from a copy-paste. The variables pg_clipfrac and pg_clipfrac_lower are calculated here but are immediately overwritten by a new calculation later in the function. Other variables defined in this block are not used outside of it. This block should be removed to improve clarity and avoid unnecessary computations.


# CFPO loss: -advantage * ratio + |advantage| * (ratio - 1)^2 / (2 * epsilon)
pg_losses = -(
advantages * ratio -
torch.abs(advantages) * torch.pow(ratio - 1, 2) / (2 * cliprange)
)

if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights

pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

# Compute clip fraction metrics for monitoring
is_ratio_high = ratio > 1 + cliprange_high
is_ratio_low = ratio < 1 - cliprange_low

pg_clipfrac = verl_F.masked_mean(
((is_ratio_high & (advantages > 0)) | (is_ratio_low & (advantages < 0))).float(),
response_mask
)

pg_clipfrac_lower = verl_F.masked_mean(
(is_ratio_low & (advantages < 0)).float(),
response_mask
)
pg_metrics = {
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
}

return pg_loss, pg_metrics



@register_policy_loss("gspo")
def compute_policy_loss_gspo(
Expand Down