Skip to content

Commit ab69ef8

Browse files
hangfeicopybara-github
authored andcommitted
feat: Move livebidi agents esp multi-agent to use session/events
The old live/bidi agents are using a cache to store context/history during agent transfer etc. As we have added support for session for live/bidi, we are now migrating the context/history cache to it. This improves scalability, efficiency and maintainability. It introduces several changes: * AudioTranscriber support is removed as now we are using native transcription from models. * Transcription is returned as input_transcription/output_transcription fields and no longer as contents. * We will return a new event with artifact references of file type of audio/pcm.(in addition to existing audio response event. So the users of this api need to do proper filtering here.) PiperOrigin-RevId: 805997675
1 parent 873551d commit ab69ef8

File tree

12 files changed

+474
-456
lines changed

12 files changed

+474
-456
lines changed

contributing/samples/live_bidi_streaming_multi_agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def get_current_weather(location: str):
100100

101101
root_agent = Agent(
102102
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
103-
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
104-
# model="gemini-live-2.5-flash-preview", # for AI studio key
103+
# model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
104+
model="gemini-live-2.5-flash-preview", # for AI studio key
105105
name="root_agent",
106106
instruction="""
107107
You are a helpful assistant that can check time, roll dice and check if numbers are prime.

src/google/adk/agents/run_config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.genai import types
2323
from pydantic import BaseModel
2424
from pydantic import ConfigDict
25+
from pydantic import Field
2526
from pydantic import field_validator
2627

2728
logger = logging.getLogger('google_adk.' + __name__)
@@ -64,10 +65,14 @@ class RunConfig(BaseModel):
6465
streaming_mode: StreamingMode = StreamingMode.NONE
6566
"""Streaming mode, None or StreamingMode.SSE or StreamingMode.BIDI."""
6667

67-
output_audio_transcription: Optional[types.AudioTranscriptionConfig] = None
68+
output_audio_transcription: Optional[types.AudioTranscriptionConfig] = Field(
69+
default_factory=types.AudioTranscriptionConfig
70+
)
6871
"""Output transcription for live agents with audio response."""
6972

70-
input_audio_transcription: Optional[types.AudioTranscriptionConfig] = None
73+
input_audio_transcription: Optional[types.AudioTranscriptionConfig] = Field(
74+
default_factory=types.AudioTranscriptionConfig
75+
)
7176
"""Input transcription for live agents with audio input from user."""
7277

7378
realtime_input_config: Optional[types.RealtimeInputConfig] = None
@@ -82,6 +87,12 @@ class RunConfig(BaseModel):
8287
session_resumption: Optional[types.SessionResumptionConfig] = None
8388
"""Configures session resumption mechanism. Only support transparent session resumption mode now."""
8489

90+
save_live_audio: bool = False
91+
"""Saves live video and audio data to session and artifact service.
92+
93+
Right now, only audio is supported.
94+
"""
95+
8596
max_llm_calls: int = 500
8697
"""
8798
A limit on the total number of llm calls for a given run.

src/google/adk/flows/llm_flows/audio_cache_manager.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def flush_caches(
8787
flush_user_audio: bool = True,
8888
flush_model_audio: bool = True,
8989
) -> None:
90-
"""Flush audio caches to session and artifact services.
90+
"""Flush audio caches to artifact services.
9191
9292
The multimodality data is saved in artifact service in the format of
9393
audio file. The file data reference is added to the session as an event.
@@ -103,32 +103,31 @@ async def flush_caches(
103103
flush_model_audio: Whether to flush the output (model) audio cache.
104104
"""
105105
if flush_user_audio and invocation_context.input_realtime_cache:
106-
success = await self._flush_cache_to_services(
106+
flush_success = await self._flush_cache_to_services(
107107
invocation_context,
108108
invocation_context.input_realtime_cache,
109109
'input_audio',
110110
)
111-
if success:
111+
if flush_success:
112112
invocation_context.input_realtime_cache = []
113-
logger.debug('Flushed input audio cache')
114113

115114
if flush_model_audio and invocation_context.output_realtime_cache:
116-
success = await self._flush_cache_to_services(
115+
logger.debug('Flushed output audio cache')
116+
flush_success = await self._flush_cache_to_services(
117117
invocation_context,
118118
invocation_context.output_realtime_cache,
119119
'output_audio',
120120
)
121-
if success:
121+
if flush_success:
122122
invocation_context.output_realtime_cache = []
123-
logger.debug('Flushed output audio cache')
124123

125124
async def _flush_cache_to_services(
126125
self,
127126
invocation_context: InvocationContext,
128127
audio_cache: list[RealtimeCacheEntry],
129128
cache_type: str,
130129
) -> bool:
131-
"""Flush a list of audio cache entries to session and artifact services.
130+
"""Flush a list of audio cache entries to artifact services.
132131
133132
The artifact service stores the actual blob. The session stores the
134133
reference to the stored blob.
@@ -191,19 +190,14 @@ async def _flush_cache_to_services(
191190
timestamp=audio_cache[0].timestamp,
192191
)
193192

194-
# Add to session
195-
await invocation_context.session_service.append_event(
196-
invocation_context.session, audio_event
197-
)
198-
199193
logger.debug(
200194
'Successfully flushed %s cache: %d chunks, %d bytes, saved as %s',
201195
cache_type,
202196
len(audio_cache),
203197
len(combined_audio_data),
204198
filename,
205199
)
206-
return True
200+
return audio_event
207201

208202
except Exception as e:
209203
logger.error('Failed to flush %s cache: %s', cache_type, e)

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 83 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,42 @@
6969
DEFAULT_ENABLE_CACHE_STATISTICS = False
7070

7171

72+
def _get_audio_transcription_from_session(
73+
invocation_context: InvocationContext,
74+
) -> list[types.Content]:
75+
"""Get audio and transcription content from session events.
76+
77+
Collects audio file references and transcription text from session events
78+
to reconstruct the conversation history including multimodal content.
79+
Args:
80+
invocation_context: The invocation context containing session data.
81+
Returns:
82+
A list of Content objects containing audio files and transcriptions.
83+
"""
84+
contents = []
85+
86+
for event in invocation_context.session.events:
87+
# Collect transcription text events
88+
if hasattr(event, 'input_transcription') and event.input_transcription:
89+
contents.append(
90+
types.Content(
91+
role='user',
92+
parts=[types.Part.from_text(text=event.input_transcription.text)],
93+
)
94+
)
95+
96+
if hasattr(event, 'output_transcription') and event.output_transcription:
97+
contents.append(
98+
types.Content(
99+
role='model',
100+
parts=[
101+
types.Part.from_text(text=event.output_transcription.text)
102+
],
103+
)
104+
)
105+
return contents
106+
107+
72108
class BaseLlmFlow(ABC):
73109
"""A basic flow that calls the LLM in a loop until a final response is generated.
74110
@@ -129,25 +165,12 @@ async def run_live(
129165
if llm_request.contents:
130166
# Sends the conversation history to the model.
131167
with tracer.start_as_current_span('send_data'):
132-
if invocation_context.transcription_cache:
133-
from . import audio_transcriber
134-
135-
audio_transcriber = audio_transcriber.AudioTranscriber(
136-
init_client=True
137-
if invocation_context.run_config.input_audio_transcription
138-
is None
139-
else False
140-
)
141-
contents = audio_transcriber.transcribe_file(invocation_context)
142-
logger.debug('Sending history to model: %s', contents)
143-
await llm_connection.send_history(contents)
144-
invocation_context.transcription_cache = None
145-
trace_send_data(invocation_context, event_id, contents)
146-
else:
147-
await llm_connection.send_history(llm_request.contents)
148-
trace_send_data(
149-
invocation_context, event_id, llm_request.contents
150-
)
168+
# Combine regular contents with audio/transcription from session
169+
logger.debug('Sending history to model: %s', llm_request.contents)
170+
await llm_connection.send_history(llm_request.contents)
171+
trace_send_data(
172+
invocation_context, event_id, llm_request.contents
173+
)
151174

152175
send_task = asyncio.create_task(
153176
self._send_to_model(llm_connection, invocation_context)
@@ -324,22 +347,6 @@ def get_author_for_event(llm_response):
324347
author=get_author_for_event(llm_response),
325348
)
326349

327-
# Handle transcription events ONCE per llm_response, outside the event loop
328-
if llm_response.input_transcription:
329-
await self.transcription_manager.handle_input_transcription(
330-
invocation_context, llm_response.input_transcription
331-
)
332-
333-
if llm_response.output_transcription:
334-
await self.transcription_manager.handle_output_transcription(
335-
invocation_context, llm_response.output_transcription
336-
)
337-
338-
# Flush audio caches based on control events using configurable settings
339-
await self._handle_control_event_flush(
340-
invocation_context, llm_response
341-
)
342-
343350
async with Aclosing(
344351
self._postprocess_live(
345352
invocation_context,
@@ -349,28 +356,11 @@ def get_author_for_event(llm_response):
349356
)
350357
) as agen:
351358
async for event in agen:
352-
if (
353-
event.content
354-
and event.content.parts
355-
and event.content.parts[0].inline_data is None
356-
and not event.partial
357-
):
358-
# This can be either user data or transcription data.
359-
# when output transcription enabled, it will contain model's
360-
# transcription.
361-
# when input transcription enabled, it will contain user
362-
# transcription.
363-
if not invocation_context.transcription_cache:
364-
invocation_context.transcription_cache = []
365-
invocation_context.transcription_cache.append(
366-
TranscriptionEntry(
367-
role=event.content.role, data=event.content
368-
)
369-
)
370359
# Cache output audio chunks from model responses
371360
# TODO: support video data
372361
if (
373-
event.content
362+
invocation_context.run_config.save_live_audio
363+
and event.content
374364
and event.content.parts
375365
and event.content.parts[0].inline_data
376366
and event.content.parts[0].inline_data.mime_type.startswith(
@@ -578,6 +568,36 @@ async def _postprocess_live(
578568
):
579569
return
580570

571+
# Handle transcription events ONCE per llm_response, outside the event loop
572+
if llm_response.input_transcription:
573+
input_transcription_event = (
574+
await self.transcription_manager.handle_input_transcription(
575+
invocation_context, llm_response.input_transcription
576+
)
577+
)
578+
yield input_transcription_event
579+
return
580+
581+
if llm_response.output_transcription:
582+
output_transcription_event = (
583+
await self.transcription_manager.handle_output_transcription(
584+
invocation_context, llm_response.output_transcription
585+
)
586+
)
587+
yield output_transcription_event
588+
return
589+
590+
# Flush audio caches based on control events using configurable settings
591+
if invocation_context.run_config.save_live_audio:
592+
_handle_control_event_flush_event = (
593+
await self._handle_control_event_flush(
594+
invocation_context, llm_response
595+
)
596+
)
597+
if _handle_control_event_flush_event:
598+
yield _handle_control_event_flush_event
599+
return
600+
581601
# Builds the event.
582602
model_response_event = self._finalize_model_response_event(
583603
llm_request, llm_response, model_response_event
@@ -877,33 +897,34 @@ async def _handle_control_event_flush(
877897
invocation_context: The invocation context containing audio caches.
878898
llm_response: The LLM response containing control event information.
879899
"""
900+
901+
# Log cache statistics if enabled
902+
if DEFAULT_ENABLE_CACHE_STATISTICS:
903+
stats = self.audio_cache_manager.get_cache_stats(invocation_context)
904+
logger.debug('Audio cache stats: %s', stats)
905+
880906
if llm_response.interrupted:
881907
# user interrupts so the model will stop. we can flush model audio here
882-
await self.audio_cache_manager.flush_caches(
908+
return await self.audio_cache_manager.flush_caches(
883909
invocation_context,
884910
flush_user_audio=False,
885911
flush_model_audio=True,
886912
)
887913
elif llm_response.turn_complete:
888914
# turn completes so we can flush both user and model
889-
await self.audio_cache_manager.flush_caches(
915+
return await self.audio_cache_manager.flush_caches(
890916
invocation_context,
891917
flush_user_audio=True,
892918
flush_model_audio=True,
893919
)
894920
elif getattr(llm_response, 'generation_complete', False):
895921
# model generation complete so we can flush model audio
896-
await self.audio_cache_manager.flush_caches(
922+
return await self.audio_cache_manager.flush_caches(
897923
invocation_context,
898924
flush_user_audio=False,
899925
flush_model_audio=True,
900926
)
901927

902-
# Log cache statistics if enabled
903-
if DEFAULT_ENABLE_CACHE_STATISTICS:
904-
stats = self.audio_cache_manager.get_cache_stats(invocation_context)
905-
logger.debug('Audio cache stats: %s', stats)
906-
907928
async def _run_and_handle_error(
908929
self,
909930
response_generator: AsyncGenerator[LlmResponse, None],

0 commit comments

Comments
 (0)