Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ algorithm:
lam: 1.0
adv_estimator: gae
norm_adv_by_std_in_grpo: true
f_grpo_gamma: 1.0
f_grpo_reward_correct: 1.0
f_grpo_reward_wrong: 0.0
use_kl_in_reward: false
kl_penalty: kl
kl_ctrl:
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/_generated_ppo_veomni_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,9 @@ algorithm:
lam: 1.0
adv_estimator: gae
norm_adv_by_std_in_grpo: true
f_grpo_gamma: 1.0
f_grpo_reward_correct: 1.0
f_grpo_reward_wrong: 0.0
use_kl_in_reward: false
kl_penalty: kl
kl_ctrl:
Expand Down
6 changes: 6 additions & 0 deletions verl/trainer/config/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,9 @@ class AlgoConfig(BaseConfig):
lam (float): Trade-off between bias and variance in the GAE estimator.
adv_estimator (str): Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc.
norm_adv_by_std_in_grpo (bool): Whether to normalize advantages by std (specific to GRPO).
f_grpo_gamma (Optional[float]): F-GRPO focal exponent γ (>= 0). Required when adv_estimator="f_grpo".
f_grpo_reward_correct (Optional[float]): Reward value for correct rollouts (R_c) used in μ̂_pos(x).
f_grpo_reward_wrong (Optional[float]): Reward value for wrong rollouts (R_w) used in μ̂_pos(x).
use_kl_in_reward (bool): Whether to enable in-reward KL penalty.
kl_penalty (str): How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full".
kl_ctrl (KLControlConfig): KL control configuration.
Expand Down Expand Up @@ -603,6 +606,9 @@ class AlgoConfig(BaseConfig):
lam: float = 1.0
adv_estimator: str = "gae"
norm_adv_by_std_in_grpo: bool = True
f_grpo_gamma: Optional[float] = None
f_grpo_reward_correct: Optional[float] = None
f_grpo_reward_wrong: Optional[float] = None
use_kl_in_reward: bool = False
kl_penalty: str = "kl"
kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig)
Expand Down
11 changes: 10 additions & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,21 @@ algorithm:
# Trade-off between bias and variance in the GAE estimator
lam: 1.0

# Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc.
# Advantage estimator type: "gae", "grpo", "f_grpo", "reinforce_plus_plus", etc.
adv_estimator: gae

# Whether to normalize advantages by std (specific to GRPO)
norm_adv_by_std_in_grpo: True

# F-GRPO focal weighting γ (>= 0), only used when adv_estimator: f_grpo
f_grpo_gamma: 1.0

# F-GRPO reward value for correct rollouts (R_c), used in μ̂_pos(x)
f_grpo_reward_correct: 1.0

# F-GRPO reward value for wrong rollouts (R_w), used in μ̂_pos(x)
f_grpo_reward_wrong: 0.0

# Whether to enable in-reward KL penalty
use_kl_in_reward: False

Expand Down
85 changes: 85 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class AdvantageEstimator(str, Enum):

GAE = "gae"
GRPO = "grpo"
F_GRPO = "f_grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
Expand Down Expand Up @@ -357,6 +358,90 @@ def compute_grpo_vectorized_outcome_advantage(
return advantages, advantages


@register_adv_est(AdvantageEstimator.F_GRPO)
def compute_f_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
*,
focal_gamma: Optional[float] = None,
reward_correct: Optional[float] = None,
reward_wrong: Optional[float] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
F-GRPO (outcome-only): difficulty-aware (focal) scaling of group-relative advantages.

We scale the GRPO group-relative advantage with a per-prompt weight g(x):

\hat{A}_i^{F-GRPO} = g(x) \cdot \hat{A}_i^{GRPO}

where

g(x) = (1 - \hat{\mu}_{pos}(x))^{\gamma}, \gamma \ge 0
\hat{\mu}_{pos}(x) = (\bar{R}(x) - R_w) / (R_c - R_w)

Config keys (when called via trainer with `config`):
- algorithm.f_grpo_gamma (float): focal exponent \gamma controlling down-weighting strength.
- algorithm.f_grpo_reward_correct (float): reward value for correct rollouts (R_c).
- algorithm.f_grpo_reward_wrong (float): reward value for wrong rollouts (R_w).

Notes:
- Implementation of F-GRPO paper https://arxiv.org/abs/2602.06717
"""
if config is not None:
norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", norm_adv_by_std_in_grpo)
if focal_gamma is None:
focal_gamma = config.get("f_grpo_gamma", None)
if reward_correct is None:
reward_correct = config.get("f_grpo_reward_correct", None)
if reward_wrong is None:
reward_wrong = config.get("f_grpo_reward_wrong", None)

if focal_gamma is None or reward_correct is None or reward_wrong is None:
raise ValueError(
"F-GRPO requires `focal_gamma`, `reward_correct`, and `reward_wrong` to be provided. "
"Set them in config or pass them directly to "
"`compute_f_grpo_outcome_advantage(..., focal_gamma=..., reward_correct=..., reward_wrong=...)`."
)

focal_gamma = float(focal_gamma)
reward_correct = float(reward_correct)
reward_wrong = float(reward_wrong)

if focal_gamma < 0:
raise ValueError(f"focal_gamma must be >= 0. Got {focal_gamma}.")
denom = reward_correct - reward_wrong
if denom <= 0:
raise ValueError(
f"reward_correct must be > reward_wrong for F-GRPO. Got reward_correct={reward_correct}, "
f"reward_wrong={reward_wrong}."
)

with torch.no_grad():
scores = token_level_rewards.sum(dim=-1) # (bs,)

gidx = as_torch_index(index, device=scores.device)
mean_g, std_g, _ = group_mean_std(scores, gidx, eps=epsilon, device=scores.device)

if norm_adv_by_std_in_grpo:
scalars = (scores - mean_g[gidx]) / (std_g[gidx] + epsilon)
else:
scalars = scores - mean_g[gidx]

mu_hat = (mean_g - reward_wrong) / denom
mu_hat = torch.clamp(mu_hat, 0.0, 1.0)
weights = torch.pow(1.0 - mu_hat, focal_gamma) # (G,)

scalars = scalars * weights[gidx]

advantages = scalars.unsqueeze(-1) * response_mask
return advantages, advantages


@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk")
def compute_grpo_passk_outcome_advantage(
token_level_rewards: torch.Tensor,
Expand Down