diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 4e39cb425..6a4f3df7b 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -788,6 +788,8 @@ def _parse_response_candidate( except (AttributeError, TypeError): thought_sig = None + has_function_call = hasattr(part, "function_call") and part.function_call + if hasattr(part, "thought") and part.thought: thinking_message = { "type": "thinking", @@ -797,7 +799,7 @@ def _parse_response_candidate( if thought_sig: thinking_message["signature"] = thought_sig content = _append_to_content(content, thinking_message) - elif text is not None and text: + elif text is not None and text.strip() and not has_function_call: # Check if this text Part has a signature attached if thought_sig: # Text with signature needs structured block to preserve signature @@ -929,15 +931,25 @@ def _parse_response_candidate( } function_call_signatures.append(sig_block) - # Add function call signatures to content only if there's already other content - # This preserves backward compatibility where content is "" for - # function-only responses - if function_call_signatures and content is not None: - for sig_block in function_call_signatures: - content = _append_to_content(content, sig_block) + # Add function call signatures to content only if there's already other content + # This preserves backward compatibility where content is "" for + # function-only responses + if function_call_signatures and content is not None: + for sig_block in function_call_signatures: + content = _append_to_content(content, sig_block) if content is None: content = "" + + if ( + hasattr(response_candidate, "logprobs_result") + and response_candidate.logprobs_result + ): + response_metadata["logprobs"] = MessageToDict( + response_candidate.logprobs_result._pb, + preserving_proto_field_name=True, + ) + if isinstance(content, list) and any( isinstance(item, dict) and "executable_code" in item for item in content ): @@ -1825,6 +1837,9 @@ class Joke(BaseModel): stop: Optional[List[str]] = None """Stop sequences for the model.""" + logprobs: Optional[int] = None + """The number of logprobs to return.""" + streaming: Optional[bool] = None """Whether to stream responses from the model.""" @@ -1963,6 +1978,7 @@ def _identifying_params(self) -> Dict[str, Any]: "media_resolution": self.media_resolution, "thinking_budget": self.thinking_budget, "include_thoughts": self.include_thoughts, + "logprobs": self.logprobs, } def invoke( @@ -2037,6 +2053,7 @@ def _prepare_params( "max_output_tokens": self.max_output_tokens, "top_k": self.top_k, "top_p": self.top_p, + "logprobs": getattr(self, "logprobs", None), "response_modalities": self.response_modalities, "thinking_config": ( ( @@ -2058,6 +2075,8 @@ def _prepare_params( }.items() if v is not None } + if getattr(self, "logprobs", None) is not None: + gen_config["response_logprobs"] = True if generation_config: gen_config = {**gen_config, **generation_config} diff --git a/libs/genai/tests/unit_tests/test_chat_models.py b/libs/genai/tests/unit_tests/test_chat_models.py index f8f82d057..c78100e54 100644 --- a/libs/genai/tests/unit_tests/test_chat_models.py +++ b/libs/genai/tests/unit_tests/test_chat_models.py @@ -141,24 +141,98 @@ def test_initialization_inside_threadpool() -> None: ).result() -def test_client_transport() -> None: +def test_logprobs() -> None: + """Test that logprobs parameter is set correctly and is in the response.""" + llm = ChatGoogleGenerativeAI( + model=MODEL_NAME, + google_api_key=SecretStr("secret-api-key"), + logprobs=10, + ) + assert llm.logprobs == 10 + + # Create proper mock response with logprobs_result + raw_response = { + "candidates": [ + { + "content": {"parts": [{"text": "Test response"}]}, + "finish_reason": 1, + "safety_ratings": [], + "logprobs_result": { + "top_candidates": [ + { + "candidates": [ + {"token": "Test", "log_probability": -0.1}, + ] + } + ] + }, + } + ], + "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "usage_metadata": { + "prompt_token_count": 5, + "candidates_token_count": 2, + "total_token_count": 7, + }, + } + response = GenerateContentResponse(raw_response) + + with patch( + "langchain_google_genai.chat_models._chat_with_retry" + ) as mock_chat_with_retry: + mock_chat_with_retry.return_value = response + llm = ChatGoogleGenerativeAI( + model=MODEL_NAME, + google_api_key="test-key", + logprobs=1, + ) + result = llm.invoke("test") + assert "logprobs" in result.response_metadata + assert result.response_metadata["logprobs"] == { + "top_candidates": [ + { + "candidates": [ + {"token": "Test", "log_probability": -0.1}, + ] + } + ] + } + + mock_chat_with_retry.assert_called_once() + request = mock_chat_with_retry.call_args.kwargs["request"] + assert request.generation_config.logprobs == 1 + assert request.generation_config.response_logprobs is True + + +@pytest.mark.enable_socket +@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient") +@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceClient") +def test_client_transport(mock_client: Mock, mock_async_client: Mock) -> None: """Test client transport configuration.""" + mock_client.return_value.transport = Mock() + mock_client.return_value.transport.kind = "grpc" model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key") assert model.client.transport.kind == "grpc" + mock_client.return_value.transport.kind = "rest" model = ChatGoogleGenerativeAI( model=MODEL_NAME, google_api_key="fake-key", transport="rest" ) assert model.client.transport.kind == "rest" async def check_async_client() -> None: + mock_async_client.return_value.transport = Mock() + mock_async_client.return_value.transport.kind = "grpc_asyncio" model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key") + _ = model.async_client assert model.async_client.transport.kind == "grpc_asyncio" # Test auto conversion of transport to "grpc_asyncio" from "rest" model = ChatGoogleGenerativeAI( model=MODEL_NAME, google_api_key="fake-key", transport="rest" ) + model.async_client_running = None + _ = model.async_client assert model.async_client.transport.kind == "grpc_asyncio" asyncio.run(check_async_client()) @@ -172,6 +246,7 @@ def test_initalization_without_async() -> None: assert chat.async_client is None +@pytest.mark.enable_socket def test_initialization_with_async() -> None: async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI: model = ChatGoogleGenerativeAI( @@ -1288,6 +1363,7 @@ def test_grounding_metadata_multiple_parts() -> None: assert grounding["grounding_supports"][0]["segment"]["part_index"] == 1 +@pytest.mark.enable_socket @pytest.mark.parametrize( "is_async,mock_target,method_name", [ @@ -1414,6 +1490,7 @@ def mock_stream() -> Iterator[GenerateContentResponse]: assert "timeout" not in call_kwargs +@pytest.mark.enable_socket @pytest.mark.parametrize( "is_async,mock_target,method_name", [