Skip to content

Commit 7f04f1a

Browse files
fix(models): avoid duplicate interaction final responses
1 parent bb89466 commit 7f04f1a

File tree

4 files changed

+184
-14
lines changed

4 files changed

+184
-14
lines changed

src/google/adk/events/event.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def is_final_response(self) -> bool:
9090
if self.actions.skip_summarization or self.long_running_tool_ids:
9191
return True
9292
return (
93-
not self.get_function_calls()
93+
self.turn_complete is not False
94+
and not self.get_function_calls()
9495
and not self.get_function_responses()
9596
and not self.partial
9697
and not self.has_trailing_code_execution_result()

src/google/adk/models/interactions_utils.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,33 @@
5656
_NEW_LINE = '\n'
5757

5858

59+
def _merge_text_with_overlap(existing: str, incoming: str) -> str:
60+
"""Merge streamed text fragments while avoiding overlap duplication."""
61+
if not existing:
62+
return incoming
63+
if not incoming:
64+
return existing
65+
66+
max_overlap = min(len(existing), len(incoming))
67+
for size in range(max_overlap, 0, -1):
68+
if existing.endswith(incoming[:size]):
69+
return existing + incoming[size:]
70+
return existing + incoming
71+
72+
73+
def _append_delta_text_part(aggregated_parts: list[types.Part], text: str):
74+
"""Append text to aggregated parts, merging with trailing text if present."""
75+
if not text:
76+
return
77+
78+
if aggregated_parts and aggregated_parts[-1].text is not None:
79+
merged_text = _merge_text_with_overlap(aggregated_parts[-1].text, text)
80+
aggregated_parts[-1] = types.Part.from_text(text=merged_text)
81+
return
82+
83+
aggregated_parts.append(types.Part.from_text(text=text))
84+
85+
5986
def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]:
6087
"""Convert a types.Part to an interaction content dict.
6188
@@ -487,8 +514,8 @@ def convert_interaction_event_to_llm_response(
487514
if delta_type == 'text':
488515
text = delta.text or ''
489516
if text:
517+
_append_delta_text_part(aggregated_parts, text)
490518
part = types.Part.from_text(text=text)
491-
aggregated_parts.append(part)
492519
return LlmResponse(
493520
content=types.Content(role='model', parts=[part]),
494521
partial=True,
@@ -539,18 +566,15 @@ def convert_interaction_event_to_llm_response(
539566
)
540567

541568
elif event_type == 'content.stop':
542-
# Content streaming finished, return aggregated content
543-
if aggregated_parts:
544-
return LlmResponse(
545-
content=types.Content(role='model', parts=list(aggregated_parts)),
546-
partial=False,
547-
turn_complete=False,
548-
interaction_id=interaction_id,
549-
)
569+
# Content.stop is a stream boundary marker.
570+
# Final content emission happens at interaction.status_update or stream end
571+
# to avoid duplicate final responses.
572+
return None
550573

551574
elif event_type == 'interaction':
552-
# Final interaction event with complete data
553-
return convert_interaction_to_llm_response(event)
575+
# We intentionally do not emit from this event in streaming mode because
576+
# interaction outputs can duplicate already aggregated content deltas.
577+
return None
554578

555579
elif event_type == 'interaction.status_update':
556580
status = getattr(event, 'status', None)
@@ -992,6 +1016,7 @@ async def generate_content_via_interactions(
9921016
)
9931017

9941018
aggregated_parts: list[types.Part] = []
1019+
has_emitted_turn_complete = False
9951020
async for event in responses:
9961021
# Log the streaming event
9971022
logger.debug(build_interactions_event_log(event))
@@ -1003,10 +1028,13 @@ async def generate_content_via_interactions(
10031028
event, aggregated_parts, current_interaction_id
10041029
)
10051030
if llm_response:
1031+
if llm_response.turn_complete:
1032+
has_emitted_turn_complete = True
10061033
yield llm_response
10071034

1008-
# Final aggregated response
1009-
if aggregated_parts:
1035+
# Final aggregated response fallback if the stream never emitted a
1036+
# completion event (e.g., missing interaction.status_update).
1037+
if aggregated_parts and not has_emitted_turn_complete:
10101038
yield LlmResponse(
10111039
content=types.Content(role='model', parts=aggregated_parts),
10121040
partial=False,

tests/unittests/events_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Event model helpers."""
16+
17+
from google.adk.events.event import Event
18+
from google.genai import types
19+
20+
21+
def test_is_final_response_false_when_turn_incomplete():
22+
"""Event is not final when turn_complete is explicitly False."""
23+
event = Event(
24+
author='agent',
25+
turn_complete=False,
26+
content=types.Content(role='model', parts=[types.Part(text='partial')]),
27+
)
28+
29+
assert not event.is_final_response()
30+
31+
32+
def test_is_final_response_true_when_turn_complete():
33+
"""Event is final for plain text response when turn is complete."""
34+
event = Event(
35+
author='agent',
36+
turn_complete=True,
37+
content=types.Content(role='model', parts=[types.Part(text='done')]),
38+
)
39+
40+
assert event.is_final_response()

