Skip to content
35 changes: 28 additions & 7 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,8 @@ def _parse_response_candidate(
except (AttributeError, TypeError):
thought_sig = None

has_function_call = hasattr(part, "function_call") and part.function_call

if hasattr(part, "thought") and part.thought:
thinking_message = {
"type": "thinking",
Expand All @@ -785,7 +787,7 @@ def _parse_response_candidate(
if thought_sig:
thinking_message["signature"] = thought_sig
content = _append_to_content(content, thinking_message)
elif text is not None and text:
elif text is not None and text.strip() and not has_function_call:
# Check if this text Part has a signature attached
if thought_sig:
# Text with signature needs structured block to preserve signature
Expand Down Expand Up @@ -917,15 +919,27 @@ def _parse_response_candidate(
}
function_call_signatures.append(sig_block)

# Add function call signatures to content only if there's already other content
# This preserves backward compatibility where content is "" for
# function-only responses
if function_call_signatures and content is not None:
for sig_block in function_call_signatures:
content = _append_to_content(content, sig_block)
# Add function call signatures to content only if there's already other content
# This preserves backward compatibility where content is "" for
# function-only responses
if function_call_signatures and content is not None:
for sig_block in function_call_signatures:
content = _append_to_content(content, sig_block)

if content is None:
content = ""

if (
hasattr(response_candidate, "logprobs_result")
and response_candidate.logprobs_result
):
# Note: logprobs is flaky, sometimes available, sometimes not
# https://discuss.ai.google.dev/t/logprobs-is-not-enabled-for-gemini-models/107989/15
response_metadata["logprobs"] = MessageToDict(
response_candidate.logprobs_result._pb,
preserving_proto_field_name=True,
)

if isinstance(content, list) and any(
isinstance(item, dict) and "executable_code" in item for item in content
):
Expand Down Expand Up @@ -1812,6 +1826,9 @@ class Joke(BaseModel):
stop: list[str] | None = None
"""Stop sequences for the model."""

logprobs: int | None = None
"""The number of logprobs to return."""

streaming: bool | None = None
"""Whether to stream responses from the model."""

Expand Down Expand Up @@ -1979,6 +1996,7 @@ def _identifying_params(self) -> dict[str, Any]:
"media_resolution": self.media_resolution,
"thinking_budget": self.thinking_budget,
"include_thoughts": self.include_thoughts,
"logprobs": self.logprobs,
}

def invoke(
Expand Down Expand Up @@ -2051,6 +2069,7 @@ def _prepare_params(
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"logprobs": getattr(self, "logprobs", None),
"response_modalities": self.response_modalities,
"thinking_config": (
(
Expand All @@ -2072,6 +2091,8 @@ def _prepare_params(
}.items()
if v is not None
}
if getattr(self, "logprobs", None) is not None:
gen_config["response_logprobs"] = True
if generation_config:
gen_config = {**gen_config, **generation_config}

Expand Down
79 changes: 78 additions & 1 deletion libs/genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,24 +140,98 @@ def test_initialization_inside_threadpool() -> None:
).result()


def test_client_transport() -> None:
def test_logprobs() -> None:
"""Test that logprobs parameter is set correctly and is in the response."""
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key=SecretStr("secret-api-key"),
logprobs=10,
)
assert llm.logprobs == 10

# Create proper mock response with logprobs_result
raw_response = {
"candidates": [
{
"content": {"parts": [{"text": "Test response"}]},
"finish_reason": 1,
"safety_ratings": [],
"logprobs_result": {
"top_candidates": [
{
"candidates": [
{"token": "Test", "log_probability": -0.1},
]
}
]
},
}
],
"prompt_feedback": {"block_reason": 0, "safety_ratings": []},
"usage_metadata": {
"prompt_token_count": 5,
"candidates_token_count": 2,
"total_token_count": 7,
},
}
response = GenerateContentResponse(raw_response)

with patch(
"langchain_google_genai.chat_models._chat_with_retry"
) as mock_chat_with_retry:
mock_chat_with_retry.return_value = response
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key="test-key",
logprobs=1,
)
result = llm.invoke("test")
assert "logprobs" in result.response_metadata
assert result.response_metadata["logprobs"] == {
"top_candidates": [
{
"candidates": [
{"token": "Test", "log_probability": -0.1},
]
}
]
}

mock_chat_with_retry.assert_called_once()
request = mock_chat_with_retry.call_args.kwargs["request"]
assert request.generation_config.logprobs == 1
assert request.generation_config.response_logprobs is True


@pytest.mark.enable_socket
@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient")
@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceClient")
def test_client_transport(mock_client: Mock, mock_async_client: Mock) -> None:
"""Test client transport configuration."""
mock_client.return_value.transport = Mock()
mock_client.return_value.transport.kind = "grpc"
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key")
assert model.client.transport.kind == "grpc"

mock_client.return_value.transport.kind = "rest"
model = ChatGoogleGenerativeAI(
model=MODEL_NAME, google_api_key="fake-key", transport="rest"
)
assert model.client.transport.kind == "rest"

async def check_async_client() -> None:
mock_async_client.return_value.transport = Mock()
mock_async_client.return_value.transport.kind = "grpc_asyncio"
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key")
_ = model.async_client
assert model.async_client.transport.kind == "grpc_asyncio"

# Test auto conversion of transport to "grpc_asyncio" from "rest"
model = ChatGoogleGenerativeAI(
model=MODEL_NAME, google_api_key="fake-key", transport="rest"
)
model.async_client_running = None
_ = model.async_client
assert model.async_client.transport.kind == "grpc_asyncio"

asyncio.run(check_async_client())
Expand All @@ -171,6 +245,7 @@ def test_initalization_without_async() -> None:
assert chat.async_client is None


@pytest.mark.enable_socket
def test_initialization_with_async() -> None:
async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI:
model = ChatGoogleGenerativeAI(
Expand Down Expand Up @@ -1536,6 +1611,7 @@ def test_grounding_metadata_multiple_parts() -> None:
assert grounding["grounding_supports"][0]["segment"]["part_index"] == 1


@pytest.mark.enable_socket
@pytest.mark.parametrize(
"is_async,mock_target,method_name",
[
Expand Down Expand Up @@ -1662,6 +1738,7 @@ def mock_stream() -> Iterator[GenerateContentResponse]:
assert "timeout" not in call_kwargs


@pytest.mark.enable_socket
@pytest.mark.parametrize(
"is_async,mock_target,method_name",
[
Expand Down