Skip to content

Commit

Permalink
Creating LitQAv2TaskDataset for agent training/evaluation (#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Sep 14, 2024
1 parent b1745b0 commit 49273f2
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 8 deletions.
7 changes: 3 additions & 4 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from typing import cast

from aviary.env import Environment as _Environment
from aviary.env import Frame
from aviary.env import Environment, Frame
from aviary.message import Message
from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage

Expand Down Expand Up @@ -91,8 +90,8 @@ def settings_to_tools(
return tools


class Environment(_Environment[EnvironmentState]):
"""Environment to connect agents with paper-qa."""
class PaperQAEnvironment(Environment[EnvironmentState]):
"""Environment connecting paper-qa's tools with state."""

def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from paperqa.types import Answer
from paperqa.utils import pqa_directory

from .env import Environment
from .env import PaperQAEnvironment
from .helpers import litellm_get_search_query, table_formatter
from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler
from .search import SearchDocumentStorage, SearchIndex
Expand Down Expand Up @@ -235,7 +235,7 @@ async def run_fake_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
_, tools = await env.reset()
if on_env_reset_callback:
await on_env_reset_callback(env.state)
Expand Down Expand Up @@ -281,7 +281,7 @@ async def run_aviary_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

try:
Expand Down Expand Up @@ -345,7 +345,7 @@ async def run_ldp_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

try:
Expand Down
209 changes: 209 additions & 0 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
__all__ = [
"ENV_NAME",
"TASK_DATASET_NAME",
"GradablePaperQAEnvironment",
"LitQATaskDataset",
"LitQAv2TaskDataset",
"LitQAv2TaskSplit",
]

from abc import ABC
from collections.abc import Awaitable, Callable, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, assert_never

from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame, TaskDataset
from aviary.message import Message
from aviary.tools import ToolRequestMessage, ToolResponseMessage

try:
from ldp.alg.callbacks import ComputeTrajectoryMetricsMixin
except ImportError:

class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
"""Placeholder for when ldp isn't installed."""


from paperqa.docs import Docs
from paperqa.litqa import (
DEFAULT_EVAL_MODEL_NAME,
DEFAULT_LABBENCH_HF_HUB_NAME,
DEFAULT_REWARD_DISTRIBUTION,
LitQAEvaluation,
read_litqa_v2_from_hub,
)
from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel
from paperqa.types import Answer

from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
from .tools import GenerateAnswer

if TYPE_CHECKING:
from ldp.data_structures import Trajectory


class GradablePaperQAEnvironment(PaperQAEnvironment):
"""Extended environment that can grade answers."""

def __init__(
self,
query: QueryRequest,
docs: Docs,
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
evaluation_from_answer: (
Callable[[Answer | str], Awaitable[LitQAEvaluation]] | None
) = None,
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
evaluation_callback: Callable[[LitQAEvaluation], Awaitable] | None = None,
**env_kwargs,
):
super().__init__(
query, docs, llm_model, summary_llm_model, embedding_model, **env_kwargs
)
self._evaluation_from_answer = evaluation_from_answer
self._evaluation_callback = evaluation_callback
self._rewards = rewards

async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
return messages, reward, done, truncated
# Filter out non-answer messages (in case parallel tool calls)
answer_tool_messages = [
m
for m in messages
if isinstance(m, ToolResponseMessage)
and m.name == GenerateAnswer.gen_answer.__name__
]
if not answer_tool_messages: # No answer, so no positive reward
return messages, reward, done, truncated
if len(answer_tool_messages) != 1:
raise NotImplementedError(
f"Expected just one answer message, got {messages}."
)
answer = GenerateAnswer.extract_answer_from_message(
content=answer_tool_messages[0].content
)
if not answer:
return messages, reward, done, truncated
evaluation = await self._evaluation_from_answer(answer)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
return messages, reward + self._rewards[evaluation.value], done, truncated

def export_frame(self) -> Frame:
raise NotImplementedError("Didn't yet need to export a frame.")


ENV_NAME = "paperqa-local"
ENV_REGISTRY[ENV_NAME] = (
GradablePaperQAEnvironment.__module__,
GradablePaperQAEnvironment.__name__,
)


class LitQATaskDataset(
TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
):
"""
Abstract base class for a task dataset of LitQA v1 or v2 questions.
This is an ABC because it's non-specific to a LitQA version.
Examples include LitQA v1, v2, or a test stub version of LitQA.
"""

def __init__(
self,
base_query_request: QueryRequest,
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
**env_kwargs,
):
self._base_query_request = base_query_request
self._rewards = rewards
self._env_kwargs = env_kwargs
self._eval_model = eval_model

def _make_gradable_environment(
self,
ideal: str,
distractors: str | list[str],
question: str,
use_unsure: bool = True,
) -> GradablePaperQAEnvironment:
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
question=question,
use_unsure=use_unsure,
eval_model=self._eval_model,
)
query_request = self._base_query_request.model_copy()
query_request.query = qa_prompt
return GradablePaperQAEnvironment(
query=query_request,
evaluation_from_answer=evaluation_from_answer,
rewards=self._rewards,
**self._env_kwargs,
)

def compute_trajectory_metrics(
self, trajectories: "Sequence[Trajectory]"
) -> dict[str, list[float]]:
return super().compute_trajectory_metrics(trajectories) | {
"correct": [
int(traj.steps[-1].reward == self._rewards[0]) for traj in trajectories
],
"correct_unsure": [
int(traj.steps[-1].reward in {self._rewards[0], self._rewards[1]})
for traj in trajectories
],
}


class LitQAv2TaskSplit(StrEnum):
TRAIN = "train"
EVAL = "eval"


class LitQAv2TaskDataset(LitQATaskDataset):
"""Task dataset of LitQA v2 questions."""

def __init__(
self,
*args,
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
**kwargs,
):
super().__init__(*args, **kwargs)
train_df, eval_df = read_litqa_v2_from_hub(labbench_dataset)
split = LitQAv2TaskSplit(split)
if split == LitQAv2TaskSplit.TRAIN:
self.data = train_df
elif split == LitQAv2TaskSplit.EVAL:
self.data = eval_df
else:
assert_never(split)

def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
return self._make_gradable_environment(
ideal=self.data.iloc[idx].ideal,
distractors=self.data.iloc[idx].distractors,
question=self.data.iloc[idx].question,
)

def __len__(self) -> int:
return len(self.data)


TASK_DATASET_NAME = "litqa-v2"
TASK_DATASET_REGISTRY[TASK_DATASET_NAME] = (
LitQAv2TaskDataset.__module__,
LitQAv2TaskDataset.__name__,
)
100 changes: 100 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
from aviary.env import TASK_DATASET_REGISTRY, TaskDataset
from ldp.agent import SimpleAgent
from ldp.alg.callbacks import MeanMetricsCallback
from ldp.alg.runners import Evaluator, EvaluatorConfig

from paperqa import Docs, QueryRequest, Settings
from paperqa.agents.task import (
GradablePaperQAEnvironment,
LitQATaskDataset,
LitQAv2TaskDataset,
LitQAv2TaskSplit,
)


@pytest.fixture(name="base_query_request")
def fixture_base_query_request(agent_test_settings: Settings) -> QueryRequest:
return QueryRequest(settings=agent_test_settings)


class StubLitQADataset(LitQATaskDataset):
"""Made up dataset of questions answerable from this repo's stub_data."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data: list[tuple[str, str | list[str], str]] = [
("Politician", ["Technologist", "Plumber"], "Who is Frederick Bates?"),
(
"Make molecular counterfactuals",
[
"Generating images of cats",
"Simple explanations of internet searches",
],
"How can you use XAI for chemical property prediction?",
),
(
"Maple Leaf",
["The Stars and Stripes", "The Blue and Yellow", "The Southern Cross"],
"What is the national flag of Canada?",
),
]

def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
return self._make_gradable_environment(
ideal=self.data[idx][0],
distractors=self.data[idx][1],
question=self.data[idx][2],
)

def __len__(self) -> int:
return len(self.data)


STUB_TASK_DATASET_NAME = "stub-litqa"
TASK_DATASET_REGISTRY[STUB_TASK_DATASET_NAME] = (
StubLitQADataset.__module__,
StubLitQADataset.__name__,
)


class TestTaskDataset:
@pytest.mark.parametrize(
("split", "expected_length"),
[(LitQAv2TaskSplit.TRAIN, 159), (LitQAv2TaskSplit.EVAL, 40)],
)
def test___len__(
self,
split: str | LitQAv2TaskSplit,
expected_length: int,
base_query_request: QueryRequest,
) -> None:
task_dataset = LitQAv2TaskDataset(
base_query_request=base_query_request, split=split
)
assert len(task_dataset) == expected_length

@pytest.mark.asyncio
async def test_evaluation(self, base_query_request: QueryRequest) -> None:
agent = SimpleAgent()
docs = Docs()
dataset = TaskDataset.from_name(
STUB_TASK_DATASET_NAME,
base_query_request=base_query_request,
docs=docs,
)
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)

evaluator = Evaluator(
config=EvaluatorConfig(batch_size=3),
agent=agent,
dataset=dataset,
callbacks=[metrics_callback],
)
await evaluator.evaluate()

assert (
not base_query_request.query
), "Should not have mutated query in base request"
assert docs.docs, "Expected to have added content"
assert isinstance(metrics_callback.eval_means["reward"], float)

0 comments on commit 49273f2

Please sign in to comment.