tests/unittests/models/test_interactions_utils.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
"""Tests for interactions_utils.py conversion functions."""
1616

17+
import asyncio
1718
import json
19+
from unittest.mock import AsyncMock
1820
from unittest.mock import MagicMock
1921

2022
from google.adk.models import interactions_utils
@@ -759,3 +761,102 @@ def test_full_conversation(self):
759761
assert len(result) == 2
760762
assert result[0].parts[0].text == 'Great'
761763
assert result[1].parts[0].text == 'Tell me more'
764+
765+
766+
class TestGenerateContentViaInteractionsStreaming:
767+
"""Tests for streaming generation via interactions API."""
768+
769+
def test_emits_single_final_response_with_status_update(self):
770+
"""Ensures no duplicate final response is emitted in streaming mode."""
771+
delta_1 = MagicMock()
772+
delta_1.event_type = 'content.delta'
773+
delta_1.id = 'interaction_1'
774+
delta_1.delta = MagicMock(type='text', text='Hello ')
775+
776+
delta_2 = MagicMock()
777+
delta_2.event_type = 'content.delta'
778+
delta_2.id = 'interaction_1'
779+
delta_2.delta = MagicMock(type='text', text='world')
780+
781+
status_update = MagicMock()
782+
status_update.event_type = 'interaction.status_update'
783+
status_update.id = 'interaction_1'
784+
status_update.status = 'completed'
785+
786+
async def _stream_events():
787+
for event in [delta_1, delta_2, status_update]:
788+
yield event
789+
790+
api_client = MagicMock()
791+
api_client.aio.interactions.create = AsyncMock(
792+
return_value=_stream_events()
793+
)
794+
795+
llm_request = LlmRequest(
796+
model='gemini-2.5-flash',
797+
contents=[types.Content(role='user', parts=[types.Part(text='hi')])],
798+
)
799+
800+
async def _collect_responses():
801+
return [
802+
response
803+
async for response in
804+
interactions_utils.generate_content_via_interactions(
805+
api_client=api_client, llm_request=llm_request, stream=True
806+
)
807+
]
808+
809+
responses = asyncio.run(_collect_responses())
810+
811+
assert len(responses) == 3
812+
assert responses[0].partial is True
813+
assert responses[1].partial is True
814+
assert responses[2].turn_complete is True
815+
assert responses[2].content.parts[0].text == 'Hello world'
816+
817+
def test_merges_overlapping_text_deltas_in_final_response(self):
818+
"""Ensures overlapping text chunks are merged without duplication."""
819+
delta_1 = MagicMock()
820+
delta_1.event_type = 'content.delta'
821+
delta_1.id = 'interaction_2'
822+
delta_1.delta = MagicMock(type='text', text='Hello wor')
823+
824+
delta_2 = MagicMock()
825+
delta_2.event_type = 'content.delta'
826+
delta_2.id = 'interaction_2'
827+
delta_2.delta = MagicMock(type='text', text='world')
828+
829+
content_stop = MagicMock()
830+
content_stop.event_type = 'content.stop'
831+
content_stop.id = 'interaction_2'
832+
833+
async def _stream_events():
834+
for event in [delta_1, delta_2, content_stop]:
835+
yield event
836+
837+
api_client = MagicMock()
838+
api_client.aio.interactions.create = AsyncMock(
839+
return_value=_stream_events()
840+
)
841+
842+
llm_request = LlmRequest(
843+
model='gemini-2.5-flash',
844+
contents=[types.Content(role='user', parts=[types.Part(text='hi')])],
845+
)
846+
847+
async def _collect_responses():
848+
return [
849+
response
850+
async for response in
851+
interactions_utils.generate_content_via_interactions(
852+
api_client=api_client, llm_request=llm_request, stream=True
853+
)
854+
]
855+
856+
responses = asyncio.run(_collect_responses())
857+
858+
assert len(responses) == 3
859+
final_response = responses[-1]
860+
assert final_response.turn_complete is True
861+
assert final_response.content.parts[0].text == 'Hello world'
862+

0 commit comments

Comments
 (0)