From 17c06a4733d79892826dcaf4ae7984f8b6bd87d2 Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 11 Sep 2024 20:40:36 -0700 Subject: [PATCH] Exposed 'make_status' helper function, and integrated into tests --- paperqa/agents/tools.py | 32 +++++++++++++++++++++++++++----- tests/test_agents.py | 17 +++++++++-------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/paperqa/agents/tools.py b/paperqa/agents/tools.py index 0453cb1f..51716958 100644 --- a/paperqa/agents/tools.py +++ b/paperqa/agents/tools.py @@ -18,6 +18,16 @@ logger = logging.getLogger(__name__) +def make_status( + total_paper_count: int, relevant_paper_count: int, evidence_count: int, cost: float +) -> str: + return ( + f"Status: Paper Count={total_paper_count}" + f" | Relevant Papers={relevant_paper_count} | Current Evidence={evidence_count}" + f" | Current Cost=${cost:.4f}" + ) + + class EnvironmentState(BaseModel): """State here contains documents and answer being populated.""" @@ -35,11 +45,23 @@ class EnvironmentState(BaseModel): @computed_field # type: ignore[prop-decorator] @property def status(self) -> str: - return ( - f"Status: Paper Count={len(self.docs.docs)} | Relevant Papers=" - f"{len({c.text.doc.dockey for c in self.answer.contexts if c.score > self.RELEVANT_SCORE_CUTOFF})}" - f" | Current Evidence={len([c for c in self.answer.contexts if c.score > self.RELEVANT_SCORE_CUTOFF])}" - f" | Current Cost=${self.answer.cost:.4f}" + return make_status( + total_paper_count=len(self.docs.docs), + relevant_paper_count=len( + { + c.text.doc.dockey + for c in self.answer.contexts + if c.score > self.RELEVANT_SCORE_CUTOFF + } + ), + evidence_count=len( + [ + c + for c in self.answer.contexts + if c.score > self.RELEVANT_SCORE_CUTOFF + ] + ), + cost=self.answer.cost, ) diff --git a/tests/test_agents.py b/tests/test_agents.py index 4b89e475..81a8ae52 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -27,6 +27,7 @@ GatherEvidence, GenerateAnswer, PaperSearch, + make_status, ) from paperqa.docs import Docs from paperqa.settings import AgentSettings, Settings @@ -121,14 +122,14 @@ async def test_successful_memory_agent() -> None: " ampicillin\n\nThe gen_answer tool output is visible to the user, so" " you do not need to restate the answer and can simply terminate if" " the answer looks sufficient. The current status of" - " evidence/papers/cost is Status: Paper Count=0 | Relevant Papers=0 |" - " Current Evidence=0 | Current Cost=$0.00\n\nTool request message ''" + " evidence/papers/cost is Status: " + f"{make_status(total_paper_count=0, relevant_paper_count=0, evidence_count=0, cost=0.0)}" # Started 0 + "\n\nTool request message ''" f" for tool call: {memory_id} with content" " 'paper_search(query='Acinetobacter lwoffii antibiotic resistance" - " evolution', min_year='None', max_year='None')'\n\nTool response" - " message 'Status: Paper Count=11 | Relevant Papers=0 | Current" - f" Evidence=0 | Current Cost=$0.0000' for tool call ID {memory_id} of" - " tool 'paper_search'" + " evolution', min_year='None', max_year='None')'\n\nTool response message '" + f"{make_status(total_paper_count=11, relevant_paper_count=0, evidence_count=0, cost=0.0)}" # Found 11 + f"' for tool call ID {memory_id} of tool 'paper_search'" ), input=( "Use the tools to answer the question: Q: Acinetobacter lwoffii has" @@ -138,8 +139,8 @@ async def test_successful_memory_agent() -> None: " ampicillin\n\nThe gen_answer tool output is visible to the user, so" " you do not need to restate the answer and can simply terminate if the" " answer looks sufficient. The current status of evidence/papers/cost" - " is Status: Paper Count=0 | Relevant Papers=0 | Current Evidence=0 |" - " Current Cost=$0.0000" + " is Status: " + f"{make_status(total_paper_count=0, relevant_paper_count=0, evidence_count=0, cost=0.0)}" ), output=( f"Tool request message '' for tool call: {memory_id} with content"