diff --git a/contributing/samples/live_bidi_streaming_multi_agent/agent.py b/contributing/samples/live_bidi_streaming_multi_agent/agent.py index 413e33a727..ddb36b2845 100644 --- a/contributing/samples/live_bidi_streaming_multi_agent/agent.py +++ b/contributing/samples/live_bidi_streaming_multi_agent/agent.py @@ -16,6 +16,7 @@ from google.adk.agents.llm_agent import Agent from google.adk.examples.example import Example +from google.adk.models.google_llm import Gemini from google.adk.tools.example_tool import ExampleTool from google.genai import types @@ -28,6 +29,17 @@ def roll_die(sides: int) -> int: roll_agent = Agent( name="roll_agent", + model=Gemini( + # model="gemini-2.0-flash-live-preview-04-09", # for Vertex project + model="gemini-live-2.5-flash-preview", # for AI studio key + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name="Kore", + ) + ) + ), + ), description="Handles rolling dice of different sizes.", instruction=""" You are responsible for rolling dice based on the user's request. @@ -69,6 +81,17 @@ def check_prime(nums: list[int]) -> str: prime_agent = Agent( name="prime_agent", + model=Gemini( + # model="gemini-2.0-flash-live-preview-04-09", # for Vertex project + model="gemini-live-2.5-flash-preview", # for AI studio key + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name="Puck", + ) + ) + ), + ), description="Handles checking if numbers are prime.", instruction=""" You are responsible for checking whether numbers are prime. @@ -100,8 +123,17 @@ def get_current_weather(location: str): root_agent = Agent( # find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/ - model="gemini-2.0-flash-live-preview-04-09", # for Vertex project - # model="gemini-live-2.5-flash-preview", # for AI studio key + model=Gemini( + # model="gemini-2.0-flash-live-preview-04-09", # for Vertex project + model="gemini-live-2.5-flash-preview", # for AI studio key + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name="Zephyr", + ) + ) + ), + ), name="root_agent", instruction=""" You are a helpful assistant that can check time, roll dice and check if numbers are prime. diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index 754970537e..bfcf370676 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -35,7 +35,10 @@ class StreamingMode(Enum): class RunConfig(BaseModel): - """Configs for runtime behavior of agents.""" + """Configs for runtime behavior of agents. + + The configs here will be overriden by agent-specific configurations. + """ model_config = ConfigDict( extra='forbid', diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 411162bb0c..b96e56e169 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -60,6 +60,8 @@ class Gemini(BaseLlm): model: str = 'gemini-2.5-flash' + speech_config: Optional[types.SpeechConfig] = None + retry_options: Optional[types.HttpRetryOptions] = None """Allow Gemini to retry failed responses. @@ -261,6 +263,9 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: self._live_api_version ) + if self.speech_config is not None: + llm_request.live_connect_config.speech_config = self.speech_config + llm_request.live_connect_config.system_instruction = types.Content( role='system', parts=[ diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 1b5979bdf9..f3356975f5 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -1858,3 +1858,189 @@ def mock_model_dump(*args, **kwargs): # Should still succeed using repr() assert "Config:" in log_output assert "GenerateContentConfig" in log_output + + +@pytest.mark.asyncio +async def test_connect_uses_gemini_speech_config_when_request_is_none( + gemini_llm, llm_request +): + """Tests that Gemini's speech_config is used when live_connect_config's is None.""" + # Arrange: Set a speech_config on the Gemini instance with the voice "Kore" + gemini_llm.speech_config = types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name="Kore", + ) + ) + ) + llm_request.live_connect_config = ( + types.LiveConnectConfig() + ) # speech_config is None + + mock_live_session = mock.AsyncMock() + + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + # Act + async with gemini_llm.connect(llm_request) as connection: + # Assert + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify the speech_config from the Gemini instance was used + assert config_arg.speech_config is not None + assert ( + config_arg.speech_config.voice_config.prebuilt_voice_config.voice_name + == "Kore" + ) + assert isinstance(connection, GeminiLlmConnection) + + +@pytest.mark.asyncio +async def test_connect_uses_request_speech_config_when_gemini_is_none( + gemini_llm, llm_request +): + """Tests that request's speech_config is used when Gemini's is None.""" + # Arrange: Set a speech_config on the request instance with the voice "Kore" + gemini_llm.speech_config = None + request_speech_config = types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name="Kore", + ) + ) + ) + llm_request.live_connect_config = types.LiveConnectConfig( + speech_config=request_speech_config + ) + + mock_live_session = mock.AsyncMock() + + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + # Act + async with gemini_llm.connect(llm_request) as connection: + # Assert + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify the speech_config from the request instance was used + assert config_arg.speech_config is not None + assert ( + config_arg.speech_config.voice_config.prebuilt_voice_config.voice_name + == "Kore" + ) + assert isinstance(connection, GeminiLlmConnection) + + +@pytest.mark.asyncio +async def test_connect_request_gemini_config_overrides_speech_config( + gemini_llm, llm_request +): + """Tests that live_connect_config's speech_config is preserved even if Gemini has one.""" + # Arrange: Set different speech_configs on both the Gemini instance ("Puck") and the request ("Zephyr") + gemini_llm.speech_config = types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name="Puck", + ) + ) + ) + request_speech_config = types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name="Zephyr", + ) + ) + ) + llm_request.live_connect_config = types.LiveConnectConfig( + speech_config=request_speech_config + ) + + mock_live_session = mock.AsyncMock() + + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + # Act + async with gemini_llm.connect(llm_request) as connection: + # Assert + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify the speech_config from the request ("Zephyr") was overwritten by Gemini's speech_config ("Puck") + assert config_arg.speech_config is not None + assert ( + config_arg.speech_config.voice_config.prebuilt_voice_config.voice_name + == "Puck" + ) + assert isinstance(connection, GeminiLlmConnection) + + +@pytest.mark.asyncio +async def test_connect_speech_config_remains_none_when_both_are_none( + gemini_llm, llm_request +): + """Tests that speech_config is None when neither Gemini nor the request has it.""" + # Arrange: Ensure both Gemini instance and request have no speech_config + gemini_llm.speech_config = None + llm_request.live_connect_config = ( + types.LiveConnectConfig() + ) # speech_config is None + + mock_live_session = mock.AsyncMock() + + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + # Act + async with gemini_llm.connect(llm_request) as connection: + # Assert + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify the final speech_config is still None + assert config_arg.speech_config is None + assert isinstance(connection, GeminiLlmConnection)