Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,42 @@

from __future__ import annotations

from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
MessageLikeRepresentation,
ToolMessage,
)
from typing_extensions import override

from langchain.agents.middleware.context_editing import (
ClearToolUsesEdit,
ContextEditingMiddleware,
)
from langchain.agents.middleware.types import AgentState, ModelRequest
from langchain.agents.middleware.types import (
AgentState,
ModelRequest,
ModelResponse,
)

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Sequence

from langgraph.runtime import Runtime


class _TokenCountingChatModel(FakeChatModel):
"""Fake chat model that counts tokens deterministically for tests."""

@override
def get_num_tokens_from_messages(
self,
messages: list[MessageLikeRepresentation],
tools: Iterable | None = None,
messages: list[BaseMessage],
tools: Sequence | None = None,
) -> int:
return sum(_count_message_tokens(message) for message in messages)

Expand All @@ -46,7 +54,7 @@ def _count_content(content: MessageLikeRepresentation) -> int:
if isinstance(content, str):
return len(content)
if isinstance(content, list):
return sum(_count_content(block) for block in content) # type: ignore[arg-type]
return sum(_count_content(block) for block in content)
if isinstance(content, dict):
return len(str(content))
return len(str(content))
Expand All @@ -56,10 +64,10 @@ def _make_state_and_request(
messages: list[AIMessage | ToolMessage],
*,
system_prompt: str | None = None,
) -> tuple[AgentState, ModelRequest]:
) -> tuple[AgentState[Any], ModelRequest]:
model = _TokenCountingChatModel()
conversation = list(messages)
state = cast("AgentState", {"messages": conversation})
conversation: list[AnyMessage] = list(messages)
state = cast("AgentState[Any]", {"messages": conversation})
request = ModelRequest(
model=model,
system_prompt=system_prompt,
Expand Down Expand Up @@ -89,10 +97,10 @@ def test_no_edit_when_below_trigger() -> None:

modified_request = None

def mock_handler(req: ModelRequest) -> AIMessage:
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
return ModelResponse(result=[AIMessage(content="mock response")])

# Call wrap_model_call which creates a new request
middleware.wrap_model_call(request, mock_handler)
Expand Down Expand Up @@ -129,10 +137,10 @@ def test_clear_tool_outputs_and_inputs() -> None:

modified_request = None

def mock_handler(req: ModelRequest) -> AIMessage:
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
return ModelResponse(result=[AIMessage(content="mock response")])

# Call wrap_model_call which creates a new request with edits
middleware.wrap_model_call(request, mock_handler)
Expand All @@ -152,7 +160,9 @@ def mock_handler(req: ModelRequest) -> AIMessage:
assert context_meta["cleared_tool_inputs"] == [tool_call_id]

# Original request should be unchanged
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
request_ai_message = request.messages[0]
assert isinstance(request_ai_message, AIMessage)
assert request_ai_message.tool_calls[0]["args"] == {"query": "foo"}
assert request.messages[1].content == "x" * 200


Expand Down Expand Up @@ -190,10 +200,10 @@ def test_respects_keep_last_tool_results() -> None:

modified_request = None

def mock_handler(req: ModelRequest) -> AIMessage:
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
return ModelResponse(result=[AIMessage(content="mock response")])

# Call wrap_model_call which creates a new request with edits
middleware.wrap_model_call(request, mock_handler)
Expand Down Expand Up @@ -243,10 +253,10 @@ def test_exclude_tools_prevents_clearing() -> None:

modified_request = None

def mock_handler(req: ModelRequest) -> AIMessage:
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
return ModelResponse(result=[AIMessage(content="mock response")])

# Call wrap_model_call which creates a new request with edits
middleware.wrap_model_call(request, mock_handler)
Expand Down Expand Up @@ -282,10 +292,10 @@ async def test_no_edit_when_below_trigger_async() -> None:

modified_request = None

async def mock_handler(req: ModelRequest) -> AIMessage:
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
return ModelResponse(result=[AIMessage(content="mock response")])

# Call awrap_model_call which creates a new request
await middleware.awrap_model_call(request, mock_handler)
Expand Down Expand Up @@ -323,10 +333,10 @@ async def test_clear_tool_outputs_and_inputs_async() -> None:

modified_request = None

async def mock_handler(req: ModelRequest) -> AIMessage:
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
return ModelResponse(result=[AIMessage(content="mock response")])

# Call awrap_model_call which creates a new request with edits
await middleware.awrap_model_call(request, mock_handler)
Expand All @@ -346,7 +356,9 @@ async def mock_handler(req: ModelRequest) -> AIMessage:
assert context_meta["cleared_tool_inputs"] == [tool_call_id]

# Original request should be unchanged
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
request_ai_message = request.messages[0]
assert isinstance(request_ai_message, AIMessage)
assert request_ai_message.tool_calls[0]["args"] == {"query": "foo"}
assert request.messages[1].content == "x" * 200


Expand Down Expand Up @@ -385,10 +397,10 @@ async def test_respects_keep_last_tool_results_async() -> None:

modified_request = None

async def mock_handler(req: ModelRequest) -> AIMessage:
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
return ModelResponse(result=[AIMessage(content="mock response")])

# Call awrap_model_call which creates a new request with edits
await middleware.awrap_model_call(request, mock_handler)
Expand Down Expand Up @@ -439,10 +451,10 @@ async def test_exclude_tools_prevents_clearing_async() -> None:

modified_request = None

async def mock_handler(req: ModelRequest) -> AIMessage:
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
return ModelResponse(result=[AIMessage(content="mock response")])

# Call awrap_model_call which creates a new request with edits
await middleware.awrap_model_call(request, mock_handler)
Expand Down
15 changes: 11 additions & 4 deletions libs/langchain_v1/tests/unit_tests/agents/test_agent_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from __future__ import annotations

import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import (
AIMessage,
HumanMessage,
ToolCall,
)
from langchain_core.tools import tool

from langchain.agents import create_agent
Expand All @@ -21,8 +25,9 @@ def simple_tool(x: int) -> str:

def test_agent_name_set_on_ai_message() -> None:
"""Test that agent name is set on AIMessage when name is provided."""
tool_calls: list[list[ToolCall]] = [[]]
agent = create_agent(
model=FakeToolCallingModel(tool_calls=[[]]),
model=FakeToolCallingModel(tool_calls=tool_calls),
name="test_agent",
)

Expand All @@ -35,8 +40,9 @@ def test_agent_name_set_on_ai_message() -> None:

def test_agent_name_not_set_when_none() -> None:
"""Test that AIMessage.name is not set when name is not provided."""
tool_calls: list[list[ToolCall]] = [[]]
agent = create_agent(
model=FakeToolCallingModel(tool_calls=[[]]),
model=FakeToolCallingModel(tool_calls=tool_calls),
)

result = agent.invoke({"messages": [HumanMessage("Hello")]})
Expand Down Expand Up @@ -67,8 +73,9 @@ def test_agent_name_on_multiple_iterations() -> None:
@pytest.mark.asyncio
async def test_agent_name_async() -> None:
"""Test that agent name is set on AIMessage in async execution."""
tool_calls: list[list[ToolCall]] = [[]]
agent = create_agent(
model=FakeToolCallingModel(tool_calls=[[]]),
model=FakeToolCallingModel(tool_calls=tool_calls),
name="async_agent",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_weather(city: str) -> str:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_inference_to_native_output(*, use_responses_api: bool) -> None:
"""Test that native output is inferred when a model supports it."""
model_kwargs = {"model": "gpt-5", "use_responses_api": use_responses_api}
model_kwargs: dict[str, Any] = {"model": "gpt-5", "use_responses_api": use_responses_api}

if "OPENAI_API_KEY" not in os.environ:
model_kwargs["api_key"] = "foo"
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_inference_to_native_output(*, use_responses_api: bool) -> None:
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_inference_to_tool_output(*, use_responses_api: bool) -> None:
"""Test that tool output is inferred when a model supports it."""
model_kwargs = {"model": "gpt-5", "use_responses_api": use_responses_api}
model_kwargs: dict[str, Any] = {"model": "gpt-5", "use_responses_api": use_responses_api}

if "OPENAI_API_KEY" not in os.environ:
model_kwargs["api_key"] = "foo"
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_inference_to_tool_output(*, use_responses_api: bool) -> None:
@pytest.mark.vcr
@pytest.mark.parametrize("use_responses_api", [False, True])
def test_strict_mode(*, use_responses_api: bool) -> None:
model_kwargs = {"model": "gpt-5", "use_responses_api": use_responses_api}
model_kwargs: dict[str, Any] = {"model": "gpt-5", "use_responses_api": use_responses_api}

if "OPENAI_API_KEY" not in os.environ:
model_kwargs["api_key"] = "foo"
Expand Down