Skip to content

Commit a1abd98

Browse files
committed
fix: stabilize function call IDs across streaming events
When models don't provide function call IDs, ADK generates client-side IDs via populate_client_function_call_id(). In streaming mode, partial and final events for the same logical function call each get a fresh uuid4, causing an ID mismatch that breaks HITL (human-in-the-loop) workflows and SSE consumers that correlate function calls across chunks. Root cause: _finalize_model_response_event creates a new Event object for each llm_response chunk, and populate_client_function_call_id generates a brand-new ID every time without knowledge of prior IDs. Fix: Add an optional function_call_id_cache dict that maps (name, index) keys to previously generated IDs. The streaming loop in _run_async creates the cache before iteration and threads it through _postprocess_async → _finalize_model_response_event → populate_client_function_call_id, ensuring the same logical function call gets a stable ID across all streaming events. The cache is keyed by (name:index) to correctly handle multiple calls to the same function within a single response. Fixes #4609
1 parent 8ddddc0 commit a1abd98

File tree

3 files changed

+232
-6
lines changed

3 files changed

+232
-6
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _finalize_model_response_event(
8181
llm_request: LlmRequest,
8282
llm_response: LlmResponse,
8383
model_response_event: Event,
84+
function_call_id_cache: Optional[dict[str, str]] = None,
8485
) -> Event:
8586
"""Finalize and build the model response event from LLM response.
8687
@@ -91,6 +92,9 @@ def _finalize_model_response_event(
9192
llm_request: The original LLM request.
9293
llm_response: The LLM response from the model.
9394
model_response_event: The base event to populate.
95+
function_call_id_cache: Optional dict mapping function call names to
96+
previously generated IDs. Used to keep IDs stable across partial
97+
and final streaming events.
9498
9599
Returns:
96100
The finalized Event with LLM response data merged in.
@@ -103,7 +107,9 @@ def _finalize_model_response_event(
103107
if finalized_event.content:
104108
function_calls = finalized_event.get_function_calls()
105109
if function_calls:
106-
functions.populate_client_function_call_id(finalized_event)
110+
functions.populate_client_function_call_id(
111+
finalized_event, function_call_id_cache
112+
)
107113
finalized_event.long_running_tool_ids = (
108114
functions.get_long_running_function_calls(
109115
function_calls, llm_request.tools_dict
@@ -827,6 +833,9 @@ async def _run_one_step_async(
827833
author=invocation_context.agent.name,
828834
branch=invocation_context.branch,
829835
)
836+
# Cache maps function call names to generated IDs so that partial and
837+
# final streaming events for the same call share a stable ID.
838+
function_call_id_cache: dict[str, str] = {}
830839
async with Aclosing(
831840
self._call_llm_async(
832841
invocation_context, llm_request, model_response_event
@@ -840,6 +849,7 @@ async def _run_one_step_async(
840849
llm_request,
841850
llm_response,
842851
model_response_event,
852+
function_call_id_cache,
843853
)
844854
) as agen:
845855
async for event in agen:
@@ -886,6 +896,7 @@ async def _postprocess_async(
886896
llm_request: LlmRequest,
887897
llm_response: LlmResponse,
888898
model_response_event: Event,
899+
function_call_id_cache: Optional[dict[str, str]] = None,
889900
) -> AsyncGenerator[Event, None]:
890901
"""Postprocess after calling the LLM.
891902
@@ -894,6 +905,9 @@ async def _postprocess_async(
894905
llm_request: The original LLM request.
895906
llm_response: The LLM response from the LLM call.
896907
model_response_event: A mutable event for the LLM response.
908+
function_call_id_cache: Optional dict mapping function call names to
909+
previously generated IDs. Keeps IDs stable across partial and final
910+
streaming events.
897911
898912
Yields:
899913
A generator of events.
@@ -917,7 +931,8 @@ async def _postprocess_async(
917931

918932
# Builds the event.
919933
model_response_event = self._finalize_model_response_event(
920-
llm_request, llm_response, model_response_event
934+
llm_request, llm_response, model_response_event,
935+
function_call_id_cache,
921936
)
922937
yield model_response_event
923938

@@ -1197,9 +1212,11 @@ def _finalize_model_response_event(
11971212
llm_request: LlmRequest,
11981213
llm_response: LlmResponse,
11991214
model_response_event: Event,
1215+
function_call_id_cache: Optional[dict[str, str]] = None,
12001216
) -> Event:
12011217
return _finalize_model_response_event(
1202-
llm_request, llm_response, model_response_event
1218+
llm_request, llm_response, model_response_event,
1219+
function_call_id_cache,
12031220
)
12041221

12051222
async def _resolve_toolset_auth(

src/google/adk/flows/llm_flows/functions.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,25 @@ def generate_client_function_call_id() -> str:
181181
return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}'
182182

183183

184-
def populate_client_function_call_id(model_response_event: Event) -> None:
184+
def populate_client_function_call_id(
185+
model_response_event: Event,
186+
function_call_id_cache: Optional[dict[str, str]] = None,
187+
) -> None:
185188
if not model_response_event.get_function_calls():
186189
return
187-
for function_call in model_response_event.get_function_calls():
190+
for idx, function_call in enumerate(
191+
model_response_event.get_function_calls()
192+
):
188193
if not function_call.id:
189-
function_call.id = generate_client_function_call_id()
194+
# Use (name, index) as cache key so that two calls to the same
195+
# function in a single response keep separate stable IDs.
196+
cache_key = f'{function_call.name}:{idx}'
197+
if function_call_id_cache is not None and cache_key in function_call_id_cache:
198+
function_call.id = function_call_id_cache[cache_key]
199+
else:
200+
function_call.id = generate_client_function_call_id()
201+
if function_call_id_cache is not None:
202+
function_call_id_cache[cache_key] = function_call.id
190203

191204

192205
def remove_client_function_call_id(content: Optional[types.Content]) -> None:
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests that function call IDs stay stable across streaming events."""
16+
17+
from google.adk.events.event import Event
18+
from google.adk.flows.llm_flows.base_llm_flow import _finalize_model_response_event
19+
from google.adk.flows.llm_flows.functions import populate_client_function_call_id
20+
from google.adk.models.llm_request import LlmRequest
21+
from google.adk.models.llm_response import LlmResponse
22+
from google.genai import types
23+
import pytest
24+
25+
26+
def _make_fc_response(name: str, args: dict | None = None, partial: bool = False) -> LlmResponse:
27+
"""Create an LlmResponse containing a single function call."""
28+
fc = types.FunctionCall(name=name, args=args or {})
29+
return LlmResponse(
30+
content=types.Content(role='model', parts=[types.Part(function_call=fc)]),
31+
partial=partial,
32+
)
33+
34+
35+
def _make_multi_fc_response(calls: list[tuple[str, dict]], partial: bool = False) -> LlmResponse:
36+
"""Create an LlmResponse containing multiple function calls."""
37+
parts = [
38+
types.Part(function_call=types.FunctionCall(name=name, args=args))
39+
for name, args in calls
40+
]
41+
return LlmResponse(
42+
content=types.Content(role='model', parts=parts),
43+
partial=partial,
44+
)
45+
46+
47+
class TestPopulateClientFunctionCallIdWithCache:
48+
"""Tests for populate_client_function_call_id with ID caching."""
49+
50+
def test_generates_id_and_stores_in_cache(self):
51+
event = Event(author='agent')
52+
event.content = types.Content(
53+
role='model',
54+
parts=[types.Part(function_call=types.FunctionCall(name='get_weather', args={}))],
55+
)
56+
cache: dict[str, str] = {}
57+
populate_client_function_call_id(event, cache)
58+
fc = event.get_function_calls()[0]
59+
assert fc.id.startswith('adk-')
60+
assert 'get_weather:0' in cache
61+
assert cache['get_weather:0'] == fc.id
62+
63+
def test_reuses_cached_id(self):
64+
cache: dict[str, str] = {'get_weather:0': 'adk-cached-id-123'}
65+
66+
event = Event(author='agent')
67+
event.content = types.Content(
68+
role='model',
69+
parts=[types.Part(function_call=types.FunctionCall(name='get_weather', args={}))],
70+
)
71+
populate_client_function_call_id(event, cache)
72+
assert event.get_function_calls()[0].id == 'adk-cached-id-123'
73+
74+
def test_no_cache_generates_new_id_each_time(self):
75+
event1 = Event(author='agent')
76+
event1.content = types.Content(
77+
role='model',
78+
parts=[types.Part(function_call=types.FunctionCall(name='get_weather', args={}))],
79+
)
80+
event2 = Event(author='agent')
81+
event2.content = types.Content(
82+
role='model',
83+
parts=[types.Part(function_call=types.FunctionCall(name='get_weather', args={}))],
84+
)
85+
populate_client_function_call_id(event1)
86+
populate_client_function_call_id(event2)
87+
assert event1.get_function_calls()[0].id != event2.get_function_calls()[0].id
88+
89+
def test_multiple_calls_same_name_get_separate_ids(self):
90+
event = Event(author='agent')
91+
event.content = types.Content(
92+
role='model',
93+
parts=[
94+
types.Part(function_call=types.FunctionCall(name='search', args={'q': 'a'})),
95+
types.Part(function_call=types.FunctionCall(name='search', args={'q': 'b'})),
96+
],
97+
)
98+
cache: dict[str, str] = {}
99+
populate_client_function_call_id(event, cache)
100+
fcs = event.get_function_calls()
101+
assert fcs[0].id != fcs[1].id
102+
assert cache['search:0'] == fcs[0].id
103+
assert cache['search:1'] == fcs[1].id
104+
105+
def test_skips_function_calls_that_already_have_ids(self):
106+
event = Event(author='agent')
107+
event.content = types.Content(
108+
role='model',
109+
parts=[types.Part(function_call=types.FunctionCall(
110+
name='get_weather', args={}, id='server-provided-id'))],
111+
)
112+
cache: dict[str, str] = {}
113+
populate_client_function_call_id(event, cache)
114+
assert event.get_function_calls()[0].id == 'server-provided-id'
115+
assert len(cache) == 0
116+
117+
118+
class TestFinalizeModelResponseEventWithCache:
119+
"""Tests that _finalize_model_response_event preserves IDs via cache."""
120+
121+
def test_partial_and_final_share_same_function_call_id(self):
122+
model_response_event = Event(
123+
author='agent',
124+
invocation_id='inv-1',
125+
)
126+
llm_request = LlmRequest(model='mock', contents=[])
127+
cache: dict[str, str] = {}
128+
129+
# Partial event
130+
partial_response = _make_fc_response('get_weather', partial=True)
131+
partial_event = _finalize_model_response_event(
132+
llm_request, partial_response, model_response_event, cache,
133+
)
134+
partial_id = partial_event.get_function_calls()[0].id
135+
assert partial_id.startswith('adk-')
136+
137+
# Final event — same function call must get the same ID
138+
final_response = _make_fc_response('get_weather', partial=False)
139+
final_event = _finalize_model_response_event(
140+
llm_request, final_response, model_response_event, cache,
141+
)
142+
final_id = final_event.get_function_calls()[0].id
143+
assert final_id == partial_id
144+
145+
def test_without_cache_ids_differ(self):
146+
model_response_event = Event(
147+
author='agent',
148+
invocation_id='inv-1',
149+
)
150+
llm_request = LlmRequest(model='mock', contents=[])
151+
152+
partial_response = _make_fc_response('get_weather', partial=True)
153+
partial_event = _finalize_model_response_event(
154+
llm_request, partial_response, model_response_event,
155+
)
156+
partial_id = partial_event.get_function_calls()[0].id
157+
158+
final_response = _make_fc_response('get_weather', partial=False)
159+
final_event = _finalize_model_response_event(
160+
llm_request, final_response, model_response_event,
161+
)
162+
final_id = final_event.get_function_calls()[0].id
163+
164+
# Without cache, IDs are different (this is the bug scenario)
165+
assert final_id != partial_id
166+
167+
def test_multi_function_call_streaming_preserves_all_ids(self):
168+
model_response_event = Event(
169+
author='agent',
170+
invocation_id='inv-1',
171+
)
172+
llm_request = LlmRequest(model='mock', contents=[])
173+
cache: dict[str, str] = {}
174+
175+
# Partial with two function calls
176+
partial_response = _make_multi_fc_response(
177+
[('search', {'q': 'weather'}), ('lookup', {'id': '42'})],
178+
partial=True,
179+
)
180+
partial_event = _finalize_model_response_event(
181+
llm_request, partial_response, model_response_event, cache,
182+
)
183+
partial_ids = [fc.id for fc in partial_event.get_function_calls()]
184+
185+
# Final with same two function calls
186+
final_response = _make_multi_fc_response(
187+
[('search', {'q': 'weather'}), ('lookup', {'id': '42'})],
188+
partial=False,
189+
)
190+
final_event = _finalize_model_response_event(
191+
llm_request, final_response, model_response_event, cache,
192+
)
193+
final_ids = [fc.id for fc in final_event.get_function_calls()]
194+
195+
assert partial_ids == final_ids
196+
assert partial_ids[0] != partial_ids[1] # different calls have different IDs

0 commit comments

Comments
 (0)