From 7309c7912d086e7676f4d24d1f238d9ac0ff7cbc Mon Sep 17 00:00:00 2001 From: atlas <903216099@qq.com> Date: Fri, 14 Feb 2025 10:23:51 +0800 Subject: [PATCH] add text similarity for more common accuracy reward --- src/open_r1/grpo.py | 1 + src/open_r1/rewards.py | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/open_r1/grpo.py b/src/open_r1/grpo.py index 1970f8ef..d0f8f7f9 100644 --- a/src/open_r1/grpo.py +++ b/src/open_r1/grpo.py @@ -32,6 +32,7 @@ get_repetition_penalty_reward, len_reward, reasoning_steps_reward, + similarity_accuracy_reward ) from open_r1.utils.callbacks import get_callbacks from open_r1.utils.wandb_logging import init_wandb_training diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py index 27962784..a8125adf 100644 --- a/src/open_r1/rewards.py +++ b/src/open_r1/rewards.py @@ -3,11 +3,34 @@ import math import re from typing import Dict - +from sentence_transformers import SentenceTransformer from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify - +sentence_model = SentenceTransformer('BAAI/bge-large-zh-v1.5') +similarity_scale=2.0 +similarity_threshold=0.5 +def similarity_func(text1,text2): + ''' + calculate text similarity by using BAAI/bge-large-zh-v1.5 ie. + use similarity instead of parsing answer for more common task + ''' + + embeddings_1 = sentence_model.encode(text1, normalize_embeddings=True) + embeddings_2 = sentence_model.encode(text2, normalize_embeddings=True) + similarity = embeddings_1 @ embeddings_2.T + return similarity +def similarity_accuracy_reward(completions,answer,**kwargs): + contents=[completion[0]["content"] for completion in completions] + rewards=[] + for content,ans in zip(contents,answer): + sim=similarity_func(content,ans) + if sim