Skip to content

Commit b5daf19

Browse files
committed
fix: preserve function call IDs across SSE streaming partial/final events
In SSE streaming mode, _finalize_model_response_event creates a brand-new Event object for every LlmResponse chunk (partial and final). Because each new Event's function calls start with an empty .id, populate_client_function_call_id generates a fresh adk-{uuid} every time. This means the partial event and the final event for the same function call end up with different IDs, breaking LongRunningFunctionTool / HITL workflows that match responses by ID. Fix: add an optional function_call_ids dict parameter to _finalize_model_response_event. The dict maps (function_name, index) to the ID that was assigned the first time that function call was seen. Before populate_client_function_call_id runs, any previously stored ID is restored onto the function call so the guard 'if not function_call.id' keeps it. After population, newly generated IDs are written back into the dict. _run_one_step_async creates one such dict per LLM call and threads it through _postprocess_async for the lifetime of the streaming sequence, so all partial and final events share the same stable IDs. Fixes #4609
1 parent 8f54281 commit b5daf19

2 files changed

Lines changed: 218 additions & 2 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _finalize_model_response_event(
8181
llm_request: LlmRequest,
8282
llm_response: LlmResponse,
8383
model_response_event: Event,
84+
function_call_ids: Optional[dict[tuple[str, int], str]] = None,
8485
) -> Event:
8586
"""Finalize and build the model response event from LLM response.
8687
@@ -91,6 +92,11 @@ def _finalize_model_response_event(
9192
llm_request: The original LLM request.
9293
llm_response: The LLM response from the model.
9394
model_response_event: The base event to populate.
95+
function_call_ids: Optional mutable dict mapping (function_name, index) to
96+
previously assigned client function call IDs. Used during SSE streaming
97+
to ensure partial and final events for the same function call share the
98+
same ID. When provided, newly generated IDs are stored back into this
99+
dict for reuse by subsequent events in the same streaming sequence.
94100
95101
Returns:
96102
The finalized Event with LLM response data merged in.
@@ -103,7 +109,23 @@ def _finalize_model_response_event(
103109
if finalized_event.content:
104110
function_calls = finalized_event.get_function_calls()
105111
if function_calls:
112+
# Restore previously assigned IDs before populating new ones so that
113+
# partial and final events in an SSE stream share the same IDs.
114+
if function_call_ids is not None:
115+
for i, fc in enumerate(function_calls):
116+
key = (fc.name, i)
117+
if key in function_call_ids:
118+
fc.id = function_call_ids[key]
119+
106120
functions.populate_client_function_call_id(finalized_event)
121+
122+
# Persist any newly generated IDs for subsequent events.
123+
if function_call_ids is not None:
124+
for i, fc in enumerate(function_calls):
125+
key = (fc.name, i)
126+
if fc.id and key not in function_call_ids:
127+
function_call_ids[key] = fc.id
128+
107129
finalized_event.long_running_tool_ids = (
108130
functions.get_long_running_function_calls(
109131
function_calls, llm_request.tools_dict
@@ -821,6 +843,9 @@ async def _run_one_step_async(
821843
author=invocation_context.agent.name,
822844
branch=invocation_context.branch,
823845
)
846+
# Track function call IDs across partial/final events in SSE streaming
847+
# so that the same function call keeps the same client-generated ID.
848+
function_call_ids: dict[tuple[str, int], str] = {}
824849
async with Aclosing(
825850
self._call_llm_async(
826851
invocation_context, llm_request, model_response_event
@@ -834,6 +859,7 @@ async def _run_one_step_async(
834859
llm_request,
835860
llm_response,
836861
model_response_event,
862+
function_call_ids,
837863
)
838864
) as agen:
839865
async for event in agen:
@@ -880,6 +906,7 @@ async def _postprocess_async(
880906
llm_request: LlmRequest,
881907
llm_response: LlmResponse,
882908
model_response_event: Event,
909+
function_call_ids: Optional[dict[tuple[str, int], str]] = None,
883910
) -> AsyncGenerator[Event, None]:
884911
"""Postprocess after calling the LLM.
885912
@@ -888,6 +915,8 @@ async def _postprocess_async(
888915
llm_request: The original LLM request.
889916
llm_response: The LLM response from the LLM call.
890917
model_response_event: A mutable event for the LLM response.
918+
function_call_ids: Optional mutable dict for preserving function call IDs
919+
across partial and final events in an SSE streaming sequence.
891920
892921
Yields:
893922
A generator of events.
@@ -911,7 +940,7 @@ async def _postprocess_async(
911940

912941
# Builds the event.
913942
model_response_event = self._finalize_model_response_event(
914-
llm_request, llm_response, model_response_event
943+
llm_request, llm_response, model_response_event, function_call_ids
915944
)
916945
yield model_response_event
917946

@@ -1191,9 +1220,10 @@ def _finalize_model_response_event(
11911220
llm_request: LlmRequest,
11921221
llm_response: LlmResponse,
11931222
model_response_event: Event,
1223+
function_call_ids: Optional[dict[tuple[str, int], str]] = None,
11941224
) -> Event:
11951225
return _finalize_model_response_event(
1196-
llm_request, llm_response, model_response_event
1226+
llm_request, llm_response, model_response_event, function_call_ids
11971227
)
11981228

11991229
async def _resolve_toolset_auth(
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests that SSE streaming preserves function call IDs across partial/final events.
16+
17+
Regression test for https://github.com/google/adk-python/issues/4609
18+
"""
19+
20+
from google.adk.events.event import Event
21+
from google.adk.flows.llm_flows.base_llm_flow import _finalize_model_response_event
22+
from google.adk.models.llm_request import LlmRequest
23+
from google.adk.models.llm_response import LlmResponse
24+
from google.genai import types
25+
import pytest
26+
27+
28+
def _make_base_event() -> Event:
29+
return Event(
30+
id=Event.new_id(),
31+
invocation_id="test-inv",
32+
author="test-agent",
33+
)
34+
35+
36+
def _make_llm_response(*, partial: bool, fc_name: str = "get_weather") -> LlmResponse:
37+
return LlmResponse(
38+
content=types.Content(
39+
role="model",
40+
parts=[
41+
types.Part(
42+
function_call=types.FunctionCall(
43+
name=fc_name,
44+
args={"location": "NYC"},
45+
)
46+
)
47+
],
48+
),
49+
partial=partial,
50+
)
51+
52+
53+
def _make_llm_request() -> LlmRequest:
54+
req = LlmRequest()
55+
req.tools_dict = {}
56+
return req
57+
58+
59+
class TestStreamingFunctionCallIdConsistency:
60+
"""Ensure partial and final events share the same function call ID."""
61+
62+
def test_partial_and_final_share_same_id(self):
63+
"""The core regression: partial event ID must equal final event ID."""
64+
llm_request = _make_llm_request()
65+
function_call_ids: dict[tuple[str, int], str] = {}
66+
67+
# Simulate partial event
68+
partial_event = _finalize_model_response_event(
69+
llm_request,
70+
_make_llm_response(partial=True),
71+
_make_base_event(),
72+
function_call_ids,
73+
)
74+
partial_fc_id = partial_event.get_function_calls()[0].id
75+
assert partial_fc_id is not None
76+
assert partial_fc_id.startswith("adk-")
77+
78+
# Simulate final event (new Event object, same streaming sequence)
79+
final_event = _finalize_model_response_event(
80+
llm_request,
81+
_make_llm_response(partial=False),
82+
_make_base_event(),
83+
function_call_ids,
84+
)
85+
final_fc_id = final_event.get_function_calls()[0].id
86+
87+
assert final_fc_id == partial_fc_id
88+
89+
def test_without_function_call_ids_dict_generates_different_ids(self):
90+
"""Without the fix dict, each event gets a fresh ID (old behaviour)."""
91+
llm_request = _make_llm_request()
92+
93+
partial_event = _finalize_model_response_event(
94+
llm_request,
95+
_make_llm_response(partial=True),
96+
_make_base_event(),
97+
)
98+
final_event = _finalize_model_response_event(
99+
llm_request,
100+
_make_llm_response(partial=False),
101+
_make_base_event(),
102+
)
103+
104+
# Without the dict, IDs differ (demonstrating the old bug)
105+
assert (
106+
partial_event.get_function_calls()[0].id
107+
!= final_event.get_function_calls()[0].id
108+
)
109+
110+
def test_multiple_function_calls_preserve_ids(self):
111+
"""Each function call in a multi-call response keeps its own stable ID."""
112+
llm_request = _make_llm_request()
113+
function_call_ids: dict[tuple[str, int], str] = {}
114+
115+
def make_multi_fc_response(partial: bool) -> LlmResponse:
116+
return LlmResponse(
117+
content=types.Content(
118+
role="model",
119+
parts=[
120+
types.Part(
121+
function_call=types.FunctionCall(
122+
name="get_weather",
123+
args={"location": "NYC"},
124+
)
125+
),
126+
types.Part(
127+
function_call=types.FunctionCall(
128+
name="get_time",
129+
args={"timezone": "EST"},
130+
)
131+
),
132+
],
133+
),
134+
partial=partial,
135+
)
136+
137+
partial_event = _finalize_model_response_event(
138+
llm_request,
139+
make_multi_fc_response(partial=True),
140+
_make_base_event(),
141+
function_call_ids,
142+
)
143+
partial_ids = [fc.id for fc in partial_event.get_function_calls()]
144+
145+
final_event = _finalize_model_response_event(
146+
llm_request,
147+
make_multi_fc_response(partial=False),
148+
_make_base_event(),
149+
function_call_ids,
150+
)
151+
final_ids = [fc.id for fc in final_event.get_function_calls()]
152+
153+
assert partial_ids == final_ids
154+
# The two function calls should have different IDs from each other
155+
assert partial_ids[0] != partial_ids[1]
156+
157+
def test_server_provided_id_is_preserved(self):
158+
"""If the server already provides an ID, it should not be overwritten."""
159+
llm_request = _make_llm_request()
160+
function_call_ids: dict[tuple[str, int], str] = {}
161+
162+
server_id = "server-provided-id-123"
163+
response = LlmResponse(
164+
content=types.Content(
165+
role="model",
166+
parts=[
167+
types.Part(
168+
function_call=types.FunctionCall(
169+
id=server_id,
170+
name="get_weather",
171+
args={"location": "NYC"},
172+
)
173+
)
174+
],
175+
),
176+
partial=False,
177+
)
178+
179+
event = _finalize_model_response_event(
180+
llm_request,
181+
response,
182+
_make_base_event(),
183+
function_call_ids,
184+
)
185+
186+
assert event.get_function_calls()[0].id == server_id

0 commit comments

Comments
 (0)