Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
"""Test Middleware handling of tools in agents."""

from collections.abc import Callable
from typing import Any

import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import BaseTool
from langgraph.prebuilt.tool_node import ToolNode

from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelCallResult,
ModelRequest,
ModelResponse,
)
from tests.unit_tests.agents.model import FakeToolCallingModel


Expand All @@ -30,8 +38,8 @@ class RequestCapturingMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
captured_requests.append(request)
return handler(request)

Expand All @@ -51,7 +59,15 @@ def wrap_model_call(
request = captured_requests[0]
assert isinstance(request.tools, list)
assert len(request.tools) == 2
assert {t.name for t in request.tools} == {"search_tool", "calculator"}

tools = []
for t in request.tools:
assert isinstance(t, BaseTool)
tools.append(t.name)
assert set(tools) == {
"search_tool",
"calculator",
}


def test_middleware_can_modify_tools() -> None:
Expand All @@ -76,10 +92,14 @@ class ToolFilteringMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Only allow tool_a and tool_b
filtered_tools = [t for t in request.tools if t.name in {"tool_a", "tool_b"}]
filtered_tools: list[BaseTool | dict[str, Any]] = []
for t in request.tools:
assert isinstance(t, BaseTool)
if t.name in {"tool_a", "tool_b"}:
filtered_tools.append(t)
return handler(request.override(tools=filtered_tools))

# Model will try to call tool_a
Expand Down Expand Up @@ -120,8 +140,8 @@ class BadMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Add an unknown tool
return handler(request.override(tools=[*request.tools, unknown_tool]))

Expand Down Expand Up @@ -149,7 +169,7 @@ def admin_tool(command: str) -> str:
"""Admin-only tool."""
return f"Admin: {command}"

class AdminState(AgentState):
class AdminState(AgentState[Any]):
is_admin: bool

class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
Expand All @@ -158,11 +178,15 @@ class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Remove admin_tool if not admin
if not request.state.get("is_admin", False):
filtered_tools = [t for t in request.tools if t.name != "admin_tool"]
filtered_tools: list[BaseTool | dict[str, Any]] = []
for t in request.tools:
assert isinstance(t, BaseTool)
if t.name != "admin_tool":
filtered_tools.append(t)
request = request.override(tools=filtered_tools)
return handler(request)

Expand Down Expand Up @@ -197,8 +221,8 @@ class NoToolsMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Remove all tools
request = request.override(tools=[])
return handler(request)
Expand Down Expand Up @@ -240,25 +264,37 @@ class FirstMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
modification_order.append([t.name for t in request.tools])
# Remove tool_c
filtered_tools = [t for t in request.tools if t.name != "tool_c"]
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
tools: list[str] = []
filtered_tools: list[BaseTool | dict[str, Any]] = []
for t in request.tools:
assert isinstance(t, BaseTool)
tools.append(t.name)
# Remove tool_c
if t.name != "tool_c":
filtered_tools.append(t)
modification_order.append(tools)
request = request.override(tools=filtered_tools)
return handler(request)

class SecondMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
modification_order.append([t.name for t in request.tools])
# Should not see tool_c here
assert all(t.name != "tool_c" for t in request.tools)
# Remove tool_b
filtered_tools = [t for t in request.tools if t.name != "tool_b"]
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
tools: list[str] = []
filtered_tools: list[BaseTool | dict[str, Any]] = []
for t in request.tools:
assert isinstance(t, BaseTool)
# Should not see tool_c here
assert t.name != "tool_c"
tools.append(t.name)
# Remove tool_b
if t.name != "tool_b":
filtered_tools.append(t)
modification_order.append(tools)
request = request.override(tools=filtered_tools)
return handler(request)

Expand Down Expand Up @@ -317,6 +353,7 @@ class ToolProvidingMiddleware(AgentMiddleware):
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert len(tool_messages) == 1
assert tool_messages[0].name == "middleware_tool"
assert isinstance(tool_messages[0].content, str)
assert "middleware" in tool_messages[0].content.lower()


Expand Down