Skip to content

Commit

Permalink
Allowing case insensitive "fake" agent type (#437)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Sep 18, 2024
1 parent e8678a3 commit ddea8c3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
5 changes: 4 additions & 1 deletion paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ async def agent_query(
return response


FAKE_AGENT_TYPE = "fake" # No agent, just invoke tools in deterministic order


async def run_agent(
docs: Docs,
query: QueryRequest,
Expand Down Expand Up @@ -103,7 +106,7 @@ async def run_agent(
f" query {query.model_dump()}."
)

if agent_type == "fake":
if isinstance(agent_type, str) and agent_type.lower() == FAKE_AGENT_TYPE:
answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs)
elif tool_selector_or_none := query.settings.make_aviary_tool_selector(agent_type):
answer, agent_status = await run_aviary_agent(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from paperqa.agents import agent_query
from paperqa.agents.env import settings_to_tools
from paperqa.agents.main import FAKE_AGENT_TYPE
from paperqa.agents.models import AgentStatus, AnswerResponse, QueryRequest
from paperqa.agents.search import FAILED_DOCUMENT_ADD_ID, get_directory_index
from paperqa.agents.tools import (
Expand Down Expand Up @@ -78,7 +79,7 @@ async def test_get_directory_index_w_manifest(


@pytest.mark.flaky(reruns=2, only_rerun=["AssertionError", "httpx.RemoteProtocolError"])
@pytest.mark.parametrize("agent_type", ["fake", ToolSelector, SimpleAgent])
@pytest.mark.parametrize("agent_type", [FAKE_AGENT_TYPE, ToolSelector, SimpleAgent])
@pytest.mark.asyncio
async def test_agent_types(
agent_test_settings: Settings, agent_type: str | type
Expand All @@ -102,8 +103,7 @@ async def test_agent_types(
assert response.answer.question == question
agent_llm = request.settings.agent.agent_llm
# TODO: once LDP can track tokens, we can remove this check
if agent_type not in {"fake", SimpleAgent}:
print(response.answer.token_counts)
if agent_type not in {FAKE_AGENT_TYPE, SimpleAgent}:
assert (
response.answer.token_counts[agent_llm][0] > 1000
), "Expected many prompt tokens"
Expand Down Expand Up @@ -232,7 +232,7 @@ async def test_propagate_options(agent_test_settings: Settings) -> None:
query = QueryRequest(
query="What is is a self-explanatory model?", settings=agent_test_settings
)
response = await agent_query(query, agent_type="fake")
response = await agent_query(query, agent_type=FAKE_AGENT_TYPE)
assert response.status == AgentStatus.SUCCESS, "Agent did not succeed"
result = response.answer
assert len(result.answer) > 200, "Answer did not return any results"
Expand Down

0 comments on commit ddea8c3

Please sign in to comment.