From 78322906872d7daafa002eaa49d89cbf0d7c43ad Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Feb 2025 11:51:09 +0100 Subject: [PATCH] [Rewards] add kimi len_reward (#292) * add kimi len_reward * add to REWARD_FUNCS_REGISTRY * fix formatting * Update src/open_r1/grpo.py Co-authored-by: lewtun * Update src/open_r1/grpo.py Co-authored-by: lewtun * Update src/open_r1/grpo.py Co-authored-by: lewtun * Update src/open_r1/rewards.py Co-authored-by: lewtun * Update src/open_r1/rewards.py Co-authored-by: lewtun * Update src/open_r1/rewards.py Co-authored-by: lewtun * Update src/open_r1/rewards.py Co-authored-by: lewtun * Update src/open_r1/rewards.py Co-authored-by: lewtun * missing import --------- Co-authored-by: lewtun --- src/open_r1/grpo.py | 6 ++-- src/open_r1/rewards.py | 74 ++++++++++++++++++++++++++++++++++++++++++ tests/test_rewards.py | 70 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 148 insertions(+), 2 deletions(-) diff --git a/src/open_r1/grpo.py b/src/open_r1/grpo.py index 128375db..916be06e 100644 --- a/src/open_r1/grpo.py +++ b/src/open_r1/grpo.py @@ -30,6 +30,7 @@ format_reward, get_cosine_scaled_reward, get_repetition_penalty_reward, + len_reward, reasoning_steps_reward, ) from open_r1.utils.callbacks import get_callbacks @@ -47,7 +48,7 @@ class GRPOScriptArguments(ScriptArguments): Args: reward_funcs (`list[str]`): - List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'. + List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'. cosine_min_value_wrong (`float`): Minimum reward for cosine scaling for wrong answers. cosine_max_value_wrong (`float`): @@ -63,7 +64,7 @@ class GRPOScriptArguments(ScriptArguments): reward_funcs: list[str] = field( default_factory=lambda: ["accuracy", "format"], metadata={ - "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'" + "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'" }, ) cosine_min_value_wrong: float = field( @@ -162,6 +163,7 @@ def main(script_args, training_args, model_args): ngram_size=script_args.repetition_n_grams, max_penalty=script_args.repetition_max_penalty, ), + "length": len_reward, } reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs] diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py index bec3d11c..27962784 100644 --- a/src/open_r1/rewards.py +++ b/src/open_r1/rewards.py @@ -2,6 +2,7 @@ import math import re +from typing import Dict from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify @@ -74,6 +75,79 @@ def reasoning_steps_reward(completions, **kwargs): return [min(1.0, count / 3) for count in matches] +def len_reward(completions: list[Dict[str, str]], solutions: list[str], **kwargs) -> float: + """Compute length-based rewards to discourage overthinking and promote token efficiency. + + Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599 + + Args: + completions: List of model completions + solutions: List of ground truth solutions + + Returns: + List of rewards where: + - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len) + - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len)) + """ + contents = [completion[0]["content"] for completion in completions] + + # First check correctness of answers + correctness = [] + for content, sol in zip(contents, solutions): + gold_parsed = parse( + sol, + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + if len(gold_parsed) == 0: + # Skip unparseable examples + correctness.append(True) # Treat as correct to avoid penalizing + print("Failed to parse gold solution: ", sol) + continue + + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed=True, + units=True, + ), + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + correctness.append(verify(answer_parsed, gold_parsed)) + + # Calculate lengths + lengths = [len(content) for content in contents] + min_len = min(lengths) + max_len = max(lengths) + + # If all responses have the same length, return zero rewards + if max_len == min_len: + return [0.0] * len(completions) + + rewards = [] + for length, is_correct in zip(lengths, correctness): + lambda_val = 0.5 - (length - min_len) / (max_len - min_len) + + if is_correct: + reward = lambda_val + else: + reward = min(0, lambda_val) + + rewards.append(float(reward)) + + return rewards + + def get_cosine_scaled_reward( min_value_wrong: float = -1.0, max_value_wrong: float = -0.5, diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 7f0cbfa9..9e41bdb0 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -5,6 +5,7 @@ format_reward, get_cosine_scaled_reward, get_repetition_penalty_reward, + len_reward, reasoning_steps_reward, ) @@ -110,6 +111,75 @@ def test_format_reward_specific_multiline(self): rewards = format_reward(completion) self.assertEqual(rewards[0], 1.0) + def test_same_length_responses(self): + """Test len_reward when all responses have the same length.""" + completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]] + solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] + + rewards = len_reward(completions, solutions) + self.assertEqual(rewards, [0.0, 0.0]) + + def test_different_lengths_correct_answers(self): + """Test len_reward with different length correct answers.""" + completions = [ + [{"content": r"\boxed{\frac{63}{400}}"}], # shorter + [{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # longer + ] + solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] + + rewards = len_reward(completions, solutions) + self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward + self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward + + def test_different_lengths_incorrect_answers(self): + """Test len_reward with different length incorrect answers.""" + completions = [ + [{"content": r"\boxed{\frac{64}{400}}"}], # shorter + [{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # longer + ] + solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] + + rewards = len_reward(completions, solutions) + self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards + self.assertLessEqual(rewards[1], 0.0) + self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less + + def test_mixed_correctness(self): + """Test len_reward with mix of correct and incorrect answers of different lengths.""" + completions = [ + [{"content": r"\boxed{\frac{63}{400}}"}], # correct, shorter + [{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # correct, longer + [{"content": r"\boxed{\frac{64}{400}}"}], # incorrect, shorter + [{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # incorrect, longer + ] + solutions = [r"\frac{63}{400}"] * 4 + + rewards = len_reward(completions, solutions) + + # Shortest correct answer should get positive reward + self.assertGreater(rewards[0], 0.0) + + # Longer correct answer might get negative reward: + self.assertGreater(rewards[2], rewards[1]) + self.assertGreaterEqual(rewards[1], rewards[3]) + + # Incorrect answers should get non-positive rewards + self.assertLessEqual(rewards[2], 0.0) + self.assertLessEqual(rewards[3], 0.0) + + # Shorter answers should get better rewards within their correctness category + self.assertGreater(rewards[0], rewards[1]) # correct answers + self.assertGreater(rewards[2], rewards[3]) # incorrect answers + + def test_unparseable_solution(self): + """Test len_reward with unparseable solution.""" + completions = [[{"content": r"\boxed{answer}"}], [{"content": r"\boxed{answer} " + "x" * 10}]] + solutions = ["unparseable_latex", "unparseable_latex"] + + rewards = len_reward(completions, solutions) + self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward + self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward + class TestRepetitionPenaltyReward(unittest.TestCase): def test_positive_max_penalty_raises_value_error(self):