diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_tools.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_tools.py index 12274cd3b63fb..d6c3bbfd20a1f 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_tools.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_tools.py @@ -1,14 +1,22 @@ """Test Middleware handling of tools in agents.""" from collections.abc import Callable +from typing import Any import pytest -from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.messages import HumanMessage, ToolMessage from langchain_core.tools import tool +from langchain_core.tools.base import BaseTool from langgraph.prebuilt.tool_node import ToolNode from langchain.agents.factory import create_agent -from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ModelCallResult, + ModelRequest, + ModelResponse, +) from tests.unit_tests.agents.model import FakeToolCallingModel @@ -30,8 +38,8 @@ class RequestCapturingMiddleware(AgentMiddleware): def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: captured_requests.append(request) return handler(request) @@ -51,7 +59,15 @@ def wrap_model_call( request = captured_requests[0] assert isinstance(request.tools, list) assert len(request.tools) == 2 - assert {t.name for t in request.tools} == {"search_tool", "calculator"} + + tools = [] + for t in request.tools: + assert isinstance(t, BaseTool) + tools.append(t.name) + assert set(tools) == { + "search_tool", + "calculator", + } def test_middleware_can_modify_tools() -> None: @@ -76,10 +92,14 @@ class ToolFilteringMiddleware(AgentMiddleware): def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Only allow tool_a and tool_b - filtered_tools = [t for t in request.tools if t.name in {"tool_a", "tool_b"}] + filtered_tools: list[BaseTool | dict[str, Any]] = [] + for t in request.tools: + assert isinstance(t, BaseTool) + if t.name in {"tool_a", "tool_b"}: + filtered_tools.append(t) return handler(request.override(tools=filtered_tools)) # Model will try to call tool_a @@ -120,8 +140,8 @@ class BadMiddleware(AgentMiddleware): def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Add an unknown tool return handler(request.override(tools=[*request.tools, unknown_tool])) @@ -149,7 +169,7 @@ def admin_tool(command: str) -> str: """Admin-only tool.""" return f"Admin: {command}" - class AdminState(AgentState): + class AdminState(AgentState[Any]): is_admin: bool class ConditionalToolMiddleware(AgentMiddleware[AdminState]): @@ -158,11 +178,15 @@ class ConditionalToolMiddleware(AgentMiddleware[AdminState]): def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Remove admin_tool if not admin if not request.state.get("is_admin", False): - filtered_tools = [t for t in request.tools if t.name != "admin_tool"] + filtered_tools: list[BaseTool | dict[str, Any]] = [] + for t in request.tools: + assert isinstance(t, BaseTool) + if t.name != "admin_tool": + filtered_tools.append(t) request = request.override(tools=filtered_tools) return handler(request) @@ -197,8 +221,8 @@ class NoToolsMiddleware(AgentMiddleware): def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Remove all tools request = request.override(tools=[]) return handler(request) @@ -240,11 +264,17 @@ class FirstMiddleware(AgentMiddleware): def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: - modification_order.append([t.name for t in request.tools]) - # Remove tool_c - filtered_tools = [t for t in request.tools if t.name != "tool_c"] + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + tools: list[str] = [] + filtered_tools: list[BaseTool | dict[str, Any]] = [] + for t in request.tools: + assert isinstance(t, BaseTool) + tools.append(t.name) + # Remove tool_c + if t.name != "tool_c": + filtered_tools.append(t) + modification_order.append(tools) request = request.override(tools=filtered_tools) return handler(request) @@ -252,13 +282,19 @@ class SecondMiddleware(AgentMiddleware): def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: - modification_order.append([t.name for t in request.tools]) - # Should not see tool_c here - assert all(t.name != "tool_c" for t in request.tools) - # Remove tool_b - filtered_tools = [t for t in request.tools if t.name != "tool_b"] + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + tools: list[str] = [] + filtered_tools: list[BaseTool | dict[str, Any]] = [] + for t in request.tools: + assert isinstance(t, BaseTool) + # Should not see tool_c here + assert t.name != "tool_c" + tools.append(t.name) + # Remove tool_b + if t.name != "tool_b": + filtered_tools.append(t) + modification_order.append(tools) request = request.override(tools=filtered_tools) return handler(request) @@ -317,6 +353,7 @@ class ToolProvidingMiddleware(AgentMiddleware): tool_messages = [m for m in messages if isinstance(m, ToolMessage)] assert len(tool_messages) == 1 assert tool_messages[0].name == "middleware_tool" + assert isinstance(tool_messages[0].content, str) assert "middleware" in tool_messages[0].content.lower()