Skip to content

Commit

Permalink
add text similarity for more common accuracy reward
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulingChen committed Feb 14, 2025
1 parent 7041fbc commit 7309c79
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
similarity_accuracy_reward
)
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
Expand Down
27 changes: 25 additions & 2 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,34 @@
import math
import re
from typing import Dict

from sentence_transformers import SentenceTransformer
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify


sentence_model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
similarity_scale=2.0
similarity_threshold=0.5
def similarity_func(text1,text2):
'''
calculate text similarity by using BAAI/bge-large-zh-v1.5 ie.
use similarity instead of parsing answer for more common task
'''

embeddings_1 = sentence_model.encode(text1, normalize_embeddings=True)
embeddings_2 = sentence_model.encode(text2, normalize_embeddings=True)
similarity = embeddings_1 @ embeddings_2.T
return similarity
def similarity_accuracy_reward(completions,answer,**kwargs):
contents=[completion[0]["content"] for completion in completions]
rewards=[]
for content,ans in zip(contents,answer):
sim=similarity_func(content,ans)
if sim<similarity_threshold:
rewards.append(0.0) #reward 0
else:
rewards.append(1.0) #reward 1 or just similarity? maybe an scale for different similarity threshold
return rewards

def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
Expand Down

0 comments on commit 7309c79

Please sign in to comment.