From ddea8c370f44860e4a8011b93f68e876c4ae265b Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 18 Sep 2024 13:02:37 -0700 Subject: [PATCH] Allowing case insensitive `"fake"` agent type (#437) --- paperqa/agents/main.py | 5 ++++- tests/test_agents.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index 45097315..f2f968af 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -75,6 +75,9 @@ async def agent_query( return response +FAKE_AGENT_TYPE = "fake" # No agent, just invoke tools in deterministic order + + async def run_agent( docs: Docs, query: QueryRequest, @@ -103,7 +106,7 @@ async def run_agent( f" query {query.model_dump()}." ) - if agent_type == "fake": + if isinstance(agent_type, str) and agent_type.lower() == FAKE_AGENT_TYPE: answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs) elif tool_selector_or_none := query.settings.make_aviary_tool_selector(agent_type): answer, agent_status = await run_aviary_agent( diff --git a/tests/test_agents.py b/tests/test_agents.py index 7d5f7011..ebc241fe 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -20,6 +20,7 @@ from paperqa.agents import agent_query from paperqa.agents.env import settings_to_tools +from paperqa.agents.main import FAKE_AGENT_TYPE from paperqa.agents.models import AgentStatus, AnswerResponse, QueryRequest from paperqa.agents.search import FAILED_DOCUMENT_ADD_ID, get_directory_index from paperqa.agents.tools import ( @@ -78,7 +79,7 @@ async def test_get_directory_index_w_manifest( @pytest.mark.flaky(reruns=2, only_rerun=["AssertionError", "httpx.RemoteProtocolError"]) -@pytest.mark.parametrize("agent_type", ["fake", ToolSelector, SimpleAgent]) +@pytest.mark.parametrize("agent_type", [FAKE_AGENT_TYPE, ToolSelector, SimpleAgent]) @pytest.mark.asyncio async def test_agent_types( agent_test_settings: Settings, agent_type: str | type @@ -102,8 +103,7 @@ async def test_agent_types( assert response.answer.question == question agent_llm = request.settings.agent.agent_llm # TODO: once LDP can track tokens, we can remove this check - if agent_type not in {"fake", SimpleAgent}: - print(response.answer.token_counts) + if agent_type not in {FAKE_AGENT_TYPE, SimpleAgent}: assert ( response.answer.token_counts[agent_llm][0] > 1000 ), "Expected many prompt tokens" @@ -232,7 +232,7 @@ async def test_propagate_options(agent_test_settings: Settings) -> None: query = QueryRequest( query="What is is a self-explanatory model?", settings=agent_test_settings ) - response = await agent_query(query, agent_type="fake") + response = await agent_query(query, agent_type=FAKE_AGENT_TYPE) assert response.status == AgentStatus.SUCCESS, "Agent did not succeed" result = response.answer assert len(result.answer) > 200, "Answer did not return any results"