Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
RealtimeInput = Union[types.Blob, types.ActivityStart, types.ActivityEnd]
from typing import TYPE_CHECKING

PUNCTUATION_CHARS = {'.', '!', '?', ';', ':', "'"}

if TYPE_CHECKING:
from google.genai import live

Expand Down Expand Up @@ -181,13 +183,27 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
# generation_complete, causing transcription to appear after
# tool_call in the session log.
if message.server_content.input_transcription:
if message.server_content.input_transcription.text:
self._input_transcription_text += (
message.server_content.input_transcription.text
if (
new_input_transcription_chunk := message.server_content.input_transcription.text
):
existing = self._input_transcription_text
# Insert a space when joining fragments except when the new
# chunk starts with a punctuation character that should attach
# to the previous token, or the existing text ends with an
# apostrophe.
conditional_space = (
' '
if existing
and not (
new_input_transcription_chunk[0] in PUNCTUATION_CHARS
or existing.endswith("'")
)
else ''
)
self._input_transcription_text = f'{existing}{conditional_space}{new_input_transcription_chunk.strip()}'.strip()
Comment on lines +189 to +203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for stitching transcription chunks is duplicated for both input_transcription (here) and output_transcription (lines 226-240). To improve maintainability and adhere to the Don't Repeat Yourself (DRY) principle, this logic should be extracted into a private helper method.

For example, you could create a method like this:

def _stitch_transcription_chunk(self, existing_text: str, new_chunk: str) -> str:
    if not new_chunk:
        return existing_text

    # Insert a space when joining fragments except when the new
    # chunk starts with a punctuation character that should attach
    # to the previous token, or the existing text ends with an
    # apostrophe.
    conditional_space = (
        ' '
        if existing_text
        and not (
            new_chunk[0] in PUNCTUATION_CHARS
            or existing_text.endswith("'")
        )
        else ''
    )
    return f'{existing_text}{conditional_space}{new_chunk.strip()}'.strip()

Then you could call it like so:
self._input_transcription_text = self._stitch_transcription_chunk(self._input_transcription_text, new_input_transcription_chunk)

yield LlmResponse(
input_transcription=types.Transcription(
text=message.server_content.input_transcription.text,
text=new_input_transcription_chunk,
finished=False,
),
partial=True,
Expand All @@ -204,13 +220,27 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
self._input_transcription_text = ''
if message.server_content.output_transcription:
if message.server_content.output_transcription.text:
self._output_transcription_text += (
message.server_content.output_transcription.text
if (
new_output_transcription_chunk := message.server_content.output_transcription.text
):
existing = self._output_transcription_text
# Insert a space when joining fragments except when the new
# chunk starts with a punctuation character that should attach
# to the previous token, or the existing text ends with an
# apostrophe.
conditional_space = (
' '
if existing
and not (
new_output_transcription_chunk[0] in PUNCTUATION_CHARS
or existing.endswith("'")
)
else ''
)
self._output_transcription_text = f'{existing}{conditional_space}{new_output_transcription_chunk.strip()}'.strip()
yield LlmResponse(
output_transcription=types.Transcription(
text=message.server_content.output_transcription.text,
text=new_output_transcription_chunk,
finished=False,
),
partial=True,
Expand Down
106 changes: 106 additions & 0 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,109 @@ async def mock_receive_generator():
assert responses[2].output_transcription.text == 'How can I help?'
assert responses[2].output_transcription.finished is True
assert responses[2].partial is False


@pytest.mark.asyncio
@pytest.mark.parametrize('tx_direction', ['input', 'output'])
@pytest.mark.parametrize(
'fragments',
[
('That', "'s great", "That's great"),
("That'", 's great', "That's great"),
("That's", 'great', "That's great"),
("That's", ' great', "That's great"),
("That's ", 'great', "That's great"),
("Great", '! Good to hear', 'Great! Good to hear'),
("Great!", 'Good to hear', 'Great! Good to hear'),
("Great! ", 'Good to hear', 'Great! Good to hear'),
("Great! Good", 'to hear', 'Great! Good to hear'),
("Great! Good ", 'to hear', 'Great! Good to hear'),
("Great! Good", ' to hear', 'Great! Good to hear'),
],
)
async def test_receive_final_transcription_space_between_fragments(
gemini_connection, mock_gemini_session, tx_direction, fragments
):
"""Test receive final transcription fragments are joined with a space between words."""
fragment1, fragment2, expected = fragments

message1 = mock.Mock()
message1.usage_metadata = None
message1.server_content = mock.Mock()
message1.server_content.model_turn = None
message1.server_content.interrupted = False
message1.server_content.turn_complete = False
message1.server_content.generation_complete = False
message1.tool_call = None
message1.session_resumption_update = None
message1.server_content.input_transcription = (
types.Transcription(text=fragment1, finished=False)
if tx_direction == 'input'
else None
)
message1.server_content.output_transcription = (
types.Transcription(text=fragment1, finished=False)
if tx_direction == 'output'
else None
)

message2 = mock.Mock()
message2.usage_metadata = None
message2.server_content = mock.Mock()
message2.server_content.model_turn = None
message2.server_content.interrupted = False
message2.server_content.turn_complete = False
message2.server_content.generation_complete = False
message2.tool_call = None
message2.session_resumption_update = None
message2.server_content.input_transcription = (
types.Transcription(text=fragment2, finished=False)
if tx_direction == 'input'
else None
)
message2.server_content.output_transcription = (
types.Transcription(text=fragment2, finished=False)
if tx_direction == 'output'
else None
)

message3 = mock.Mock()
message3.usage_metadata = None
message3.server_content = mock.Mock()
message3.server_content.model_turn = None
message3.server_content.interrupted = False
message3.server_content.turn_complete = False
message3.server_content.generation_complete = False
message3.tool_call = None
message3.session_resumption_update = None
message3.server_content.input_transcription = (
types.Transcription(text=None, finished=True)
if tx_direction == 'input'
else None
)
message3.server_content.output_transcription = (
types.Transcription(text=None, finished=True)
if tx_direction == 'output'
else None
)
Comment on lines +622 to +680
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The setup for message1, message2, and message3 is quite verbose and contains a lot of repeated code. To improve readability and maintainability, consider extracting the message creation into a helper function.

For example, you could define a helper within the test file:

def _create_mock_transcription_message(text: str | None, finished: bool, direction: str) -> mock.Mock:
    msg = mock.Mock()
    msg.usage_metadata = None
    msg.server_content = mock.Mock()
    msg.server_content.model_turn = None
    msg.server_content.interrupted = False
    msg.server_content.turn_complete = False
    msg.server_content.generation_complete = False
    msg.tool_call = None
    msg.session_resumption_update = None

    transcription = types.Transcription(text=text, finished=finished)
    if direction == 'input':
        msg.server_content.input_transcription = transcription
        msg.server_content.output_transcription = None
    else:
        msg.server_content.input_transcription = None
        msg.server_content.output_transcription = transcription
    return msg

Then, you could simplify the test setup significantly:

message1 = _create_mock_transcription_message(fragment1, False, tx_direction)
message2 = _create_mock_transcription_message(fragment2, False, tx_direction)
message3 = _create_mock_transcription_message(None, True, tx_direction)


async def mock_receive_generator():
yield message1
yield message2
yield message3

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

# find the finished transcription response
attr_name = f'{tx_direction}_transcription'
finished_resps = [
r
for r in responses
if getattr(r, attr_name) and getattr(r, attr_name).finished
]
assert finished_resps, 'Expected finished transcription response'
transcription = getattr(finished_resps[0], attr_name)
assert transcription.text == expected
Loading