Skip to content

Commit

Permalink
[Rewards] add kimi len_reward (#292)
Browse files Browse the repository at this point in the history
* add kimi len_reward

* add to REWARD_FUNCS_REGISTRY

* fix formatting

* Update src/open_r1/grpo.py

Co-authored-by: lewtun <[email protected]>

* Update src/open_r1/grpo.py

Co-authored-by: lewtun <[email protected]>

* Update src/open_r1/grpo.py

Co-authored-by: lewtun <[email protected]>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <[email protected]>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <[email protected]>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <[email protected]>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <[email protected]>

* Update src/open_r1/rewards.py

Co-authored-by: lewtun <[email protected]>

* missing import

---------

Co-authored-by: lewtun <[email protected]>
  • Loading branch information
kashif and lewtun authored Feb 13, 2025
1 parent 80e7e7b commit 7832290
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
)
from open_r1.utils.callbacks import get_callbacks
Expand All @@ -47,7 +48,7 @@ class GRPOScriptArguments(ScriptArguments):
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'.
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Expand All @@ -63,7 +64,7 @@ class GRPOScriptArguments(ScriptArguments):
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
},
)
cosine_min_value_wrong: float = field(
Expand Down Expand Up @@ -162,6 +163,7 @@ def main(script_args, training_args, model_args):
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

Expand Down
74 changes: 74 additions & 0 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import re
from typing import Dict

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
Expand Down Expand Up @@ -74,6 +75,79 @@ def reasoning_steps_reward(completions, **kwargs):
return [min(1.0, count / 3) for count in matches]


def len_reward(completions: list[Dict[str, str]], solutions: list[str], **kwargs) -> float:
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
Args:
completions: List of model completions
solutions: List of ground truth solutions
Returns:
List of rewards where:
- For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
- For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
"""
contents = [completion[0]["content"] for completion in completions]

# First check correctness of answers
correctness = []
for content, sol in zip(contents, solutions):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) == 0:
# Skip unparseable examples
correctness.append(True) # Treat as correct to avoid penalizing
print("Failed to parse gold solution: ", sol)
continue

answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
correctness.append(verify(answer_parsed, gold_parsed))

# Calculate lengths
lengths = [len(content) for content in contents]
min_len = min(lengths)
max_len = max(lengths)

# If all responses have the same length, return zero rewards
if max_len == min_len:
return [0.0] * len(completions)

rewards = []
for length, is_correct in zip(lengths, correctness):
lambda_val = 0.5 - (length - min_len) / (max_len - min_len)

if is_correct:
reward = lambda_val
else:
reward = min(0, lambda_val)

rewards.append(float(reward))

return rewards


def get_cosine_scaled_reward(
min_value_wrong: float = -1.0,
max_value_wrong: float = -0.5,
Expand Down
70 changes: 70 additions & 0 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
)

Expand Down Expand Up @@ -110,6 +111,75 @@ def test_format_reward_specific_multiline(self):
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0)

def test_same_length_responses(self):
"""Test len_reward when all responses have the same length."""
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]

rewards = len_reward(completions, solutions)
self.assertEqual(rewards, [0.0, 0.0])

def test_different_lengths_correct_answers(self):
"""Test len_reward with different length correct answers."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}], # shorter
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # longer
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]

rewards = len_reward(completions, solutions)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward
self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward

def test_different_lengths_incorrect_answers(self):
"""Test len_reward with different length incorrect answers."""
completions = [
[{"content": r"\boxed{\frac{64}{400}}"}], # shorter
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # longer
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]

rewards = len_reward(completions, solutions)
self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards
self.assertLessEqual(rewards[1], 0.0)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less

def test_mixed_correctness(self):
"""Test len_reward with mix of correct and incorrect answers of different lengths."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}], # correct, shorter
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # correct, longer
[{"content": r"\boxed{\frac{64}{400}}"}], # incorrect, shorter
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # incorrect, longer
]
solutions = [r"\frac{63}{400}"] * 4

rewards = len_reward(completions, solutions)

# Shortest correct answer should get positive reward
self.assertGreater(rewards[0], 0.0)

# Longer correct answer might get negative reward:
self.assertGreater(rewards[2], rewards[1])
self.assertGreaterEqual(rewards[1], rewards[3])

# Incorrect answers should get non-positive rewards
self.assertLessEqual(rewards[2], 0.0)
self.assertLessEqual(rewards[3], 0.0)

# Shorter answers should get better rewards within their correctness category
self.assertGreater(rewards[0], rewards[1]) # correct answers
self.assertGreater(rewards[2], rewards[3]) # incorrect answers

def test_unparseable_solution(self):
"""Test len_reward with unparseable solution."""
completions = [[{"content": r"\boxed{answer}"}], [{"content": r"\boxed{answer} " + "x" * 10}]]
solutions = ["unparseable_latex", "unparseable_latex"]

rewards = len_reward(completions, solutions)
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward
self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward


class TestRepetitionPenaltyReward(unittest.TestCase):
def test_positive_max_penalty_raises_value_error(self):
Expand Down

0 comments on commit 7832290

Please sign in to comment.