Skip to content

Commit

Permalink
Renamed Environment to PaperQAEnvironment and defined __all__ to enco…
Browse files Browse the repository at this point in the history
…urage correct import paths
  • Loading branch information
jamesbraza committed Sep 14, 2024
1 parent fc658d3 commit c41f48f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
5 changes: 2 additions & 3 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from typing import cast

from aviary.env import Environment as _Environment
from aviary.env import Frame
from aviary.env import Environment, Frame
from aviary.message import Message
from aviary.tools import Tool, ToolRequestMessage, ToolResponseMessage

Expand Down Expand Up @@ -91,7 +90,7 @@ def settings_to_tools(
return tools


class Environment(_Environment[EnvironmentState]):
class PaperQAEnvironment(Environment[EnvironmentState]):
"""Environment connecting paper-qa's tools with state."""

def __init__(
Expand Down
8 changes: 4 additions & 4 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from paperqa.types import Answer
from paperqa.utils import pqa_directory

from .env import Environment
from .env import PaperQAEnvironment
from .helpers import litellm_get_search_query, table_formatter
from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler
from .search import SearchDocumentStorage, SearchIndex
Expand Down Expand Up @@ -235,7 +235,7 @@ async def run_fake_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
_, tools = await env.reset()
if on_env_reset_callback:
await on_env_reset_callback(env.state)
Expand Down Expand Up @@ -281,7 +281,7 @@ async def run_aviary_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

try:
Expand Down Expand Up @@ -345,7 +345,7 @@ async def run_ldp_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
env = Environment(query, docs, **env_kwargs)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

try:
Expand Down
32 changes: 20 additions & 12 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
__all__ = [
"ENV_NAME",
"TASK_DATASET_NAME",
"GradablePaperQAEnvironment",
"LitQATaskDataset",
"LitQAv2TaskDataset",
"LitQAv2TaskSplit",
]

from abc import ABC
from collections.abc import Awaitable, Callable, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, assert_never

from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame
from aviary.env import (
TaskDataset as _TaskDataset,
)
from aviary.env import ENV_REGISTRY, TASK_DATASET_REGISTRY, Frame, TaskDataset
from aviary.message import Message
from aviary.tools import ToolRequestMessage, ToolResponseMessage

Expand All @@ -29,16 +35,15 @@ class ComputeTrajectoryMetricsMixin: # type: ignore[no-redef]
from paperqa.llms import EmbeddingModel, LiteLLMModel, LLMModel
from paperqa.types import Answer

from .env import POPULATE_FROM_SETTINGS
from .env import Environment as _Environment
from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
from .tools import GenerateAnswer

if TYPE_CHECKING:
from ldp.data_structures import Trajectory


class GradableEnvironment(_Environment):
class GradablePaperQAEnvironment(PaperQAEnvironment):
"""Extended environment that can grade answers."""

def __init__(
Expand Down Expand Up @@ -96,11 +101,14 @@ def export_frame(self) -> Frame:


ENV_NAME = "paperqa-local"
ENV_REGISTRY[ENV_NAME] = GradableEnvironment.__module__, GradableEnvironment.__name__
ENV_REGISTRY[ENV_NAME] = (
GradablePaperQAEnvironment.__module__,
GradablePaperQAEnvironment.__name__,
)


class LitQATaskDataset(
_TaskDataset[GradableEnvironment], ComputeTrajectoryMetricsMixin, ABC
TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
):
"""
Abstract base class for a task dataset of LitQA v1 or v2 questions.
Expand All @@ -127,7 +135,7 @@ def _make_gradable_environment(
distractors: str | list[str],
question: str,
use_unsure: bool = True,
) -> GradableEnvironment:
) -> GradablePaperQAEnvironment:
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
Expand All @@ -137,7 +145,7 @@ def _make_gradable_environment(
)
query_request = self._base_query_request.model_copy()
query_request.query = qa_prompt
return GradableEnvironment(
return GradablePaperQAEnvironment(
query=query_request,
evaluation_from_answer=evaluation_from_answer,
rewards=self._rewards,
Expand Down Expand Up @@ -183,7 +191,7 @@ def __init__(
else:
assert_never(split)

def get_new_env_by_idx(self, idx: int) -> GradableEnvironment:
def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
return self._make_gradable_environment(
ideal=self.data.iloc[idx].ideal,
distractors=self.data.iloc[idx].distractors,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from paperqa import Docs, QueryRequest, Settings
from paperqa.agents.task import (
GradableEnvironment,
GradablePaperQAEnvironment,
LitQATaskDataset,
LitQAv2TaskDataset,
LitQAv2TaskSplit,
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(self, *args, **kwargs):
),
]

def get_new_env_by_idx(self, idx: int) -> GradableEnvironment:
def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
return self._make_gradable_environment(
ideal=self.data[idx][0],
distractors=self.data[idx][1],
Expand Down

0 comments on commit c41f48f

Please sign in to comment.