Skip to content

Commit

Permalink
Preventing Environments from sharing one Docs (#425)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Sep 17, 2024
1 parent fd1f8e9 commit 3327821
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
13 changes: 8 additions & 5 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,14 @@ class LitQATaskDataset(

def __init__(
self,
base_query_request: QueryRequest,
base_query: QueryRequest | None = None,
base_docs: Docs | None = None,
rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION,
eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME,
**env_kwargs,
):
self._base_query_request = base_query_request
self._base_query = base_query or QueryRequest()
self._base_docs = base_docs or Docs()
self._rewards = rewards
self._env_kwargs = env_kwargs
self._eval_model = eval_model
Expand All @@ -143,10 +145,11 @@ def _make_gradable_environment(
use_unsure=use_unsure,
eval_model=self._eval_model,
)
query_request = self._base_query_request.model_copy()
query_request.query = qa_prompt
query = self._base_query.model_copy()
query.query = qa_prompt
return GradablePaperQAEnvironment(
query=query_request,
query=query,
docs=self._base_docs.model_copy(),
evaluation_from_answer=evaluation_from_answer,
rewards=self._rewards,
**self._env_kwargs,
Expand Down
10 changes: 3 additions & 7 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,15 @@ def test___len__(
expected_length: int,
base_query_request: QueryRequest,
) -> None:
task_dataset = LitQAv2TaskDataset(
base_query_request=base_query_request, split=split
)
task_dataset = LitQAv2TaskDataset(base_query=base_query_request, split=split)
assert len(task_dataset) == expected_length

@pytest.mark.asyncio
async def test_evaluation(self, base_query_request: QueryRequest) -> None:
agent = SimpleAgent()
docs = Docs()
dataset = TaskDataset.from_name(
STUB_TASK_DATASET_NAME,
base_query_request=base_query_request,
docs=docs,
STUB_TASK_DATASET_NAME, base_query=base_query_request, base_docs=docs
)
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)

Expand All @@ -96,5 +92,5 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None:
assert (
not base_query_request.query
), "Should not have mutated query in base request"
assert docs.docs, "Expected to have added content"
assert not docs.docs, "Should not have mutated docs in base docs"
assert isinstance(metrics_callback.eval_means["reward"], float)

0 comments on commit 3327821

Please sign in to comment.