From 2531264af847734bb5f65246c9cac69f10863bf8 Mon Sep 17 00:00:00 2001 From: James Braza Date: Fri, 13 Sep 2024 14:01:12 -0700 Subject: [PATCH 1/3] Updated Environment's docstring to be more precise --- paperqa/agents/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index 6171baf4..c080b6c8 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -92,7 +92,7 @@ def settings_to_tools( class Environment(_Environment[EnvironmentState]): - """Environment to connect agents with paper-qa.""" + """Environment connecting paper-qa's tools with state.""" def __init__( self, From fc658d338385256b9062f7b749dda2bfa86740f0 Mon Sep 17 00:00:00 2001 From: James Braza Date: Fri, 13 Sep 2024 15:48:06 -0700 Subject: [PATCH 2/3] Implemented LitQA and LitQA v2 task dataset, and made simple test suite for them --- paperqa/agents/task.py | 201 +++++++++++++++++++++++++++++++++++++++++ tests/test_task.py | 100 ++++++++++++++++++++ 2 files changed, 301 insertions(+) create mode 100644 paperqa/agents/task.py create mode 100644 tests/test_task.py diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py new file mode 100644 index 00000000..aaa21833 --- /dev/null +++ b/paperqa/agents/task.py @@ -0,0 +1,201 @@ +from abc import ABC +from collections.abc import Awaitable, Callable, Sequence +from enum import StrEnum +from typing import TYPE_CHECKING, assert_never + +from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame +from aviary.env import ( + TaskDataset as _TaskDataset, +) +from aviary.message import Message +from aviary.tools import ToolRequestMessage, ToolResponseMessage + +try: + from ldp.alg.callbacks import ComputeTrajectoryMetricsMixin +except ImportError: + + class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] + """Placeholder for when ldp isn't installed.""" + + +from paperqa.docs import Docs +from paperqa.litqa import ( + DEFAULT_EVAL_MODEL_NAME, + DEFAULT_LABBENCH_HF_HUB_NAME, + DEFAULT_REWARD_DISTRIBUTION, + LitQAEvaluation, + read_litqa_v2_from_hub, +) +from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel +from paperqa.types import Answer + +from .env import POPULATE_FROM_SETTINGS +from .env import Environment as _Environment +from .models import QueryRequest +from .tools import GenerateAnswer + +if TYPE_CHECKING: + from ldp.data_structures import Trajectory + + +class GradableEnvironment(_Environment): + """Extended environment that can grade answers.""" + + def __init__( + self, + query: QueryRequest, + docs: Docs, + llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS, + summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS, + embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS, + evaluation_from_answer: ( + Callable[[Answer | str], Awaitable[LitQAEvaluation]] | None + ) = None, + rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION, + evaluation_callback: Callable[[LitQAEvaluation], Awaitable] | None = None, + **env_kwargs, + ): + super().__init__( + query, docs, llm_model, summary_llm_model, embedding_model, **env_kwargs + ) + self._evaluation_from_answer = evaluation_from_answer + self._evaluation_callback = evaluation_callback + self._rewards = rewards + + async def step( + self, action: ToolRequestMessage + ) -> tuple[list[Message], float, bool, bool]: + messages, reward, done, truncated = await super().step(action) + if not done or not self._evaluation_from_answer: + return messages, reward, done, truncated + # Filter out non-answer messages (in case parallel tool calls) + answer_tool_messages = [ + m + for m in messages + if isinstance(m, ToolResponseMessage) + and m.name == GenerateAnswer.gen_answer.__name__ + ] + if not answer_tool_messages: # No answer, so no positive reward + return messages, reward, done, truncated + if len(answer_tool_messages) != 1: + raise NotImplementedError( + f"Expected just one answer message, got {messages}." + ) + answer = GenerateAnswer.extract_answer_from_message( + content=answer_tool_messages[0].content + ) + if not answer: + return messages, reward, done, truncated + evaluation = await self._evaluation_from_answer(answer) + if evaluation_callback := self._evaluation_callback: + await evaluation_callback(evaluation) + return messages, reward + self._rewards[evaluation.value], done, truncated + + def export_frame(self) -> Frame: + raise NotImplementedError("Didn't yet need to export a frame.") + + +ENV_NAME = "paperqa-local" +ENV_REGISTRY[ENV_NAME] = GradableEnvironment.__module__, GradableEnvironment.__name__ + + +class LitQATaskDataset( + _TaskDataset[GradableEnvironment], ComputeTrajectoryMetricsMixin, ABC +): + """ + Abstract base class for a task dataset of LitQA v1 or v2 questions. + + This is an ABC because it's non-specific to a LitQA version. + Examples include LitQA v1, v2, or a test stub version of LitQA. + """ + + def __init__( + self, + base_query_request: QueryRequest, + rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION, + eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME, + **env_kwargs, + ): + self._base_query_request = base_query_request + self._rewards = rewards + self._env_kwargs = env_kwargs + self._eval_model = eval_model + + def _make_gradable_environment( + self, + ideal: str, + distractors: str | list[str], + question: str, + use_unsure: bool = True, + ) -> GradableEnvironment: + qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question( + ideal=ideal, + distractors=distractors, + question=question, + use_unsure=use_unsure, + eval_model=self._eval_model, + ) + query_request = self._base_query_request.model_copy() + query_request.query = qa_prompt + return GradableEnvironment( + query=query_request, + evaluation_from_answer=evaluation_from_answer, + rewards=self._rewards, + **self._env_kwargs, + ) + + def compute_trajectory_metrics( + self, trajectories: "Sequence[Trajectory]" + ) -> dict[str, list[float]]: + return super().compute_trajectory_metrics(trajectories) | { + "correct": [ + int(traj.steps[-1].reward == self._rewards[0]) for traj in trajectories + ], + "correct_unsure": [ + int(traj.steps[-1].reward in {self._rewards[0], self._rewards[1]}) + for traj in trajectories + ], + } + + +class LitQAv2TaskSplit(StrEnum): + TRAIN = "train" + EVAL = "eval" + + +class LitQAv2TaskDataset(LitQATaskDataset): + """Task dataset of LitQA v2 questions.""" + + def __init__( + self, + *args, + labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME, + split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL, + **kwargs, + ): + super().__init__(*args, **kwargs) + train_df, eval_df = read_litqa_v2_from_hub(labbench_dataset) + split = LitQAv2TaskSplit(split) + if split == LitQAv2TaskSplit.TRAIN: + self.data = train_df + elif split == LitQAv2TaskSplit.EVAL: + self.data = eval_df + else: + assert_never(split) + + def get_new_env_by_idx(self, idx: int) -> GradableEnvironment: + return self._make_gradable_environment( + ideal=self.data.iloc[idx].ideal, + distractors=self.data.iloc[idx].distractors, + question=self.data.iloc[idx].question, + ) + + def __len__(self) -> int: + return len(self.data) + + +TASK_DATASET_NAME = "litqa-v2" +TASK_DATASET_REGISTRY[TASK_DATASET_NAME] = ( + LitQAv2TaskDataset.__module__, + LitQAv2TaskDataset.__name__, +) diff --git a/tests/test_task.py b/tests/test_task.py new file mode 100644 index 00000000..b99d8a16 --- /dev/null +++ b/tests/test_task.py @@ -0,0 +1,100 @@ +import pytest +from aviary.env import TASK_DATASET_REGISTRY, TaskDataset +from ldp.agent import SimpleAgent +from ldp.alg.callbacks import MeanMetricsCallback +from ldp.alg.runners import Evaluator, EvaluatorConfig + +from paperqa import Docs, QueryRequest, Settings +from paperqa.agents.task import ( + GradableEnvironment, + LitQATaskDataset, + LitQAv2TaskDataset, + LitQAv2TaskSplit, +) + + +@pytest.fixture(name="base_query_request") +def fixture_base_query_request(agent_test_settings: Settings) -> QueryRequest: + return QueryRequest(settings=agent_test_settings) + + +class StubLitQADataset(LitQATaskDataset): + """Made up dataset of questions answerable from this repo's stub_data.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data: list[tuple[str, str | list[str], str]] = [ + ("Politician", ["Technologist", "Plumber"], "Who is Frederick Bates?"), + ( + "Make molecular counterfactuals", + [ + "Generating images of cats", + "Simple explanations of internet searches", + ], + "How can you use XAI for chemical property prediction?", + ), + ( + "Maple Leaf", + ["The Stars and Stripes", "The Blue and Yellow", "The Southern Cross"], + "What is the national flag of Canada?", + ), + ] + + def get_new_env_by_idx(self, idx: int) -> GradableEnvironment: + return self._make_gradable_environment( + ideal=self.data[idx][0], + distractors=self.data[idx][1], + question=self.data[idx][2], + ) + + def __len__(self) -> int: + return len(self.data) + + +STUB_TASK_DATASET_NAME = "stub-litqa" +TASK_DATASET_REGISTRY[STUB_TASK_DATASET_NAME] = ( + StubLitQADataset.__module__, + StubLitQADataset.__name__, +) + + +class TestTaskDataset: + @pytest.mark.parametrize( + ("split", "expected_length"), + [(LitQAv2TaskSplit.TRAIN, 159), (LitQAv2TaskSplit.EVAL, 40)], + ) + def test___len__( + self, + split: str | LitQAv2TaskSplit, + expected_length: int, + base_query_request: QueryRequest, + ) -> None: + task_dataset = LitQAv2TaskDataset( + base_query_request=base_query_request, split=split + ) + assert len(task_dataset) == expected_length + + @pytest.mark.asyncio + async def test_evaluation(self, base_query_request: QueryRequest) -> None: + agent = SimpleAgent() + docs = Docs() + dataset = TaskDataset.from_name( + STUB_TASK_DATASET_NAME, + base_query_request=base_query_request, + docs=docs, + ) + metrics_callback = MeanMetricsCallback(eval_dataset=dataset) + + evaluator = Evaluator( + config=EvaluatorConfig(batch_size=3), + agent=agent, + dataset=dataset, + callbacks=[metrics_callback], + ) + await evaluator.evaluate() + + assert ( + not base_query_request.query + ), "Should not have mutated query in base request" + assert docs.docs, "Expected to have added content" + assert isinstance(metrics_callback.eval_means["reward"], float) From c41f48f73d3ff90b1c5a09aa4443eed10fb79709 Mon Sep 17 00:00:00 2001 From: James Braza Date: Fri, 13 Sep 2024 16:31:22 -0700 Subject: [PATCH 3/3] Renamed Environment to PaperQAEnvironment and defined __all__ to encourage correct import paths --- paperqa/agents/env.py | 5 ++--- paperqa/agents/main.py | 8 ++++---- paperqa/agents/task.py | 32 ++++++++++++++++++++------------ tests/test_task.py | 4 ++-- 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index c080b6c8..8c33a1cd 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -1,8 +1,7 @@ import logging from typing import cast -from aviary.env import Environment as _Environment -from aviary.env import Frame +from aviary.env import Environment, Frame from aviary.message import Message from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage @@ -91,7 +90,7 @@ def settings_to_tools( return tools -class Environment(_Environment[EnvironmentState]): +class PaperQAEnvironment(Environment[EnvironmentState]): """Environment connecting paper-qa's tools with state.""" def __init__( diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index fa1d4fd3..7443b15c 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -46,7 +46,7 @@ from paperqa.types import Answer from paperqa.utils import pqa_directory -from .env import Environment +from .env import PaperQAEnvironment from .helpers import litellm_get_search_query, table_formatter from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler from .search import SearchDocumentStorage, SearchIndex @@ -235,7 +235,7 @@ async def run_fake_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: - env = Environment(query, docs, **env_kwargs) + env = PaperQAEnvironment(query, docs, **env_kwargs) _, tools = await env.reset() if on_env_reset_callback: await on_env_reset_callback(env.state) @@ -281,7 +281,7 @@ async def run_aviary_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: - env = Environment(query, docs, **env_kwargs) + env = PaperQAEnvironment(query, docs, **env_kwargs) done = False try: @@ -345,7 +345,7 @@ async def run_ldp_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: - env = Environment(query, docs, **env_kwargs) + env = PaperQAEnvironment(query, docs, **env_kwargs) done = False try: diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index aaa21833..bb4c7c9f 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -1,12 +1,18 @@ +__all__ = [ + "ENV_NAME", + "TASK_DATASET_NAME", + "GradablePaperQAEnvironment", + "LitQATaskDataset", + "LitQAv2TaskDataset", + "LitQAv2TaskSplit", +] + from abc import ABC from collections.abc import Awaitable, Callable, Sequence from enum import StrEnum from typing import TYPE_CHECKING, assert_never -from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame -from aviary.env import ( - TaskDataset as _TaskDataset, -) +from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame, TaskDataset from aviary.message import Message from aviary.tools import ToolRequestMessage, ToolResponseMessage @@ -29,8 +35,7 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel from paperqa.types import Answer -from .env import POPULATE_FROM_SETTINGS -from .env import Environment as _Environment +from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment from .models import QueryRequest from .tools import GenerateAnswer @@ -38,7 +43,7 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] from ldp.data_structures import Trajectory -class GradableEnvironment(_Environment): +class GradablePaperQAEnvironment(PaperQAEnvironment): """Extended environment that can grade answers.""" def __init__( @@ -96,11 +101,14 @@ def export_frame(self) -> Frame: ENV_NAME = "paperqa-local" -ENV_REGISTRY[ENV_NAME] = GradableEnvironment.__module__, GradableEnvironment.__name__ +ENV_REGISTRY[ENV_NAME] = ( + GradablePaperQAEnvironment.__module__, + GradablePaperQAEnvironment.__name__, +) class LitQATaskDataset( - _TaskDataset[GradableEnvironment], ComputeTrajectoryMetricsMixin, ABC + TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC ): """ Abstract base class for a task dataset of LitQA v1 or v2 questions. @@ -127,7 +135,7 @@ def _make_gradable_environment( distractors: str | list[str], question: str, use_unsure: bool = True, - ) -> GradableEnvironment: + ) -> GradablePaperQAEnvironment: qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question( ideal=ideal, distractors=distractors, @@ -137,7 +145,7 @@ def _make_gradable_environment( ) query_request = self._base_query_request.model_copy() query_request.query = qa_prompt - return GradableEnvironment( + return GradablePaperQAEnvironment( query=query_request, evaluation_from_answer=evaluation_from_answer, rewards=self._rewards, @@ -183,7 +191,7 @@ def __init__( else: assert_never(split) - def get_new_env_by_idx(self, idx: int) -> GradableEnvironment: + def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment: return self._make_gradable_environment( ideal=self.data.iloc[idx].ideal, distractors=self.data.iloc[idx].distractors, diff --git a/tests/test_task.py b/tests/test_task.py index b99d8a16..5f8fb17b 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -6,7 +6,7 @@ from paperqa import Docs, QueryRequest, Settings from paperqa.agents.task import ( - GradableEnvironment, + GradablePaperQAEnvironment, LitQATaskDataset, LitQAv2TaskDataset, LitQAv2TaskSplit, @@ -40,7 +40,7 @@ def __init__(self, *args, **kwargs): ), ] - def get_new_env_by_idx(self, idx: int) -> GradableEnvironment: + def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment: return self._make_gradable_environment( ideal=self.data[idx][0], distractors=self.data[idx][1],