From e8678a3aac3ad8ba5fb187845fd89902b328e7bb Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 18 Sep 2024 12:32:20 -0700 Subject: [PATCH] Including paper and evidence counts in metrics (#435) --- paperqa/agents/task.py | 35 ++++++++++++++++++++++++++++++++--- tests/test_task.py | 1 + 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index 414aef78..34a3a1c8 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -7,6 +7,8 @@ "LitQAv2TaskSplit", ] +import logging +import re from abc import ABC from collections.abc import Awaitable, Callable, Sequence from enum import StrEnum @@ -42,6 +44,8 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] if TYPE_CHECKING: from ldp.data_structures import Trajectory +logger = logging.getLogger(__name__) + class GradablePaperQAEnvironment(PaperQAEnvironment): """Extended environment that can grade answers.""" @@ -158,13 +162,38 @@ def _make_gradable_environment( def compute_trajectory_metrics( self, trajectories: "Sequence[Trajectory]" ) -> dict[str, list[float]]: + total_paper_count: list[float] = [] + relevant_paper_count: list[float] = [] + evidence_count: list[float] = [] + for t in trajectories: + split_answers = [ + re.split( + pattern=GenerateAnswer.ANSWER_SPLIT_REGEX_PATTERN, + string=obs.content, + ) + for obs in t.steps[-1].next_observation + if ( + isinstance(obs, ToolResponseMessage) + and obs.name == GenerateAnswer.TOOL_FN_NAME + ) + ] + for i, metric_list in enumerate( + (total_paper_count, relevant_paper_count, evidence_count), + start=1, # Regex extraction of status starts after answer + ): + metric_list.append( # Use mean to allow for multiple answers + sum(int(sa[i]) for sa in split_answers) / len(split_answers) + ) return super().compute_trajectory_metrics(trajectories) | { + "total_paper_count": total_paper_count, + "relevant_paper_count": relevant_paper_count, + "evidence_count": evidence_count, "correct": [ - int(traj.steps[-1].reward == self._rewards[0]) for traj in trajectories + int(t.steps[-1].reward == self._rewards[0]) for t in trajectories ], "correct_unsure": [ - int(traj.steps[-1].reward in {self._rewards[0], self._rewards[1]}) - for traj in trajectories + int(t.steps[-1].reward in {self._rewards[0], self._rewards[1]}) + for t in trajectories ], } diff --git a/tests/test_task.py b/tests/test_task.py index f6455ff9..1f9bdbd9 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -94,3 +94,4 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None: ), "Should not have mutated query in base request" assert not docs.docs, "Should not have mutated docs in base docs" assert isinstance(metrics_callback.eval_means["reward"], float) + assert isinstance(metrics_callback.eval_means["total_paper_count"], float)