Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating LitQAv2TaskDataset for agent training/evaluation #401

Merged
merged 3 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from typing import cast

from aviary.env import Environment as _Environment
from aviary.env import Frame
from aviary.env import Environment, Frame
from aviary.message import Message
from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage

Expand Down Expand Up @@ -91,8 +90,8 @@ def settings_to_tools(
return tools


class Environment(_Environment[EnvironmentState]):
"""Environment to connect agents with paper-qa."""
class PaperQAEnvironment(Environment[EnvironmentState]):
"""Environment connecting paper-qa's tools with state."""

def __init__(
self,
Expand Down
8 changes: 4 additions & 4 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from paperqa.types import Answer
from paperqa.utils import pqa_directory

from .env import Environment
from .env import PaperQAEnvironment
from .helpers import litellm_get_search_query, table_formatter
from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler
from .search import SearchDocumentStorage, SearchIndex
Expand Down Expand Up @@ -235,7 +235,7 @@ async def run_fake_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
_, tools = await env.reset()
if on_env_reset_callback:
await on_env_reset_callback(env.state)
Expand Down Expand Up @@ -281,7 +281,7 @@ async def run_aviary_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

try:
Expand Down Expand Up @@ -345,7 +345,7 @@ async def run_ldp_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

try:
Expand Down
209 changes: 209 additions & 0 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
__all__ = [
"ENV_NAME",
"TASK_DATASET_NAME",
"GradablePaperQAEnvironment",
"LitQATaskDataset",
"LitQAv2TaskDataset",
"LitQAv2TaskSplit",
]

from abc import ABC
from collections.abc import Awaitable, Callable, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, assert_never

from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame, TaskDataset
from aviary.message import Message
from aviary.tools import ToolRequestMessage, ToolResponseMessage

try:
from ldp.alg.callbacks import ComputeTrajectoryMetricsMixin
except ImportError:

class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
"""Placeholder for when ldp isn't installed."""


from paperqa.docs import Docs
from paperqa.litqa import (
DEFAULT_EVAL_MODEL_NAME,
DEFAULT_LABBENCH_HF_HUB_NAME,
DEFAULT_REWARD_DISTRIBUTION,
LitQAEvaluation,
read_litqa_v2_from_hub,
)
from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel
from paperqa.types import Answer

from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
from .tools import GenerateAnswer

if TYPE_CHECKING:
from ldp.data_structures import Trajectory


class GradablePaperQAEnvironment(PaperQAEnvironment):
"""Extended environment that can grade answers."""

def __init__(
self,
query: QueryRequest,
docs: Docs,
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
evaluation_from_answer: (
Callable[[Answer | str], Awaitable[LitQAEvaluation]] | None
) = None,
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
evaluation_callback: Callable[[LitQAEvaluation], Awaitable] | None = None,
**env_kwargs,
):
super().__init__(
query, docs, llm_model, summary_llm_model, embedding_model, **env_kwargs
)
self._evaluation_from_answer = evaluation_from_answer
self._evaluation_callback = evaluation_callback
self._rewards = rewards

async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
return messages, reward, done, truncated
# Filter out non-answer messages (in case parallel tool calls)
answer_tool_messages = [
m
for m in messages
if isinstance(m, ToolResponseMessage)
and m.name == GenerateAnswer.gen_answer.__name__
]
if not answer_tool_messages: # No answer, so no positive reward
return messages, reward, done, truncated
if len(answer_tool_messages) != 1:
raise NotImplementedError(
f"Expected just one answer message, got {messages}."
)
answer = GenerateAnswer.extract_answer_from_message(
content=answer_tool_messages[0].content
)
if not answer:
return messages, reward, done, truncated
evaluation = await self._evaluation_from_answer(answer)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
return messages, reward + self._rewards[evaluation.value], done, truncated

def export_frame(self) -> Frame:
raise NotImplementedError("Didn't yet need to export a frame.")


ENV_NAME = "paperqa-local"
ENV_REGISTRY[ENV_NAME] = (
GradablePaperQAEnvironment.__module__,
GradablePaperQAEnvironment.__name__,
)


class LitQATaskDataset(
TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
):
"""
Abstract base class for a task dataset of LitQA v1 or v2 questions.

This is an ABC because it's non-specific to a LitQA version.
Examples include LitQA v1, v2, or a test stub version of LitQA.
"""

def __init__(
self,
base_query_request: QueryRequest,
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
**env_kwargs,
):
self._base_query_request = base_query_request
self._rewards = rewards
self._env_kwargs = env_kwargs
self._eval_model = eval_model

def _make_gradable_environment(
self,
ideal: str,
distractors: str | list[str],
question: str,
use_unsure: bool = True,
) -> GradablePaperQAEnvironment:
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
question=question,
use_unsure=use_unsure,
eval_model=self._eval_model,
)
query_request = self._base_query_request.model_copy()
query_request.query = qa_prompt
return GradablePaperQAEnvironment(
query=query_request,
evaluation_from_answer=evaluation_from_answer,
rewards=self._rewards,
**self._env_kwargs,
)

def compute_trajectory_metrics(
self, trajectories: "Sequence[Trajectory]"
) -> dict[str, list[float]]:
return super().compute_trajectory_metrics(trajectories) | {
"correct": [
int(traj.steps[-1].reward == self._rewards[0]) for traj in trajectories
],
"correct_unsure": [
int(traj.steps[-1].reward in {self._rewards[0], self._rewards[1]})
for traj in trajectories
],
}


class LitQAv2TaskSplit(StrEnum):
TRAIN = "train"
EVAL = "eval"


class LitQAv2TaskDataset(LitQATaskDataset):
"""Task dataset of LitQA v2 questions."""

def __init__(
self,
*args,
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
**kwargs,
):
super().__init__(*args, **kwargs)
train_df, eval_df = read_litqa_v2_from_hub(labbench_dataset)
split = LitQAv2TaskSplit(split)
if split == LitQAv2TaskSplit.TRAIN:
self.data = train_df
elif split == LitQAv2TaskSplit.EVAL:
self.data = eval_df
else:
assert_never(split)

def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
return self._make_gradable_environment(
ideal=self.data.iloc[idx].ideal,
distractors=self.data.iloc[idx].distractors,
question=self.data.iloc[idx].question,
)

def __len__(self) -> int:
return len(self.data)


TASK_DATASET_NAME = "litqa-v2"
TASK_DATASET_REGISTRY[TASK_DATASET_NAME] = (
LitQAv2TaskDataset.__module__,
LitQAv2TaskDataset.__name__,
)
100 changes: 100 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
from aviary.env import TASK_DATASET_REGISTRY, TaskDataset
from ldp.agent import SimpleAgent
from ldp.alg.callbacks import MeanMetricsCallback
from ldp.alg.runners import Evaluator, EvaluatorConfig

from paperqa import Docs, QueryRequest, Settings
from paperqa.agents.task import (
GradablePaperQAEnvironment,
LitQATaskDataset,
LitQAv2TaskDataset,
LitQAv2TaskSplit,
)


@pytest.fixture(name="base_query_request")
def fixture_base_query_request(agent_test_settings: Settings) -> QueryRequest:
return QueryRequest(settings=agent_test_settings)


class StubLitQADataset(LitQATaskDataset):
"""Made up dataset of questions answerable from this repo's stub_data."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data: list[tuple[str, str | list[str], str]] = [
("Politician", ["Technologist", "Plumber"], "Who is Frederick Bates?"),
(
"Make molecular counterfactuals",
[
"Generating images of cats",
"Simple explanations of internet searches",
],
"How can you use XAI for chemical property prediction?",
),
(
"Maple Leaf",
["The Stars and Stripes", "The Blue and Yellow", "The Southern Cross"],
"What is the national flag of Canada?",
),
]

def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
return self._make_gradable_environment(
ideal=self.data[idx][0],
distractors=self.data[idx][1],
question=self.data[idx][2],
)

def __len__(self) -> int:
return len(self.data)


STUB_TASK_DATASET_NAME = "stub-litqa"
TASK_DATASET_REGISTRY[STUB_TASK_DATASET_NAME] = (
StubLitQADataset.__module__,
StubLitQADataset.__name__,
)


class TestTaskDataset:
@pytest.mark.parametrize(
("split", "expected_length"),
[(LitQAv2TaskSplit.TRAIN, 159), (LitQAv2TaskSplit.EVAL, 40)],
)
def test___len__(
self,
split: str | LitQAv2TaskSplit,
expected_length: int,
base_query_request: QueryRequest,
) -> None:
task_dataset = LitQAv2TaskDataset(
base_query_request=base_query_request, split=split
)
assert len(task_dataset) == expected_length

@pytest.mark.asyncio
async def test_evaluation(self, base_query_request: QueryRequest) -> None:
agent = SimpleAgent()
docs = Docs()
dataset = TaskDataset.from_name(
STUB_TASK_DATASET_NAME,
base_query_request=base_query_request,
docs=docs,
)
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)

evaluator = Evaluator(
config=EvaluatorConfig(batch_size=3),
agent=agent,
dataset=dataset,
callbacks=[metrics_callback],
)
await evaluator.evaluate()

assert (
not base_query_request.query
), "Should not have mutated query in base request"
assert docs.docs, "Expected to have added content"
assert isinstance(metrics_callback.eval_means["reward"], float)
Loading