diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index c080b6c88..8c33a1cdf 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -1,8 +1,7 @@ import logging from typing import cast -from aviary.env import Environment as _Environment -from aviary.env import Frame +from aviary.env import Environment, Frame from aviary.message import Message from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage @@ -91,7 +90,7 @@ def settings_to_tools( return tools -class Environment(_Environment[EnvironmentState]): +class PaperQAEnvironment(Environment[EnvironmentState]): """Environment connecting paper-qa's tools with state.""" def __init__( diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index fa1d4fd3d..7443b15ce 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -46,7 +46,7 @@ from paperqa.types import Answer from paperqa.utils import pqa_directory -from .env import Environment +from .env import PaperQAEnvironment from .helpers import litellm_get_search_query, table_formatter from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler from .search import SearchDocumentStorage, SearchIndex @@ -235,7 +235,7 @@ async def run_fake_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: - env = Environment(query, docs, **env_kwargs) + env = PaperQAEnvironment(query, docs, **env_kwargs) _, tools = await env.reset() if on_env_reset_callback: await on_env_reset_callback(env.state) @@ -281,7 +281,7 @@ async def run_aviary_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: - env = Environment(query, docs, **env_kwargs) + env = PaperQAEnvironment(query, docs, **env_kwargs) done = False try: @@ -345,7 +345,7 @@ async def run_ldp_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: - env = Environment(query, docs, **env_kwargs) + env = PaperQAEnvironment(query, docs, **env_kwargs) done = False try: diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index aaa218331..bb4c7c9fb 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -1,12 +1,18 @@ +__all__ = [ + "ENV_NAME", + "TASK_DATASET_NAME", + "GradablePaperQAEnvironment", + "LitQATaskDataset", + "LitQAv2TaskDataset", + "LitQAv2TaskSplit", +] + from abc import ABC from collections.abc import Awaitable, Callable, Sequence from enum import StrEnum from typing import TYPE_CHECKING, assert_never -from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame -from aviary.env import ( - TaskDataset as _TaskDataset, -) +from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame, TaskDataset from aviary.message import Message from aviary.tools import ToolRequestMessage, ToolResponseMessage @@ -29,8 +35,7 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel from paperqa.types import Answer -from .env import POPULATE_FROM_SETTINGS -from .env import Environment as _Environment +from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment from .models import QueryRequest from .tools import GenerateAnswer @@ -38,7 +43,7 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef] from ldp.data_structures import Trajectory -class GradableEnvironment(_Environment): +class GradablePaperQAEnvironment(PaperQAEnvironment): """Extended environment that can grade answers.""" def __init__( @@ -96,11 +101,14 @@ def export_frame(self) -> Frame: ENV_NAME = "paperqa-local" -ENV_REGISTRY[ENV_NAME] = GradableEnvironment.__module__, GradableEnvironment.__name__ +ENV_REGISTRY[ENV_NAME] = ( + GradablePaperQAEnvironment.__module__, + GradablePaperQAEnvironment.__name__, +) class LitQATaskDataset( - _TaskDataset[GradableEnvironment], ComputeTrajectoryMetricsMixin, ABC + TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC ): """ Abstract base class for a task dataset of LitQA v1 or v2 questions. @@ -127,7 +135,7 @@ def _make_gradable_environment( distractors: str | list[str], question: str, use_unsure: bool = True, - ) -> GradableEnvironment: + ) -> GradablePaperQAEnvironment: qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question( ideal=ideal, distractors=distractors, @@ -137,7 +145,7 @@ def _make_gradable_environment( ) query_request = self._base_query_request.model_copy() query_request.query = qa_prompt - return GradableEnvironment( + return GradablePaperQAEnvironment( query=query_request, evaluation_from_answer=evaluation_from_answer, rewards=self._rewards, @@ -183,7 +191,7 @@ def __init__( else: assert_never(split) - def get_new_env_by_idx(self, idx: int) -> GradableEnvironment: + def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment: return self._make_gradable_environment( ideal=self.data.iloc[idx].ideal, distractors=self.data.iloc[idx].distractors, diff --git a/tests/test_task.py b/tests/test_task.py index b99d8a169..5f8fb17bd 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -6,7 +6,7 @@ from paperqa import Docs, QueryRequest, Settings from paperqa.agents.task import ( - GradableEnvironment, + GradablePaperQAEnvironment, LitQATaskDataset, LitQAv2TaskDataset, LitQAv2TaskSplit, @@ -40,7 +40,7 @@ def __init__(self, *args, **kwargs): ), ] - def get_new_env_by_idx(self, idx: int) -> GradableEnvironment: + def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment: return self._make_gradable_environment( ideal=self.data[idx][0], distractors=self.data[idx][1],