Skip to content

Commit d8076b6

Browse files
committed
fix(a2a): avoid UUID session IDs by mapping A2A context IDs
1 parent 798d005 commit d8076b6

File tree

4 files changed

+127
-44
lines changed

4 files changed

+127
-44
lines changed

src/google/adk/a2a/converters/request_converter.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..experimental import a2a_experimental
2727
from .part_converter import A2APartToGenAIPartConverter
2828
from .part_converter import convert_a2a_part_to_genai_part
29+
from .utils import _from_a2a_context_id
2930

3031

3132
@a2a_experimental
@@ -70,6 +71,10 @@ def _get_user_id(request: RequestContext) -> str:
7071
):
7172
return request.call_context.user.user_name
7273

74+
_, user_id, _ = _from_a2a_context_id(request.context_id)
75+
if user_id:
76+
return user_id
77+
7378
# Get user from context id
7479
return f'A2A_USER_{request.context_id}'
7580

@@ -106,9 +111,11 @@ def convert_a2a_request_to_agent_run_request(
106111
genai_parts = [genai_parts] if genai_parts else []
107112
output_parts.extend(genai_parts)
108113

114+
_, _, session_id = _from_a2a_context_id(request.context_id)
115+
109116
return AgentRunRequest(
110117
user_id=_get_user_id(request),
111-
session_id=request.context_id,
118+
session_id=session_id,
112119
new_message=genai_types.Content(
113120
role='user',
114121
parts=output_parts,

src/google/adk/a2a/executor/a2a_agent_executor.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from ..converters.request_converter import AgentRunRequest
5050
from ..converters.request_converter import convert_a2a_request_to_agent_run_request
5151
from ..converters.utils import _get_adk_metadata_key
52+
from ..converters.utils import _to_a2a_context_id
5253
from ..experimental import a2a_experimental
5354
from .task_result_aggregator import TaskResultAggregator
5455

@@ -135,21 +136,6 @@ async def execute(
135136
if not context.message:
136137
raise ValueError('A2A request must have a message')
137138

138-
# for new task, create a task submitted event
139-
if not context.current_task:
140-
await event_queue.enqueue_event(
141-
TaskStatusUpdateEvent(
142-
task_id=context.task_id,
143-
status=TaskStatus(
144-
state=TaskState.submitted,
145-
message=context.message,
146-
timestamp=datetime.now(timezone.utc).isoformat(),
147-
),
148-
context_id=context.context_id,
149-
final=False,
150-
)
151-
)
152-
153139
# Handle the request and publish updates to the event queue
154140
try:
155141
await self._handle_request(context, event_queue)
@@ -194,6 +180,27 @@ async def _handle_request(
194180

195181
# ensure the session exists
196182
session = await self._prepare_session(context, run_request, runner)
183+
response_context_id = self._get_response_context_id(
184+
context=context,
185+
runner=runner,
186+
run_request=run_request,
187+
session_id=session.id,
188+
)
189+
190+
# for new task, create a task submitted event
191+
if not context.current_task:
192+
await event_queue.enqueue_event(
193+
TaskStatusUpdateEvent(
194+
task_id=context.task_id,
195+
status=TaskStatus(
196+
state=TaskState.submitted,
197+
message=context.message,
198+
timestamp=datetime.now(timezone.utc).isoformat(),
199+
),
200+
context_id=response_context_id,
201+
final=False,
202+
)
203+
)
197204

198205
# create invocation context
199206
invocation_context = runner._new_invocation_context(
@@ -210,7 +217,7 @@ async def _handle_request(
210217
state=TaskState.working,
211218
timestamp=datetime.now(timezone.utc).isoformat(),
212219
),
213-
context_id=context.context_id,
220+
context_id=response_context_id,
214221
final=False,
215222
metadata={
216223
_get_adk_metadata_key('app_name'): runner.app_name,
@@ -227,7 +234,7 @@ async def _handle_request(
227234
adk_event,
228235
invocation_context,
229236
context.task_id,
230-
context.context_id,
237+
response_context_id,
231238
self._config.gen_ai_part_converter,
232239
):
233240
task_result_aggregator.process_event(a2a_event)
@@ -245,7 +252,7 @@ async def _handle_request(
245252
TaskArtifactUpdateEvent(
246253
task_id=context.task_id,
247254
last_chunk=True,
248-
context_id=context.context_id,
255+
context_id=response_context_id,
249256
artifact=Artifact(
250257
artifact_id=str(uuid.uuid4()),
251258
parts=task_result_aggregator.task_status_message.parts,
@@ -260,7 +267,7 @@ async def _handle_request(
260267
state=TaskState.completed,
261268
timestamp=datetime.now(timezone.utc).isoformat(),
262269
),
263-
context_id=context.context_id,
270+
context_id=response_context_id,
264271
final=True,
265272
)
266273
)
@@ -273,21 +280,44 @@ async def _handle_request(
273280
timestamp=datetime.now(timezone.utc).isoformat(),
274281
message=task_result_aggregator.task_status_message,
275282
),
276-
context_id=context.context_id,
283+
context_id=response_context_id,
277284
final=True,
278285
)
279286
)
280287

288+
def _get_response_context_id(
289+
self,
290+
*,
291+
context: RequestContext,
292+
runner: Runner,
293+
run_request: AgentRunRequest,
294+
session_id: str,
295+
) -> str:
296+
try:
297+
return _to_a2a_context_id(
298+
runner.app_name, run_request.user_id, session_id
299+
)
300+
except ValueError:
301+
return context.context_id
302+
281303
async def _prepare_session(
282304
self,
283305
context: RequestContext,
284306
run_request: AgentRunRequest,
285307
runner: Runner,
286308
):
287-
288309
session_id = run_request.session_id
289-
# create a new session if not exists
290310
user_id = run_request.user_id
311+
if not session_id:
312+
session = await runner.session_service.create_session(
313+
app_name=runner.app_name,
314+
user_id=user_id,
315+
state={},
316+
)
317+
run_request.session_id = session.id
318+
return session
319+
320+
# create a new session if not exists
291321
session = await runner.session_service.get_session(
292322
app_name=runner.app_name,
293323
user_id=user_id,

tests/unittests/a2a/converters/test_request_converter.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from a2a.server.agent_execution import RequestContext
1919
from google.adk.a2a.converters.request_converter import _get_user_id
2020
from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request
21+
from google.adk.a2a.converters.utils import _to_a2a_context_id
2122
from google.adk.runners import RunConfig
2223
from google.genai import types as genai_types
2324
import pytest
@@ -58,6 +59,16 @@ def test_get_user_id_from_context_when_no_call_context(self):
5859
# Assert
5960
assert result == "A2A_USER_test_context"
6061

62+
def test_get_user_id_from_adk_context_id(self):
63+
"""Test getting user ID from ADK-formatted context id."""
64+
request = Mock(spec=RequestContext)
65+
request.call_context = None
66+
request.context_id = _to_a2a_context_id("app", "user-123", "session-456")
67+
68+
result = _get_user_id(request)
69+
70+
assert result == "user-123"
71+
6172
def test_get_user_id_from_context_when_call_context_has_no_user(self):
6273
"""Test getting user ID from context when call context has no user."""
6374
# Arrange
@@ -129,6 +140,27 @@ def test_get_user_id_with_none_context_id(self):
129140
class TestConvertA2aRequestToAgentRunRequest:
130141
"""Test cases for convert_a2a_request_to_agent_run_request function."""
131142

143+
def test_convert_a2a_request_with_adk_context_id(self):
144+
"""Test conversion uses ADK context id for user/session."""
145+
mock_message = Mock()
146+
mock_message.parts = [Mock()]
147+
148+
request = Mock(spec=RequestContext)
149+
request.message = mock_message
150+
request.context_id = _to_a2a_context_id("app", "user-1", "session-1")
151+
request.call_context = None
152+
request.metadata = {}
153+
154+
mock_genai_part = genai_types.Part(text="test part")
155+
mock_convert_part = Mock(return_value=mock_genai_part)
156+
157+
result = convert_a2a_request_to_agent_run_request(
158+
request, mock_convert_part
159+
)
160+
161+
assert result.user_id == "user-1"
162+
assert result.session_id == "session-1"
163+
132164
def test_convert_a2a_request_basic(self):
133165
"""Test basic conversion of A2A request to ADK AgentRunRequest."""
134166
# Arrange
@@ -164,7 +196,7 @@ def test_convert_a2a_request_basic(self):
164196
# Assert
165197
assert result is not None
166198
assert result.user_id == "test_user"
167-
assert result.session_id == "test_context_123"
199+
assert result.session_id is None
168200
assert isinstance(result.new_message, genai_types.Content)
169201
assert result.new_message.role == "user"
170202
assert result.new_message.parts == [mock_genai_part1, mock_genai_part2]
@@ -213,7 +245,7 @@ def test_convert_a2a_request_multiple_parts(self):
213245
# Assert
214246
assert result is not None
215247
assert result.user_id == "test_user"
216-
assert result.session_id == "test_context_123"
248+
assert result.session_id is None
217249
assert isinstance(result.new_message, genai_types.Content)
218250
assert result.new_message.role == "user"
219251
assert result.new_message.parts == [
@@ -261,7 +293,7 @@ def test_convert_a2a_request_empty_parts(self):
261293
# Assert
262294
assert result is not None
263295
assert result.user_id == "A2A_USER_test_context_123"
264-
assert result.session_id == "test_context_123"
296+
assert result.session_id is None
265297
assert isinstance(result.new_message, genai_types.Content)
266298
assert result.new_message.role == "user"
267299
assert result.new_message.parts == []
@@ -328,7 +360,7 @@ def test_convert_a2a_request_no_auth(self):
328360
# Assert
329361
assert result is not None
330362
assert result.user_id == "A2A_USER_session_123"
331-
assert result.session_id == "session_123"
363+
assert result.session_id is None
332364
assert isinstance(result.new_message, genai_types.Content)
333365
assert result.new_message.role == "user"
334366
assert result.new_message.parts == [mock_genai_part]
@@ -370,7 +402,7 @@ def test_end_to_end_conversion_with_auth_user(self):
370402
# Assert
371403
assert result is not None
372404
assert result.user_id == "auth_user" # Should use authenticated user
373-
assert result.session_id == "mysession"
405+
assert result.session_id is None
374406
assert isinstance(result.new_message, genai_types.Content)
375407
assert result.new_message.role == "user"
376408
assert result.new_message.parts == [mock_genai_part]
@@ -404,7 +436,7 @@ def test_end_to_end_conversion_with_fallback_user(self):
404436
assert (
405437
result.user_id == "A2A_USER_test_session_456"
406438
) # Should fall back to context ID
407-
assert result.session_id == "test_session_456"
439+
assert result.session_id is None
408440
assert isinstance(result.new_message, genai_types.Content)
409441
assert result.new_message.role == "user"
410442
assert result.new_message.parts == [mock_genai_part]

0 commit comments

Comments
 (0)