diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 55d4b62e96..da0b3cc895 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -30,6 +30,8 @@ RealtimeInput = Union[types.Blob, types.ActivityStart, types.ActivityEnd] from typing import TYPE_CHECKING +PUNCTUATION_CHARS = {'.', '!', '?', ';', ':', "'"} + if TYPE_CHECKING: from google.genai import live @@ -181,13 +183,27 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # generation_complete, causing transcription to appear after # tool_call in the session log. if message.server_content.input_transcription: - if message.server_content.input_transcription.text: - self._input_transcription_text += ( - message.server_content.input_transcription.text + if ( + new_input_transcription_chunk := message.server_content.input_transcription.text + ): + existing = self._input_transcription_text + # Insert a space when joining fragments except when the new + # chunk starts with a punctuation character that should attach + # to the previous token, or the existing text ends with an + # apostrophe. + conditional_space = ( + ' ' + if existing + and not ( + new_input_transcription_chunk[0] in PUNCTUATION_CHARS + or existing.endswith("'") + ) + else '' ) + self._input_transcription_text = f'{existing}{conditional_space}{new_input_transcription_chunk.strip()}'.strip() yield LlmResponse( input_transcription=types.Transcription( - text=message.server_content.input_transcription.text, + text=new_input_transcription_chunk, finished=False, ), partial=True, @@ -204,13 +220,27 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) self._input_transcription_text = '' if message.server_content.output_transcription: - if message.server_content.output_transcription.text: - self._output_transcription_text += ( - message.server_content.output_transcription.text + if ( + new_output_transcription_chunk := message.server_content.output_transcription.text + ): + existing = self._output_transcription_text + # Insert a space when joining fragments except when the new + # chunk starts with a punctuation character that should attach + # to the previous token, or the existing text ends with an + # apostrophe. + conditional_space = ( + ' ' + if existing + and not ( + new_output_transcription_chunk[0] in PUNCTUATION_CHARS + or existing.endswith("'") + ) + else '' ) + self._output_transcription_text = f'{existing}{conditional_space}{new_output_transcription_chunk.strip()}'.strip() yield LlmResponse( output_transcription=types.Transcription( - text=message.server_content.output_transcription.text, + text=new_output_transcription_chunk, finished=False, ), partial=True, diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 190007603c..1a6460b3a5 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -593,3 +593,109 @@ async def mock_receive_generator(): assert responses[2].output_transcription.text == 'How can I help?' assert responses[2].output_transcription.finished is True assert responses[2].partial is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize('tx_direction', ['input', 'output']) +@pytest.mark.parametrize( + 'fragments', + [ + ('That', "'s great", "That's great"), + ("That'", 's great', "That's great"), + ("That's", 'great', "That's great"), + ("That's", ' great', "That's great"), + ("That's ", 'great', "That's great"), + ("Great", '! Good to hear', 'Great! Good to hear'), + ("Great!", 'Good to hear', 'Great! Good to hear'), + ("Great! ", 'Good to hear', 'Great! Good to hear'), + ("Great! Good", 'to hear', 'Great! Good to hear'), + ("Great! Good ", 'to hear', 'Great! Good to hear'), + ("Great! Good", ' to hear', 'Great! Good to hear'), + ], +) +async def test_receive_final_transcription_space_between_fragments( + gemini_connection, mock_gemini_session, tx_direction, fragments +): + """Test receive final transcription fragments are joined with a space between words.""" + fragment1, fragment2, expected = fragments + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = None + message1.server_content.interrupted = False + message1.server_content.turn_complete = False + message1.server_content.generation_complete = False + message1.tool_call = None + message1.session_resumption_update = None + message1.server_content.input_transcription = ( + types.Transcription(text=fragment1, finished=False) + if tx_direction == 'input' + else None + ) + message1.server_content.output_transcription = ( + types.Transcription(text=fragment1, finished=False) + if tx_direction == 'output' + else None + ) + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = None + message2.server_content.interrupted = False + message2.server_content.turn_complete = False + message2.server_content.generation_complete = False + message2.tool_call = None + message2.session_resumption_update = None + message2.server_content.input_transcription = ( + types.Transcription(text=fragment2, finished=False) + if tx_direction == 'input' + else None + ) + message2.server_content.output_transcription = ( + types.Transcription(text=fragment2, finished=False) + if tx_direction == 'output' + else None + ) + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None + message3.server_content.interrupted = False + message3.server_content.turn_complete = False + message3.server_content.generation_complete = False + message3.tool_call = None + message3.session_resumption_update = None + message3.server_content.input_transcription = ( + types.Transcription(text=None, finished=True) + if tx_direction == 'input' + else None + ) + message3.server_content.output_transcription = ( + types.Transcription(text=None, finished=True) + if tx_direction == 'output' + else None + ) + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # find the finished transcription response + attr_name = f'{tx_direction}_transcription' + finished_resps = [ + r + for r in responses + if getattr(r, attr_name) and getattr(r, attr_name).finished + ] + assert finished_resps, 'Expected finished transcription response' + transcription = getattr(finished_resps[0], attr_name) + assert transcription.text == expected