diff --git a/tests/test_agents.py b/tests/test_agents.py index 2c18d176..4b89e475 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -3,13 +3,18 @@ import itertools import json import re +import time from pathlib import Path from typing import Any, cast from unittest.mock import patch +import ldp.agent import pytest -from aviary.tools import ToolsAdapter, ToolSelector +from aviary.tools import ToolRequestMessage, ToolsAdapter, ToolSelector from ldp.agent import SimpleAgent +from ldp.graph.memory import Memory, UIndexMemoryModel +from ldp.graph.ops import OpResult +from ldp.llms import EmbeddingModel from pydantic import ValidationError from pytest_subtests import SubTests @@ -26,7 +31,7 @@ from paperqa.docs import Docs from paperqa.settings import AgentSettings, Settings from paperqa.types import Answer, Context, Doc, Text -from paperqa.utils import get_year, md5sum +from paperqa.utils import extract_thought, get_year, md5sum @pytest.mark.asyncio @@ -103,6 +108,86 @@ async def test_agent_types( assert response.answer.cost > 0, "Expected nonzero cost" +@pytest.mark.asyncio +async def test_successful_memory_agent() -> None: + tic = time.perf_counter() + memory_id = "call_Wtmv95JbNcQ2nRQCZBoOfcJy" # Stub value + memory = Memory( + query=( + "Use the tools to answer the question: Q: Acinetobacter lwoffii has" + " been evolved in the lab to be resistant to which of these" + " antibiotics?\n\nOptions:\nA) gentamicin\nB) Insufficient information" + " to answer this question\nC) meropenem\nD) ciproflaxin\nE)" + " ampicillin\n\nThe gen_answer tool output is visible to the user, so" + " you do not need to restate the answer and can simply terminate if" + " the answer looks sufficient. The current status of" + " evidence/papers/cost is Status: Paper Count=0 | Relevant Papers=0 |" + " Current Evidence=0 | Current Cost=$0.00\n\nTool request message ''" + f" for tool call: {memory_id} with content" + " 'paper_search(query='Acinetobacter lwoffii antibiotic resistance" + " evolution', min_year='None', max_year='None')'\n\nTool response" + " message 'Status: Paper Count=11 | Relevant Papers=0 | Current" + f" Evidence=0 | Current Cost=$0.0000' for tool call ID {memory_id} of" + " tool 'paper_search'" + ), + input=( + "Use the tools to answer the question: Q: Acinetobacter lwoffii has" + " been evolved in the lab to be resistant to which of these" + " antibiotics?\n\nOptions:\nA) gentamicin\nB) Insufficient information" + " to answer this question\nC) meropenem\nD) ciproflaxin\nE)" + " ampicillin\n\nThe gen_answer tool output is visible to the user, so" + " you do not need to restate the answer and can simply terminate if the" + " answer looks sufficient. The current status of evidence/papers/cost" + " is Status: Paper Count=0 | Relevant Papers=0 | Current Evidence=0 |" + " Current Cost=$0.0000" + ), + output=( + f"Tool request message '' for tool call: {memory_id} with content" + " 'paper_search(query='Acinetobacter lwoffii" + " antibiotic resistance evolution', min_year='None'," + " max_year='None')'" + ), + value=5.61, + template="Input: {input}\n\nOutput: {output}\n\nDiscounted Reward: {value}", + ) + memory_model = UIndexMemoryModel( + embedding_model=EmbeddingModel.from_name("text-embedding-3-small") + ) + await memory_model.add_memory(memory) + serialized_memory_model = memory_model.model_dump(exclude_none=True) + query = QueryRequest( + query=( + "Use the tools to answer the question: Q: Acinetobacter lwoffii has" + " been evolved in the lab to be resistant to which of these" + " antibiotics?\n\nOptions:\nA) gentamicin\nB) Insufficient information" + " to answer this question\nC) meropenem\nD) ciproflaxin\nE) ampicillin" + ) + ) + # NOTE: use Claude 3 for its feature, testing regex replacement of it + query.settings.agent.agent_llm = "claude-3-opus-20240229" + query.settings.agent.agent_config = { + "memories": serialized_memory_model.pop("memories"), + "memory_model": serialized_memory_model, + } + + thoughts: list[str] = [] + + async def on_agent_action(action: OpResult[ToolRequestMessage], *_) -> None: + thoughts.append(extract_thought(content=action.value.content)) + + response = await agent_query( + query, + Docs(), + agent_type=f"{ldp.agent.__name__}.MemoryAgent", + on_agent_action_callback=on_agent_action, + ) + assert response.status == AgentStatus.SUCCESS, "Agent did not succeed" + assert ( + time.perf_counter() - tic <= query.settings.agent.timeout + ), "Agent should not have timed out" + assert all("" not in thought for thought in thoughts) + + @pytest.mark.parametrize("agent_type", [ToolSelector, SimpleAgent]) @pytest.mark.asyncio async def test_timeout(agent_test_settings: Settings, agent_type: str | type) -> None: