Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/google/adk/agents/remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,10 @@ async def _run_async_impl(
message_id=str(uuid.uuid4()),
parts=message_parts,
role="user",
context_id=context_id,
# Use existing context_id if available (for conversation continuity),
# otherwise use the local session ID to maintain session identity
# across local and remote agents.
context_id=context_id if context_id else ctx.session.id,
)

logger.debug(build_a2a_request_log(a2a_request))
Expand Down
154 changes: 154 additions & 0 deletions tests/unittests/agents/test_remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,160 @@ async def test_run_async_impl_successful_request(self):
in mock_event.custom_metadata
)

@pytest.mark.asyncio
async def test_run_async_impl_uses_session_id_when_no_context_id(self):
"""Test that session ID is used as context_id when no existing context.

When _construct_message_parts_from_session returns None for context_id,
the agent should use ctx.session.id to maintain session identity across
local and remote agents.
"""
with patch.object(self.agent, "_ensure_resolved"):
with patch.object(
self.agent, "_create_a2a_request_for_user_function_response"
) as mock_create_func:
mock_create_func.return_value = None

with patch.object(
self.agent, "_construct_message_parts_from_session"
) as mock_construct:
# Create proper A2A part mocks
from a2a.client import Client as A2AClient
from a2a.types import TextPart

mock_a2a_part = Mock(spec=TextPart)
# Return None for context_id to trigger session ID fallback
mock_construct.return_value = (
[mock_a2a_part],
None,
) # Tuple with parts and NO context_id

# Mock A2A client
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
mock_response = Mock()
mock_send_message = AsyncMock()
mock_send_message.__aiter__.return_value = [mock_response]
mock_a2a_client.send_message.return_value = mock_send_message
self.agent._a2a_client = mock_a2a_client

mock_event = Event(
author=self.agent.name,
invocation_id=self.mock_context.invocation_id,
branch=self.mock_context.branch,
)

with patch.object(self.agent, "_handle_a2a_response") as mock_handle:
mock_handle.return_value = mock_event

# Mock the logging functions to avoid iteration issues
with patch(
"google.adk.agents.remote_a2a_agent.build_a2a_request_log"
) as mock_req_log:
with patch(
"google.adk.agents.remote_a2a_agent.build_a2a_response_log"
) as mock_resp_log:
mock_req_log.return_value = "Mock request log"
mock_resp_log.return_value = "Mock response log"

# Mock the A2AMessage constructor to capture the arguments
with patch(
"google.adk.agents.remote_a2a_agent.A2AMessage"
) as mock_message_class:
mock_message = Mock(spec=A2AMessage)
mock_message_class.return_value = mock_message

# Add model_dump to mock_response for metadata
mock_response.model_dump.return_value = {"test": "response"}

# Execute
events = []
async for event in self.agent._run_async_impl(
self.mock_context
):
events.append(event)

# Verify A2AMessage was called with session ID as context_id
mock_message_class.assert_called_once()
call_kwargs = mock_message_class.call_args[1]
assert call_kwargs["context_id"] == self.mock_session.id

@pytest.mark.asyncio
async def test_run_async_impl_preserves_existing_context_id(self):
"""Test that existing context_id is preserved when available.

When _construct_message_parts_from_session returns a context_id from
a previous remote agent response, that context_id should be used
for conversation continuity.
"""
with patch.object(self.agent, "_ensure_resolved"):
with patch.object(
self.agent, "_create_a2a_request_for_user_function_response"
) as mock_create_func:
mock_create_func.return_value = None

with patch.object(
self.agent, "_construct_message_parts_from_session"
) as mock_construct:
# Create proper A2A part mocks
from a2a.client import Client as A2AClient
from a2a.types import TextPart

mock_a2a_part = Mock(spec=TextPart)
existing_context_id = "existing-context-456"
mock_construct.return_value = (
[mock_a2a_part],
existing_context_id,
) # Tuple with parts and existing context_id

# Mock A2A client
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
mock_response = Mock()
mock_send_message = AsyncMock()
mock_send_message.__aiter__.return_value = [mock_response]
mock_a2a_client.send_message.return_value = mock_send_message
self.agent._a2a_client = mock_a2a_client

mock_event = Event(
author=self.agent.name,
invocation_id=self.mock_context.invocation_id,
branch=self.mock_context.branch,
)

with patch.object(self.agent, "_handle_a2a_response") as mock_handle:
mock_handle.return_value = mock_event

# Mock the logging functions to avoid iteration issues
with patch(
"google.adk.agents.remote_a2a_agent.build_a2a_request_log"
) as mock_req_log:
with patch(
"google.adk.agents.remote_a2a_agent.build_a2a_response_log"
) as mock_resp_log:
mock_req_log.return_value = "Mock request log"
mock_resp_log.return_value = "Mock response log"

# Mock the A2AMessage constructor to capture the arguments
with patch(
"google.adk.agents.remote_a2a_agent.A2AMessage"
) as mock_message_class:
mock_message = Mock(spec=A2AMessage)
mock_message_class.return_value = mock_message

# Add model_dump to mock_response for metadata
mock_response.model_dump.return_value = {"test": "response"}

# Execute
events = []
async for event in self.agent._run_async_impl(
self.mock_context
):
events.append(event)

# Verify A2AMessage was called with existing context_id
mock_message_class.assert_called_once()
call_kwargs = mock_message_class.call_args[1]
assert call_kwargs["context_id"] == existing_context_id

@pytest.mark.asyncio
async def test_run_async_impl_a2a_client_error(self):
"""Test _run_async_impl when A2A send_message fails."""
Expand Down