diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_context_editing.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_context_editing.py index 27740c2cd88de..39f79ac74521a 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_context_editing.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_context_editing.py @@ -2,23 +2,30 @@ from __future__ import annotations -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from langchain_core.language_models.fake_chat_models import FakeChatModel from langchain_core.messages import ( AIMessage, + AnyMessage, + BaseMessage, MessageLikeRepresentation, ToolMessage, ) +from typing_extensions import override from langchain.agents.middleware.context_editing import ( ClearToolUsesEdit, ContextEditingMiddleware, ) -from langchain.agents.middleware.types import AgentState, ModelRequest +from langchain.agents.middleware.types import ( + AgentState, + ModelRequest, + ModelResponse, +) if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Sequence from langgraph.runtime import Runtime @@ -26,10 +33,11 @@ class _TokenCountingChatModel(FakeChatModel): """Fake chat model that counts tokens deterministically for tests.""" + @override def get_num_tokens_from_messages( self, - messages: list[MessageLikeRepresentation], - tools: Iterable | None = None, + messages: list[BaseMessage], + tools: Sequence | None = None, ) -> int: return sum(_count_message_tokens(message) for message in messages) @@ -46,7 +54,7 @@ def _count_content(content: MessageLikeRepresentation) -> int: if isinstance(content, str): return len(content) if isinstance(content, list): - return sum(_count_content(block) for block in content) # type: ignore[arg-type] + return sum(_count_content(block) for block in content) if isinstance(content, dict): return len(str(content)) return len(str(content)) @@ -56,10 +64,10 @@ def _make_state_and_request( messages: list[AIMessage | ToolMessage], *, system_prompt: str | None = None, -) -> tuple[AgentState, ModelRequest]: +) -> tuple[AgentState[Any], ModelRequest]: model = _TokenCountingChatModel() - conversation = list(messages) - state = cast("AgentState", {"messages": conversation}) + conversation: list[AnyMessage] = list(messages) + state = cast("AgentState[Any]", {"messages": conversation}) request = ModelRequest( model=model, system_prompt=system_prompt, @@ -89,10 +97,10 @@ def test_no_edit_when_below_trigger() -> None: modified_request = None - def mock_handler(req: ModelRequest) -> AIMessage: + def mock_handler(req: ModelRequest) -> ModelResponse: nonlocal modified_request modified_request = req - return AIMessage(content="mock response") + return ModelResponse(result=[AIMessage(content="mock response")]) # Call wrap_model_call which creates a new request middleware.wrap_model_call(request, mock_handler) @@ -129,10 +137,10 @@ def test_clear_tool_outputs_and_inputs() -> None: modified_request = None - def mock_handler(req: ModelRequest) -> AIMessage: + def mock_handler(req: ModelRequest) -> ModelResponse: nonlocal modified_request modified_request = req - return AIMessage(content="mock response") + return ModelResponse(result=[AIMessage(content="mock response")]) # Call wrap_model_call which creates a new request with edits middleware.wrap_model_call(request, mock_handler) @@ -152,7 +160,9 @@ def mock_handler(req: ModelRequest) -> AIMessage: assert context_meta["cleared_tool_inputs"] == [tool_call_id] # Original request should be unchanged - assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"} + request_ai_message = request.messages[0] + assert isinstance(request_ai_message, AIMessage) + assert request_ai_message.tool_calls[0]["args"] == {"query": "foo"} assert request.messages[1].content == "x" * 200 @@ -190,10 +200,10 @@ def test_respects_keep_last_tool_results() -> None: modified_request = None - def mock_handler(req: ModelRequest) -> AIMessage: + def mock_handler(req: ModelRequest) -> ModelResponse: nonlocal modified_request modified_request = req - return AIMessage(content="mock response") + return ModelResponse(result=[AIMessage(content="mock response")]) # Call wrap_model_call which creates a new request with edits middleware.wrap_model_call(request, mock_handler) @@ -243,10 +253,10 @@ def test_exclude_tools_prevents_clearing() -> None: modified_request = None - def mock_handler(req: ModelRequest) -> AIMessage: + def mock_handler(req: ModelRequest) -> ModelResponse: nonlocal modified_request modified_request = req - return AIMessage(content="mock response") + return ModelResponse(result=[AIMessage(content="mock response")]) # Call wrap_model_call which creates a new request with edits middleware.wrap_model_call(request, mock_handler) @@ -282,10 +292,10 @@ async def test_no_edit_when_below_trigger_async() -> None: modified_request = None - async def mock_handler(req: ModelRequest) -> AIMessage: + async def mock_handler(req: ModelRequest) -> ModelResponse: nonlocal modified_request modified_request = req - return AIMessage(content="mock response") + return ModelResponse(result=[AIMessage(content="mock response")]) # Call awrap_model_call which creates a new request await middleware.awrap_model_call(request, mock_handler) @@ -323,10 +333,10 @@ async def test_clear_tool_outputs_and_inputs_async() -> None: modified_request = None - async def mock_handler(req: ModelRequest) -> AIMessage: + async def mock_handler(req: ModelRequest) -> ModelResponse: nonlocal modified_request modified_request = req - return AIMessage(content="mock response") + return ModelResponse(result=[AIMessage(content="mock response")]) # Call awrap_model_call which creates a new request with edits await middleware.awrap_model_call(request, mock_handler) @@ -346,7 +356,9 @@ async def mock_handler(req: ModelRequest) -> AIMessage: assert context_meta["cleared_tool_inputs"] == [tool_call_id] # Original request should be unchanged - assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"} + request_ai_message = request.messages[0] + assert isinstance(request_ai_message, AIMessage) + assert request_ai_message.tool_calls[0]["args"] == {"query": "foo"} assert request.messages[1].content == "x" * 200 @@ -385,10 +397,10 @@ async def test_respects_keep_last_tool_results_async() -> None: modified_request = None - async def mock_handler(req: ModelRequest) -> AIMessage: + async def mock_handler(req: ModelRequest) -> ModelResponse: nonlocal modified_request modified_request = req - return AIMessage(content="mock response") + return ModelResponse(result=[AIMessage(content="mock response")]) # Call awrap_model_call which creates a new request with edits await middleware.awrap_model_call(request, mock_handler) @@ -439,10 +451,10 @@ async def test_exclude_tools_prevents_clearing_async() -> None: modified_request = None - async def mock_handler(req: ModelRequest) -> AIMessage: + async def mock_handler(req: ModelRequest) -> ModelResponse: nonlocal modified_request modified_request = req - return AIMessage(content="mock response") + return ModelResponse(result=[AIMessage(content="mock response")]) # Call awrap_model_call which creates a new request with edits await middleware.awrap_model_call(request, mock_handler) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_agent_name.py b/libs/langchain_v1/tests/unit_tests/agents/test_agent_name.py index d1d932a5712be..536a03519b77c 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_agent_name.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_agent_name.py @@ -6,7 +6,11 @@ from __future__ import annotations import pytest -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import ( + AIMessage, + HumanMessage, + ToolCall, +) from langchain_core.tools import tool from langchain.agents import create_agent @@ -21,8 +25,9 @@ def simple_tool(x: int) -> str: def test_agent_name_set_on_ai_message() -> None: """Test that agent name is set on AIMessage when name is provided.""" + tool_calls: list[list[ToolCall]] = [[]] agent = create_agent( - model=FakeToolCallingModel(tool_calls=[[]]), + model=FakeToolCallingModel(tool_calls=tool_calls), name="test_agent", ) @@ -35,8 +40,9 @@ def test_agent_name_set_on_ai_message() -> None: def test_agent_name_not_set_when_none() -> None: """Test that AIMessage.name is not set when name is not provided.""" + tool_calls: list[list[ToolCall]] = [[]] agent = create_agent( - model=FakeToolCallingModel(tool_calls=[[]]), + model=FakeToolCallingModel(tool_calls=tool_calls), ) result = agent.invoke({"messages": [HumanMessage("Hello")]}) @@ -67,8 +73,9 @@ def test_agent_name_on_multiple_iterations() -> None: @pytest.mark.asyncio async def test_agent_name_async() -> None: """Test that agent name is set on AIMessage in async execution.""" + tool_calls: list[list[ToolCall]] = [[]] agent = create_agent( - model=FakeToolCallingModel(tool_calls=[[]]), + model=FakeToolCallingModel(tool_calls=tool_calls), name="async_agent", ) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_response_format_integration.py b/libs/langchain_v1/tests/unit_tests/agents/test_response_format_integration.py index 4e8416f3af67d..3c3c3541f5779 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_response_format_integration.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_response_format_integration.py @@ -76,7 +76,7 @@ def get_weather(city: str) -> str: @pytest.mark.parametrize("use_responses_api", [False, True]) def test_inference_to_native_output(*, use_responses_api: bool) -> None: """Test that native output is inferred when a model supports it.""" - model_kwargs = {"model": "gpt-5", "use_responses_api": use_responses_api} + model_kwargs: dict[str, Any] = {"model": "gpt-5", "use_responses_api": use_responses_api} if "OPENAI_API_KEY" not in os.environ: model_kwargs["api_key"] = "foo" @@ -111,7 +111,7 @@ def test_inference_to_native_output(*, use_responses_api: bool) -> None: @pytest.mark.parametrize("use_responses_api", [False, True]) def test_inference_to_tool_output(*, use_responses_api: bool) -> None: """Test that tool output is inferred when a model supports it.""" - model_kwargs = {"model": "gpt-5", "use_responses_api": use_responses_api} + model_kwargs: dict[str, Any] = {"model": "gpt-5", "use_responses_api": use_responses_api} if "OPENAI_API_KEY" not in os.environ: model_kwargs["api_key"] = "foo" @@ -146,7 +146,7 @@ def test_inference_to_tool_output(*, use_responses_api: bool) -> None: @pytest.mark.vcr @pytest.mark.parametrize("use_responses_api", [False, True]) def test_strict_mode(*, use_responses_api: bool) -> None: - model_kwargs = {"model": "gpt-5", "use_responses_api": use_responses_api} + model_kwargs: dict[str, Any] = {"model": "gpt-5", "use_responses_api": use_responses_api} if "OPENAI_API_KEY" not in os.environ: model_kwargs["api_key"] = "foo"