Skip to content

Commit

Permalink
Created a test of MemoryAgent with thought extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Sep 11, 2024
1 parent 24746f2 commit 726c60d
Showing 1 changed file with 87 additions and 2 deletions.
89 changes: 87 additions & 2 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import itertools
import json
import re
import time
from pathlib import Path
from typing import Any, cast
from unittest.mock import patch

import ldp.agent
import pytest
from aviary.tools import ToolsAdapter, ToolSelector
from aviary.tools import ToolRequestMessage, ToolsAdapter, ToolSelector
from ldp.agent import SimpleAgent
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.ops import OpResult
from ldp.llms import EmbeddingModel
from pydantic import ValidationError
from pytest_subtests import SubTests

Expand All @@ -26,7 +31,7 @@
from paperqa.docs import Docs
from paperqa.settings import AgentSettings, Settings
from paperqa.types import Answer, Context, Doc, Text
from paperqa.utils import get_year, md5sum
from paperqa.utils import extract_thought, get_year, md5sum


@pytest.mark.asyncio
Expand Down Expand Up @@ -103,6 +108,86 @@ async def test_agent_types(
assert response.answer.cost > 0, "Expected nonzero cost"


@pytest.mark.asyncio
async def test_successful_memory_agent() -> None:
tic = time.perf_counter()
memory_id = "call_Wtmv95JbNcQ2nRQCZBoOfcJy" # Stub value
memory = Memory(
query=(
"Use the tools to answer the question: Q: Acinetobacter lwoffii has"
" been evolved in the lab to be resistant to which of these"
" antibiotics?\n\nOptions:\nA) gentamicin\nB) Insufficient information"
" to answer this question\nC) meropenem\nD) ciproflaxin\nE)"
" ampicillin\n\nThe gen_answer tool output is visible to the user, so"
" you do not need to restate the answer and can simply terminate if"
" the answer looks sufficient. The current status of"
" evidence/papers/cost is Status: Paper Count=0 | Relevant Papers=0 |"
" Current Evidence=0 | Current Cost=$0.00\n\nTool request message ''"
f" for tool call: {memory_id} with content"
" 'paper_search(query='Acinetobacter lwoffii antibiotic resistance"
" evolution', min_year='None', max_year='None')'\n\nTool response"
" message 'Status: Paper Count=11 | Relevant Papers=0 | Current"
f" Evidence=0 | Current Cost=$0.00' for tool call ID {memory_id} of"
" tool 'paper_search'"
),
input=(
"Use the tools to answer the question: Q: Acinetobacter lwoffii has"
" been evolved in the lab to be resistant to which of these"
" antibiotics?\n\nOptions:\nA) gentamicin\nB) Insufficient information"
" to answer this question\nC) meropenem\nD) ciproflaxin\nE)"
" ampicillin\n\nThe gen_answer tool output is visible to the user, so"
" you do not need to restate the answer and can simply terminate if the"
" answer looks sufficient. The current status of evidence/papers/cost"
" is Status: Paper Count=0 | Relevant Papers=0 | Current Evidence=0 |"
" Current Cost=$0.00"
),
output=(
f"Tool request message '' for tool call: {memory_id} with content"
" 'paper_search(query='Acinetobacter lwoffii"
" antibiotic resistance evolution', min_year='None',"
" max_year='None')'"
),
value=5.61,
template="Input: {input}\n\nOutput: {output}\n\nDiscounted Reward: {value}",
)
memory_model = UIndexMemoryModel(
embedding_model=EmbeddingModel.from_name("text-embedding-3-small")
)
await memory_model.add_memory(memory)
serialized_memory_model = memory_model.model_dump(exclude_none=True)
query = QueryRequest(
query=(
"Use the tools to answer the question: Q: Acinetobacter lwoffii has"
" been evolved in the lab to be resistant to which of these"
" antibiotics?\n\nOptions:\nA) gentamicin\nB) Insufficient information"
" to answer this question\nC) meropenem\nD) ciproflaxin\nE) ampicillin"
)
)
# NOTE: use Claude 3 for its <thinking> feature, testing regex replacement of it
query.settings.agent.agent_llm = "claude-3-opus-20240229"
query.settings.agent.agent_config = {
"memories": serialized_memory_model.pop("memories"),
"memory_model": serialized_memory_model,
}

thoughts: list[str] = []

async def on_agent_action(action: OpResult[ToolRequestMessage], *_) -> None:
thoughts.append(extract_thought(content=action.value.content))

response = await agent_query(
query,
Docs(),
agent_type=f"{ldp.agent.__name__}.MemoryAgent",
on_agent_action_callback=on_agent_action,
)
assert response.status == AgentStatus.SUCCESS, "Agent did not succeed"
assert (
time.perf_counter() - tic <= query.settings.agent.timeout
), "Agent should not have timed out"
assert all("<thinking>" not in thought for thought in thoughts)


@pytest.mark.parametrize("agent_type", [ToolSelector, SimpleAgent])
@pytest.mark.asyncio
async def test_timeout(agent_test_settings: Settings, agent_type: str | type) -> None:
Expand Down

0 comments on commit 726c60d

Please sign in to comment.