Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/google/adk/a2a/converters/request_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..experimental import a2a_experimental
from .part_converter import A2APartToGenAIPartConverter
from .part_converter import convert_a2a_part_to_genai_part
from .utils import _from_a2a_context_id


@a2a_experimental
Expand Down Expand Up @@ -70,6 +71,10 @@ def _get_user_id(request: RequestContext) -> str:
):
return request.call_context.user.user_name

_, user_id, _ = _from_a2a_context_id(request.context_id)
if user_id:
return user_id

# Get user from context id
return f'A2A_USER_{request.context_id}'

Expand Down Expand Up @@ -106,9 +111,11 @@ def convert_a2a_request_to_agent_run_request(
genai_parts = [genai_parts] if genai_parts else []
output_parts.extend(genai_parts)

_, _, session_id = _from_a2a_context_id(request.context_id)

return AgentRunRequest(
user_id=_get_user_id(request),
session_id=request.context_id,
session_id=session_id,
new_message=genai_types.Content(
role='user',
parts=output_parts,
Expand Down
74 changes: 52 additions & 22 deletions src/google/adk/a2a/executor/a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from ..converters.request_converter import AgentRunRequest
from ..converters.request_converter import convert_a2a_request_to_agent_run_request
from ..converters.utils import _get_adk_metadata_key
from ..converters.utils import _to_a2a_context_id
from ..experimental import a2a_experimental
from .task_result_aggregator import TaskResultAggregator

Expand Down Expand Up @@ -135,21 +136,6 @@ async def execute(
if not context.message:
raise ValueError('A2A request must have a message')

# for new task, create a task submitted event
if not context.current_task:
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=context.task_id,
status=TaskStatus(
state=TaskState.submitted,
message=context.message,
timestamp=datetime.now(timezone.utc).isoformat(),
),
context_id=context.context_id,
final=False,
)
)

# Handle the request and publish updates to the event queue
try:
await self._handle_request(context, event_queue)
Expand Down Expand Up @@ -194,6 +180,27 @@ async def _handle_request(

# ensure the session exists
session = await self._prepare_session(context, run_request, runner)
response_context_id = self._get_response_context_id(
context=context,
runner=runner,
run_request=run_request,
session_id=session.id,
)

# for new task, create a task submitted event
if not context.current_task:
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=context.task_id,
status=TaskStatus(
state=TaskState.submitted,
message=context.message,
timestamp=datetime.now(timezone.utc).isoformat(),
),
context_id=response_context_id,
final=False,
)
)

# create invocation context
invocation_context = runner._new_invocation_context(
Expand All @@ -210,7 +217,7 @@ async def _handle_request(
state=TaskState.working,
timestamp=datetime.now(timezone.utc).isoformat(),
),
context_id=context.context_id,
context_id=response_context_id,
final=False,
metadata={
_get_adk_metadata_key('app_name'): runner.app_name,
Expand All @@ -227,7 +234,7 @@ async def _handle_request(
adk_event,
invocation_context,
context.task_id,
context.context_id,
response_context_id,
self._config.gen_ai_part_converter,
):
task_result_aggregator.process_event(a2a_event)
Expand All @@ -245,7 +252,7 @@ async def _handle_request(
TaskArtifactUpdateEvent(
task_id=context.task_id,
last_chunk=True,
context_id=context.context_id,
context_id=response_context_id,
artifact=Artifact(
artifact_id=str(uuid.uuid4()),
parts=task_result_aggregator.task_status_message.parts,
Expand All @@ -260,7 +267,7 @@ async def _handle_request(
state=TaskState.completed,
timestamp=datetime.now(timezone.utc).isoformat(),
),
context_id=context.context_id,
context_id=response_context_id,
final=True,
)
)
Expand All @@ -273,21 +280,44 @@ async def _handle_request(
timestamp=datetime.now(timezone.utc).isoformat(),
message=task_result_aggregator.task_status_message,
),
context_id=context.context_id,
context_id=response_context_id,
final=True,
)
)

def _get_response_context_id(
self,
*,
context: RequestContext,
runner: Runner,
run_request: AgentRunRequest,
session_id: str,
) -> str:
try:
return _to_a2a_context_id(
runner.app_name, run_request.user_id, session_id
)
except ValueError:
return context.context_id

async def _prepare_session(
self,
context: RequestContext,
run_request: AgentRunRequest,
runner: Runner,
):

session_id = run_request.session_id
# create a new session if not exists
user_id = run_request.user_id
if not session_id:
session = await runner.session_service.create_session(
app_name=runner.app_name,
user_id=user_id,
state={},
)
run_request.session_id = session.id
return session

# create a new session if not exists
session = await runner.session_service.get_session(
app_name=runner.app_name,
user_id=user_id,
Expand Down
44 changes: 38 additions & 6 deletions tests/unittests/a2a/converters/test_request_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from a2a.server.agent_execution import RequestContext
from google.adk.a2a.converters.request_converter import _get_user_id
from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request
from google.adk.a2a.converters.utils import _to_a2a_context_id
from google.adk.runners import RunConfig
from google.genai import types as genai_types
import pytest
Expand Down Expand Up @@ -58,6 +59,16 @@ def test_get_user_id_from_context_when_no_call_context(self):
# Assert
assert result == "A2A_USER_test_context"

def test_get_user_id_from_adk_context_id(self):
"""Test getting user ID from ADK-formatted context id."""
request = Mock(spec=RequestContext)
request.call_context = None
request.context_id = _to_a2a_context_id("app", "user-123", "session-456")

result = _get_user_id(request)

assert result == "user-123"

def test_get_user_id_from_context_when_call_context_has_no_user(self):
"""Test getting user ID from context when call context has no user."""
# Arrange
Expand Down Expand Up @@ -129,6 +140,27 @@ def test_get_user_id_with_none_context_id(self):
class TestConvertA2aRequestToAgentRunRequest:
"""Test cases for convert_a2a_request_to_agent_run_request function."""

def test_convert_a2a_request_with_adk_context_id(self):
"""Test conversion uses ADK context id for user/session."""
mock_message = Mock()
mock_message.parts = [Mock()]

request = Mock(spec=RequestContext)
request.message = mock_message
request.context_id = _to_a2a_context_id("app", "user-1", "session-1")
request.call_context = None
request.metadata = {}

mock_genai_part = genai_types.Part(text="test part")
mock_convert_part = Mock(return_value=mock_genai_part)

result = convert_a2a_request_to_agent_run_request(
request, mock_convert_part
)

assert result.user_id == "user-1"
assert result.session_id == "session-1"

def test_convert_a2a_request_basic(self):
"""Test basic conversion of A2A request to ADK AgentRunRequest."""
# Arrange
Expand Down Expand Up @@ -164,7 +196,7 @@ def test_convert_a2a_request_basic(self):
# Assert
assert result is not None
assert result.user_id == "test_user"
assert result.session_id == "test_context_123"
assert result.session_id is None
assert isinstance(result.new_message, genai_types.Content)
assert result.new_message.role == "user"
assert result.new_message.parts == [mock_genai_part1, mock_genai_part2]
Expand Down Expand Up @@ -213,7 +245,7 @@ def test_convert_a2a_request_multiple_parts(self):
# Assert
assert result is not None
assert result.user_id == "test_user"
assert result.session_id == "test_context_123"
assert result.session_id is None
assert isinstance(result.new_message, genai_types.Content)
assert result.new_message.role == "user"
assert result.new_message.parts == [
Expand Down Expand Up @@ -261,7 +293,7 @@ def test_convert_a2a_request_empty_parts(self):
# Assert
assert result is not None
assert result.user_id == "A2A_USER_test_context_123"
assert result.session_id == "test_context_123"
assert result.session_id is None
assert isinstance(result.new_message, genai_types.Content)
assert result.new_message.role == "user"
assert result.new_message.parts == []
Expand Down Expand Up @@ -328,7 +360,7 @@ def test_convert_a2a_request_no_auth(self):
# Assert
assert result is not None
assert result.user_id == "A2A_USER_session_123"
assert result.session_id == "session_123"
assert result.session_id is None
assert isinstance(result.new_message, genai_types.Content)
assert result.new_message.role == "user"
assert result.new_message.parts == [mock_genai_part]
Expand Down Expand Up @@ -370,7 +402,7 @@ def test_end_to_end_conversion_with_auth_user(self):
# Assert
assert result is not None
assert result.user_id == "auth_user" # Should use authenticated user
assert result.session_id == "mysession"
assert result.session_id is None
assert isinstance(result.new_message, genai_types.Content)
assert result.new_message.role == "user"
assert result.new_message.parts == [mock_genai_part]
Expand Down Expand Up @@ -404,7 +436,7 @@ def test_end_to_end_conversion_with_fallback_user(self):
assert (
result.user_id == "A2A_USER_test_session_456"
) # Should fall back to context ID
assert result.session_id == "test_session_456"
assert result.session_id is None
assert isinstance(result.new_message, genai_types.Content)
assert result.new_message.role == "user"
assert result.new_message.parts == [mock_genai_part]
Expand Down
Loading