Skip to content

Commit

Permalink
Including paper and evidence counts in metrics (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Sep 18, 2024
1 parent 1e3a4b9 commit e8678a3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
35 changes: 32 additions & 3 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"LitQAv2TaskSplit",
]

import logging
import re
from abc import ABC
from collections.abc import Awaitable, Callable, Sequence
from enum import StrEnum
Expand Down Expand Up @@ -42,6 +44,8 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
if TYPE_CHECKING:
from ldp.data_structures import Trajectory

logger = logging.getLogger(__name__)


class GradablePaperQAEnvironment(PaperQAEnvironment):
"""Extended environment that can grade answers."""
Expand Down Expand Up @@ -158,13 +162,38 @@ def _make_gradable_environment(
def compute_trajectory_metrics(
self, trajectories: "Sequence[Trajectory]"
) -> dict[str, list[float]]:
total_paper_count: list[float] = []
relevant_paper_count: list[float] = []
evidence_count: list[float] = []
for t in trajectories:
split_answers = [
re.split(
pattern=GenerateAnswer.ANSWER_SPLIT_REGEX_PATTERN,
string=obs.content,
)
for obs in t.steps[-1].next_observation
if (
isinstance(obs, ToolResponseMessage)
and obs.name == GenerateAnswer.TOOL_FN_NAME
)
]
for i, metric_list in enumerate(
(total_paper_count, relevant_paper_count, evidence_count),
start=1, # Regex extraction of status starts after answer
):
metric_list.append( # Use mean to allow for multiple answers
sum(int(sa[i]) for sa in split_answers) / len(split_answers)
)
return super().compute_trajectory_metrics(trajectories) | {
"total_paper_count": total_paper_count,
"relevant_paper_count": relevant_paper_count,
"evidence_count": evidence_count,
"correct": [
int(traj.steps[-1].reward == self._rewards[0]) for traj in trajectories
int(t.steps[-1].reward == self._rewards[0]) for t in trajectories
],
"correct_unsure": [
int(traj.steps[-1].reward in {self._rewards[0], self._rewards[1]})
for traj in trajectories
int(t.steps[-1].reward in {self._rewards[0], self._rewards[1]})
for t in trajectories
],
}

Expand Down
1 change: 1 addition & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,4 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None:
), "Should not have mutated query in base request"
assert not docs.docs, "Should not have mutated docs in base docs"
assert isinstance(metrics_callback.eval_means["reward"], float)
assert isinstance(metrics_callback.eval_means["total_paper_count"], float)

0 comments on commit e8678a3

Please sign in to comment.