Skip to content

Commit

Permalink
Testing MemoryAgent and timeouts on ldp agents (#375)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Sep 13, 2024
1 parent 72f13b5 commit 870e2f6
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 29 deletions.
35 changes: 18 additions & 17 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from rich.console import Console

from paperqa.docs import Docs
from paperqa.settings import Settings
from paperqa.types import Answer
from paperqa.utils import pqa_directory

Expand All @@ -64,7 +65,7 @@ async def agent_query(
query: str | QueryRequest,
docs: Docs | None = None,
agent_type: str | type = DEFAULT_AGENT_TYPE,
**env_kwargs,
**runner_kwargs,
) -> AnswerResponse:
if isinstance(query, str):
query = QueryRequest(query=query)
Expand All @@ -78,7 +79,7 @@ async def agent_query(
storage=SearchDocumentStorage.JSON_MODEL_DUMP,
)

response = await run_agent(docs, query, agent_type, **env_kwargs)
response = await run_agent(docs, query, agent_type, **runner_kwargs)
agent_logger.debug(f"agent_response: {response}")

agent_logger.info(f"[bold blue]Answer: {response.answer.answer}[/bold blue]")
Expand All @@ -96,7 +97,7 @@ async def agent_query(


def to_aviary_tool_selector(
agent_type: str | type, query: QueryRequest
agent_type: str | type, settings: Settings
) -> ToolSelector | None:
"""Attempt to convert the agent type to an aviary ToolSelector."""
if agent_type is ToolSelector or (
Expand All @@ -110,15 +111,15 @@ def to_aviary_tool_selector(
)
):
return ToolSelector(
model_name=query.settings.agent.agent_llm,
acompletion=query.settings.get_agent_llm().router.acompletion,
**(query.settings.agent.agent_config or {}),
model_name=settings.agent.agent_llm,
acompletion=settings.get_agent_llm().router.acompletion,
**(settings.agent.agent_config or {}),
)
return None


async def to_ldp_agent(
agent_type: str | type, query: QueryRequest
agent_type: str | type, settings: Settings
) -> "Agent[SimpleAgentState] | None":
"""Attempt to convert the agent type to an ldp Agent."""
if not isinstance(agent_type, str): # Convert to fully qualified name
Expand All @@ -133,7 +134,7 @@ async def to_ldp_agent(

# TODO: support general agents
agent_cls = cast(type[Agent], locate(agent_type))
agent_settings = query.settings.agent
agent_settings = settings.agent
agent_llm, config = agent_settings.agent_llm, agent_settings.agent_config or {}
if issubclass(agent_cls, ReActAgent | MemoryAgent):
if (
Expand All @@ -157,12 +158,12 @@ async def to_ldp_agent(
)
)
return agent_cls(
llm_model={"model": agent_llm, "temperature": query.settings.temperature},
llm_model={"model": agent_llm, "temperature": settings.temperature},
**config,
)
if issubclass(agent_cls, SimpleAgent):
return agent_cls(
llm_model={"model": agent_llm, "temperature": query.settings.temperature},
llm_model={"model": agent_llm, "temperature": settings.temperature},
sys_prompt=agent_settings.agent_system_prompt,
**config,
)
Expand All @@ -178,7 +179,7 @@ async def run_agent(
docs: Docs,
query: QueryRequest,
agent_type: str | type = DEFAULT_AGENT_TYPE,
**env_kwargs,
**runner_kwargs,
) -> AnswerResponse:
"""
Run an agent.
Expand All @@ -188,7 +189,7 @@ async def run_agent(
query: Query to answer.
agent_type: Agent type (or fully qualified name to the type) to pass to
AgentType.get_agent, or "fake" to TODOC.
env_kwargs: Keyword arguments to pass to Environment instantiation.
runner_kwargs: Keyword arguments to pass to the runner.
Returns:
Tuple of resultant answer, token counts, and agent status.
Expand All @@ -203,14 +204,14 @@ async def run_agent(
)

if agent_type == "fake":
answer, agent_status = await run_fake_agent(query, docs, **env_kwargs)
elif tool_selector_or_none := to_aviary_tool_selector(agent_type, query):
answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs)
elif tool_selector_or_none := to_aviary_tool_selector(agent_type, query.settings):
answer, agent_status = await run_aviary_agent(
query, docs, tool_selector_or_none, **env_kwargs
query, docs, tool_selector_or_none, **runner_kwargs
)
elif ldp_agent_or_none := await to_ldp_agent(agent_type, query):
elif ldp_agent_or_none := await to_ldp_agent(agent_type, query.settings):
answer, agent_status = await run_ldp_agent(
query, docs, ldp_agent_or_none, **env_kwargs
query, docs, ldp_agent_or_none, **runner_kwargs
)
else:
raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.")
Expand Down
32 changes: 27 additions & 5 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
logger = logging.getLogger(__name__)


def make_status(
total_paper_count: int, relevant_paper_count: int, evidence_count: int, cost: float
) -> str:
return (
f"Status: Paper Count={total_paper_count}"
f" | Relevant Papers={relevant_paper_count} | Current Evidence={evidence_count}"
f" | Current Cost=${cost:.4f}"
)


class EnvironmentState(BaseModel):
"""State here contains documents and answer being populated."""

Expand All @@ -35,11 +45,23 @@ class EnvironmentState(BaseModel):
@computed_field # type: ignore[prop-decorator]
@property
def status(self) -> str:
return (
f"Status: Paper Count={len(self.docs.docs)} | Relevant Papers="
f"{len({c.text.doc.dockey for c in self.answer.contexts if c.score > self.RELEVANT_SCORE_CUTOFF})}"
f" | Current Evidence={len([c for c in self.answer.contexts if c.score > self.RELEVANT_SCORE_CUTOFF])}"
f" | Current Cost=${self.answer.cost:.4f}"
return make_status(
total_paper_count=len(self.docs.docs),
relevant_paper_count=len(
{
c.text.doc.dockey
for c in self.answer.contexts
if c.score > self.RELEVANT_SCORE_CUTOFF
}
),
evidence_count=len(
[
c
for c in self.answer.contexts
if c.score > self.RELEVANT_SCORE_CUTOFF
]
),
cost=self.answer.cost,
)


Expand Down
9 changes: 8 additions & 1 deletion paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,9 @@ def setup_default_logs() -> None:
"stream": "ext://sys.stdout",
},
},
# Lower level for httpx and LiteLLM
# Lower level for verbose logs
"loggers": {
"httpcore": {"level": "WARNING"},
"httpx": {"level": "WARNING"},
# SEE: https://github.com/BerriAI/litellm/issues/2256
"LiteLLM": {"level": "WARNING"},
Expand All @@ -474,3 +475,9 @@ def setup_default_logs() -> None:
},
}
)


def extract_thought(content: str | None) -> str:
"""Extract an Anthropic thought from a message's content."""
# SEE: https://regex101.com/r/bpJt05/1
return re.sub(r"<\/?thinking>", "", content or "")
98 changes: 92 additions & 6 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import itertools
import json
import re
import time
from pathlib import Path
from typing import Any, cast
from unittest.mock import patch

import ldp.agent
import pytest
from aviary.tools import ToolsAdapter, ToolSelector
from ldp.agent import SimpleAgent
from aviary.tools import ToolRequestMessage, ToolsAdapter, ToolSelector
from ldp.agent import MemoryAgent, SimpleAgent
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.ops import OpResult
from ldp.llms import EmbeddingModel, MultipleCompletionLLMModel
from pydantic import ValidationError
from pytest_subtests import SubTests

Expand All @@ -22,11 +27,12 @@
GatherEvidence,
GenerateAnswer,
PaperSearch,
make_status,
)
from paperqa.docs import Docs
from paperqa.settings import AgentSettings, Settings
from paperqa.types import Answer, Context, Doc, Text
from paperqa.utils import get_year, md5sum
from paperqa.utils import extract_thought, get_year, md5sum


@pytest.mark.asyncio
Expand Down Expand Up @@ -81,7 +87,7 @@ async def test_agent_types(
agent_test_settings.llm = "gpt-4o-mini"
agent_test_settings.summary_llm = "gpt-4o-mini"
agent_test_settings.agent.agent_prompt += (
"\n\n Call each tool once in appropriate order and "
"\n\nCall each tool once in appropriate order and"
" accept the answer for now, as we're in debug mode."
)
request = QueryRequest(query=question, settings=agent_test_settings)
Expand All @@ -104,15 +110,95 @@ async def test_agent_types(


@pytest.mark.asyncio
async def test_timeout(agent_test_settings: Settings) -> None:
async def test_successful_memory_agent(agent_test_settings: Settings) -> None:
tic = time.perf_counter()
memory_id = "call_Wtmv95JbNcQ2nRQCZBoOfcJy" # Stub value
memory = Memory(
query=(
"Use the tools to answer the question: How can you use XAI for chemical"
" property prediction?\n\nThe gen_answer tool output is visible to the"
" user, so you do not need to restate the answer and can simply"
" terminate if the answer looks sufficient. The current status of"
" evidence/papers/cost is "
f"{make_status(total_paper_count=0, relevant_paper_count=0, evidence_count=0, cost=0.0)}" # Started 0
"\n\nTool request message '' for tool calls: paper_search(query='XAI for"
" chemical property prediction', min_year='2018', max_year='2024')"
f" [id={memory_id}]\n\nTool response message '"
f"{make_status(total_paper_count=2, relevant_paper_count=0, evidence_count=0, cost=0.0)}" # Found 2
f"' for tool call ID {memory_id} of tool 'paper_search'"
),
input=(
"Use the tools to answer the question: How can you use XAI for chemical"
" property prediction?\n\nThe gen_answer tool output is visible to the"
" user, so you do not need to restate the answer and can simply terminate"
" if the answer looks sufficient. The current status of"
" evidence/papers/cost is "
f"{make_status(total_paper_count=0, relevant_paper_count=0, evidence_count=0, cost=0.0)}"
),
output=(
"Tool request message '' for tool calls: paper_search(query='XAI for"
" chemical property prediction', min_year='2018', max_year='2024')"
f" [id={memory_id}]"
),
value=5.0, # Stub value
template="Input: {input}\n\nOutput: {output}\n\nDiscounted Reward: {value}",
)
memory_model = UIndexMemoryModel(
embedding_model=EmbeddingModel.from_name("text-embedding-3-small")
)
await memory_model.add_memory(memory)
serialized_memory_model = memory_model.model_dump(exclude_none=True)
query = QueryRequest(
query="How can you use XAI for chemical property prediction?",
settings=agent_test_settings,
)
# NOTE: use Claude 3 for its <thinking> feature, testing regex replacement of it
query.settings.agent.agent_llm = "claude-3-5-sonnet-20240620"
query.settings.agent.agent_config = {
"memories": serialized_memory_model.pop("memories"),
"memory_model": serialized_memory_model,
}

thoughts: list[str] = []
orig_llm_model_call = MultipleCompletionLLMModel.call

async def on_agent_action(action: OpResult[ToolRequestMessage], *_) -> None:
thoughts.append(extract_thought(content=action.value.content))

async def llm_model_call(*args, **kwargs):
# NOTE: "required" will not lead to thoughts being emitted, it has to be "auto"
# https://docs.anthropic.com/en/docs/build-with-claude/tool-use#chain-of-thought
kwargs.pop("tool_choice", MultipleCompletionLLMModel.TOOL_CHOICE_REQUIRED)
return await orig_llm_model_call(*args, tool_choice="auto", **kwargs) # type: ignore[misc]

with patch.object(
MultipleCompletionLLMModel, "call", side_effect=llm_model_call, autospec=True
):
response = await agent_query(
query,
Docs(),
agent_type=f"{ldp.agent.__name__}.{MemoryAgent.__name__}",
on_agent_action_callback=on_agent_action,
)
assert response.status == AgentStatus.SUCCESS, "Agent did not succeed"
assert (
time.perf_counter() - tic <= query.settings.agent.timeout
), "Agent should not have timed out"
assert all(thought and "<thinking>" not in thought for thought in thoughts)


@pytest.mark.parametrize("agent_type", [ToolSelector, SimpleAgent])
@pytest.mark.asyncio
async def test_timeout(agent_test_settings: Settings, agent_type: str | type) -> None:
agent_test_settings.prompts.pre = None
agent_test_settings.agent.timeout = 0.001
agent_test_settings.llm = "gpt-4o-mini"
agent_test_settings.agent.tool_names = {"gen_answer"}
response = await agent_query(
QueryRequest(
query="Are COVID-19 vaccines effective?", settings=agent_test_settings
)
),
agent_type=agent_type,
)
# ensure that GenerateAnswerTool was called
assert response.status == AgentStatus.TIMEOUT, "Agent did not timeout"
Expand Down

0 comments on commit 870e2f6

Please sign in to comment.