Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/sphinx_doc/assets/agentscope_gsm8k_reward.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions trinity/common/workflows/envs/email_searcher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def read_email_tool(message_id: str) -> Optional[Email]:
############ LLM-as-a-judge ############


def judge_correctness(
async def judge_correctness(
answer: str,
query: QueryModel,
judger: Any,
Expand Down Expand Up @@ -318,7 +318,7 @@ def judge_correctness(
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
completion = judger.chat.completions.create(
completion = await judger.chat.completions.create(
model=judger.model_path, messages=messages, stream=False
)
result = completion.choices[0].message.content
Expand Down
8 changes: 4 additions & 4 deletions trinity/common/workflows/envs/email_searcher/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def run_async(self):
experiences
) # NOTE: this metrics works only if the agent calls model once in each turn

reward_dict = self.calculate_reward(answer_and_sources)
reward_dict = await self.calculate_reward(answer_and_sources)
reward = sum(reward_dict.values())

for i, experience in enumerate(experiences):
Expand All @@ -107,7 +107,7 @@ async def run_async(self):
)
return experiences

def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
async def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
"""Ref: calculate_reward in https://github.com/OpenPipe/ART/blob/main/dev/art-e/art_e/rollout.py#L64"""
try:
answer = answer_and_sources.get("answer", None)
Expand Down Expand Up @@ -140,7 +140,7 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:

try:
judge_model = self.auxiliary_models[0] if self.auxiliary_models else None
judge_response = judge_correctness(answer, self.query, judge_model)
judge_response = await judge_correctness(answer, self.query, judge_model)
Comment thread
chenyushuo marked this conversation as resolved.
rubric.answer_correct = judge_response

except Exception as e:
Expand Down Expand Up @@ -179,4 +179,4 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
return result

self.logger.error(f"Rubric {rubric} not handled properly")
raise ValueError("Rubric is not handled properly")
return {"accuracy": 0.0, "format": 0.0}