diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index bb4c7c9f..414aef78 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -119,12 +119,14 @@ class LitQATaskDataset( def __init__( self, - base_query_request: QueryRequest, + base_query: QueryRequest | None = None, + base_docs: Docs | None = None, rewards: Sequence[float] = DEFAULT_REWARD_DISTRIBUTION, eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME, **env_kwargs, ): - self._base_query_request = base_query_request + self._base_query = base_query or QueryRequest() + self._base_docs = base_docs or Docs() self._rewards = rewards self._env_kwargs = env_kwargs self._eval_model = eval_model @@ -143,10 +145,11 @@ def _make_gradable_environment( use_unsure=use_unsure, eval_model=self._eval_model, ) - query_request = self._base_query_request.model_copy() - query_request.query = qa_prompt + query = self._base_query.model_copy() + query.query = qa_prompt return GradablePaperQAEnvironment( - query=query_request, + query=query, + docs=self._base_docs.model_copy(), evaluation_from_answer=evaluation_from_answer, rewards=self._rewards, **self._env_kwargs, diff --git a/tests/test_task.py b/tests/test_task.py index 5f8fb17b..f6455ff9 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -69,9 +69,7 @@ def test___len__( expected_length: int, base_query_request: QueryRequest, ) -> None: - task_dataset = LitQAv2TaskDataset( - base_query_request=base_query_request, split=split - ) + task_dataset = LitQAv2TaskDataset(base_query=base_query_request, split=split) assert len(task_dataset) == expected_length @pytest.mark.asyncio @@ -79,9 +77,7 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None: agent = SimpleAgent() docs = Docs() dataset = TaskDataset.from_name( - STUB_TASK_DATASET_NAME, - base_query_request=base_query_request, - docs=docs, + STUB_TASK_DATASET_NAME, base_query=base_query_request, base_docs=docs ) metrics_callback = MeanMetricsCallback(eval_dataset=dataset) @@ -96,5 +92,5 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None: assert ( not base_query_request.query ), "Should not have mutated query in base request" - assert docs.docs, "Expected to have added content" + assert not docs.docs, "Should not have mutated docs in base docs" assert isinstance(metrics_callback.eval_means["reward"], float)