Skip to content

Commit

Permalink
Expanded test_evaluation to have a possible zero-shot evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Nov 1, 2024
1 parent 89f2c49 commit d276e8b
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pytest
from aviary.env import TASK_DATASET_REGISTRY, TaskConfig, TaskDataset
from ldp.agent import SimpleAgent
from ldp.alg.callbacks import MeanMetricsCallback
from ldp.alg.callbacks import MeanMetricsCallback, StoreTrajectoriesCallback
from ldp.alg.runners import Evaluator, EvaluatorConfig
from pytest_subtests import SubTests

from paperqa import Docs, QueryRequest, Settings
from paperqa.agents import get_directory_index
Expand All @@ -16,6 +17,7 @@
LitQAv2TaskDataset,
LitQAv2TaskSplit,
)
from paperqa.agents.tools import GenerateAnswer


@pytest.fixture(name="base_query_request")
Expand Down Expand Up @@ -106,7 +108,9 @@ async def test_can_validate_stub_dataset_sources(
)

@pytest.mark.asyncio
async def test_evaluation(self, base_query_request: QueryRequest) -> None:
async def test_evaluation(
self, subtests: SubTests, base_query_request: QueryRequest
) -> None:
await get_directory_index(settings=base_query_request.settings) # Build
docs = Docs()
# Why are we constructing a TaskConfig here using a serialized QueryRequest and
Expand Down Expand Up @@ -150,6 +154,31 @@ async def test_evaluation(self, base_query_request: QueryRequest) -> None:
isinstance(metrics_callback.eval_means["reward"], float) > 0
), "Expected some wins"

with subtests.test(msg="zero-shot"):
# Confirm we can just directly call gen_answer
base_query_request.settings.agent.tool_names = {
GenerateAnswer.gen_answer.__name__
}
base_query_request.settings.answer.get_evidence_if_no_contexts = False
dataset = LitQAv2TaskDataset(
base_query=base_query_request, split=LitQAv2TaskSplit.EVAL
)
dataset.data = dataset.data[:2] # Save the world: just use two questions
storage_callback = StoreTrajectoriesCallback()
evaluator = Evaluator(
config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=2),
agent=SimpleAgent(),
dataset=dataset,
callbacks=[storage_callback],
)
await evaluator.evaluate()
for traj in storage_callback.eval_trajectories:
for step in traj.steps:
assert all(
tc.function.name == GenerateAnswer.gen_answer.__name__
for tc in step.action.value.tool_calls
)

@pytest.mark.vcr
@pytest.mark.asyncio
async def test_tool_failure(self, base_query_request: QueryRequest) -> None:
Expand Down

0 comments on commit d276e8b

Please sign in to comment.