diff --git a/libs/genai/langchain_google_genai/_common.py b/libs/genai/langchain_google_genai/_common.py index 59a253d66..423b033f3 100644 --- a/libs/genai/langchain_google_genai/_common.py +++ b/libs/genai/langchain_google_genai/_common.py @@ -1,12 +1,17 @@ import os from importlib import metadata -from typing import Any, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Dict, List, Optional, Tuple from google.api_core.gapic_v1.client_info import ClientInfo from langchain_core.utils import secret_from_env from pydantic import BaseModel, Field, SecretStr -from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory, Modality +from langchain_google_genai._enums import ( + HarmBlockThreshold, + HarmCategory, + Modality, + SafetySetting, +) _TELEMETRY_TAG = "remote_reasoning_engine" _TELEMETRY_ENV_VARIABLE_NAME = "GOOGLE_CLOUD_AGENT_ENGINE_ID" @@ -38,6 +43,8 @@ class _BaseGoogleGenerativeAI(BaseModel): "The default custom credentials (google.auth.credentials.Credentials) to use " "when making API calls. If not provided, credentials will be ascertained from " "the GOOGLE_API_KEY envvar" + base_url: Optional[str] = None + """The base URL for the API. If not provided, will default to the public API.""" temperature: float = 0.7 """Run inference with this temperature. Must be within ``[0.0, 2.0]``.""" top_p: Optional[float] = None @@ -158,6 +165,4 @@ def get_client_info(module: Optional[str] = None) -> "ClientInfo": ) -class SafetySettingDict(TypedDict): - category: HarmCategory - threshold: HarmBlockThreshold +SafetySettingDict = SafetySetting diff --git a/libs/genai/langchain_google_genai/_enums.py b/libs/genai/langchain_google_genai/_enums.py index 2c2bd12cf..aa92abebc 100644 --- a/libs/genai/langchain_google_genai/_enums.py +++ b/libs/genai/langchain_google_genai/_enums.py @@ -1,7 +1,17 @@ -import google.ai.generativelanguage_v1beta as genai +from google.genai.types import ( + BlockedReason, + HarmBlockThreshold, + HarmCategory, + MediaModality, + Modality, + SafetySetting, +) -HarmBlockThreshold = genai.SafetySetting.HarmBlockThreshold -HarmCategory = genai.HarmCategory -Modality = genai.GenerationConfig.Modality +HarmCategory = HarmCategory +MediaModality = MediaModality +SafetySetting = SafetySetting +HarmBlockThreshold = HarmBlockThreshold +BlockedReason = BlockedReason -__all__ = ["HarmBlockThreshold", "HarmCategory", "Modality"] + +__all__ = ["SafetySetting", "HarmCategory", "Modality", "BlockedReason"] diff --git a/libs/genai/langchain_google_genai/_function_utils.py b/libs/genai/langchain_google_genai/_function_utils.py index 80ec30717..8f0867254 100644 --- a/libs/genai/langchain_google_genai/_function_utils.py +++ b/libs/genai/langchain_google_genai/_function_utils.py @@ -2,7 +2,6 @@ import collections import importlib -import json import logging from typing import ( Any, @@ -18,9 +17,7 @@ cast, ) -import google.ai.generativelanguage as glm -import google.ai.generativelanguage_v1beta.types as gapic -import proto # type: ignore[import] +from google.genai import types from langchain_core.tools import BaseTool from langchain_core.tools import tool as callable_as_lc_tool from langchain_core.utils.function_calling import ( @@ -36,38 +33,51 @@ TYPE_ENUM = { - "string": glm.Type.STRING, - "number": glm.Type.NUMBER, - "integer": glm.Type.INTEGER, - "boolean": glm.Type.BOOLEAN, - "array": glm.Type.ARRAY, - "object": glm.Type.OBJECT, + "string": types.Type.STRING, + "number": types.Type.NUMBER, + "integer": types.Type.INTEGER, + "boolean": types.Type.BOOLEAN, + "array": types.Type.ARRAY, + "object": types.Type.OBJECT, "null": None, } -_ALLOWED_SCHEMA_FIELDS = [] -_ALLOWED_SCHEMA_FIELDS.extend([f.name for f in gapic.Schema()._pb.DESCRIPTOR.fields]) -_ALLOWED_SCHEMA_FIELDS.extend( - [ - f - for f in gapic.Schema.to_dict( - gapic.Schema(), preserving_proto_field_name=False - ).keys() - ] -) +# Note: For google.genai, we'll use a simplified approach for allowed schema fields +# since the new library doesn't expose protobuf fields in the same way +_ALLOWED_SCHEMA_FIELDS = [ + "type", + "type_", + "description", + "enum", + "format", + "items", + "properties", + "required", + "nullable", + "anyOf", + "default", + "minimum", + "maximum", + "minLength", + "maxLength", + "pattern", + "minItems", + "maxItems", + "title", +] _ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS) # Info: This is a FunctionDeclaration(=fc). _FunctionDeclarationLike = Union[ - BaseTool, Type[BaseModel], gapic.FunctionDeclaration, Callable, Dict[str, Any] + BaseTool, Type[BaseModel], types.FunctionDeclaration, Callable, Dict[str, Any] ] _GoogleSearchRetrievalLike = Union[ - gapic.GoogleSearchRetrieval, + types.GoogleSearchRetrieval, Dict[str, Any], ] -_GoogleSearchLike = Union[gapic.Tool.GoogleSearch, Dict[str, Any]] -_CodeExecutionLike = Union[gapic.CodeExecution, Dict[str, Any]] +_GoogleSearchLike = Union[types.GoogleSearch, Dict[str, Any]] +_CodeExecutionLike = Union[types.ToolCodeExecution, Dict[str, Any]] class _ToolDict(TypedDict): @@ -78,9 +88,9 @@ class _ToolDict(TypedDict): # Info: This means one tool=Sequence of FunctionDeclaration -# The dict should be gapic.Tool like. {"function_declarations": [ { "name": ...}. +# The dict should be Tool like. {"function_declarations": [ { "name": ...}. # OpenAI like dict is not be accepted. {{'type': 'function', 'function': {'name': ...} -_ToolType = Union[gapic.Tool, _ToolDict, _FunctionDeclarationLike] +_ToolType = Union[types.Tool, _ToolDict, _FunctionDeclarationLike] _ToolsType = Sequence[_ToolType] @@ -90,7 +100,8 @@ def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]: if key == "definitions": continue elif key == "items": - converted_schema["items"] = _format_json_schema_to_gapic(value) + if value is not None: + converted_schema["items"] = _format_json_schema_to_gapic(value) elif key == "properties": converted_schema["properties"] = _get_properties_from_schema(value) continue @@ -101,8 +112,13 @@ def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]: f"Got {len(value)}, ignoring other than first value!" ) return _format_json_schema_to_gapic(value[0]) - elif key in ["type", "_type"]: - converted_schema["type"] = str(value).upper() + elif key in ["type", "type_"]: + if isinstance(value, dict): + converted_schema["type"] = value["_value_"] + elif isinstance(value, str): + converted_schema["type"] = value + else: + raise ValueError(f"Invalid type: {value}") elif key not in _ALLOWED_SCHEMA_FIELDS_SET: logger.warning(f"Key '{key}' is not supported in schema, ignoring") else: @@ -110,52 +126,88 @@ def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]: return converted_schema -def _dict_to_gapic_schema(schema: Dict[str, Any]) -> Optional[gapic.Schema]: +def _dict_to_genai_schema(schema: Dict[str, Any]) -> Optional[types.Schema]: if schema: dereferenced_schema = dereference_refs(schema) formatted_schema = _format_json_schema_to_gapic(dereferenced_schema) - json_schema = json.dumps(formatted_schema) - return gapic.Schema.from_json(json_schema) + + # Convert the formatted schema to google.genai.types.Schema + schema_dict = {} + if "type" in formatted_schema: + type_value = "STRING" + type_obj = formatted_schema["type"] + if isinstance(type_obj, dict): + type_value = type_obj["_value_"] + elif isinstance(type_obj, str): + type_value = type_obj + else: + raise ValueError(f"Invalid type: {type_obj}") + schema_dict["type"] = types.Type(type_value) + if "description" in formatted_schema: + schema_dict["description"] = formatted_schema["description"] + if "title" in formatted_schema: + schema_dict["title"] = formatted_schema["title"] + if "properties" in formatted_schema: + schema_dict["properties"] = formatted_schema["properties"] + # Always set required to empty list if not present + schema_dict["required"] = formatted_schema.get("required", []) + if "items" in formatted_schema: + schema_dict["items"] = formatted_schema["items"] + if "enum" in formatted_schema: + schema_dict["enum"] = formatted_schema["enum"] + if "nullable" in formatted_schema: + schema_dict["nullable"] = formatted_schema["nullable"] + + return types.Schema.model_validate(schema_dict) return None def _format_dict_to_function_declaration( tool: Union[FunctionDescription, Dict[str, Any]], -) -> gapic.FunctionDeclaration: - return gapic.FunctionDeclaration( - name=tool.get("name") or tool.get("title"), - description=tool.get("description"), - parameters=_dict_to_gapic_schema(tool.get("parameters", {})), +) -> types.FunctionDeclaration: + name = tool.get("name") or tool.get("title") or "MISSING_NAME" + description = tool.get("description") or None + parameters = _dict_to_genai_schema(tool.get("parameters", {})) + return types.FunctionDeclaration( + name=str(name), + description=description, + parameters=parameters, ) -# Info: gapic.Tool means function_declarations and proto.Message. +# Info: Tool means function_declarations and other tool types. def convert_to_genai_function_declarations( tools: _ToolsType, -) -> gapic.Tool: +) -> types.Tool: if not isinstance(tools, collections.abc.Sequence): logger.warning( "convert_to_genai_function_declarations expects a Sequence " "and not a single tool." ) tools = [tools] - gapic_tool = gapic.Tool() + + tool_dict: Dict[str, Any] = {} + function_declarations: List[types.FunctionDeclaration] = [] + for tool in tools: - if any(f in gapic_tool for f in ["google_search_retrieval"]): - raise ValueError( - "Providing multiple google_search_retrieval" - " or mixing with function_declarations is not supported" - ) - if isinstance(tool, (gapic.Tool)): - rt: gapic.Tool = ( - tool if isinstance(tool, gapic.Tool) else tool._raw_tool # type: ignore - ) - if "google_search_retrieval" in rt: - gapic_tool.google_search_retrieval = rt.google_search_retrieval - if "function_declarations" in rt: - gapic_tool.function_declarations.extend(rt.function_declarations) - if "google_search" in rt: - gapic_tool.google_search = rt.google_search + if isinstance(tool, types.Tool): + # Handle existing Tool objects + if hasattr(tool, "function_declarations") and tool.function_declarations: + function_declarations.extend(tool.function_declarations) + if ( + hasattr(tool, "google_search_retrieval") + and tool.google_search_retrieval + ): + if "google_search_retrieval" in tool_dict: + raise ValueError( + "Providing multiple google_search_retrieval" + " or mixing with function_declarations is not supported" + ) + tool_dict["google_search_retrieval"] = tool.google_search_retrieval + if hasattr(tool, "google_search") and tool.google_search: + tool_dict["google_search"] = tool.google_search + if hasattr(tool, "code_execution") and tool.code_execution: + tool_dict["code_execution"] = tool.code_execution elif isinstance(tool, dict): # not _ToolDictLike if not any( @@ -167,58 +219,87 @@ def convert_to_genai_function_declarations( "code_execution", ] ): - fd = _format_to_gapic_function_declaration(tool) # type: ignore[arg-type] - gapic_tool.function_declarations.append(fd) + fd = _format_to_genai_function_declaration(tool) # type: ignore[arg-type] + function_declarations.append(fd) continue # _ToolDictLike tool = cast(_ToolDict, tool) if "function_declarations" in tool: - function_declarations = tool["function_declarations"] - if not isinstance( + tool_function_declarations = tool["function_declarations"] + if tool_function_declarations is not None and not isinstance( tool["function_declarations"], collections.abc.Sequence ): raise ValueError( "function_declarations should be a list" - f"got '{type(function_declarations)}'" + f"got '{type(tool_function_declarations)}'" ) - if function_declarations: + if tool_function_declarations: fds = [ - _format_to_gapic_function_declaration(fd) - for fd in function_declarations + _format_to_genai_function_declaration(fd) + for fd in tool_function_declarations ] - gapic_tool.function_declarations.extend(fds) + function_declarations.extend(fds) if "google_search_retrieval" in tool: - gapic_tool.google_search_retrieval = gapic.GoogleSearchRetrieval( - tool["google_search_retrieval"] - ) + if "google_search_retrieval" in tool_dict: + raise ValueError( + "Providing multiple google_search_retrieval" + " or mixing with function_declarations is not supported" + ) + if isinstance(tool["google_search_retrieval"], dict): + tool_dict["google_search_retrieval"] = types.GoogleSearchRetrieval( + **tool["google_search_retrieval"] + ) + else: + tool_dict["google_search_retrieval"] = tool[ + "google_search_retrieval" + ] if "google_search" in tool: - gapic_tool.google_search = gapic.Tool.GoogleSearch( - tool["google_search"] - ) + if isinstance(tool["google_search"], dict): + tool_dict["google_search"] = types.GoogleSearch( + **tool["google_search"] + ) + else: + tool_dict["google_search"] = tool["google_search"] if "code_execution" in tool: - gapic_tool.code_execution = gapic.CodeExecution(tool["code_execution"]) + if isinstance(tool["code_execution"], dict): + tool_dict["code_execution"] = types.ToolCodeExecution( + **tool["code_execution"] + ) + else: + tool_dict["code_execution"] = tool["code_execution"] else: - fd = _format_to_gapic_function_declaration(tool) # type: ignore[arg-type] - gapic_tool.function_declarations.append(fd) - return gapic_tool + fd = _format_to_genai_function_declaration(tool) # type: ignore[arg-type] + function_declarations.append(fd) + + if function_declarations: + tool_dict["function_declarations"] = function_declarations + return types.Tool(**tool_dict) -def tool_to_dict(tool: gapic.Tool) -> _ToolDict: + +def tool_to_dict(tool: types.Tool) -> _ToolDict: def _traverse_values(raw: Any) -> Any: if isinstance(raw, list): return [_traverse_values(v) for v in raw] if isinstance(raw, dict): - return {k: _traverse_values(v) for k, v in raw.items()} - if isinstance(raw, proto.Message): - return _traverse_values(type(raw).to_dict(raw)) + processed = {k: _traverse_values(v) for k, v in raw.items()} + return processed + if hasattr(raw, "__dict__"): + return _traverse_values(raw.__dict__) return raw - return _traverse_values(type(tool).to_dict(tool)) + if hasattr(tool, "model_dump"): + raw_result = tool.model_dump() + else: + raw_result = tool.__dict__ + + result = _traverse_values(raw_result) + return result -def _format_to_gapic_function_declaration( +def _format_to_genai_function_declaration( tool: _FunctionDeclarationLike, -) -> gapic.FunctionDeclaration: +) -> types.FunctionDeclaration: if isinstance(tool, BaseTool): return _format_base_tool_to_function_declaration(tool) elif isinstance(tool, type) and is_basemodel_subclass_safe(tool): @@ -249,15 +330,15 @@ def _format_to_gapic_function_declaration( def _format_base_tool_to_function_declaration( tool: BaseTool, -) -> gapic.FunctionDeclaration: +) -> types.FunctionDeclaration: if not tool.args_schema: - return gapic.FunctionDeclaration( + return types.FunctionDeclaration( name=tool.name, description=tool.description, - parameters=gapic.Schema( - type=gapic.Type.OBJECT, + parameters=types.Schema( + type=types.Type.OBJECT, properties={ - "__arg1": gapic.Schema(type=gapic.Type.STRING), + "__arg1": types.Schema(type=types.Type.STRING), }, required=["__arg1"], ), @@ -274,9 +355,9 @@ def _format_base_tool_to_function_declaration( "args_schema must be a Pydantic BaseModel or JSON schema, " f"got {tool.args_schema}." ) - parameters = _dict_to_gapic_schema(schema) + parameters = _dict_to_genai_schema(schema) - return gapic.FunctionDeclaration( + return types.FunctionDeclaration( name=tool.name or schema.get("title"), description=tool.description or schema.get("description"), parameters=parameters, @@ -287,7 +368,7 @@ def _convert_pydantic_to_genai_function( pydantic_model: Type[BaseModel], tool_name: Optional[str] = None, tool_description: Optional[str] = None, -) -> gapic.FunctionDeclaration: +) -> types.FunctionDeclaration: if issubclass(pydantic_model, BaseModel): schema = pydantic_model.model_json_schema() elif issubclass(pydantic_model, BaseModelV1): @@ -298,19 +379,18 @@ def _convert_pydantic_to_genai_function( ) schema = dereference_refs(schema) schema.pop("definitions", None) - function_declaration = gapic.FunctionDeclaration( + + # Convert to google.genai Schema format + parameters_dict = { + "type": TYPE_ENUM[schema["type"]], + "properties": _get_properties_from_schema_any(schema.get("properties")), + "required": schema.get("required", []), + } + + function_declaration = types.FunctionDeclaration( name=tool_name if tool_name else schema.get("title"), description=tool_description if tool_description else schema.get("description"), - parameters={ - "properties": _get_properties_from_schema_any( - schema.get("properties") - ), # TODO: use _dict_to_gapic_schema() if possible - # "items": _get_items_from_schema_any( - # schema - # ), # TODO: fix it https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/function-calling?hl#schema - "required": schema.get("required", []), - "type_": TYPE_ENUM[schema["type"]], - }, + parameters=types.Schema(**parameters_dict), ) return function_declaration @@ -340,14 +420,14 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]: ] elif v.get("type") or v.get("anyOf") or v.get("type_"): item_type_ = _get_type_from_schema(v) - properties_item["type_"] = item_type_ + properties_item["type"] = item_type_ if _is_nullable_schema(v): properties_item["nullable"] = True # Replace `v` with chosen definition for array / object json types any_of_types = v.get("anyOf") - if any_of_types and item_type_ in [glm.Type.ARRAY, glm.Type.OBJECT]: - json_type_ = "array" if item_type_ == glm.Type.ARRAY else "object" + if any_of_types and item_type_ in [types.Type.ARRAY, types.Type.OBJECT]: + json_type_ = "array" if item_type_ == types.Type.ARRAY else "object" # Use Index -1 for consistency with `_get_nullable_type_from_schema` v = [val for val in any_of_types if val.get("type") == json_type_][-1] @@ -358,10 +438,10 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]: if description and isinstance(description, str): properties_item["description"] = description - if properties_item.get("type_") == glm.Type.ARRAY and v.get("items"): + if properties_item.get("type") == types.Type.ARRAY and v.get("items"): properties_item["items"] = _get_items_from_schema_any(v.get("items")) - if properties_item.get("type_") == glm.Type.OBJECT: + if properties_item.get("type") == types.Type.OBJECT: if ( v.get("anyOf") and isinstance(v["anyOf"], list) @@ -379,7 +459,7 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]: ] else: # Providing dummy type for object without properties - properties_item["type_"] = glm.Type.STRING + properties_item["type"] = types.Type.STRING if k == "title" and "description" not in properties_item: properties_item["description"] = k + " is " + str(v) @@ -401,65 +481,75 @@ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]: for i, v in enumerate(schema): items[f"item{i}"] = _get_properties_from_schema_any(v) elif isinstance(schema, Dict): - items["type_"] = _get_type_from_schema(schema) - if items["type_"] == glm.Type.OBJECT and "properties" in schema: + items["type"] = _get_type_from_schema(schema) + if items["type"] == types.Type.OBJECT and "properties" in schema: items["properties"] = _get_properties_from_schema_any(schema["properties"]) - if items["type_"] == glm.Type.ARRAY and "items" in schema: + if items["type"] == types.Type.ARRAY and "items" in schema: items["items"] = _format_json_schema_to_gapic(schema["items"]) if "title" in schema or "description" in schema: - items["description"] = ( - schema.get("description") or schema.get("title") or "" - ) + items["description"] = schema.get("description") or schema.get("title") if _is_nullable_schema(schema): items["nullable"] = True if "required" in schema: items["required"] = schema["required"] else: # str - items["type_"] = _get_type_from_schema({"type": schema}) + items["type"] = _get_type_from_schema({"type": schema}) if _is_nullable_schema({"type": schema}): items["nullable"] = True return items -def _get_type_from_schema(schema: Dict[str, Any]) -> int: - return _get_nullable_type_from_schema(schema) or glm.Type.STRING +def _get_type_from_schema(schema: Dict[str, Any]) -> types.Type: + type_ = _get_nullable_type_from_schema(schema) + return type_ if type_ is not None else types.Type.STRING -def _get_nullable_type_from_schema(schema: Dict[str, Any]) -> Optional[int]: +def _get_nullable_type_from_schema(schema: Dict[str, Any]) -> Optional[types.Type]: if "anyOf" in schema: - types = [ + schema_types = [ _get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"] ] - types = [t for t in types if t is not None] # Remove None values - if types: - return types[-1] # TODO: update FunctionDeclaration and pass all types? + schema_types = [t for t in schema_types if t is not None] # Remove None values + # TODO: update FunctionDeclaration and pass all types? + if schema_types: + return schema_types[-1] else: pass elif "type" in schema or "type_" in schema: type_ = schema["type"] if "type" in schema else schema["type_"] - if isinstance(type_, int): + if isinstance(type_, types.Type): return type_ - stype = str(schema["type"]) if "type" in schema else str(schema["type_"]) - return TYPE_ENUM.get(stype, glm.Type.STRING) + elif isinstance(type_, int): + raise ValueError(f"Invalid type, int not supported: {type_}") + elif isinstance(type_, dict): + return types.Type(type_["_value_"]) + elif isinstance(type_, str): + if type_ == "null": + return None + return types.Type(type_) + else: + return None else: pass - return glm.Type.STRING # Default to string if no valid types found + return None # Default to string if no valid types found def _is_nullable_schema(schema: Dict[str, Any]) -> bool: if "anyOf" in schema: - types = [ + schema_types = [ _get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"] ] - return any(t is None for t in types) + return any(t is None for t in schema_types) elif "type" in schema or "type_" in schema: type_ = schema["type"] if "type" in schema else schema["type_"] - if isinstance(type_, int): + if isinstance(type_, types.Type): return False - stype = str(schema["type"]) if "type" in schema else str(schema["type_"]) - return TYPE_ENUM.get(stype, glm.Type.STRING) is None + elif isinstance(type_, int): + # Handle integer type values (from tool_to_dict serialization) + # Integer types are never null (except for NULL type handled separately) + return type_ == 7 # 7 corresponds to NULL type else: pass return False @@ -470,19 +560,10 @@ def _is_nullable_schema(schema: Dict[str, Any]) -> bool: ] -class _FunctionCallingConfigDict(TypedDict): - mode: Union[gapic.FunctionCallingConfig.Mode, str] - allowed_function_names: Optional[List[str]] - - -class _ToolConfigDict(TypedDict): - function_calling_config: _FunctionCallingConfigDict - - def _tool_choice_to_tool_config( tool_choice: _ToolChoiceType, all_names: List[str], -) -> _ToolConfigDict: +) -> types.ToolConfig: allowed_function_names: Optional[List[str]] = None if tool_choice is True or tool_choice == "any": mode = "ANY" @@ -513,11 +594,11 @@ def _tool_choice_to_tool_config( ) else: raise ValueError(f"Unrecognized tool choice format:\n\n{tool_choice=}") - return _ToolConfigDict( - function_calling_config={ - "mode": mode.upper(), - "allowed_function_names": allowed_function_names, - } + return types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode(mode), + allowed_function_names=allowed_function_names, + ) ) @@ -597,3 +678,7 @@ def _get_def_key_from_schema_path(schema_path: str) -> str: raise ValueError(error_message) return parts[-1] + + +# Backward compatibility alias +_dict_to_gapic_schema = _dict_to_genai_schema diff --git a/libs/genai/langchain_google_genai/_genai_extension.py b/libs/genai/langchain_google_genai/_genai_extension.py index 07d51ba4e..4daea19e4 100644 --- a/libs/genai/langchain_google_genai/_genai_extension.py +++ b/libs/genai/langchain_google_genai/_genai_extension.py @@ -10,7 +10,6 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator, List, MutableSequence, Optional -import google.ai.generativelanguage as genai import langchain_core from google.ai.generativelanguage_v1beta import ( GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient, @@ -18,6 +17,13 @@ from google.ai.generativelanguage_v1beta import ( GenerativeServiceClient as v1betaGenerativeServiceClient, ) +from google.ai.generativelanguage_v1beta import types as old_genai +from google.ai.generativelanguage_v1beta.services.generative_service import ( + GenerativeServiceClient, +) +from google.ai.generativelanguage_v1beta.services.retriever_service import ( + RetrieverServiceClient, +) from google.api_core import client_options as client_options_lib from google.api_core import exceptions as gapi_exception from google.api_core import gapic_v1 @@ -91,7 +97,7 @@ def corpus_id(self) -> str: return name.corpus_id @classmethod - def from_corpus(cls, c: genai.Corpus) -> "Corpus": + def from_corpus(cls, c: old_genai.Corpus) -> "Corpus": return cls( name=c.name, display_name=c.display_name, @@ -106,7 +112,7 @@ class Document: display_name: Optional[str] create_time: Optional[timestamp_pb2.Timestamp] update_time: Optional[timestamp_pb2.Timestamp] - custom_metadata: Optional[MutableSequence[genai.CustomMetadata]] + custom_metadata: Optional[MutableSequence[old_genai.CustomMetadata]] @property def corpus_id(self) -> str: @@ -120,7 +126,7 @@ def document_id(self) -> str: return name.document_id @classmethod - def from_document(cls, d: genai.Document) -> "Document": + def from_document(cls, d: old_genai.Document) -> "Document": return cls( name=d.name, display_name=d.display_name, @@ -220,9 +226,9 @@ def _get_credentials() -> Optional[credentials.Credentials]: return None -def build_semantic_retriever() -> genai.RetrieverServiceClient: +def build_semantic_retriever() -> RetrieverServiceClient: credentials = _get_credentials() - return genai.RetrieverServiceClient( + return RetrieverServiceClient( credentials=credentials, client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT), client_options=client_options_lib.ClientOptions( @@ -295,10 +301,10 @@ def build_generative_async_service( def list_corpora( *, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Iterator[Corpus]: for corpus in client.list_corpora( - genai.ListCorporaRequest(page_size=_config.page_size) + old_genai.ListCorporaRequest(page_size=_config.page_size) ): yield Corpus.from_corpus(corpus) @@ -306,11 +312,11 @@ def list_corpora( def get_corpus( *, corpus_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Optional[Corpus]: try: corpus = client.get_corpus( - genai.GetCorpusRequest(name=str(EntityName(corpus_id=corpus_id))) + old_genai.GetCorpusRequest(name=str(EntityName(corpus_id=corpus_id))) ) return Corpus.from_corpus(corpus) except Exception as e: @@ -325,7 +331,7 @@ def create_corpus( *, corpus_id: Optional[str] = None, display_name: Optional[str] = None, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Corpus: name: Optional[str] if corpus_id is not None: @@ -336,8 +342,8 @@ def create_corpus( new_display_name = display_name or f"Untitled {datetime.datetime.now()}" new_corpus = client.create_corpus( - genai.CreateCorpusRequest( - corpus=genai.Corpus(name=name, display_name=new_display_name) + old_genai.CreateCorpusRequest( + corpus=old_genai.Corpus(name=name, display_name=new_display_name) ) ) @@ -347,20 +353,22 @@ def create_corpus( def delete_corpus( *, corpus_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> None: client.delete_corpus( - genai.DeleteCorpusRequest(name=str(EntityName(corpus_id=corpus_id)), force=True) + old_genai.DeleteCorpusRequest( + name=str(EntityName(corpus_id=corpus_id)), force=True + ) ) def list_documents( *, corpus_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Iterator[Document]: for document in client.list_documents( - genai.ListDocumentsRequest( + old_genai.ListDocumentsRequest( parent=str(EntityName(corpus_id=corpus_id)), page_size=_DEFAULT_PAGE_SIZE ) ): @@ -371,11 +379,11 @@ def get_document( *, corpus_id: str, document_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Optional[Document]: try: document = client.get_document( - genai.GetDocumentRequest( + old_genai.GetDocumentRequest( name=str(EntityName(corpus_id=corpus_id, document_id=document_id)) ) ) @@ -393,7 +401,7 @@ def create_document( document_id: Optional[str] = None, display_name: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Document: name: Optional[str] if document_id is not None: @@ -405,9 +413,9 @@ def create_document( new_metadatas = _convert_to_metadata(metadata) if metadata else None new_document = client.create_document( - genai.CreateDocumentRequest( + old_genai.CreateDocumentRequest( parent=str(EntityName(corpus_id=corpus_id)), - document=genai.Document( + document=old_genai.Document( name=name, display_name=new_display_name, custom_metadata=new_metadatas ), ) @@ -420,10 +428,10 @@ def delete_document( *, corpus_id: str, document_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> None: client.delete_document( - genai.DeleteDocumentRequest( + old_genai.DeleteDocumentRequest( name=str(EntityName(corpus_id=corpus_id, document_id=document_id)), force=True, ) @@ -436,8 +444,8 @@ def batch_create_chunk( document_id: str, texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None, - client: genai.RetrieverServiceClient, -) -> List[genai.Chunk]: + client: RetrieverServiceClient, +) -> List[old_genai.Chunk]: if metadatas is None: metadatas = [{} for _ in texts] if len(texts) != len(metadatas): @@ -448,18 +456,18 @@ def batch_create_chunk( doc_name = str(EntityName(corpus_id=corpus_id, document_id=document_id)) - created_chunks: List[genai.Chunk] = [] + created_chunks: List[old_genai.Chunk] = [] - batch_request = genai.BatchCreateChunksRequest( + batch_request = old_genai.BatchCreateChunksRequest( parent=doc_name, requests=[], ) for text, metadata in zip(texts, metadatas): batch_request.requests.append( - genai.CreateChunkRequest( + old_genai.CreateChunkRequest( parent=doc_name, - chunk=genai.Chunk( - data=genai.ChunkData(string_value=text), + chunk=old_genai.Chunk( + data=old_genai.ChunkData(string_value=text), custom_metadata=_convert_to_metadata(metadata), ), ) @@ -469,7 +477,7 @@ def batch_create_chunk( response = client.batch_create_chunks(batch_request) created_chunks.extend(list(response.chunks)) # Prepare a new batch for next round. - batch_request = genai.BatchCreateChunksRequest( + batch_request = old_genai.BatchCreateChunksRequest( parent=doc_name, requests=[], ) @@ -487,10 +495,10 @@ def delete_chunk( corpus_id: str, document_id: str, chunk_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> None: client.delete_chunk( - genai.DeleteChunkRequest( + old_genai.DeleteChunkRequest( name=str( EntityName( corpus_id=corpus_id, document_id=document_id, chunk_id=chunk_id @@ -506,10 +514,10 @@ def query_corpus( query: str, k: int = 4, filter: Optional[Dict[str, Any]] = None, - client: genai.RetrieverServiceClient, -) -> List[genai.RelevantChunk]: + client: RetrieverServiceClient, +) -> List[old_genai.RelevantChunk]: response = client.query_corpus( - genai.QueryCorpusRequest( + old_genai.QueryCorpusRequest( name=str(EntityName(corpus_id=corpus_id)), query=query, metadata_filters=_convert_filter(filter), @@ -526,10 +534,10 @@ def query_document( query: str, k: int = 4, filter: Optional[Dict[str, Any]] = None, - client: genai.RetrieverServiceClient, -) -> List[genai.RelevantChunk]: + client: RetrieverServiceClient, +) -> List[old_genai.RelevantChunk]: response = client.query_document( - genai.QueryDocumentRequest( + old_genai.QueryDocumentRequest( name=str(EntityName(corpus_id=corpus_id, document_id=document_id)), query=query, metadata_filters=_convert_filter(filter), @@ -554,9 +562,9 @@ class GroundedAnswer: @dataclass class GenerateAnswerError(Exception): - finish_reason: genai.Candidate.FinishReason + finish_reason: old_genai.Candidate.FinishReason finish_message: str - safety_ratings: MutableSequence[genai.SafetyRating] + safety_ratings: MutableSequence[old_genai.SafetyRating] def __str__(self) -> str: return ( @@ -570,28 +578,28 @@ def generate_answer( *, prompt: str, passages: List[str], - answer_style: int = genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, - safety_settings: List[genai.SafetySetting] = [], + answer_style: int = old_genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + safety_settings: List[old_genai.SafetySetting] = [], temperature: Optional[float] = None, - client: genai.GenerativeServiceClient, + client: GenerativeServiceClient, ) -> GroundedAnswer: # TODO: Consider passing in the corpus ID instead of the actual # passages. response = client.generate_answer( - genai.GenerateAnswerRequest( + old_genai.GenerateAnswerRequest( contents=[ - genai.Content(parts=[genai.Part(text=prompt)]), + old_genai.Content(parts=[old_genai.Part(text=prompt)]), ], model=_DEFAULT_GENERATE_SERVICE_MODEL, answer_style=answer_style, safety_settings=safety_settings, temperature=temperature, - inline_passages=genai.GroundingPassages( + inline_passages=old_genai.GroundingPassages( passages=[ - genai.GroundingPassage( + old_genai.GroundingPassage( # IDs here takes alphanumeric only. No dashes allowed. id=str(index), - content=genai.Content(parts=[genai.Part(text=chunk)]), + content=old_genai.Content(parts=[old_genai.Part(text=chunk)]), ) for index, chunk in enumerate(passages) ] @@ -599,7 +607,7 @@ def generate_answer( ) ) - if response.answer.finish_reason != genai.Candidate.FinishReason.STOP: + if response.answer.finish_reason != old_genai.Candidate.FinishReason.STOP: finish_message = _get_finish_message(response.answer) raise GenerateAnswerError( finish_reason=response.answer.finish_reason, @@ -624,11 +632,11 @@ def generate_answer( # TODO: Use candidate.finish_message when that field is launched. # For now, we derive this message from other existing fields. -def _get_finish_message(candidate: genai.Candidate) -> str: +def _get_finish_message(candidate: old_genai.Candidate) -> str: finish_messages: Dict[int, str] = { - genai.Candidate.FinishReason.MAX_TOKENS: "Maximum token in context window reached", # noqa: E501 - genai.Candidate.FinishReason.SAFETY: "Blocked because of safety", - genai.Candidate.FinishReason.RECITATION: "Blocked because of recitation", + old_genai.Candidate.FinishReason.MAX_TOKENS: "Maximum token in context window reached", # noqa: E501 + old_genai.Candidate.FinishReason.SAFETY: "Blocked because of safety", + old_genai.Candidate.FinishReason.RECITATION: "Blocked because of recitation", } finish_reason = candidate.finish_reason @@ -638,13 +646,13 @@ def _get_finish_message(candidate: genai.Candidate) -> str: return finish_messages[finish_reason] -def _convert_to_metadata(metadata: Dict[str, Any]) -> List[genai.CustomMetadata]: - cs: List[genai.CustomMetadata] = [] +def _convert_to_metadata(metadata: Dict[str, Any]) -> List[old_genai.CustomMetadata]: + cs: List[old_genai.CustomMetadata] = [] for key, value in metadata.items(): if isinstance(value, str): - c = genai.CustomMetadata(key=key, string_value=value) + c = old_genai.CustomMetadata(key=key, string_value=value) elif isinstance(value, (float, int)): - c = genai.CustomMetadata(key=key, numeric_value=value) + c = old_genai.CustomMetadata(key=key, numeric_value=value) else: raise ValueError(f"Metadata value {value} is not supported") @@ -652,24 +660,24 @@ def _convert_to_metadata(metadata: Dict[str, Any]) -> List[genai.CustomMetadata] return cs -def _convert_filter(fs: Optional[Dict[str, Any]]) -> List[genai.MetadataFilter]: +def _convert_filter(fs: Optional[Dict[str, Any]]) -> List[old_genai.MetadataFilter]: if fs is None: return [] assert isinstance(fs, dict) - filters: List[genai.MetadataFilter] = [] + filters: List[old_genai.MetadataFilter] = [] for key, value in fs.items(): if isinstance(value, str): - condition = genai.Condition( - operation=genai.Condition.Operator.EQUAL, string_value=value + condition = old_genai.Condition( + operation=old_genai.Condition.Operator.EQUAL, string_value=value ) elif isinstance(value, (float, int)): - condition = genai.Condition( - operation=genai.Condition.Operator.EQUAL, numeric_value=value + condition = old_genai.Condition( + operation=old_genai.Condition.Operator.EQUAL, numeric_value=value ) else: raise ValueError(f"Filter value {value} is not supported") - filters.append(genai.MetadataFilter(key=key, conditions=[condition])) + filters.append(old_genai.MetadataFilter(key=key, conditions=[condition])) return filters diff --git a/libs/genai/langchain_google_genai/_image_utils.py b/libs/genai/langchain_google_genai/_image_utils.py index 9a9e47f53..c5b4842fd 100644 --- a/libs/genai/langchain_google_genai/_image_utils.py +++ b/libs/genai/langchain_google_genai/_image_utils.py @@ -5,12 +5,11 @@ import os import re from enum import Enum -from typing import Any, Dict from urllib.parse import urlparse import filetype # type: ignore[import] import requests -from google.ai.generativelanguage_v1beta.types import Part +from google.genai.types import Blob, Part class Route(Enum): @@ -87,18 +86,14 @@ def load_part(self, image_string: str) -> Part: ) raise ValueError(msg) - inline_data: Dict[str, Any] = {"data": bytes_} - mime_type, _ = mimetypes.guess_type(image_string) if not mime_type: kind = filetype.guess(bytes_) if kind: mime_type = kind.mime - if mime_type: - inline_data["mime_type"] = mime_type - - return Part(inline_data=inline_data) + blob = Blob(data=bytes_, mime_type=mime_type) + return Part(inline_data=blob) def _route(self, image_string: str) -> Route: if image_string.startswith("data:image/"): diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 13419dcee..f7f7858ce 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -6,7 +6,6 @@ import json import logging import mimetypes -import time import uuid import warnings import wave @@ -29,18 +28,15 @@ cast, ) -import filetype # type: ignore[import] -import google.api_core - -# TODO: remove ignore once the Google package is published with types -import proto # type: ignore[import] -from google.ai.generativelanguage_v1beta import ( - GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient, +from google.api_core.exceptions import ResourceExhausted, ServiceUnavailable +from google.genai.client import Client +from google.genai.errors import ( + ClientError, + ServerError, ) -from google.ai.generativelanguage_v1beta.types import ( +from google.genai.types import ( Blob, Candidate, - CodeExecution, CodeExecutionResult, Content, ExecutableCode, @@ -48,15 +44,21 @@ FunctionCall, FunctionDeclaration, FunctionResponse, - GenerateContentRequest, + GenerateContentConfig, GenerateContentResponse, GenerationConfig, + HttpOptions, Part, SafetySetting, + ThinkingConfig, + ToolCodeExecution, ToolConfig, VideoMetadata, ) -from google.ai.generativelanguage_v1beta.types import Tool as GoogleTool +from google.genai.types import ( + Outcome as CodeExecutionResultOutcome, +) +from google.genai.types import Tool as GoogleTool from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -113,13 +115,11 @@ GoogleGenerativeAIError, SafetySettingDict, _BaseGoogleGenerativeAI, - get_client_info, ) from langchain_google_genai._function_utils import ( _dict_to_gapic_schema, _tool_choice_to_tool_config, _ToolChoiceType, - _ToolConfigDict, _ToolDict, convert_to_genai_function_declarations, is_basemodel_subclass_safe, @@ -131,11 +131,20 @@ image_bytes_to_b64_string, ) -from . import _genai_extension as genaix - logger = logging.getLogger(__name__) -_allowed_params_prediction_service = ["request", "timeout", "metadata", "labels"] +_allowed_params_prediction_service_gapi = [ + "request", + "timeout", + "metadata", + "labels", +] + +_allowed_params_prediction_service_genai = [ + "model", + "contents", + "config", +] _FunctionDeclarationType = Union[ FunctionDeclaration, @@ -180,9 +189,9 @@ def _create_retry_decorator( max=wait_exponential_max, ), retry=( - retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) - | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) - | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) + retry_if_exception_type( + (ServerError, ResourceExhausted, ServiceUnavailable) + ) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -210,38 +219,25 @@ def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: wait_exponential_max=kwargs.get("wait_exponential_max", 60.0), ) + allowed_params = kwargs.get( + "allowed_params", _allowed_params_prediction_service_gapi + ) + @retry_decorator def _chat_with_retry(**kwargs: Any) -> Any: try: return generation_method(**kwargs) - except google.api_core.exceptions.FailedPrecondition as exc: - if "location is not supported" in exc.message: - error_msg = ( - "Your location is not supported by google-generativeai " - "at the moment. Try to use ChatVertexAI LLM from " - "langchain_google_vertexai." - ) - raise ValueError(error_msg) - - except google.api_core.exceptions.InvalidArgument as e: - raise ChatGoogleGenerativeAIError( - f"Invalid argument provided to Gemini: {e}" - ) from e - except google.api_core.exceptions.ResourceExhausted as e: - # Handle quota-exceeded error with recommended retry delay - if hasattr(e, "retry_after") and e.retry_after < kwargs.get( - "wait_exponential_max", 60.0 - ): - time.sleep(e.retry_after) + # Do not retry for these errors. + except ClientError as e: + if e.status == "INVALID_ARGUMENT": + raise ChatGoogleGenerativeAIError(f"Invalid argument: {e}") from e raise e except Exception as e: raise e params = ( - {k: v for k, v in kwargs.items() if k in _allowed_params_prediction_service} - if (request := kwargs.get("request")) - and hasattr(request, "model") - and "gemini" in request.model + {k: v for k, v in kwargs.items() if k in allowed_params} + if (model := kwargs.get("model")) and "gemini" in model else kwargs ) return _chat_with_retry(**params) @@ -262,26 +258,32 @@ async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: Returns: Any: The result from the chat generation method. """ - retry_decorator = _create_retry_decorator() - from google.api_core.exceptions import InvalidArgument # type: ignore + retry_decorator = _create_retry_decorator( + max_retries=kwargs.get("max_retries", 6), + wait_exponential_multiplier=kwargs.get("wait_exponential_multiplier", 2.0), + wait_exponential_min=kwargs.get("wait_exponential_min", 1.0), + wait_exponential_max=kwargs.get("wait_exponential_max", 60.0), + ) + + allowed_params = kwargs.get( + "allowed_params", _allowed_params_prediction_service_gapi + ) @retry_decorator async def _achat_with_retry(**kwargs: Any) -> Any: try: return await generation_method(**kwargs) - except InvalidArgument as e: + except ClientError as e: # Do not retry for these errors. - raise ChatGoogleGenerativeAIError( - f"Invalid argument provided to Gemini: {e}" - ) from e + if e.status == "INVALID_ARGUMENT": + raise ChatGoogleGenerativeAIError(f"Invalid argument: {e}") from e + raise e except Exception as e: raise e params = ( - {k: v for k, v in kwargs.items() if k in _allowed_params_prediction_service} - if (request := kwargs.get("request")) - and hasattr(request, "model") - and "gemini" in request.model + {k: v for k, v in kwargs.items() if k in allowed_params} + if (model := kwargs.get("model")) and "gemini" in model else kwargs ) return await _achat_with_retry(**params) @@ -329,19 +331,18 @@ def _convert_to_parts( bytes_ = base64.b64decode(part["data"]) else: raise ValueError("source_type must be url or base64.") - inline_data: dict = {"data": bytes_} - if "mime_type" in part: - inline_data["mime_type"] = part["mime_type"] - else: + mime_type = part.get("mime_type") + if not mime_type: source = cast(str, part.get("url") or part.get("data")) mime_type, _ = mimetypes.guess_type(source) - if not mime_type: - kind = filetype.guess(bytes_) - if kind: - mime_type = kind.mime - if mime_type: - inline_data["mime_type"] = mime_type - parts.append(Part(inline_data=inline_data)) + parts.append( + Part( + inline_data=Blob( + data=bytes_, + mime_type=mime_type, + ) + ) + ) elif part["type"] == "image_url": img_url = part["image_url"] if isinstance(img_url, dict): @@ -372,7 +373,7 @@ def _convert_to_parts( f"Media part must have either data or file_uri: {part}" ) if "video_metadata" in part: - metadata = VideoMetadata(part["video_metadata"]) + metadata = VideoMetadata.model_validate(part["video_metadata"]) media_part.video_metadata = metadata parts.append(media_part) elif part["type"] == "executable_code": @@ -397,10 +398,12 @@ def _convert_to_parts( outcome = part["outcome"] else: # Backward compatibility - outcome = 1 # Default to success if not specified + # Default to success if not specified + outcome = CodeExecutionResultOutcome.OUTCOME_OK code_execution_result_part = Part( code_execution_result=CodeExecutionResult( - output=part["code_execution_result"], outcome=outcome + output=part["code_execution_result"], + outcome=outcome, ) ) parts.append(code_execution_result_part) @@ -511,7 +514,10 @@ def _parse_chat_history( if i == 0: system_instruction = Content(parts=system_parts) elif system_instruction is not None: - system_instruction.parts.extend(system_parts) + if system_instruction.parts is None: + system_instruction.parts = system_parts + else: + system_instruction.parts.extend(system_parts) else: pass continue @@ -521,10 +527,8 @@ def _parse_chat_history( ai_message_parts = [] for tool_call in message.tool_calls: function_call = FunctionCall( - { - "name": tool_call["name"], - "args": tool_call["args"], - } + name=tool_call["name"], + args=tool_call["args"], ) ai_message_parts.append(Part(function_call=function_call)) tool_messages_parts = _get_ai_message_tool_messages_parts( @@ -535,10 +539,8 @@ def _parse_chat_history( continue elif raw_function_call := message.additional_kwargs.get("function_call"): function_call = FunctionCall( - { - "name": raw_function_call["name"], - "args": json.loads(raw_function_call["arguments"]), - } + name=raw_function_call["name"], + args=json.loads(raw_function_call["arguments"]), ) parts = [Part(function_call=function_call)] else: @@ -547,7 +549,7 @@ def _parse_chat_history( role = "user" parts = _convert_to_parts(message.content) if i == 1 and convert_system_message_to_human and system_instruction: - parts = [p for p in system_instruction.parts] + parts + parts = [p for p in system_instruction.parts or []] + parts system_instruction = None elif isinstance(message, FunctionMessage): role = "user" @@ -590,7 +592,9 @@ def _parse_response_candidate( invalid_tool_calls = [] tool_call_chunks = [] - for part in response_candidate.content.parts: + parts = response_candidate.content.parts or [] if response_candidate.content else [] + + for part in parts: text: Optional[str] = None try: if hasattr(part, "text") and part.text is not None: @@ -609,7 +613,6 @@ def _parse_response_candidate( content = _append_to_content(content, thinking_message) elif text is not None and text: content = _append_to_content(content, text) - if hasattr(part, "executable_code") and part.executable_code is not None: if part.executable_code.code and part.executable_code.language: code_message = { @@ -631,34 +634,35 @@ def _parse_response_candidate( } content = _append_to_content(content, execution_result) - if part.inline_data.mime_type.startswith("audio/"): - buffer = io.BytesIO() - - with wave.open(buffer, "wb") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - # TODO: Read Sample Rate from MIME content type. - wf.setframerate(24000) - wf.writeframes(part.inline_data.data) - - additional_kwargs["audio"] = buffer.getvalue() - - if part.inline_data.mime_type.startswith("image/"): - image_format = part.inline_data.mime_type[6:] - image_message = { - "type": "image_url", - "image_url": { - "url": image_bytes_to_b64_string( - part.inline_data.data, image_format=image_format - ) - }, - } - content = _append_to_content(content, image_message) + if part.inline_data and part.inline_data.data and part.inline_data.mime_type: + if part.inline_data.mime_type.startswith("audio/"): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + # TODO: Read Sample Rate from MIME content type. + wf.setframerate(24000) + wf.writeframes(part.inline_data.data) + + additional_kwargs["audio"] = buffer.getvalue() + + if part.inline_data.mime_type.startswith("image/"): + image_format = part.inline_data.mime_type[6:] + image_message = { + "type": "image_url", + "image_url": { + "url": image_bytes_to_b64_string( + part.inline_data.data, image_format=image_format + ) + }, + } + content = _append_to_content(content, image_message) if part.function_call: function_call = {"name": part.function_call.name} # dump to match other function calling llm for now - function_call_args_dict = proto.Message.to_dict(part.function_call)["args"] + function_call_args_dict = part.function_call.model_dump()["args"] function_call["arguments"] = json.dumps( {k: function_call_args_dict[k] for k in function_call_args_dict} ) @@ -698,6 +702,7 @@ def _parse_response_candidate( ) if content is None: content = "" + if any(isinstance(item, dict) and "executable_code" in item for item in content): warnings.warn( """ @@ -729,16 +734,24 @@ def _response_to_result( stream: bool = False, prev_usage: Optional[UsageMetadata] = None, ) -> ChatResult: - """Converts a PaLM API response into a LangChain ChatResult.""" - llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)} + """Converts a Google AI response into a LangChain ChatResult.""" + llm_output = ( + {"prompt_feedback": response.prompt_feedback.model_dump()} + if response.prompt_feedback + else {} + ) # Get usage metadata try: - input_tokens = response.usage_metadata.prompt_token_count - thought_tokens = response.usage_metadata.thoughts_token_count - output_tokens = response.usage_metadata.candidates_token_count + thought_tokens - total_tokens = response.usage_metadata.total_token_count - cache_read_tokens = response.usage_metadata.cached_content_token_count + if response.usage_metadata is None: + raise AttributeError("Usage metadata is None") + input_tokens = response.usage_metadata.prompt_token_count or 0 + thought_tokens = response.usage_metadata.thoughts_token_count or 0 + output_tokens = ( + response.usage_metadata.candidates_token_count or 0 + ) + thought_tokens + total_tokens = response.usage_metadata.total_token_count or 0 + cache_read_tokens = response.usage_metadata.cached_content_token_count or 0 if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0: if thought_tokens > 0: cumulative_usage = UsageMetadata( @@ -776,21 +789,22 @@ def _response_to_result( generations: List[ChatGeneration] = [] - for candidate in response.candidates: - generation_info = {} + for candidate in response.candidates or []: + generation_info: Dict[str, Any] = {} if candidate.finish_reason: generation_info["finish_reason"] = candidate.finish_reason.name # Add model_name in last chunk - generation_info["model_name"] = response.model_version - generation_info["safety_ratings"] = [ - proto.Message.to_dict(safety_rating, use_integers_for_enums=False) - for safety_rating in candidate.safety_ratings - ] + generation_info["model_name"] = response.model_version or "" + generation_info["safety_ratings"] = ( + [safety_rating.model_dump() for safety_rating in candidate.safety_ratings] + if candidate.safety_ratings + else [] + ) try: if candidate.grounding_metadata: - generation_info["grounding_metadata"] = proto.Message.to_dict( - candidate.grounding_metadata - ) + generation_info[ + "grounding_metadata" + ] = candidate.grounding_metadata.model_dump() except AttributeError: pass message = _parse_response_candidate(candidate, streaming=stream) @@ -1055,11 +1069,11 @@ class GetPopulation(BaseModel): Use Search with Gemini 2: .. code-block:: python - from google.ai.generativelanguage_v1beta.types import Tool as GenAITool + from google.genai.types import Tool as GoogleTool llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash") resp = llm.invoke( "When is the next total solar eclipse in US?", - tools=[GenAITool(google_search={})], + tools=[GoogleTool(google_search={})], ) Structured output: @@ -1189,7 +1203,7 @@ class Joke(BaseModel): .. code-block:: python - 'The video is a demo of multimodal live streaming in Gemini 2.0. The narrator is sharing his screen in AI Studio and asks if the AI can see it. The AI then reads text that is highlighted on the screen, defines the word “multimodal,” and summarizes everything that was seen and heard.' + 'The video is a demo of multimodal live streaming in Gemini 2.0. The narrator is sharing his screen in AI Studio and asks if the AI can see it. The AI then reads text that is highlighted on the screen, defines the word "multimodal," and summarizes everything that was seen and heard.' Audio input: .. code-block:: python @@ -1279,8 +1293,7 @@ class Joke(BaseModel): """ # noqa: E501 - client: Any = Field(default=None, exclude=True) #: :meta private: - async_client_running: Any = Field(default=None, exclude=True) #: :meta private: + client: Optional[Client] = Field(default=None, exclude=True) #: :meta private: default_metadata: Sequence[Tuple[str, str]] = Field( default_factory=list ) #: :meta private: @@ -1391,51 +1404,25 @@ def validate_environment(self) -> Self: additional_headers = self.additional_headers or {} self.default_metadata = tuple(additional_headers.items()) - client_info = get_client_info(f"ChatGoogleGenerativeAI:{self.model}") - google_api_key = None - if not self.credentials: - if isinstance(self.google_api_key, SecretStr): - google_api_key = self.google_api_key.get_secret_value() - else: - google_api_key = self.google_api_key - transport: Optional[str] = self.transport - self.client = genaix.build_generative_service( - credentials=self.credentials, - api_key=google_api_key, - client_info=client_info, - client_options=self.client_options, - transport=transport, - ) - self.async_client_running = None - return self - - @property - def async_client(self) -> v1betaGenerativeServiceAsyncClient: google_api_key = None if not self.credentials: if isinstance(self.google_api_key, SecretStr): google_api_key = self.google_api_key.get_secret_value() else: google_api_key = self.google_api_key - # NOTE: genaix.build_generative_async_service requires - # a running event loop, which causes an error - # when initialized inside a ThreadPoolExecutor. - # this check ensures that async client is only initialized - # within an asyncio event loop to avoid the error - if not self.async_client_running and _is_event_loop_running(): - # async clients don't support "rest" transport - # https://github.com/googleapis/gapic-generator-python/issues/1962 - transport = self.transport - if transport == "rest": - transport = "grpc_asyncio" - self.async_client_running = genaix.build_generative_async_service( - credentials=self.credentials, - api_key=google_api_key, - client_info=get_client_info(f"ChatGoogleGenerativeAI:{self.model}"), - client_options=self.client_options, - transport=transport, + http_options = HttpOptions(base_url=self.base_url, headers=additional_headers) + if google_api_key: + self.client = Client(api_key=google_api_key, http_options=http_options) + else: + project_id = getattr(self.credentials, "project_id", None) + location = getattr(self.credentials, "location", "us-central1") + self.client = Client( + vertexai=True, + project=project_id, + location=location, + http_options=http_options, ) - return self.async_client_running + return self @property def _identifying_params(self) -> Dict[str, Any]: @@ -1477,7 +1464,7 @@ def invoke( Current model: {self.model}" ) if "tools" not in kwargs: - code_execution_tool = GoogleTool(code_execution=CodeExecution()) + code_execution_tool = GoogleTool(code_execution=ToolCodeExecution()) kwargs["tools"] = [code_execution_tool] else: @@ -1510,60 +1497,322 @@ def _get_ls_params( ls_params["ls_stop"] = ls_stop return ls_params + def _supports_thinking(self) -> bool: + """Check if the current model supports thinking capabilities.""" + # Models that don't support thinking based on known patterns + non_thinking_models = [ + "image-generation", # Image generation models don't support thinking + "tts", # Text-to-speech models don't support thinking + ] + + model_name = self.model.lower() + return not any(pattern in model_name for pattern in non_thinking_models) + def _prepare_params( self, stop: Optional[List[str]], generation_config: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> GenerationConfig: - gen_config = { - k: v - for k, v in { - "candidate_count": self.n, - "temperature": self.temperature, - "stop_sequences": stop, - "max_output_tokens": self.max_output_tokens, - "top_k": self.top_k, - "top_p": self.top_p, - "response_modalities": self.response_modalities, - "thinking_config": ( - ( - {"thinking_budget": self.thinking_budget} - if self.thinking_budget is not None - else {} - ) - | ( - {"include_thoughts": self.include_thoughts} - if self.include_thoughts is not None - else {} - ) - ) - if self.thinking_budget is not None or self.include_thoughts is not None - else None, - }.items() - if v is not None - } + """Prepare generation parameters with common configuration logic.""" + gen_config = self._build_base_generation_config(stop, **kwargs) + if generation_config: - gen_config = {**gen_config, **generation_config} + gen_config = self._merge_generation_config(gen_config, generation_config) + + # Handle response-specific parameters + gen_config = self._add_response_parameters(gen_config, **kwargs) + + return GenerationConfig.model_validate(gen_config) + def _build_base_generation_config( + self, stop: Optional[List[str]], **kwargs: Any + ) -> Dict[str, Any]: + """Build the base generation configuration from instance attributes.""" + config: Dict[str, Any] = { + "candidate_count": self.n, + "temperature": self.temperature, + "stop_sequences": stop, + "max_output_tokens": self.max_output_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "response_modalities": self.response_modalities, + } + + # Add thinking config if supported + thinking_config = self._build_thinking_config() + if thinking_config is not None: + config["thinking_config"] = thinking_config + + return {k: v for k, v in config.items() if v is not None} + + def _build_thinking_config(self) -> Optional[Dict[str, Any]]: + """Build thinking configuration if supported by the model.""" + if not (self.thinking_budget is not None or self.include_thoughts is not None): + return None + + if not self._supports_thinking(): + return None + + config = {} + if self.thinking_budget is not None: + config["thinking_budget"] = self.thinking_budget + if self.include_thoughts is not None: + config["include_thoughts"] = self.include_thoughts + + return config + + def _merge_generation_config( + self, base_config: Dict[str, Any], generation_config: Dict[str, Any] + ) -> Dict[str, Any]: + """Merge user-provided generation config with base config.""" + processed_config = dict(generation_config) + + # Convert string response_modalities to Modality enums if needed + if "response_modalities" in processed_config: + modalities = processed_config["response_modalities"] + if ( + isinstance(modalities, list) + and modalities + and isinstance(modalities[0], str) + ): + from langchain_google_genai import Modality + + try: + processed_config["response_modalities"] = [ + getattr(Modality, modality) for modality in modalities + ] + except AttributeError as e: + raise ValueError(f"Invalid response modality: {e}") from e + + return {**base_config, **processed_config} + + def _add_response_parameters( + self, gen_config: Dict[str, Any], **kwargs: Any + ) -> Dict[str, Any]: + """Add response-specific parameters to generation config.""" + # Handle response mime type response_mime_type = kwargs.get("response_mime_type", self.response_mime_type) if response_mime_type is not None: gen_config["response_mime_type"] = response_mime_type + # Handle response schema response_schema = kwargs.get("response_schema", self.response_schema) if response_schema is not None: - allowed_mime_types = ("application/json", "text/x.enum") - if response_mime_type not in allowed_mime_types: - error_message = ( - "`response_schema` is only supported when " - f"`response_mime_type` is set to one of {allowed_mime_types}" + self._validate_and_add_response_schema( + gen_config, response_schema, response_mime_type + ) + + return gen_config + + def _validate_and_add_response_schema( + self, + gen_config: Dict[str, Any], + response_schema: Dict[str, Any], + response_mime_type: Optional[str], + ) -> None: + """Validate and add response schema to generation config.""" + allowed_mime_types = ("application/json", "text/x.enum") + if response_mime_type not in allowed_mime_types: + error_message = ( + "`response_schema` is only supported when " + f"`response_mime_type` is set to one of {allowed_mime_types}" + ) + raise ValueError(error_message) + + gapic_response_schema = _dict_to_gapic_schema(response_schema) + if gapic_response_schema is not None: + gen_config["response_schema"] = gapic_response_schema + + def _prepare_request( + self, + messages: List[BaseMessage], + *, + stop: Optional[List[str]] = None, + tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None, + functions: Optional[Sequence[_FunctionDeclarationType]] = None, + safety_settings: Optional[SafetySettingDict] = None, + tool_config: Optional[Union[Dict, ToolConfig]] = None, + tool_choice: Optional[Union[_ToolChoiceType, bool]] = None, + generation_config: Optional[Dict[str, Any]] = None, + cached_content: Optional[str] = None, + **kwargs: Any, + ) -> Tuple[GenerateContentConfig, Dict[str, Any]]: + """Prepare the request configuration for the API call.""" + # Validate tool configuration + if tool_choice and tool_config: + raise ValueError( + "Must specify at most one of tool_choice and tool_config, received " + f"both:\n\n{tool_choice=}\n\n{tool_config=}" + ) + + # Process tools and functions + formatted_tools = self._format_tools(tools, functions) + + # Filter and parse messages + filtered_messages = self._filter_messages(messages) + system_instruction, history = _parse_chat_history( + filtered_messages, + convert_system_message_to_human=self.convert_system_message_to_human, + ) + + # Process tool configuration + formatted_tool_config = self._process_tool_config( + tool_choice, tool_config, formatted_tools + ) + + # Process safety settings + formatted_safety_settings = self._format_safety_settings(safety_settings) + + # Get generation parameters + params = self._prepare_params( + stop, generation_config=generation_config, **kwargs + ) + + # Build request configuration + request = self._build_request_config( + formatted_tools, + formatted_tool_config, + formatted_safety_settings, + params, + cached_content, + system_instruction, + stop, + **kwargs, + ) + + # Return config and additional params needed for API call + api_params = {"model": self.model, "contents": history, "config": request} + return request, api_params + + def _format_tools( + self, + tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None, + functions: Optional[Sequence[_FunctionDeclarationType]] = None, + ) -> Optional[List]: + """Format tools and functions for the API.""" + code_execution_tool = GoogleTool(code_execution=ToolCodeExecution()) + + if tools == [code_execution_tool]: + return list(tools) + elif tools: + return [convert_to_genai_function_declarations(tools)] + elif functions: + return [convert_to_genai_function_declarations(functions)] + return None + + def _filter_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: + """Filter out messages with empty content.""" + filtered_messages = [] + for message in messages: + if isinstance(message, HumanMessage) and not message.content: + warnings.warn( + "HumanMessage with empty content was removed to prevent API error" ) - raise ValueError(error_message) + else: + filtered_messages.append(message) + return filtered_messages - gapic_response_schema = _dict_to_gapic_schema(response_schema) - if gapic_response_schema is not None: - gen_config["response_schema"] = gapic_response_schema - return GenerationConfig(**gen_config) + def _process_tool_config( + self, + tool_choice: Optional[Union[_ToolChoiceType, bool]], + tool_config: Optional[Union[Dict, ToolConfig]], + formatted_tools: Optional[List], + ) -> Optional[ToolConfig]: + """Process tool configuration and choice.""" + if tool_choice: + if not formatted_tools: + msg = ( + f"Received {tool_choice=} but no {formatted_tools=}. 'tool_choice' " + "can only be specified if 'tools' is specified." + ) + raise ValueError(msg) + + all_names = self._extract_tool_names(formatted_tools) + return _tool_choice_to_tool_config(tool_choice, all_names) + elif tool_config: + if isinstance(tool_config, dict): + return ToolConfig.model_validate(tool_config) + return tool_config + return None + + def _extract_tool_names(self, formatted_tools: List) -> List[str]: + """Extract tool names from formatted tools.""" + all_names: List[str] = [] + for t in formatted_tools: + if hasattr(t, "function_declarations"): + t_with_declarations = cast(Any, t) + all_names.extend( + f.name for f in t_with_declarations.function_declarations + ) + elif isinstance(t, GoogleTool) and hasattr(t, "code_execution"): + continue + else: + raise TypeError( + f"Tool {t} doesn't have function_declarations attribute" + ) + return all_names + + def _format_safety_settings( + self, safety_settings: Optional[SafetySettingDict] + ) -> List[SafetySetting]: + """Format safety settings for the API.""" + if not safety_settings: + return [] + + if isinstance(safety_settings, dict): + # Handle dictionary format: {HarmCategory: HarmBlockThreshold} + return [ + SafetySetting(category=category, threshold=threshold) + for category, threshold in safety_settings.items() + ] + elif isinstance(safety_settings, list): + # Handle list format: [SafetySetting, ...] + return safety_settings + else: + # Handle single SafetySetting object + return [safety_settings] + + def _build_request_config( + self, + formatted_tools: Optional[List], + formatted_tool_config: Optional[ToolConfig], + formatted_safety_settings: List[SafetySetting], + params: GenerationConfig, + cached_content: Optional[str], + system_instruction: Optional[Content], + stop: Optional[List[str]], + **kwargs: Any, + ) -> GenerateContentConfig: + """Build the final request configuration.""" + # Convert response modalities + response_modalities = ( + [m.value for m in params.response_modalities] + if params.response_modalities + else None + ) + + # Create thinking config if supported + thinking_config = None + if params.thinking_config is not None and self._supports_thinking(): + thinking_config = ThinkingConfig( + include_thoughts=params.thinking_config.include_thoughts, + thinking_budget=params.thinking_config.thinking_budget, + ) + + request = GenerateContentConfig( + tools=list(formatted_tools) if formatted_tools else None, + tool_config=formatted_tool_config, + safety_settings=formatted_safety_settings, + response_modalities=response_modalities if response_modalities else None, + thinking_config=thinking_config, + cached_content=cached_content, + system_instruction=system_instruction, + stop_sequences=stop, + **kwargs, + ) + + return request def _generate( self, @@ -1574,13 +1823,15 @@ def _generate( tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None, functions: Optional[Sequence[_FunctionDeclarationType]] = None, safety_settings: Optional[SafetySettingDict] = None, - tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, + tool_config: Optional[Union[Dict, ToolConfig]] = None, generation_config: Optional[Dict[str, Any]] = None, cached_content: Optional[str] = None, tool_choice: Optional[Union[_ToolChoiceType, bool]] = None, **kwargs: Any, ) -> ChatResult: - request = self._prepare_request( + if self.client is None: + raise ValueError("Client not initialized.") + request, api_params = self._prepare_request( messages, stop=stop, tools=tools, @@ -1593,9 +1844,10 @@ def _generate( **kwargs, ) response: GenerateContentResponse = _chat_with_retry( - request=request, + **api_params, **kwargs, - generation_method=self.client.generate_content, + generation_method=self.client.models.generate_content, + allowed_params=_allowed_params_prediction_service_genai, metadata=self.default_metadata, ) return _response_to_result(response) @@ -1609,28 +1861,15 @@ async def _agenerate( tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None, functions: Optional[Sequence[_FunctionDeclarationType]] = None, safety_settings: Optional[SafetySettingDict] = None, - tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, + tool_config: Optional[Union[Dict, ToolConfig]] = None, generation_config: Optional[Dict[str, Any]] = None, cached_content: Optional[str] = None, tool_choice: Optional[Union[_ToolChoiceType, bool]] = None, **kwargs: Any, ) -> ChatResult: - if not self.async_client: - updated_kwargs = { - **kwargs, - **{ - "tools": tools, - "functions": functions, - "safety_settings": safety_settings, - "tool_config": tool_config, - "generation_config": generation_config, - }, - } - return await super()._agenerate( - messages, stop, run_manager, **updated_kwargs - ) - - request = self._prepare_request( + if self.client is None: + raise ValueError("Client not initialized.") + request, api_params = self._prepare_request( messages, stop=stop, tools=tools, @@ -1643,9 +1882,10 @@ async def _agenerate( **kwargs, ) response: GenerateContentResponse = await _achat_with_retry( - request=request, + **api_params, **kwargs, - generation_method=self.async_client.generate_content, + generation_method=self.client.aio.models.generate_content, + allowed_params=_allowed_params_prediction_service_genai, metadata=self.default_metadata, ) return _response_to_result(response) @@ -1659,13 +1899,15 @@ def _stream( tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None, functions: Optional[Sequence[_FunctionDeclarationType]] = None, safety_settings: Optional[SafetySettingDict] = None, - tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, + tool_config: Optional[Union[Dict, ToolConfig]] = None, generation_config: Optional[Dict[str, Any]] = None, cached_content: Optional[str] = None, tool_choice: Optional[Union[_ToolChoiceType, bool]] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - request = self._prepare_request( + if self.client is None: + raise ValueError("Client not initialized.") + request, api_params = self._prepare_request( messages, stop=stop, tools=tools, @@ -1677,20 +1919,22 @@ def _stream( tool_choice=tool_choice, **kwargs, ) - response: GenerateContentResponse = _chat_with_retry( - request=request, - generation_method=self.client.stream_generate_content, + response: Iterator[GenerateContentResponse] = _chat_with_retry( + **api_params, + generation_method=self.client.models.generate_content_stream, + allowed_params=_allowed_params_prediction_service_genai, **kwargs, metadata=self.default_metadata, ) prev_usage_metadata: UsageMetadata | None = None # cumulative usage for chunk in response: - _chat_result = _response_to_result( - chunk, stream=True, prev_usage=prev_usage_metadata - ) - gen = cast(ChatGenerationChunk, _chat_result.generations[0]) - message = cast(AIMessageChunk, gen.message) + if chunk: + _chat_result = _response_to_result( + chunk, stream=True, prev_usage=prev_usage_metadata + ) + gen = cast(ChatGenerationChunk, _chat_result.generations[0]) + message = cast(AIMessageChunk, gen.message) prev_usage_metadata = ( message.usage_metadata @@ -1711,159 +1955,54 @@ async def _astream( tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None, functions: Optional[Sequence[_FunctionDeclarationType]] = None, safety_settings: Optional[SafetySettingDict] = None, - tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, + tool_config: Optional[Union[Dict, ToolConfig]] = None, generation_config: Optional[Dict[str, Any]] = None, cached_content: Optional[str] = None, tool_choice: Optional[Union[_ToolChoiceType, bool]] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - if not self.async_client: - updated_kwargs = { - **kwargs, - **{ - "tools": tools, - "functions": functions, - "safety_settings": safety_settings, - "tool_config": tool_config, - "generation_config": generation_config, - }, - } - async for value in super()._astream( - messages, stop, run_manager, **updated_kwargs - ): - yield value - else: - request = self._prepare_request( - messages, - stop=stop, - tools=tools, - functions=functions, - safety_settings=safety_settings, - tool_config=tool_config, - generation_config=generation_config, - cached_content=cached_content or self.cached_content, - tool_choice=tool_choice, - **kwargs, - ) - prev_usage_metadata: UsageMetadata | None = None # cumulative usage - async for chunk in await _achat_with_retry( - request=request, - generation_method=self.async_client.stream_generate_content, - **kwargs, - metadata=self.default_metadata, - ): - _chat_result = _response_to_result( - chunk, stream=True, prev_usage=prev_usage_metadata - ) - gen = cast(ChatGenerationChunk, _chat_result.generations[0]) - message = cast(AIMessageChunk, gen.message) - - prev_usage_metadata = ( - message.usage_metadata - if prev_usage_metadata is None - else add_usage(prev_usage_metadata, message.usage_metadata) - ) - - if run_manager: - await run_manager.on_llm_new_token(gen.text, chunk=gen) - yield gen - - def _prepare_request( - self, - messages: List[BaseMessage], - *, - stop: Optional[List[str]] = None, - tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None, - functions: Optional[Sequence[_FunctionDeclarationType]] = None, - safety_settings: Optional[SafetySettingDict] = None, - tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, - tool_choice: Optional[Union[_ToolChoiceType, bool]] = None, - generation_config: Optional[Dict[str, Any]] = None, - cached_content: Optional[str] = None, - **kwargs: Any, - ) -> Tuple[GenerateContentRequest, Dict[str, Any]]: - if tool_choice and tool_config: - raise ValueError( - "Must specify at most one of tool_choice and tool_config, received " - f"both:\n\n{tool_choice=}\n\n{tool_config=}" - ) - - formatted_tools = None - code_execution_tool = GoogleTool(code_execution=CodeExecution()) - if tools == [code_execution_tool]: - formatted_tools = tools - elif tools: - formatted_tools = [convert_to_genai_function_declarations(tools)] - elif functions: - formatted_tools = [convert_to_genai_function_declarations(functions)] - - filtered_messages = [] - for message in messages: - if isinstance(message, HumanMessage) and not message.content: - warnings.warn( - "HumanMessage with empty content was removed to prevent API error" - ) - else: - filtered_messages.append(message) - messages = filtered_messages - - system_instruction, history = _parse_chat_history( + if self.client is None: + raise ValueError("Client not initialized.") + request, api_params = self._prepare_request( messages, - convert_system_message_to_human=self.convert_system_message_to_human, + stop=stop, + tools=tools, + functions=functions, + safety_settings=safety_settings, + tool_config=tool_config, + generation_config=generation_config, + cached_content=cached_content or self.cached_content, + tool_choice=tool_choice, + **kwargs, ) - if tool_choice: - if not formatted_tools: - msg = ( - f"Received {tool_choice=} but no {tools=}. 'tool_choice' can only " - f"be specified if 'tools' is specified." - ) - raise ValueError(msg) - all_names: List[str] = [] - for t in formatted_tools: - if hasattr(t, "function_declarations"): - t_with_declarations = cast(Any, t) - all_names.extend( - f.name for f in t_with_declarations.function_declarations - ) - elif isinstance(t, GoogleTool) and hasattr(t, "code_execution"): - continue - else: - raise TypeError( - f"Tool {t} doesn't have function_declarations attribute" - ) - - tool_config = _tool_choice_to_tool_config(tool_choice, all_names) + prev_usage_metadata: UsageMetadata | None = None # cumulative usage + async for chunk in await _achat_with_retry( + **api_params, + generation_method=self.client.aio.models.generate_content_stream, + allowed_params=_allowed_params_prediction_service_genai, + **kwargs, + metadata=self.default_metadata, + ): + chunk = cast(GenerateContentResponse, chunk) + _chat_result = _response_to_result( + chunk, stream=True, prev_usage=prev_usage_metadata + ) + gen = cast(ChatGenerationChunk, _chat_result.generations[0]) + message = cast(AIMessageChunk, gen.message) - formatted_tool_config = None - if tool_config: - formatted_tool_config = ToolConfig( - function_calling_config=tool_config["function_calling_config"] + prev_usage_metadata = ( + message.usage_metadata + if prev_usage_metadata is None + else add_usage(prev_usage_metadata, message.usage_metadata) ) - formatted_safety_settings = [] - if safety_settings: - formatted_safety_settings = [ - SafetySetting(category=c, threshold=t) - for c, t in safety_settings.items() - ] - request = GenerateContentRequest( - model=self.model, - contents=history, - tools=formatted_tools, - tool_config=formatted_tool_config, - safety_settings=formatted_safety_settings, - generation_config=self._prepare_params( - stop, - generation_config=generation_config, - **kwargs, - ), - cached_content=cached_content, - ) - if system_instruction: - request.system_instruction = system_instruction - return request + if run_manager: + await run_manager.on_llm_new_token(gen.text, chunk=gen) + yield gen def get_num_tokens(self, text: str) -> int: + if self.client is None: + raise ValueError("Client not initialized.") """Get the number of tokens present in the text. Useful for checking if an input will fit in a model's context window. @@ -1874,10 +2013,10 @@ def get_num_tokens(self, text: str) -> int: Returns: The integer number of tokens in the text. """ - result = self.client.count_tokens( + result = self.client.models.count_tokens( model=self.model, contents=[Content(parts=[Part(text=text)])] ) - return result.total_tokens + return result.total_tokens if result and result.total_tokens is not None else 0 def with_structured_output( self, @@ -1957,7 +2096,7 @@ def bind_tools( tools: Sequence[ dict[str, Any] | type | Callable[..., Any] | BaseTool | GoogleTool ], - tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, + tool_config: Optional[Union[Dict, ToolConfig]] = None, *, tool_choice: Optional[Union[_ToolChoiceType, bool]] = None, **kwargs: Any, diff --git a/libs/genai/poetry.lock b/libs/genai/poetry.lock index b7745ea03..d1ffee6c0 100644 --- a/libs/genai/poetry.lock +++ b/libs/genai/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -332,46 +332,46 @@ files = [ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0", extras = ["grpc"]} google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0" proto-plus = [ - {version = ">=1.22.3,<2.0.0"}, {version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""}, + {version = ">=1.22.3,<2.0.0", markers = "python_version < \"3.13\""}, ] protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0" [[package]] name = "google-api-core" -version = "2.24.1" +version = "2.25.1" description = "Google API client core library" optional = false python-versions = ">=3.7" groups = ["main"] files = [ - {file = "google_api_core-2.24.1-py3-none-any.whl", hash = "sha256:bc78d608f5a5bf853b80bd70a795f703294de656c096c0968320830a4bc280f1"}, - {file = "google_api_core-2.24.1.tar.gz", hash = "sha256:f8b36f5456ab0dd99a1b693a40a31d1e7757beea380ad1b38faaf8941eae9d8a"}, + {file = "google_api_core-2.25.1-py3-none-any.whl", hash = "sha256:8a2a56c1fef82987a524371f99f3bd0143702fecc670c72e600c1cda6bf8dbb7"}, + {file = "google_api_core-2.25.1.tar.gz", hash = "sha256:d2aaa0b13c78c61cb3f4282c464c046e45fbd75755683c9c525e6e8f7ed0a5e8"}, ] [package.dependencies] -google-auth = ">=2.14.1,<3.0.dev0" -googleapis-common-protos = ">=1.56.2,<2.0.dev0" +google-auth = ">=2.14.1,<3.0.0" +googleapis-common-protos = ">=1.56.2,<2.0.0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.0", optional = true, markers = "extra == \"grpc\""}, ] proto-plus = [ - {version = ">=1.22.3,<2.0.0dev"}, - {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, + {version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""}, + {version = ">=1.22.3,<2.0.0"}, ] -protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" -requests = ">=2.18.0,<3.0.0.dev0" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0" +requests = ">=2.18.0,<3.0.0" [package.extras] -async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"] -grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev) ; python_version >= \"3.11\"", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0) ; python_version >= \"3.11\""] -grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] -grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.0)"] +grpc = ["grpcio (>=1.33.2,<2.0.0)", "grpcio (>=1.49.1,<2.0.0) ; python_version >= \"3.11\"", "grpcio-status (>=1.33.2,<2.0.0)", "grpcio-status (>=1.49.1,<2.0.0) ; python_version >= \"3.11\""] +grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.0)"] +grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.0)"] [[package]] name = "google-auth" @@ -398,108 +398,129 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] +[[package]] +name = "google-genai" +version = "1.27.0" +description = "GenAI Python SDK" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "google_genai-1.27.0-py3-none-any.whl", hash = "sha256:afd6b4efaf8ec1d20a6e6657d768b68d998d60007c6e220e9024e23c913c1833"}, + {file = "google_genai-1.27.0.tar.gz", hash = "sha256:15a13ffe7b3938da50b9ab77204664d82122617256f55b5ce403d593848ef635"}, +] + +[package.dependencies] +anyio = ">=4.8.0,<5.0.0" +google-auth = ">=2.14.1,<3.0.0" +httpx = ">=0.28.1,<1.0.0" +pydantic = ">=2.0.0,<3.0.0" +requests = ">=2.28.1,<3.0.0" +tenacity = ">=8.2.3,<9.0.0" +typing-extensions = ">=4.11.0,<5.0.0" +websockets = ">=13.0.0,<15.1.0" + +[package.extras] +aiohttp = ["aiohttp (<4.0.0)"] + [[package]] name = "googleapis-common-protos" -version = "1.68.0" +version = "1.70.0" description = "Common protobufs used in Google APIs" optional = false python-versions = ">=3.7" groups = ["main"] files = [ - {file = "googleapis_common_protos-1.68.0-py2.py3-none-any.whl", hash = "sha256:aaf179b2f81df26dfadac95def3b16a95064c76a5f45f07e4c68a21bb371c4ac"}, - {file = "googleapis_common_protos-1.68.0.tar.gz", hash = "sha256:95d38161f4f9af0d9423eed8fb7b64ffd2568c3464eb542ff02c5bfa1953ab3c"}, + {file = "googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8"}, + {file = "googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257"}, ] [package.dependencies] -protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0" [package.extras] -grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +grpc = ["grpcio (>=1.44.0,<2.0.0)"] [[package]] name = "grpcio" -version = "1.70.0" +version = "1.74.0" description = "HTTP/2-based RPC framework" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "grpcio-1.70.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:95469d1977429f45fe7df441f586521361e235982a0b39e33841549143ae2851"}, - {file = "grpcio-1.70.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:ed9718f17fbdb472e33b869c77a16d0b55e166b100ec57b016dc7de9c8d236bf"}, - {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:374d014f29f9dfdb40510b041792e0e2828a1389281eb590df066e1cc2b404e5"}, - {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2af68a6f5c8f78d56c145161544ad0febbd7479524a59c16b3e25053f39c87f"}, - {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce7df14b2dcd1102a2ec32f621cc9fab6695effef516efbc6b063ad749867295"}, - {file = "grpcio-1.70.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c78b339869f4dbf89881e0b6fbf376313e4f845a42840a7bdf42ee6caed4b11f"}, - {file = "grpcio-1.70.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58ad9ba575b39edef71f4798fdb5c7b6d02ad36d47949cd381d4392a5c9cbcd3"}, - {file = "grpcio-1.70.0-cp310-cp310-win32.whl", hash = "sha256:2b0d02e4b25a5c1f9b6c7745d4fa06efc9fd6a611af0fb38d3ba956786b95199"}, - {file = "grpcio-1.70.0-cp310-cp310-win_amd64.whl", hash = "sha256:0de706c0a5bb9d841e353f6343a9defc9fc35ec61d6eb6111802f3aa9fef29e1"}, - {file = "grpcio-1.70.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:17325b0be0c068f35770f944124e8839ea3185d6d54862800fc28cc2ffad205a"}, - {file = "grpcio-1.70.0-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:dbe41ad140df911e796d4463168e33ef80a24f5d21ef4d1e310553fcd2c4a386"}, - {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5ea67c72101d687d44d9c56068328da39c9ccba634cabb336075fae2eab0d04b"}, - {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb5277db254ab7586769e490b7b22f4ddab3876c490da0a1a9d7c695ccf0bf77"}, - {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7831a0fc1beeeb7759f737f5acd9fdcda520e955049512d68fda03d91186eea"}, - {file = "grpcio-1.70.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:27cc75e22c5dba1fbaf5a66c778e36ca9b8ce850bf58a9db887754593080d839"}, - {file = "grpcio-1.70.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d63764963412e22f0491d0d32833d71087288f4e24cbcddbae82476bfa1d81fd"}, - {file = "grpcio-1.70.0-cp311-cp311-win32.whl", hash = "sha256:bb491125103c800ec209d84c9b51f1c60ea456038e4734688004f377cfacc113"}, - {file = "grpcio-1.70.0-cp311-cp311-win_amd64.whl", hash = "sha256:d24035d49e026353eb042bf7b058fb831db3e06d52bee75c5f2f3ab453e71aca"}, - {file = "grpcio-1.70.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:ef4c14508299b1406c32bdbb9fb7b47612ab979b04cf2b27686ea31882387cff"}, - {file = "grpcio-1.70.0-cp312-cp312-macosx_10_14_universal2.whl", hash = "sha256:aa47688a65643afd8b166928a1da6247d3f46a2784d301e48ca1cc394d2ffb40"}, - {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:880bfb43b1bb8905701b926274eafce5c70a105bc6b99e25f62e98ad59cb278e"}, - {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e654c4b17d07eab259d392e12b149c3a134ec52b11ecdc6a515b39aceeec898"}, - {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2394e3381071045a706ee2eeb6e08962dd87e8999b90ac15c55f56fa5a8c9597"}, - {file = "grpcio-1.70.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:b3c76701428d2df01964bc6479422f20e62fcbc0a37d82ebd58050b86926ef8c"}, - {file = "grpcio-1.70.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac073fe1c4cd856ebcf49e9ed6240f4f84d7a4e6ee95baa5d66ea05d3dd0df7f"}, - {file = "grpcio-1.70.0-cp312-cp312-win32.whl", hash = "sha256:cd24d2d9d380fbbee7a5ac86afe9787813f285e684b0271599f95a51bce33528"}, - {file = "grpcio-1.70.0-cp312-cp312-win_amd64.whl", hash = "sha256:0495c86a55a04a874c7627fd33e5beaee771917d92c0e6d9d797628ac40e7655"}, - {file = "grpcio-1.70.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:aa573896aeb7d7ce10b1fa425ba263e8dddd83d71530d1322fd3a16f31257b4a"}, - {file = "grpcio-1.70.0-cp313-cp313-macosx_10_14_universal2.whl", hash = "sha256:d405b005018fd516c9ac529f4b4122342f60ec1cee181788249372524e6db429"}, - {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f32090238b720eb585248654db8e3afc87b48d26ac423c8dde8334a232ff53c9"}, - {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfa089a734f24ee5f6880c83d043e4f46bf812fcea5181dcb3a572db1e79e01c"}, - {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f19375f0300b96c0117aca118d400e76fede6db6e91f3c34b7b035822e06c35f"}, - {file = "grpcio-1.70.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:7c73c42102e4a5ec76608d9b60227d917cea46dff4d11d372f64cbeb56d259d0"}, - {file = "grpcio-1.70.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:0a5c78d5198a1f0aa60006cd6eb1c912b4a1520b6a3968e677dbcba215fabb40"}, - {file = "grpcio-1.70.0-cp313-cp313-win32.whl", hash = "sha256:fe9dbd916df3b60e865258a8c72ac98f3ac9e2a9542dcb72b7a34d236242a5ce"}, - {file = "grpcio-1.70.0-cp313-cp313-win_amd64.whl", hash = "sha256:4119fed8abb7ff6c32e3d2255301e59c316c22d31ab812b3fbcbaf3d0d87cc68"}, - {file = "grpcio-1.70.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:8058667a755f97407fca257c844018b80004ae8035565ebc2812cc550110718d"}, - {file = "grpcio-1.70.0-cp38-cp38-macosx_10_14_universal2.whl", hash = "sha256:879a61bf52ff8ccacbedf534665bb5478ec8e86ad483e76fe4f729aaef867cab"}, - {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:0ba0a173f4feacf90ee618fbc1a27956bfd21260cd31ced9bc707ef551ff7dc7"}, - {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558c386ecb0148f4f99b1a65160f9d4b790ed3163e8610d11db47838d452512d"}, - {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:412faabcc787bbc826f51be261ae5fa996b21263de5368a55dc2cf824dc5090e"}, - {file = "grpcio-1.70.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3b0f01f6ed9994d7a0b27eeddea43ceac1b7e6f3f9d86aeec0f0064b8cf50fdb"}, - {file = "grpcio-1.70.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7385b1cb064734005204bc8994eed7dcb801ed6c2eda283f613ad8c6c75cf873"}, - {file = "grpcio-1.70.0-cp38-cp38-win32.whl", hash = "sha256:07269ff4940f6fb6710951116a04cd70284da86d0a4368fd5a3b552744511f5a"}, - {file = "grpcio-1.70.0-cp38-cp38-win_amd64.whl", hash = "sha256:aba19419aef9b254e15011b230a180e26e0f6864c90406fdbc255f01d83bc83c"}, - {file = "grpcio-1.70.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:4f1937f47c77392ccd555728f564a49128b6a197a05a5cd527b796d36f3387d0"}, - {file = "grpcio-1.70.0-cp39-cp39-macosx_10_14_universal2.whl", hash = "sha256:0cd430b9215a15c10b0e7d78f51e8a39d6cf2ea819fd635a7214fae600b1da27"}, - {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:e27585831aa6b57b9250abaf147003e126cd3a6c6ca0c531a01996f31709bed1"}, - {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1af8e15b0f0fe0eac75195992a63df17579553b0c4af9f8362cc7cc99ccddf4"}, - {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbce24409beaee911c574a3d75d12ffb8c3e3dd1b813321b1d7a96bbcac46bf4"}, - {file = "grpcio-1.70.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ff4a8112a79464919bb21c18e956c54add43ec9a4850e3949da54f61c241a4a6"}, - {file = "grpcio-1.70.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5413549fdf0b14046c545e19cfc4eb1e37e9e1ebba0ca390a8d4e9963cab44d2"}, - {file = "grpcio-1.70.0-cp39-cp39-win32.whl", hash = "sha256:b745d2c41b27650095e81dea7091668c040457483c9bdb5d0d9de8f8eb25e59f"}, - {file = "grpcio-1.70.0-cp39-cp39-win_amd64.whl", hash = "sha256:a31d7e3b529c94e930a117b2175b2efd179d96eb3c7a21ccb0289a8ab05b645c"}, - {file = "grpcio-1.70.0.tar.gz", hash = "sha256:8d1584a68d5922330025881e63a6c1b54cc8117291d382e4fa69339b6d914c56"}, + {file = "grpcio-1.74.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:85bd5cdf4ed7b2d6438871adf6afff9af7096486fcf51818a81b77ef4dd30907"}, + {file = "grpcio-1.74.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:68c8ebcca945efff9d86d8d6d7bfb0841cf0071024417e2d7f45c5e46b5b08eb"}, + {file = "grpcio-1.74.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:e154d230dc1bbbd78ad2fdc3039fa50ad7ffcf438e4eb2fa30bce223a70c7486"}, + {file = "grpcio-1.74.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8978003816c7b9eabe217f88c78bc26adc8f9304bf6a594b02e5a49b2ef9c11"}, + {file = "grpcio-1.74.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3d7bd6e3929fd2ea7fbc3f562e4987229ead70c9ae5f01501a46701e08f1ad9"}, + {file = "grpcio-1.74.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:136b53c91ac1d02c8c24201bfdeb56f8b3ac3278668cbb8e0ba49c88069e1bdc"}, + {file = "grpcio-1.74.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:fe0f540750a13fd8e5da4b3eaba91a785eea8dca5ccd2bc2ffe978caa403090e"}, + {file = "grpcio-1.74.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4e4181bfc24413d1e3a37a0b7889bea68d973d4b45dd2bc68bb766c140718f82"}, + {file = "grpcio-1.74.0-cp310-cp310-win32.whl", hash = "sha256:1733969040989f7acc3d94c22f55b4a9501a30f6aaacdbccfaba0a3ffb255ab7"}, + {file = "grpcio-1.74.0-cp310-cp310-win_amd64.whl", hash = "sha256:9e912d3c993a29df6c627459af58975b2e5c897d93287939b9d5065f000249b5"}, + {file = "grpcio-1.74.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:69e1a8180868a2576f02356565f16635b99088da7df3d45aaa7e24e73a054e31"}, + {file = "grpcio-1.74.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:8efe72fde5500f47aca1ef59495cb59c885afe04ac89dd11d810f2de87d935d4"}, + {file = "grpcio-1.74.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:a8f0302f9ac4e9923f98d8e243939a6fb627cd048f5cd38595c97e38020dffce"}, + {file = "grpcio-1.74.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f609a39f62a6f6f05c7512746798282546358a37ea93c1fcbadf8b2fed162e3"}, + {file = "grpcio-1.74.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c98e0b7434a7fa4e3e63f250456eaef52499fba5ae661c58cc5b5477d11e7182"}, + {file = "grpcio-1.74.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:662456c4513e298db6d7bd9c3b8df6f75f8752f0ba01fb653e252ed4a59b5a5d"}, + {file = "grpcio-1.74.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3d14e3c4d65e19d8430a4e28ceb71ace4728776fd6c3ce34016947474479683f"}, + {file = "grpcio-1.74.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1bf949792cee20d2078323a9b02bacbbae002b9e3b9e2433f2741c15bdeba1c4"}, + {file = "grpcio-1.74.0-cp311-cp311-win32.whl", hash = "sha256:55b453812fa7c7ce2f5c88be3018fb4a490519b6ce80788d5913f3f9d7da8c7b"}, + {file = "grpcio-1.74.0-cp311-cp311-win_amd64.whl", hash = "sha256:86ad489db097141a907c559988c29718719aa3e13370d40e20506f11b4de0d11"}, + {file = "grpcio-1.74.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:8533e6e9c5bd630ca98062e3a1326249e6ada07d05acf191a77bc33f8948f3d8"}, + {file = "grpcio-1.74.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:2918948864fec2a11721d91568effffbe0a02b23ecd57f281391d986847982f6"}, + {file = "grpcio-1.74.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:60d2d48b0580e70d2e1954d0d19fa3c2e60dd7cbed826aca104fff518310d1c5"}, + {file = "grpcio-1.74.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3601274bc0523f6dc07666c0e01682c94472402ac2fd1226fd96e079863bfa49"}, + {file = "grpcio-1.74.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:176d60a5168d7948539def20b2a3adcce67d72454d9ae05969a2e73f3a0feee7"}, + {file = "grpcio-1.74.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e759f9e8bc908aaae0412642afe5416c9f983a80499448fcc7fab8692ae044c3"}, + {file = "grpcio-1.74.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e7c4389771855a92934b2846bd807fc25a3dfa820fd912fe6bd8136026b2707"}, + {file = "grpcio-1.74.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cce634b10aeab37010449124814b05a62fb5f18928ca878f1bf4750d1f0c815b"}, + {file = "grpcio-1.74.0-cp312-cp312-win32.whl", hash = "sha256:885912559974df35d92219e2dc98f51a16a48395f37b92865ad45186f294096c"}, + {file = "grpcio-1.74.0-cp312-cp312-win_amd64.whl", hash = "sha256:42f8fee287427b94be63d916c90399ed310ed10aadbf9e2e5538b3e497d269bc"}, + {file = "grpcio-1.74.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:2bc2d7d8d184e2362b53905cb1708c84cb16354771c04b490485fa07ce3a1d89"}, + {file = "grpcio-1.74.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:c14e803037e572c177ba54a3e090d6eb12efd795d49327c5ee2b3bddb836bf01"}, + {file = "grpcio-1.74.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f6ec94f0e50eb8fa1744a731088b966427575e40c2944a980049798b127a687e"}, + {file = "grpcio-1.74.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:566b9395b90cc3d0d0c6404bc8572c7c18786ede549cdb540ae27b58afe0fb91"}, + {file = "grpcio-1.74.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1ea6176d7dfd5b941ea01c2ec34de9531ba494d541fe2057c904e601879f249"}, + {file = "grpcio-1.74.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:64229c1e9cea079420527fa8ac45d80fc1e8d3f94deaa35643c381fa8d98f362"}, + {file = "grpcio-1.74.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:0f87bddd6e27fc776aacf7ebfec367b6d49cad0455123951e4488ea99d9b9b8f"}, + {file = "grpcio-1.74.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:3b03d8f2a07f0fea8c8f74deb59f8352b770e3900d143b3d1475effcb08eec20"}, + {file = "grpcio-1.74.0-cp313-cp313-win32.whl", hash = "sha256:b6a73b2ba83e663b2480a90b82fdae6a7aa6427f62bf43b29912c0cfd1aa2bfa"}, + {file = "grpcio-1.74.0-cp313-cp313-win_amd64.whl", hash = "sha256:fd3c71aeee838299c5887230b8a1822795325ddfea635edd82954c1eaa831e24"}, + {file = "grpcio-1.74.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:4bc5fca10aaf74779081e16c2bcc3d5ec643ffd528d9e7b1c9039000ead73bae"}, + {file = "grpcio-1.74.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:6bab67d15ad617aff094c382c882e0177637da73cbc5532d52c07b4ee887a87b"}, + {file = "grpcio-1.74.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:655726919b75ab3c34cdad39da5c530ac6fa32696fb23119e36b64adcfca174a"}, + {file = "grpcio-1.74.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a2b06afe2e50ebfd46247ac3ba60cac523f54ec7792ae9ba6073c12daf26f0a"}, + {file = "grpcio-1.74.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f251c355167b2360537cf17bea2cf0197995e551ab9da6a0a59b3da5e8704f9"}, + {file = "grpcio-1.74.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f7b5882fb50632ab1e48cb3122d6df55b9afabc265582808036b6e51b9fd6b7"}, + {file = "grpcio-1.74.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:834988b6c34515545b3edd13e902c1acdd9f2465d386ea5143fb558f153a7176"}, + {file = "grpcio-1.74.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:22b834cef33429ca6cc28303c9c327ba9a3fafecbf62fae17e9a7b7163cc43ac"}, + {file = "grpcio-1.74.0-cp39-cp39-win32.whl", hash = "sha256:7d95d71ff35291bab3f1c52f52f474c632db26ea12700c2ff0ea0532cb0b5854"}, + {file = "grpcio-1.74.0-cp39-cp39-win_amd64.whl", hash = "sha256:ecde9ab49f58433abe02f9ed076c7b5be839cf0153883a6d23995937a82392fa"}, + {file = "grpcio-1.74.0.tar.gz", hash = "sha256:80d1f4fbb35b0742d3e3d3bb654b7381cd5f015f8497279a1e9c21ba623e01b1"}, ] [package.extras] -protobuf = ["grpcio-tools (>=1.70.0)"] +protobuf = ["grpcio-tools (>=1.74.0)"] [[package]] name = "grpcio-status" -version = "1.70.0" +version = "1.74.0" description = "Status proto mapping for gRPC" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "grpcio_status-1.70.0-py3-none-any.whl", hash = "sha256:fc5a2ae2b9b1c1969cc49f3262676e6854aa2398ec69cb5bd6c47cd501904a85"}, - {file = "grpcio_status-1.70.0.tar.gz", hash = "sha256:0e7b42816512433b18b9d764285ff029bde059e9d41f8fe10a60631bd8348101"}, + {file = "grpcio_status-1.74.0-py3-none-any.whl", hash = "sha256:52cdbd759a6760fc8f668098a03f208f493dd5c76bf8e02598bbbaf1f6fc2876"}, + {file = "grpcio_status-1.74.0.tar.gz", hash = "sha256:c58c1b24aa454e30f1fc6a7e0dbbc194c54a408143971a94b5f4e40bb5831432"}, ] [package.dependencies] googleapis-common-protos = ">=1.5.5" -grpcio = ">=1.70.0" -protobuf = ">=5.26.1,<6.0dev" +grpcio = ">=1.74.0" +protobuf = ">=6.31.1,<7.0.0" [[package]] name = "h11" @@ -676,8 +697,8 @@ files = [ httpx = ">=0.25.0,<1" langchain-core = ">=0.3.63,<1.0.0" numpy = [ - {version = ">=1.26.2", markers = "python_version < \"3.13\""}, {version = ">=2.1.0", markers = "python_version >= \"3.13\""}, + {version = ">=1.26.2", markers = "python_version < \"3.13\""}, ] pytest = ">=7,<9" pytest-asyncio = ">=0.20,<1" @@ -1282,41 +1303,39 @@ files = [ [[package]] name = "proto-plus" -version = "1.26.0" +version = "1.26.1" description = "Beautiful, Pythonic protocol buffers" optional = false python-versions = ">=3.7" groups = ["main"] files = [ - {file = "proto_plus-1.26.0-py3-none-any.whl", hash = "sha256:bf2dfaa3da281fc3187d12d224c707cb57214fb2c22ba854eb0c105a3fb2d4d7"}, - {file = "proto_plus-1.26.0.tar.gz", hash = "sha256:6e93d5f5ca267b54300880fff156b6a3386b3fa3f43b1da62e680fc0c586ef22"}, + {file = "proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66"}, + {file = "proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012"}, ] [package.dependencies] -protobuf = ">=3.19.0,<6.0.0dev" +protobuf = ">=3.19.0,<7.0.0" [package.extras] testing = ["google-api-core (>=1.31.5)"] [[package]] name = "protobuf" -version = "5.29.3" +version = "6.31.1" description = "" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "protobuf-5.29.3-cp310-abi3-win32.whl", hash = "sha256:3ea51771449e1035f26069c4c7fd51fba990d07bc55ba80701c78f886bf9c888"}, - {file = "protobuf-5.29.3-cp310-abi3-win_amd64.whl", hash = "sha256:a4fa6f80816a9a0678429e84973f2f98cbc218cca434abe8db2ad0bffc98503a"}, - {file = "protobuf-5.29.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a8434404bbf139aa9e1300dbf989667a83d42ddda9153d8ab76e0d5dcaca484e"}, - {file = "protobuf-5.29.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:daaf63f70f25e8689c072cfad4334ca0ac1d1e05a92fc15c54eb9cf23c3efd84"}, - {file = "protobuf-5.29.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:c027e08a08be10b67c06bf2370b99c811c466398c357e615ca88c91c07f0910f"}, - {file = "protobuf-5.29.3-cp38-cp38-win32.whl", hash = "sha256:84a57163a0ccef3f96e4b6a20516cedcf5bb3a95a657131c5c3ac62200d23252"}, - {file = "protobuf-5.29.3-cp38-cp38-win_amd64.whl", hash = "sha256:b89c115d877892a512f79a8114564fb435943b59067615894c3b13cd3e1fa107"}, - {file = "protobuf-5.29.3-cp39-cp39-win32.whl", hash = "sha256:0eb32bfa5219fc8d4111803e9a690658aa2e6366384fd0851064b963b6d1f2a7"}, - {file = "protobuf-5.29.3-cp39-cp39-win_amd64.whl", hash = "sha256:6ce8cc3389a20693bfde6c6562e03474c40851b44975c9b2bf6df7d8c4f864da"}, - {file = "protobuf-5.29.3-py3-none-any.whl", hash = "sha256:0a18ed4a24198528f2333802eb075e59dea9d679ab7a6c5efb017a59004d849f"}, - {file = "protobuf-5.29.3.tar.gz", hash = "sha256:5da0f41edaf117bde316404bad1a486cb4ededf8e4a54891296f648e8e076620"}, + {file = "protobuf-6.31.1-cp310-abi3-win32.whl", hash = "sha256:7fa17d5a29c2e04b7d90e5e32388b8bfd0e7107cd8e616feef7ed3fa6bdab5c9"}, + {file = "protobuf-6.31.1-cp310-abi3-win_amd64.whl", hash = "sha256:426f59d2964864a1a366254fa703b8632dcec0790d8862d30034d8245e1cd447"}, + {file = "protobuf-6.31.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:6f1227473dc43d44ed644425268eb7c2e488ae245d51c6866d19fe158e207402"}, + {file = "protobuf-6.31.1-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:a40fc12b84c154884d7d4c4ebd675d5b3b5283e155f324049ae396b95ddebc39"}, + {file = "protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:4ee898bf66f7a8b0bd21bce523814e6fbd8c6add948045ce958b73af7e8878c6"}, + {file = "protobuf-6.31.1-cp39-cp39-win32.whl", hash = "sha256:0414e3aa5a5f3ff423828e1e6a6e907d6c65c1d5b7e6e975793d5590bdeecc16"}, + {file = "protobuf-6.31.1-cp39-cp39-win_amd64.whl", hash = "sha256:8764cf4587791e7564051b35524b72844f845ad0bb011704c3736cce762d8fe9"}, + {file = "protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e"}, + {file = "protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a"}, ] [[package]] @@ -1921,14 +1940,14 @@ pytest = ">=7.0.0,<9.0.0" [[package]] name = "tenacity" -version = "9.0.0" +version = "8.5.0" description = "Retry code until it succeeds" optional = false python-versions = ">=3.8" groups = ["main", "test"] files = [ - {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, - {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, + {file = "tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687"}, + {file = "tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78"}, ] [package.extras] @@ -2111,7 +2130,7 @@ files = [ [package.dependencies] PyYAML = "*" urllib3 = [ - {version = "<2", markers = "python_version < \"3.10\" or platform_python_implementation == \"PyPy\""}, + {version = "<2", markers = "platform_python_implementation == \"PyPy\""}, {version = "*", markers = "platform_python_implementation != \"PyPy\" and python_version >= \"3.10\""}, ] wrapt = "*" @@ -2163,6 +2182,85 @@ files = [ [package.extras] watchmedo = ["PyYAML (>=3.10)"] +[[package]] +name = "websockets" +version = "15.0.1" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b"}, + {file = "websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205"}, + {file = "websockets-15.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5756779642579d902eed757b21b0164cd6fe338506a8083eb58af5c372e39d9a"}, + {file = "websockets-15.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdfe3e2a29e4db3659dbd5bbf04560cea53dd9610273917799f1cde46aa725e"}, + {file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c2529b320eb9e35af0fa3016c187dffb84a3ecc572bcee7c3ce302bfeba52bf"}, + {file = "websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac1e5c9054fe23226fb11e05a6e630837f074174c4c2f0fe442996112a6de4fb"}, + {file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5df592cd503496351d6dc14f7cdad49f268d8e618f80dce0cd5a36b93c3fc08d"}, + {file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a34631031a8f05657e8e90903e656959234f3a04552259458aac0b0f9ae6fd9"}, + {file = "websockets-15.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3d00075aa65772e7ce9e990cab3ff1de702aa09be3940d1dc88d5abf1ab8a09c"}, + {file = "websockets-15.0.1-cp310-cp310-win32.whl", hash = "sha256:1234d4ef35db82f5446dca8e35a7da7964d02c127b095e172e54397fb6a6c256"}, + {file = "websockets-15.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:39c1fec2c11dc8d89bba6b2bf1556af381611a173ac2b511cf7231622058af41"}, + {file = "websockets-15.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:823c248b690b2fd9303ba00c4f66cd5e2d8c3ba4aa968b2779be9532a4dad431"}, + {file = "websockets-15.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678999709e68425ae2593acf2e3ebcbcf2e69885a5ee78f9eb80e6e371f1bf57"}, + {file = "websockets-15.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d50fd1ee42388dcfb2b3676132c78116490976f1300da28eb629272d5d93e905"}, + {file = "websockets-15.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d99e5546bf73dbad5bf3547174cd6cb8ba7273062a23808ffea025ecb1cf8562"}, + {file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66dd88c918e3287efc22409d426c8f729688d89a0c587c88971a0faa2c2f3792"}, + {file = "websockets-15.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8dd8327c795b3e3f219760fa603dcae1dcc148172290a8ab15158cf85a953413"}, + {file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fdc51055e6ff4adeb88d58a11042ec9a5eae317a0a53d12c062c8a8865909e8"}, + {file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:693f0192126df6c2327cce3baa7c06f2a117575e32ab2308f7f8216c29d9e2e3"}, + {file = "websockets-15.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:54479983bd5fb469c38f2f5c7e3a24f9a4e70594cd68cd1fa6b9340dadaff7cf"}, + {file = "websockets-15.0.1-cp311-cp311-win32.whl", hash = "sha256:16b6c1b3e57799b9d38427dda63edcbe4926352c47cf88588c0be4ace18dac85"}, + {file = "websockets-15.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:27ccee0071a0e75d22cb35849b1db43f2ecd3e161041ac1ee9d2352ddf72f065"}, + {file = "websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3"}, + {file = "websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665"}, + {file = "websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2"}, + {file = "websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215"}, + {file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5"}, + {file = "websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65"}, + {file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe"}, + {file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4"}, + {file = "websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597"}, + {file = "websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9"}, + {file = "websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7"}, + {file = "websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931"}, + {file = "websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675"}, + {file = "websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151"}, + {file = "websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22"}, + {file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f"}, + {file = "websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8"}, + {file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375"}, + {file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d"}, + {file = "websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4"}, + {file = "websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa"}, + {file = "websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561"}, + {file = "websockets-15.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5f4c04ead5aed67c8a1a20491d54cdfba5884507a48dd798ecaf13c74c4489f5"}, + {file = "websockets-15.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abdc0c6c8c648b4805c5eacd131910d2a7f6455dfd3becab248ef108e89ab16a"}, + {file = "websockets-15.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a625e06551975f4b7ea7102bc43895b90742746797e2e14b70ed61c43a90f09b"}, + {file = "websockets-15.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d591f8de75824cbb7acad4e05d2d710484f15f29d4a915092675ad3456f11770"}, + {file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:47819cea040f31d670cc8d324bb6435c6f133b8c7a19ec3d61634e62f8d8f9eb"}, + {file = "websockets-15.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac017dd64572e5c3bd01939121e4d16cf30e5d7e110a119399cf3133b63ad054"}, + {file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4a9fac8e469d04ce6c25bb2610dc535235bd4aa14996b4e6dbebf5e007eba5ee"}, + {file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363c6f671b761efcb30608d24925a382497c12c506b51661883c3e22337265ed"}, + {file = "websockets-15.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2034693ad3097d5355bfdacfffcbd3ef5694f9718ab7f29c29689a9eae841880"}, + {file = "websockets-15.0.1-cp39-cp39-win32.whl", hash = "sha256:3b1ac0d3e594bf121308112697cf4b32be538fb1444468fb0a6ae4feebc83411"}, + {file = "websockets-15.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7643a03db5c95c799b89b31c036d5f27eeb4d259c798e878d6937d71832b1e4"}, + {file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0c9e74d766f2818bb95f84c25be4dea09841ac0f734d1966f415e4edfc4ef1c3"}, + {file = "websockets-15.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1009ee0c7739c08a0cd59de430d6de452a55e42d6b522de7aa15e6f67db0b8e1"}, + {file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d1f20b1c7a2fa82367e04982e708723ba0e7b8d43aa643d3dcd404d74f1475"}, + {file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f29d80eb9a9263b8d109135351caf568cc3f80b9928bccde535c235de55c22d9"}, + {file = "websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04"}, + {file = "websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122"}, + {file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7f493881579c90fc262d9cdbaa05a6b54b3811c2f300766748db79f098db9940"}, + {file = "websockets-15.0.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:47b099e1f4fbc95b701b6e85768e1fcdaf1630f3cbe4765fa216596f12310e2e"}, + {file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67f2b6de947f8c757db2db9c71527933ad0019737ec374a8a6be9a956786aaf9"}, + {file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d08eb4c2b7d6c41da6ca0600c077e93f5adcfd979cd777d747e9ee624556da4b"}, + {file = "websockets-15.0.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b826973a4a2ae47ba357e4e82fa44a463b8f168e1ca775ac64521442b19e87f"}, + {file = "websockets-15.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:21c1fa28a6a7e3cbdc171c694398b6df4744613ce9b36b1a498e816787e28123"}, + {file = "websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f"}, + {file = "websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee"}, +] + [[package]] name = "wrapt" version = "1.17.2" @@ -2508,4 +2606,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<4.0" -content-hash = "f42832bca1df9dfcf9ab7cf07135292f67dd6e96f79bc1883d04e1e4eb2c2733" +content-hash = "760a7d9bface17546e685ce6157a1708e2df2d8500f8eff15cc3749664b3c89f" diff --git a/libs/genai/pyproject.toml b/libs/genai/pyproject.toml index 3209e84aa..994f3f8de 100644 --- a/libs/genai/pyproject.toml +++ b/libs/genai/pyproject.toml @@ -13,6 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.9,<4.0" langchain-core = "^0.3.68" +google-genai = "^1.27.0" google-ai-generativelanguage = "^0.6.18" pydantic = ">=2,<3" filetype = "^1.2.0" diff --git a/libs/genai/tests/integration_tests/test_chat_models.py b/libs/genai/tests/integration_tests/test_chat_models.py index 31d3fa436..34a33b3e5 100644 --- a/libs/genai/tests/integration_tests/test_chat_models.py +++ b/libs/genai/tests/integration_tests/test_chat_models.py @@ -5,6 +5,7 @@ from typing import Dict, Generator, List, Literal, Optional import pytest +from google.genai.types import Tool as GoogleTool from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -23,10 +24,11 @@ HarmCategory, Modality, ) +from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError _MODEL = "models/gemini-1.5-flash-latest" _VISION_MODEL = "models/gemini-2.0-flash-001" -_IMAGE_OUTPUT_MODEL = "models/gemini-2.0-flash-exp-image-generation" +_IMAGE_OUTPUT_MODEL = "models/gemini-2.0-flash-preview-image-generation" _AUDIO_OUTPUT_MODEL = "models/gemini-2.5-flash-preview-tts" _THINKING_MODEL = "models/gemini-2.5-flash" _B64_string = """iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABhGlDQ1BJQ0MgUHJvZmlsZQAAeJx9kT1Iw0AcxV8/xCIVQTuIKGSoTi2IijhqFYpQIdQKrTqYXPoFTRqSFBdHwbXg4Mdi1cHFWVcHV0EQ/ABxdXFSdJES/5cUWsR4cNyPd/ced+8Af6PCVDM4DqiaZaSTCSGbWxW6XxHECPoRQ0hipj4niil4jq97+Ph6F+dZ3uf+HL1K3mSATyCeZbphEW8QT29aOud94ggrSQrxOXHMoAsSP3JddvmNc9FhP8+MGJn0PHGEWCh2sNzBrGSoxFPEUUXVKN+fdVnhvMVZrdRY6578heG8trLMdZrDSGIRSxAhQEYNZVRgIU6rRoqJNO0nPPxDjl8kl0yuMhg5FlCFCsnxg//B727NwuSEmxROAF0vtv0xCnTvAs26bX8f23bzBAg8A1da219tADOfpNfbWvQI6NsGLq7bmrwHXO4Ag0+6ZEiOFKDpLxSA9zP6phwwcAv0rLm9tfZx+gBkqKvUDXBwCIwVKXvd492hzt7+PdPq7wdzbXKn5swsVgAAA8lJREFUeJx90dtPHHUUB/Dz+81vZhb2wrDI3soUKBSRcisF21iqqCRNY01NTE0k8aHpi0k18VJfjOFvUF9M44MmGrHFQqSQiKSmFloL5c4CXW6Fhb0vO3ufvczMzweiBGI9+eW8ffI95/yQqqrwv4UxBgCfJ9w/2NfSVB+Nyn6/r+vdLo7H6FkYY6yoABR2PJujj34MSo/d/nHeVLYbydmIp/bEO0fEy/+NMcbTU4/j4Vs6Lr0ccKeYuUKWS4ABVCVHmRdszbfvTgfjR8kz5Jjs+9RREl9Zy2lbVK9wU3/kWLJLCXnqza1bfVe7b9jLbIeTMcYu13Jg/aMiPrCwVFcgtDiMhnxwJ/zXVDwSdVCVMRV7nqzl2i9e/fKrw8mqSp84e2sFj3Oj8/SrF/MaicmyYhAaXu58NPAbeAeyzY0NLecmh2+ODN3BewYBAkAY43giI3kebrnsRmvV9z2D4ciOa3EBAf31Tp9sMgdxMTFm6j74/Ogb70VCYQKAAIDCXkOAIC6pkYBWdwwnpHEdf6L9dJtJKPh95DZhzFKMEWRAGL927XpWTmMA+s8DAOBYAoR483l/iHZ/8bXoODl8b9UfyH72SXepzbyRJNvjFGHKMlhvMBze+cH9+4lEuOOlU2X1tVkFTU7Om03q080NDGXV1cflRpHwaaoiiiildB8jhDLZ7HDfz2Yidba6Vn2L4fhzFrNRKy5OZ2QOZ1U5W8VtqlVH/iUHcM933zZYWS7Wtj66zZr65bzGJQt0glHgudi9XVzEl4vKw2kUPhO020oPYI1qYc+2Xc0bRXFwTLY0VXa2VibD/lBaIXm1UChN5JSRUcQQ1Tk/47Cf3x8bY7y17Y17PVYTG1UkLPBFcqik7Zoa9JcLYoHBqHhXNgd6gS1k9EJ1TQ2l9EDy1saErmQ2kGpwGC2MLOtCM8nZEV1K0tKJtEksSm26J/rHg2zzmabKisq939nHzqUH7efzd4f/nPGW6NP8ybNFrOsWQhpoCuuhnJ4hAnPhFam01K4oQMjBg/mzBjVhuvw2O++KKT+BIVxJKzQECBDLF2qu2WTMmCovtDQ1f8iyoGkUADBCCGPsdnvTW2OtFm01VeB06msvdWlpPZU0wJRG85ns84umU3k+VyxeEcWqvYUBAGsUrbvme4be99HFeisP/pwUOIZaOqQX31ISgrKmZhLHtXNXuJq68orrr5/9mBCglCLAGGPyy81votEbcjlKLrC9E8mhH3wdHRdcyyvjidSlxjftPJpD+o25JYvRHGFoZDdks1mBQhxJu9uxvwEiXuHnHbLd1AAAAABJRU5ErkJggg==""" # noqa: E501 @@ -365,16 +367,14 @@ def test_chat_google_genai_invoke_no_image_generation_without_modalities() -> No """Test invoke tokens with image from ChatGoogleGenerativeAI without response modalities.""" llm = ChatGoogleGenerativeAI(model=_IMAGE_OUTPUT_MODEL) - - result = llm.invoke( - "Generate an image of a cat. Then, say meow!", - config=dict(tags=["meow"]), - generation_config=dict(top_k=2, top_p=1, temperature=0.7), - ) - assert isinstance(result, AIMessage) - assert isinstance(result.content, str) - assert not result.content.startswith(" ") - _check_usage_metadata(result) + # NOTE: This API changed with `google-genai` and now raises an error + # instead of returning a response with no image. + with pytest.raises(ChatGoogleGenerativeAIError): + llm.invoke( + "Generate an image of a cat. Then, say meow!", + config=dict(tags=["meow"]), + generation_config=dict(top_k=2, top_p=1, temperature=0.7), + ) @pytest.mark.xfail(reason=("investigate")) @@ -828,7 +828,7 @@ async def model_astream(context: str) -> List[BaseMessageChunk]: def test_search_builtin() -> None: llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001").bind_tools( - [{"google_search": {}}] + [{"google_search": {}}] # type: ignore[arg-type] ) input_message = { "role": "user", @@ -853,9 +853,19 @@ def test_search_builtin() -> None: _ = llm.invoke([input_message, full, next_message]) +def test_search_with_googletool() -> None: + """Test using GoogleTool with Google Search.""" + llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001") + resp = llm.invoke( + "When is the next total solar eclipse in US?", + tools=[GoogleTool(google_search={})], # type: ignore[arg-type] + ) + assert "grounding_metadata" in resp.response_metadata + + def test_code_execution_builtin() -> None: llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001").bind_tools( - [{"code_execution": {}}] + [{"code_execution": {}}] # type: ignore[arg-type] ) input_message = { "role": "user", diff --git a/libs/genai/tests/unit_tests/test_chat_models.py b/libs/genai/tests/unit_tests/test_chat_models.py index 8ea33c529..b8e5e71ab 100644 --- a/libs/genai/tests/unit_tests/test_chat_models.py +++ b/libs/genai/tests/unit_tests/test_chat_models.py @@ -1,21 +1,27 @@ """Test chat model integration.""" -import asyncio import base64 import json from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from unittest.mock import ANY, Mock, patch -import google.ai.generativelanguage as glm import pytest -from google.ai.generativelanguage_v1beta.types import ( +from google.api_core.exceptions import ResourceExhausted +from google.genai.types import ( Candidate, Content, + FunctionCall, + FunctionResponse, GenerateContentResponse, + GenerateContentResponseUsageMetadata, + HttpOptions, + Language, Part, ) -from google.api_core.exceptions import ResourceExhausted +from google.genai.types import ( + Outcome as CodeExecutionResultOutcome, +) from langchain_core.load import dumps, loads from langchain_core.messages import ( AIMessage, @@ -27,17 +33,20 @@ from langchain_core.messages.tool import tool_call as create_tool_call from pydantic import SecretStr from pydantic_core._pydantic_core import ValidationError -from pytest import CaptureFixture from langchain_google_genai.chat_models import ( ChatGoogleGenerativeAI, _chat_with_retry, + _convert_to_parts, _convert_tool_message_to_parts, + _get_ai_message_tool_messages_parts, _parse_chat_history, _parse_response_candidate, _response_to_result, ) +SMALL_VIEWABLE_BASE64_IMAGE = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII=" # noqa: E501 + def test_integration_initialization() -> None: """Test chat model initialization.""" @@ -106,28 +115,6 @@ def test_initialization_inside_threadpool() -> None: ).result() -def test_initalization_without_async() -> None: - chat = ChatGoogleGenerativeAI( - model="gemini-nano", - google_api_key=SecretStr("secret-api-key"), # type: ignore[call-arg] - ) - assert chat.async_client is None - - -def test_initialization_with_async() -> None: - async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI: - model = ChatGoogleGenerativeAI( - model="gemini-nano", - google_api_key=SecretStr("secret-api-key"), # type: ignore[call-arg] - ) - _ = model.async_client - return model - - loop = asyncio.get_event_loop() - chat = loop.run_until_complete(initialize_chat_with_async_client()) - assert chat.async_client is not None - - def test_api_key_is_string() -> None: chat = ChatGoogleGenerativeAI( model="gemini-nano", @@ -136,15 +123,26 @@ def test_api_key_is_string() -> None: assert isinstance(chat.google_api_key, SecretStr) -def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None: +def test_base_url_set_in_constructor() -> None: chat = ChatGoogleGenerativeAI( model="gemini-nano", google_api_key=SecretStr("secret-api-key"), # type: ignore[call-arg] + base_url="http://localhost:8000", ) - print(chat.google_api_key, end="") # noqa: T201 - captured = capsys.readouterr() + assert chat.base_url == "http://localhost:8000" - assert captured.out == "**********" + +def test_base_url_passed_to_client() -> None: + with patch("langchain_google_genai.chat_models.Client") as mock_client: + ChatGoogleGenerativeAI( + model="gemini-nano", + google_api_key=SecretStr("secret-api-key"), # type: ignore[call-arg] + base_url="http://localhost:8000", + ) + mock_client.assert_called_once_with( + api_key="secret-api-key", + http_options=HttpOptions(base_url="http://localhost:8000", headers={}), + ) @pytest.mark.parametrize("convert_system_message_to_human", [False, True]) @@ -218,113 +216,95 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: ) assert len(history) == 8 if convert_system_message_to_human: - assert history[0] == glm.Content( + assert history[0] == Content( role="user", - parts=[glm.Part(text=system_input), glm.Part(text=text_question1)], + parts=[Part(text=system_input), Part(text=text_question1)], ) else: - assert history[0] == glm.Content( - role="user", parts=[glm.Part(text=text_question1)] - ) - assert history[1] == glm.Content( + assert history[0] == Content(role="user", parts=[Part(text=text_question1)]) + assert history[1] == Content( role="model", parts=[ - glm.Part( - function_call=glm.FunctionCall( - { - "name": "calculator", - "args": function_call_1["args"], - } + Part( + function_call=FunctionCall( + name="calculator", + args=function_call_1["args"], # type: ignore[arg-type] ) ) ], ) - assert history[2] == glm.Content( + assert history[2] == Content( role="user", parts=[ - glm.Part( - function_response=glm.FunctionResponse( - { - "name": "calculator", - "response": {"result": 4}, - } + Part( + function_response=FunctionResponse( + name="calculator", + response={"result": 4}, ) ) ], ) - assert history[3] == glm.Content( + assert history[3] == Content( role="model", parts=[ - glm.Part( - function_call=glm.FunctionCall( - { - "name": "calculator", - "args": json.loads(function_call_2["arguments"]), - } + Part( + function_call=FunctionCall( + name="calculator", + args=json.loads(function_call_2["arguments"]), ) ) ], ) - assert history[4] == glm.Content( + assert history[4] == Content( role="user", parts=[ - glm.Part( - function_response=glm.FunctionResponse( - { - "name": "calculator", - "response": {"result": 4}, - } + Part( + function_response=FunctionResponse( + name="calculator", + response={"result": 4}, ) ) ], ) - assert history[5] == glm.Content( + assert history[5] == Content( role="model", parts=[ - glm.Part( - function_call=glm.FunctionCall( - { - "name": "calculator", - "args": function_call_3["args"], - } + Part( + function_call=FunctionCall( + name="calculator", + args=function_call_3["args"], # type: ignore[arg-type] ) ), - glm.Part( - function_call=glm.FunctionCall( - { - "name": "calculator", - "args": function_call_4["args"], - } + Part( + function_call=FunctionCall( + name="calculator", + args=function_call_4["args"], # type: ignore[arg-type] ) ), ], ) - assert history[6] == glm.Content( + assert history[6] == Content( role="user", parts=[ - glm.Part( - function_response=glm.FunctionResponse( - { - "name": "calculator", - "response": {"result": 4}, - } + Part( + function_response=FunctionResponse( + name="calculator", + response={"result": 4}, ) ), - glm.Part( - function_response=glm.FunctionResponse( - { - "name": "calculator", - "response": {"result": 6}, - } + Part( + function_response=FunctionResponse( + name="calculator", + response={"result": 6}, ) ), ], ) - assert history[7] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)]) + assert history[7] == Content(role="model", parts=[Part(text=text_answer1)]) if convert_system_message_to_human: assert system_instruction is None else: - assert system_instruction == glm.Content(parts=[glm.Part(text=system_input)]) + assert system_instruction == Content(parts=[Part(text=system_input)]) @pytest.mark.parametrize("content", ['["a"]', '{"a":"b"}', "function output"]) @@ -338,26 +318,27 @@ def test_parse_function_history(content: Union[str, List[Union[str, Dict]]]) -> ) def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: mock_client = Mock() + mock_models = Mock() mock_generate_content = Mock() mock_generate_content.return_value = GenerateContentResponse( - candidates=[Candidate(content=Content(parts=[Part(text="test response")]))] + candidates=[Candidate(content=Content(parts=[Part(text="test response")]))], + usage_metadata=GenerateContentResponseUsageMetadata( + prompt_token_count=10, + candidates_token_count=5, + total_token_count=15, + ), ) - mock_client.return_value.generate_content = mock_generate_content + mock_models.generate_content = mock_generate_content + mock_client.return_value.models = mock_models api_endpoint = "http://127.0.0.1:8000/ai" param_api_key = "[secret]" param_secret_api_key = SecretStr(param_api_key) - param_client_options = {"api_endpoint": api_endpoint} - param_transport = "rest" - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", - mock_client, - ): + with patch("langchain_google_genai.chat_models.Client", mock_client): chat = ChatGoogleGenerativeAI( model="gemini-pro", google_api_key=param_secret_api_key, # type: ignore[call-arg] - client_options=param_client_options, - transport=param_transport, + base_url=api_endpoint, additional_headers=headers, ) @@ -374,16 +355,15 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: assert response.content == "test response" mock_client.assert_called_once_with( - transport=param_transport, - client_options=ANY, - client_info=ANY, + api_key=param_api_key, + http_options=ANY, ) - call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - assert call_client_options.api_key == param_api_key - assert call_client_options.api_endpoint == api_endpoint - call_client_info = mock_client.call_args_list[0].kwargs["client_info"] - assert "langchain-google-genai" in call_client_info.user_agent - assert "ChatGoogleGenerativeAI" in call_client_info.user_agent + call_http_options = mock_client.call_args_list[0].kwargs["http_options"] + assert call_http_options.base_url == api_endpoint + if headers: + assert call_http_options.headers == headers + else: + assert call_http_options.headers == {} @pytest.mark.parametrize( @@ -471,7 +451,7 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={"name": "Ben"} ) } @@ -500,7 +480,7 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={"info": ["A", "B", "C"]}, ) @@ -530,7 +510,7 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={ "people": [ @@ -577,7 +557,7 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={"info": [[1, 2, 3], [4, 5, 6]]}, ) @@ -608,8 +588,9 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: "parts": [ {"text": "Mike age is 30"}, { - "function_call": glm.FunctionCall( - name="Information", args={"name": "Ben"} + "function_call": FunctionCall( + name="Information", + args={"name": "Ben"}, ) }, ] @@ -637,8 +618,9 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( - name="Information", args={"name": "Ben"} + "function_call": FunctionCall( + name="Information", + args={"name": "Ben"}, ) }, {"text": "Mike age is 30"}, @@ -667,7 +649,7 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: def test_parse_response_candidate(raw_candidate: Dict, expected: AIMessage) -> None: with patch("langchain_google_genai.chat_models.uuid.uuid4") as uuid4: uuid4.return_value = "00000000-0000-0000-0000-00000000000" - response_candidate = glm.Candidate(raw_candidate) + response_candidate = Candidate.model_validate(raw_candidate) result = _parse_response_candidate(response_candidate) assert result.content == expected.content assert result.tool_calls == expected.tool_calls @@ -719,10 +701,40 @@ def test__convert_tool_message_to_parts__sets_tool_name( parts = _convert_tool_message_to_parts(tool_message) assert len(parts) == 1 part = parts[0] + assert part.function_response is not None assert part.function_response.name == "tool_name" assert part.function_response.response == {"output": "test_content"} +def test_supports_thinking() -> None: + """Test that _supports_thinking correctly identifies model capabilities.""" + # Test models that don't support thinking + llm_image_gen = ChatGoogleGenerativeAI( + model="models/gemini-2.0-flash-preview-image-generation", + google_api_key=SecretStr("..."), # type: ignore[call-arg] + ) + assert not llm_image_gen._supports_thinking() + + llm_tts = ChatGoogleGenerativeAI( + model="models/gemini-2.5-flash-preview-tts", + google_api_key=SecretStr("..."), # type: ignore[call-arg] + ) + assert not llm_tts._supports_thinking() + + # Test models that do support thinking + llm_normal = ChatGoogleGenerativeAI( + model="models/gemini-2.5-flash", + google_api_key=SecretStr("..."), # type: ignore[call-arg] + ) + assert llm_normal._supports_thinking() + + llm_15 = ChatGoogleGenerativeAI( + model="models/gemini-1.5-flash-latest", + google_api_key=SecretStr("..."), # type: ignore[call-arg] + ) + assert llm_15._supports_thinking() + + def test_temperature_range_pydantic_validation() -> None: """Test that temperature is in the range [0.0, 2.0]""" @@ -756,8 +768,10 @@ def test_temperature_range_model_validation() -> None: ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=-0.5) -def test_model_kwargs() -> None: +@patch("langchain_google_genai.chat_models.Client") +def test_model_kwargs(mock_client: Mock) -> None: """Test we can transfer unknown params to model_kwargs.""" + llm = ChatGoogleGenerativeAI( model="my-model", convert_system_message_to_human=True, @@ -773,9 +787,9 @@ def test_model_kwargs() -> None: convert_system_message_to_human=True, foo="bar", ) - assert llm.model == "models/my-model" - assert llm.convert_system_message_to_human is True - assert llm.model_kwargs == {"foo": "bar"} + assert llm.model == "models/my-model" + assert llm.convert_system_message_to_human is True + assert llm.model_kwargs == {"foo": "bar"} def test_retry_decorator_with_custom_parameters() -> None: @@ -830,7 +844,10 @@ def test_retry_decorator_with_custom_parameters() -> None: }, } ], - "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "prompt_feedback": { + "block_reason": "BLOCKED_REASON_UNSPECIFIED", + "safety_ratings": [], + }, "usage_metadata": { "prompt_token_count": 10, "candidates_token_count": 5, @@ -839,7 +856,14 @@ def test_retry_decorator_with_custom_parameters() -> None: }, { "grounding_chunks": [ - {"web": {"uri": "https://example.com", "title": "Example Site"}} + { + "retrieved_context": None, + "web": { + "domain": None, + "uri": "https://example.com", + "title": "Example Site", + }, + } ], "grounding_supports": [ { @@ -847,12 +871,15 @@ def test_retry_decorator_with_custom_parameters() -> None: "start_index": 0, "end_index": 13, "text": "Test response", - "part_index": 0, + "part_index": None, }, "grounding_chunk_indices": [0], "confidence_scores": [0.95], } ], + "retrieval_metadata": None, + "retrieval_queries": None, + "search_entry_point": None, "web_search_queries": ["test query"], }, ), @@ -864,7 +891,10 @@ def test_retry_decorator_with_custom_parameters() -> None: "content": {"parts": [{"text": "Test response"}]}, } ], - "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "prompt_feedback": { + "block_reason": "BLOCKED_REASON_UNSPECIFIED", + "safety_ratings": [], + }, "usage_metadata": { "prompt_token_count": 10, "candidates_token_count": 5, @@ -879,7 +909,7 @@ def test_response_to_result_grounding_metadata( raw_response: Dict, expected_grounding_metadata: Dict ) -> None: """Test that _response_to_result includes grounding_metadata in the response.""" - response = GenerateContentResponse(raw_response) + response = GenerateContentResponse.model_validate(raw_response) result = _response_to_result(response, stream=False) assert len(result.generations) == len(raw_response["candidates"]) @@ -891,3 +921,611 @@ def test_response_to_result_grounding_metadata( else {} ) assert grounding_metadata == expected_grounding_metadata + + +def test_convert_to_parts_text_only() -> None: + """Test _convert_to_parts with text content.""" + # Test single string + result = _convert_to_parts("Hello, world!") + assert len(result) == 1 + assert result[0].text == "Hello, world!" + assert result[0].inline_data is None + + # Test list of strings + result = _convert_to_parts(["Hello", "world", "!"]) + assert len(result) == 3 + assert result[0].text == "Hello" + assert result[1].text == "world" + assert result[2].text == "!" + + +def test_convert_to_parts_text_content_block() -> None: + """Test _convert_to_parts with text content blocks.""" + content = [{"type": "text", "text": "Hello, world!"}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].text == "Hello, world!" + + +def test_convert_to_parts_image_url() -> None: + """Test _convert_to_parts with image_url content blocks.""" + content = [{"type": "image_url", "image_url": {"url": SMALL_VIEWABLE_BASE64_IMAGE}}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "image/png" + + +def test_convert_to_parts_image_url_string() -> None: + """Test _convert_to_parts with image_url as string.""" + content = [{"type": "image_url", "image_url": SMALL_VIEWABLE_BASE64_IMAGE}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "image/png" + + +def test_convert_to_parts_file_data_url() -> None: + """Test _convert_to_parts with file data URL.""" + content = [ + { + "type": "file", + "source_type": "url", + "url": "https://example.com/image.jpg", + "mime_type": "image/jpeg", + } + ] + with patch("langchain_google_genai.chat_models.ImageBytesLoader") as mock_loader: + mock_loader_instance = Mock() + mock_loader_instance._bytes_from_url.return_value = b"fake_image_data" + mock_loader.return_value = mock_loader_instance + + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "image/jpeg" + assert result[0].inline_data.data == b"fake_image_data" + + +def test_convert_to_parts_file_data_base64() -> None: + """Test _convert_to_parts with file data base64.""" + content = [ + { + "type": "file", + "source_type": "base64", + "data": "SGVsbG8gV29ybGQ=", # "Hello World" in base64 + "mime_type": "text/plain", + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "text/plain" + assert result[0].inline_data.data == b"Hello World" + + +def test_convert_to_parts_file_data_auto_mime_type() -> None: + """Test _convert_to_parts with auto-detected mime type.""" + content = [ + { + "type": "file", + "source_type": "base64", + "data": "SGVsbG8gV29ybGQ=", + # No mime_type specified, should be auto-detected + } + ] + with patch("langchain_google_genai.chat_models.mimetypes.guess_type") as mock_guess: + mock_guess.return_value = ("text/plain", None) + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "text/plain" + + +def test_convert_to_parts_media_with_data() -> None: + """Test _convert_to_parts with media type containing data.""" + content = [{"type": "media", "mime_type": "video/mp4", "data": b"fake_video_data"}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "video/mp4" + assert result[0].inline_data.data == b"fake_video_data" + + +def test_convert_to_parts_media_with_file_uri() -> None: + """Test _convert_to_parts with media type containing file_uri.""" + content = [ + { + "type": "media", + "mime_type": "application/pdf", + "file_uri": "gs://bucket/file.pdf", + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].file_data is not None + assert result[0].file_data.mime_type == "application/pdf" + assert result[0].file_data.file_uri == "gs://bucket/file.pdf" + + +def test_convert_to_parts_media_with_video_metadata() -> None: + """Test _convert_to_parts with media type containing video metadata.""" + content = [ + { + "type": "media", + "mime_type": "video/mp4", + "file_uri": "gs://bucket/video.mp4", + "video_metadata": {"start_offset": "10s", "end_offset": "20s"}, + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].file_data is not None + assert result[0].video_metadata is not None + assert result[0].video_metadata.start_offset == "10s" + assert result[0].video_metadata.end_offset == "20s" + + +def test_convert_to_parts_executable_code() -> None: + """Test _convert_to_parts with executable code.""" + content = [ + { + "type": "executable_code", + "language": "python", + "executable_code": "print('Hello, World!')", + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].executable_code is not None + assert result[0].executable_code.language == Language.PYTHON + assert result[0].executable_code.code == "print('Hello, World!')" + + +def test_convert_to_parts_code_execution_result() -> None: + """Test _convert_to_parts with code execution result.""" + content = [ + { + "type": "code_execution_result", + "code_execution_result": "Hello, World!", + "outcome": CodeExecutionResultOutcome.OUTCOME_OK, + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].code_execution_result is not None + assert result[0].code_execution_result.output == "Hello, World!" + assert ( + result[0].code_execution_result.outcome == CodeExecutionResultOutcome.OUTCOME_OK + ) + + +def test_convert_to_parts_code_execution_result_backward_compatibility() -> None: + """Test _convert_to_parts with code execution result without outcome (compat).""" + content = [ + { + "type": "code_execution_result", + "code_execution_result": "Hello, World!", + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].code_execution_result is not None + assert result[0].code_execution_result.output == "Hello, World!" + assert ( + result[0].code_execution_result.outcome == CodeExecutionResultOutcome.OUTCOME_OK + ) + + +def test_convert_to_parts_thinking() -> None: + """Test _convert_to_parts with thinking content.""" + content = [{"type": "thinking", "thinking": "I need to think about this..."}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].text == "I need to think about this..." + assert result[0].thought is True + + +def test_convert_to_parts_mixed_content() -> None: + """Test _convert_to_parts with mixed content types.""" + content: List[Dict[str, Any]] = [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "World"}, + {"type": "image_url", "image_url": {"url": SMALL_VIEWABLE_BASE64_IMAGE}}, + ] + result = _convert_to_parts(content) + assert len(result) == 3 + assert result[0].text == "Hello" + assert result[1].text == "World" + assert result[2].inline_data is not None + + +def test_convert_to_parts_invalid_type() -> None: + """Test _convert_to_parts with invalid source_type.""" + content = [ + { + "type": "file", + "source_type": "invalid", + "data": "some_data", + } + ] + with pytest.raises(ValueError, match="Unrecognized message part type: file"): + _convert_to_parts(content) + + +def test_convert_to_parts_invalid_source_type() -> None: + """Test _convert_to_parts with invalid source_type.""" + content = [ + { + "type": "media", + "source_type": "invalid", + "data": "some_data", + "mime_type": "text/plain", + } + ] + with pytest.raises(ValueError, match="Data should be valid base64"): + _convert_to_parts(content) + + +def test_convert_to_parts_invalid_image_url_format() -> None: + """Test _convert_to_parts with invalid image_url format.""" + content = [{"type": "image_url", "image_url": {"invalid_key": "value"}}] + with pytest.raises(ValueError, match="Unrecognized message image format"): + _convert_to_parts(content) + + +def test_convert_to_parts_missing_mime_type_in_media() -> None: + """Test _convert_to_parts with missing mime_type in media.""" + content = [ + { + "type": "media", + "file_uri": "gs://bucket/file.pdf", + # Missing mime_type + } + ] + with pytest.raises(ValueError, match="Missing mime_type in media part"): + _convert_to_parts(content) + + +def test_convert_to_parts_media_missing_data_and_file_uri() -> None: + """Test _convert_to_parts with media missing both data and file_uri.""" + content = [ + { + "type": "media", + "mime_type": "application/pdf", + # Missing both data and file_uri + } + ] + with pytest.raises( + ValueError, match="Media part must have either data or file_uri" + ): + _convert_to_parts(content) + + +def test_convert_to_parts_missing_executable_code_keys() -> None: + """Test _convert_to_parts with missing keys in executable_code.""" + content = [ + { + "type": "executable_code", + "language": "python", + # Missing executable_code key + } + ] + with pytest.raises( + ValueError, match="Executable code part must have 'code' and 'language'" + ): + _convert_to_parts(content) + + +def test_convert_to_parts_missing_code_execution_result_key() -> None: + """Test _convert_to_parts with missing code_execution_result key.""" + content = [ + { + "type": "code_execution_result" + # Missing code_execution_result key + } + ] + with pytest.raises( + ValueError, match="Code execution result part must have 'code_execution_result'" + ): + _convert_to_parts(content) + + +def test_convert_to_parts_unrecognized_type() -> None: + """Test _convert_to_parts with unrecognized type.""" + content = [{"type": "unrecognized_type", "data": "some_data"}] + with pytest.raises(ValueError, match="Unrecognized message part type"): + _convert_to_parts(content) + + +def test_convert_to_parts_non_dict_mapping() -> None: + """Test _convert_to_parts with non-dict mapping.""" + content = [123] # Not a string or dict + with pytest.raises( + Exception, match="Gemini only supports text and inline_data parts" + ): + _convert_to_parts(content) # type: ignore[arg-type] + + +def test_convert_to_parts_unrecognized_format_warning() -> None: + """Test _convert_to_parts with unrecognized format triggers warning.""" + content = [{"some_key": "some_value"}] # Not a recognized format + with patch("langchain_google_genai.chat_models.logger.warning") as mock_warning: + result = _convert_to_parts(content) + mock_warning.assert_called_once() + assert "Unrecognized message part format" in mock_warning.call_args[0][0] + assert len(result) == 1 + assert result[0].text == "{'some_key': 'some_value'}" + + +def test_convert_tool_message_to_parts_string_content() -> None: + """Test _convert_tool_message_to_parts with string content.""" + message = ToolMessage(name="test_tool", content="test_result", tool_call_id="123") + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "test_tool" + assert result[0].function_response.response == {"output": "test_result"} + + +def test_convert_tool_message_to_parts_json_content() -> None: + """Test _convert_tool_message_to_parts with JSON string content.""" + message = ToolMessage( + name="test_tool", + content='{"result": "success", "data": [1, 2, 3]}', + tool_call_id="123", + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "test_tool" + assert result[0].function_response.response == { + "result": "success", + "data": [1, 2, 3], + } + + +def test_convert_tool_message_to_parts_dict_content() -> None: + """Test _convert_tool_message_to_parts with dict content.""" + message = ToolMessage( + name="test_tool", + content={"result": "success", "data": [1, 2, 3]}, # type: ignore[arg-type] + tool_call_id="123", + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "test_tool" + assert result[0].function_response.response == { + "output": str({"result": "success", "data": [1, 2, 3]}) + } + + +def test_convert_tool_message_to_parts_list_content_with_media() -> None: + """Test _convert_tool_message_to_parts with list content containing media.""" + message = ToolMessage( + name="test_tool", + content=[ + "Text response", + {"type": "image_url", "image_url": {"url": SMALL_VIEWABLE_BASE64_IMAGE}}, + ], + tool_call_id="123", + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 2 + # First part should be the media (image) + assert result[0].inline_data is not None + # Second part should be the function response + assert result[1].function_response is not None + assert result[1].function_response.name == "test_tool" + assert result[1].function_response.response == {"output": [str("Text response")]} + + +def test_convert_tool_message_to_parts_with_name_parameter() -> None: + """Test _convert_tool_message_to_parts with explicit name parameter.""" + message = ToolMessage( + content="test_result", + tool_call_id="123", + # No name in message + ) + result = _convert_tool_message_to_parts(message, name="explicit_tool_name") + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "explicit_tool_name" + + +def test_convert_tool_message_to_parts_legacy_name_in_kwargs() -> None: + """Test _convert_tool_message_to_parts with legacy name in additional_kwargs.""" + message = ToolMessage( + content="test_result", + tool_call_id="123", + additional_kwargs={"name": "legacy_tool_name"}, + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "legacy_tool_name" + + +def test_convert_tool_message_to_parts_function_message() -> None: + """Test _convert_tool_message_to_parts with FunctionMessage.""" + message = FunctionMessage(name="test_function", content="function_result") + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "test_function" + assert result[0].function_response.response == {"output": "function_result"} + + +def test_convert_tool_message_to_parts_invalid_json_fallback() -> None: + """Test _convert_tool_message_to_parts with invalid JSON falls back to string.""" + message = ToolMessage( + name="test_tool", + content='{"invalid": json}', # Invalid JSON + tool_call_id="123", + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.response == {"output": '{"invalid": json}'} + + +def test_get_ai_message_tool_messages_parts_basic() -> None: + """Test _get_ai_message_tool_messages_parts with basic tool messages.""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}, + {"id": "call_2", "name": "tool_2", "args": {"arg2": "value2"}}, + ], + ) + + tool_messages = [ + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1"), + ToolMessage(name="tool_2", content="result_2", tool_call_id="call_2"), + ] + + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 2 + + # Check first tool response + assert result[0].function_response is not None + assert result[0].function_response.name == "tool_1" + assert result[0].function_response.response == {"output": "result_1"} + + # Check second tool response + assert result[1].function_response is not None + assert result[1].function_response.name == "tool_2" + assert result[1].function_response.response == {"output": "result_2"} + + +def test_get_ai_message_tool_messages_parts_partial_matches() -> None: + """Test _get_ai_message_tool_messages_parts with partial tool message matches.""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}, + {"id": "call_2", "name": "tool_2", "args": {"arg2": "value2"}}, + ], + ) + + tool_messages = [ + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1"), + # Missing tool_2 response + ] + + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 1 + + # Only tool_1 response should be included + assert result[0].function_response is not None + assert result[0].function_response.name == "tool_1" + assert result[0].function_response.response == {"output": "result_1"} + + +def test_get_ai_message_tool_messages_parts_no_matches() -> None: + """Test _get_ai_message_tool_messages_parts with no matching tool messages.""" + ai_message = AIMessage( + content="", + tool_calls=[{"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}], + ) + + tool_messages = [ + ToolMessage(name="tool_2", content="result_2", tool_call_id="call_2"), + ToolMessage(name="tool_3", content="result_3", tool_call_id="call_3"), + ] + + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 0 + + +def test_get_ai_message_tool_messages_parts_empty_tool_calls() -> None: + """Test _get_ai_message_tool_messages_parts with empty tool calls.""" + ai_message = AIMessage(content="No tool calls") + tool_messages = [ + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1") + ] + + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 0 + + +def test_get_ai_message_tool_messages_parts_empty_tool_messages() -> None: + """Test _get_ai_message_tool_messages_parts with empty tool messages.""" + ai_message = AIMessage( + content="", + tool_calls=[{"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}], + ) + + result = _get_ai_message_tool_messages_parts([], ai_message) + assert len(result) == 0 + + +def test_get_ai_message_tool_messages_parts_duplicate_tool_calls() -> None: + """Test _get_ai_message_tool_messages_parts handles duplicate tool call IDs.""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}, + { + "id": "call_1", + "name": "tool_1", + "args": {"arg1": "value1"}, + }, # Duplicate ID + ], + ) + + tool_messages = [ + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1") + ] + + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 1 # Should only process the first match + assert result[0].function_response is not None + assert result[0].function_response.name == "tool_1" + + +def test_get_ai_message_tool_messages_parts_order_preserved() -> None: + """Test _get_ai_message_tool_messages_parts preserves order of tool messages.""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}, + {"id": "call_2", "name": "tool_2", "args": {"arg2": "value2"}}, + ], + ) + + tool_messages = [ + ToolMessage(name="tool_2", content="result_2", tool_call_id="call_2"), + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1"), + ] + + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 2 + + # Order should be preserved based on tool_messages order, not tool_calls order + assert result[0].function_response is not None + assert result[0].function_response.name == "tool_2" + assert result[1].function_response is not None + assert result[1].function_response.name == "tool_1" + + +def test_get_ai_message_tool_messages_parts_with_name_from_tool_call() -> None: + """Test _get_ai_message_tool_messages_parts uses name from tool call""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_from_call", "args": {"arg1": "value1"}} + ], + ) + + tool_messages = [ + ToolMessage(content="result_1", tool_call_id="call_1") # No name in message + ] + + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 1 + assert result[0].function_response is not None + assert ( + result[0].function_response.name == "tool_from_call" + ) # Should use name from tool call diff --git a/libs/genai/tests/unit_tests/test_function_utils.py b/libs/genai/tests/unit_tests/test_function_utils.py index 1ad068795..99a345169 100644 --- a/libs/genai/tests/unit_tests/test_function_utils.py +++ b/libs/genai/tests/unit_tests/test_function_utils.py @@ -1,8 +1,16 @@ from typing import Any, Dict, Generator, List, Optional, Tuple, Union from unittest.mock import MagicMock, patch -import google.ai.generativelanguage as glm import pytest +from google.genai.types import ( + FunctionCallingConfig, + FunctionCallingConfigMode, + FunctionDeclaration, + Schema, + Tool, + ToolConfig, + Type, +) from langchain_core.documents import Document from langchain_core.tools import BaseTool, InjectedToolArg, tool from langchain_core.utils.function_calling import convert_to_openai_tool @@ -15,13 +23,82 @@ _format_dict_to_function_declaration, _FunctionDeclarationLike, _tool_choice_to_tool_config, - _ToolConfigDict, convert_to_genai_function_declarations, replace_defs_in_schema, tool_to_dict, ) +def assert_property_type( + property_dict: dict, expected_type: Type, property_name: str = "property" +) -> None: + """ + Utility function to assert that a property has the expected Type enum value. + + Since tool_to_dict serializes Type enums to dictionaries with '_value_' field, + this function handles the comparison correctly. + + Args: + property_dict: The property dictionary from the serialized schema + expected_type: The expected Type enum value + property_name: Name of the property for error messages (optional) + """ + actual_type_dict = property_dict.get("type", {}) + if isinstance(actual_type_dict, dict): + actual_value = actual_type_dict.get("_value_") + assert actual_value == expected_type.value, ( + f"Expected '{property_name}' to be {expected_type.value}, " + f"but got {actual_value}" + ) + else: + # In case the type is not serialized as a dict (fallback) + assert actual_type_dict == expected_type, ( + f"Expected '{property_name}' to be {expected_type}, " + f"but got {actual_type_dict}" + ) + + +def find_any_of_option_by_type(any_of_list: list, expected_type: Type) -> dict: + """ + Utility function to find an option in an any_of list that has the expected Type. + + Since tool_to_dict serializes Type enums to dictionaries with '_value_' field, + this function handles the search correctly. + + Args: + any_of_list: List of options from an any_of field + expected_type: The Type enum value to search for + + Returns: + The matching option dictionary + + Raises: + AssertionError: If no option with the expected type is found + """ + for opt in any_of_list: + type_dict = opt.get("type", {}) + if isinstance(type_dict, dict): + if type_dict.get("_value_") == expected_type.value: + return opt + else: + if type_dict == expected_type: + return opt + + # If we get here, no matching option was found + available_types = [] + for opt in any_of_list: + type_dict = opt.get("type", {}) + if isinstance(type_dict, dict): + available_types.append(type_dict.get("_value_", "unknown")) + else: + available_types.append(str(type_dict)) + + raise AssertionError( + f"No option with type {expected_type.value} found in any_of. " + f"Available types: {available_types}" + ) + + def test_tool_with_anyof_nullable_param() -> None: """ Example test that checks a string parameter marked as Optional, @@ -64,7 +141,7 @@ def possibly_none( a_property = properties.get("a") assert isinstance(a_property, dict), "Expected a dict." - assert a_property.get("type_") == glm.Type.STRING, "Expected 'a' to be STRING." + assert_property_type(a_property, Type.STRING, "a") assert a_property.get("nullable") is True, "Expected 'a' to be marked as nullable." @@ -110,16 +187,14 @@ def possibly_none_list( assert isinstance(items_property, dict), "Expected a dict." # Assertions - assert ( - items_property.get("type_") == glm.Type.ARRAY - ), "Expected 'items' to be ARRAY." + assert_property_type(items_property, Type.ARRAY, "items") assert items_property.get("nullable"), "Expected 'items' to be marked as nullable." # Check that the array items are recognized as strings items = items_property.get("items") assert isinstance(items, dict), "Expected 'items' to be a dict." - assert items.get("type_") == glm.Type.STRING, "Expected array items to be STRING." + assert_property_type(items, Type.STRING, "array items") def test_tool_with_nested_object_anyof_nullable_param() -> None: @@ -162,10 +237,19 @@ def possibly_none_dict( data_property = properties.get("data") assert isinstance(data_property, dict), "Expected a dict." - assert data_property.get("type_") in [ - glm.Type.OBJECT, - glm.Type.STRING, - ], "Expected 'data' to be recognized as an OBJECT or fallback to STRING." + # Check if it's OBJECT or STRING (fallback) + actual_type_dict = data_property.get("type", {}) + if isinstance(actual_type_dict, dict): + actual_value = actual_type_dict.get("_value_") + assert actual_value in [ + Type.OBJECT.value, + Type.STRING.value, + ], f"Expected 'data' to be OBJECT or STRING, but got {actual_value}" + else: + assert actual_type_dict in [ + Type.OBJECT, + Type.STRING, + ], f"Expected 'data' to be OBJECT or STRING, but got {actual_type_dict}" assert ( data_property.get("nullable") is True ), "Expected 'data' to be marked as nullable." @@ -220,9 +304,7 @@ def possibly_none_enum( assert isinstance(status_property, dict), "Expected a dict." # Assertions - assert ( - status_property.get("type_") == glm.Type.STRING - ), "Expected 'status' to be STRING." + assert_property_type(status_property, Type.STRING, "status") assert ( status_property.get("nullable") is True ), "Expected 'status' to be marked as nullable." @@ -240,13 +322,13 @@ def search(question: str) -> str: search_tool = tool(search) -search_exp = glm.FunctionDeclaration( +search_exp = FunctionDeclaration( name="search", description="Search tool", - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, description="Search tool", - properties={"question": glm.Schema(type=glm.Type.STRING)}, + properties={"question": Schema(type=Type.STRING)}, required=["question"], title="search", ), @@ -259,13 +341,13 @@ def _run(self) -> None: search_base_tool = SearchBaseTool(name="search", description="Search tool") -search_base_tool_exp = glm.FunctionDeclaration( +search_base_tool_exp = FunctionDeclaration( name=search_base_tool.name, description=search_base_tool.description, - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, properties={ - "__arg1": glm.Schema(type=glm.Type.STRING), + "__arg1": Schema(type=Type.STRING), }, required=["__arg1"], ), @@ -284,27 +366,27 @@ class SearchModel(BaseModel): "description": search_model_schema["description"], "parameters": search_model_schema, } -search_model_exp = glm.FunctionDeclaration( +search_model_exp = FunctionDeclaration( name="SearchModel", description="Search model", - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, description="Search model", properties={ - "question": glm.Schema(type=glm.Type.STRING), + "question": Schema(type=Type.STRING), }, required=["question"], title="SearchModel", ), ) -search_model_exp_pyd = glm.FunctionDeclaration( +search_model_exp_pyd = FunctionDeclaration( name="SearchModel", description="Search model", - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, properties={ - "question": glm.Schema(type=glm.Type.STRING), + "question": Schema(type=Type.STRING), }, required=["question"], ), @@ -319,7 +401,7 @@ class SearchModel(BaseModel): ) SRC_EXP_MOCKS_DESC: List[ - Tuple[_FunctionDeclarationLike, glm.FunctionDeclaration, List[MagicMock], str] + Tuple[_FunctionDeclarationLike, FunctionDeclaration, List[MagicMock], str] ] = [ (search, search_exp, [mock_base_tool], "plain function"), (search_tool, search_exp, [mock_base_tool], "LC tool"), @@ -338,6 +420,8 @@ def get_datetime() -> str: return datetime.datetime.now().strftime("%Y-%m-%d") schema = convert_to_genai_function_declarations([get_datetime]) + assert schema.function_declarations is not None + assert len(schema.function_declarations) > 0 function_declaration = schema.function_declarations[0] assert function_declaration.name == "get_datetime" assert function_declaration.description == "Gets the current datetime" @@ -353,10 +437,13 @@ def sum_two_numbers(a: float, b: float) -> str: """ return str(a + b) - schema = convert_to_genai_function_declarations([sum_two_numbers]) # type: ignore + schema = convert_to_genai_function_declarations([sum_two_numbers]) + assert schema.function_declarations is not None + assert len(schema.function_declarations) > 0 function_declaration = schema.function_declarations[0] assert function_declaration.name == "sum_two_numbers" assert function_declaration.parameters + assert function_declaration.parameters.required is not None assert len(function_declaration.parameters.required) == 2 @tool @@ -364,19 +451,23 @@ def do_something_optional(a: float, b: float = 0) -> str: """Some description""" return str(a + b) - schema = convert_to_genai_function_declarations([do_something_optional]) # type: ignore + schema = convert_to_genai_function_declarations([do_something_optional]) + assert schema.function_declarations is not None + assert len(schema.function_declarations) > 0 function_declaration = schema.function_declarations[0] assert function_declaration.name == "do_something_optional" assert function_declaration.parameters + assert function_declaration.parameters.required is not None assert len(function_declaration.parameters.required) == 1 src = [src for src, _, _, _ in SRC_EXP_MOCKS_DESC] fds = [fd for _, fd, _, _ in SRC_EXP_MOCKS_DESC] - expected = glm.Tool(function_declarations=fds) + expected = Tool(function_declarations=fds) result = convert_to_genai_function_declarations(src) assert result == expected - src_2 = glm.Tool(google_search_retrieval={}) + # Create Tool objects with proper typing + src_2 = Tool(google_search_retrieval={}) # type: ignore[arg-type] result = convert_to_genai_function_declarations([src_2]) assert result == src_2 @@ -384,19 +475,19 @@ def do_something_optional(a: float, b: float = 0) -> str: result = convert_to_genai_function_declarations([src_3]) assert result == src_2 - src_4 = glm.Tool(google_search={}) + src_4 = Tool(google_search={}) # type: ignore[arg-type] result = convert_to_genai_function_declarations([src_4]) assert result == src_4 with pytest.raises(ValueError) as exc_info1: - _ = convert_to_genai_function_declarations(["fake_tool"]) # type: ignore + _ = convert_to_genai_function_declarations(["fake_tool"]) # type: ignore[list-item] assert str(exc_info1.value).startswith("Unsupported tool") with pytest.raises(Exception) as exc_info: _ = convert_to_genai_function_declarations( [ - glm.Tool(google_search_retrieval={}), - glm.Tool(google_search_retrieval={}), + Tool(google_search_retrieval={}), # type: ignore[arg-type] + Tool(google_search_retrieval={}), # type: ignore[arg-type] ] ) assert str(exc_info.value).startswith("Providing multiple google_search_retrieval") @@ -442,134 +533,86 @@ def search_web( tools = [split_documents, search_web] # Convert to OpenAI first to mimic what we do in bind_tools. oai_tools = [convert_to_openai_tool(t) for t in tools] - expected = [ - { - "name": "split_documents", - "description": "Tool.", - "parameters": { - "any_of": [], - "type_": 6, - "properties": { - "chunk_overlap": { - "any_of": [], - "type_": 3, - "description": "chunk overlap.", - "format_": "", - "nullable": True, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "property_ordering": [], - "required": [], - "title": "", - }, - "chunk_size": { - "any_of": [], - "type_": 3, - "description": "chunk size.", - "format_": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "property_ordering": [], - "required": [], - "title": "", - }, - }, - "property_ordering": [], - "required": ["chunk_size"], - "title": "", - "format_": "", - "description": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - }, - }, - { - "name": "search_web", - "description": "Tool.", - "parameters": { - "any_of": [], - "type_": 6, - "properties": { - "truncate_threshold": { - "any_of": [], - "type_": 3, - "description": "truncate threshold.", - "format_": "", - "nullable": True, - "enum": [], - "max_items": "0", - "min_items": "0", - "property_ordering": [], - "properties": {}, - "required": [], - "title": "", - }, - "query": { - "any_of": [], - "type_": 1, - "description": "query.", - "format_": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "property_ordering": [], - "required": [], - "title": "", - }, - "engine": { - "any_of": [], - "type_": 1, - "description": "engine.", - "format_": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "property_ordering": [], - "properties": {}, - "required": [], - "title": "", - }, - "num_results": { - "any_of": [], - "type_": 3, - "description": "number of results.", - "format_": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "property_ordering": [], - "required": [], - "title": "", - }, - }, - "property_ordering": [], - "required": ["query"], - "title": "", - "format_": "", - "description": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - }, - }, - ] actual = tool_to_dict(convert_to_genai_function_declarations(oai_tools))[ "function_declarations" ] - assert expected == actual + + # Check that we have the expected number of function declarations + assert len(actual) == 2 + + # Check the first function declaration (split_documents) + assert len(actual) > 0 + split_docs = actual[0] + assert isinstance(split_docs, dict) + assert split_docs["name"] == "split_documents" + assert split_docs["description"] == "Tool." + assert split_docs["behavior"] is None + + # Check parameters structure + params = split_docs["parameters"] + assert params["type"]["_value_"] == "OBJECT" + assert params["required"] == ["chunk_size"] + + # Check properties + properties = params["properties"] + assert "chunk_size" in properties + assert "chunk_overlap" in properties + + # Check chunk_size property + chunk_size_prop = properties["chunk_size"] + assert chunk_size_prop["type"]["_value_"] == "INTEGER" + assert chunk_size_prop["description"] == "chunk size." + assert chunk_size_prop["nullable"] is None + + # Check chunk_overlap property + chunk_overlap_prop = properties["chunk_overlap"] + assert chunk_overlap_prop["type"]["_value_"] == "INTEGER" + assert chunk_overlap_prop["description"] == "chunk overlap." + assert chunk_overlap_prop["nullable"] is True + + # Check the second function declaration (search_web) + assert len(actual) > 1 + search_web_func = actual[1] + assert isinstance(search_web_func, dict) + assert search_web_func["name"] == "search_web" + assert search_web_func["description"] == "Tool." + assert search_web_func["behavior"] is None + + # Check parameters structure + params = search_web_func["parameters"] + assert params["type"]["_value_"] == "OBJECT" + assert params["required"] == ["query"] + + # Check properties + properties = params["properties"] + assert "query" in properties + assert "engine" in properties + assert "num_results" in properties + assert "truncate_threshold" in properties + + # Check query property + query_prop = properties["query"] + assert query_prop["type"]["_value_"] == "STRING" + assert query_prop["description"] == "query." + assert query_prop["nullable"] is None + + # Check engine property + engine_prop = properties["engine"] + assert engine_prop["type"]["_value_"] == "STRING" + assert engine_prop["description"] == "engine." + assert engine_prop["nullable"] is None + + # Check num_results property + num_results_prop = properties["num_results"] + assert num_results_prop["type"]["_value_"] == "INTEGER" + assert num_results_prop["description"] == "number of results." + assert num_results_prop["nullable"] is None + + # Check truncate_threshold property + truncate_prop = properties["truncate_threshold"] + assert truncate_prop["type"]["_value_"] == "INTEGER" + assert truncate_prop["description"] == "truncate threshold." + assert truncate_prop["nullable"] is True def test_format_native_dict_to_genai_function() -> None: @@ -582,9 +625,9 @@ def test_format_native_dict_to_genai_function() -> None: ] } schema = convert_to_genai_function_declarations([calculator]) - expected = glm.Tool( + expected = Tool( function_declarations=[ - glm.FunctionDeclaration( + FunctionDeclaration( name="multiply", description="Returns the product of two numbers.", parameters=None, @@ -610,6 +653,8 @@ def test_format_dict_to_genai_function() -> None: ] } schema = convert_to_genai_function_declarations([calculator]) + assert schema.function_declarations is not None + assert len(schema.function_declarations) > 0 function_declaration = schema.function_declarations[0] assert function_declaration.name == "search" assert function_declaration.parameters @@ -618,27 +663,27 @@ def test_format_dict_to_genai_function() -> None: @pytest.mark.parametrize("choice", (True, "foo", ["foo"], "any")) def test__tool_choice_to_tool_config(choice: Any) -> None: - expected = _ToolConfigDict( - function_calling_config={ - "mode": "ANY", - "allowed_function_names": ["foo"], - }, + expected = ToolConfig( + function_calling_config=FunctionCallingConfig( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=["foo"], + ), ) actual = _tool_choice_to_tool_config(choice, ["foo"]) assert expected == actual def test_tool_to_dict_glm_tool() -> None: - tool = glm.Tool( + tool = Tool( function_declarations=[ - glm.FunctionDeclaration( + FunctionDeclaration( name="multiply", description="Returns the product of two numbers.", - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, properties={ - "a": glm.Schema(type=glm.Type.NUMBER), - "b": glm.Schema(type=glm.Type.NUMBER), + "a": Schema(type=Type.NUMBER), + "b": Schema(type=Type.NUMBER), }, required=["a", "b"], ), @@ -676,86 +721,54 @@ class Models(BaseModel): gapic_tool = convert_to_genai_function_declarations([Models]) tool_dict = tool_to_dict(gapic_tool) - assert tool_dict == { - "function_declarations": [ - { - "description": "", - "name": "Models", - "parameters": { - "any_of": [], - "description": "", - "enum": [], - "format_": "", - "max_items": "0", - "min_items": "0", - "nullable": False, - "properties": { - "models": { - "any_of": [], - "description": "", - "enum": [], - "format_": "", - "items": { - "any_of": [], - "description": "MyModel", - "enum": [], - "format_": "", - "max_items": "0", - "min_items": "0", - "nullable": False, - "properties": { - "age": { - "any_of": [], - "description": "", - "enum": [], - "format_": "", - "max_items": "0", - "min_items": "0", - "nullable": False, - "properties": {}, - "property_ordering": [], - "required": [], - "title": "", - "type_": 3, - }, - "name": { - "any_of": [], - "description": "", - "enum": [], - "format_": "", - "max_items": "0", - "min_items": "0", - "nullable": False, - "properties": {}, - "property_ordering": [], - "required": [], - "title": "", - "type_": 1, - }, - }, - "property_ordering": [], - "required": ["name", "age"], - "title": "", - "type_": 6, - }, - "max_items": "0", - "min_items": "0", - "nullable": False, - "properties": {}, - "property_ordering": [], - "required": [], - "title": "", - "type_": 5, - } - }, - "property_ordering": [], - "required": ["models"], - "title": "", - "type_": 6, - }, - } - ] - } + + # Check that we have the expected structure + assert "function_declarations" in tool_dict + assert len(tool_dict["function_declarations"]) == 1 + + # Check the function declaration + assert "function_declarations" in tool_dict + assert len(tool_dict["function_declarations"]) > 0 + func_decl = tool_dict["function_declarations"][0] + assert isinstance(func_decl, dict) + assert func_decl["name"] == "Models" + assert func_decl["description"] is None + assert func_decl["behavior"] is None + + # Check parameters structure + params = func_decl["parameters"] + assert params["type"]["_value_"] == "OBJECT" + assert params["required"] == ["models"] + + # Check properties + properties = params["properties"] + assert "models" in properties + + # Check models property (array of MyModel) + models_prop = properties["models"] + assert models_prop["type"]["_value_"] == "ARRAY" + assert models_prop["nullable"] is None + + # Check items of the array + items = models_prop["items"] + assert items["type"]["_value_"] == "OBJECT" + assert items["description"] == "MyModel" + assert items["required"] == ["name", "age"] + + # Check properties of MyModel + model_properties = items["properties"] + assert "name" in model_properties + assert "age" in model_properties + + # Check name property + name_prop = model_properties["name"] + assert name_prop["type"]["_value_"] == "STRING" + assert name_prop["nullable"] is None + + # Check age property + age_prop = model_properties["age"] + assert age_prop["type"]["_value_"] == "INTEGER" + assert age_prop["nullable"] is None def test_tool_to_dict_pydantic_without_import(mock_safe_import: MagicMock) -> None: @@ -809,21 +822,15 @@ def process_nested_data( matrix_property = properties.get("matrix") assert isinstance(matrix_property, dict) - assert ( - matrix_property.get("type_") == glm.Type.ARRAY - ), "Expected 'matrix' to be ARRAY." + assert_property_type(matrix_property, Type.ARRAY, "matrix") items_level1 = matrix_property.get("items") assert isinstance(items_level1, dict), "Expected first level 'items' to be a dict." - assert ( - items_level1.get("type_") == glm.Type.ARRAY - ), "Expected first level items to be ARRAY." + assert_property_type(items_level1, Type.ARRAY, "first level items") items_level2 = items_level1.get("items") assert isinstance(items_level2, dict), "Expected second level 'items' to be a dict." - assert ( - items_level2.get("type_") == glm.Type.STRING - ), "Expected second level items to be STRING." + assert_property_type(items_level2, Type.STRING, "second level items") assert "description" in matrix_property assert "description" in items_level1 @@ -941,18 +948,14 @@ class GetWeather(BaseModel): assert isinstance(helper1, dict), "Expected first option to be a dict." assert "properties" in helper1, "Expected first option to have properties." assert "x" in helper1["properties"], "Expected first option to have 'x' property." - assert ( - helper1["properties"]["x"]["type_"] == glm.Type.BOOLEAN - ), "Expected 'x' to be BOOLEAN." + assert_property_type(helper1["properties"]["x"], Type.BOOLEAN, "x") # Check second option (Helper2) helper2 = any_of[1] assert isinstance(helper2, dict), "Expected second option to be a dict." assert "properties" in helper2, "Expected second option to have properties." assert "y" in helper2["properties"], "Expected second option to have 'y' property." - assert ( - helper2["properties"]["y"]["type_"] == glm.Type.STRING - ), "Expected 'y' to be STRING." + assert_property_type(helper2["properties"]["y"], Type.STRING, "y") def test_tool_with_union_primitive_types() -> None: @@ -1002,23 +1005,31 @@ class SearchQuery(BaseModel): assert len(any_of) == 2, "Expected 'any_of' to have 2 options." # One option should be a string - string_option = next( - (opt for opt in any_of if opt.get("type_") == glm.Type.STRING), None - ) - assert string_option is not None, "Expected one option to be a STRING." + # Just verify string option exists + _ = find_any_of_option_by_type(any_of, Type.STRING) # One option should be an object (Helper) - object_option = next( - (opt for opt in any_of if opt.get("type_") == glm.Type.OBJECT), None - ) - assert object_option is not None, "Expected one option to be an OBJECT." + object_option = find_any_of_option_by_type(any_of, Type.OBJECT) assert "properties" in object_option, "Expected object option to have properties." assert ( "value" in object_option["properties"] ), "Expected object option to have 'value' property." - assert ( - object_option["properties"]["value"]["type_"] == 3 - ), "Expected 'value' to be NUMBER or INTEGER." + # Note: This assertion expects the raw enum integer value (3 for NUMBER) + # This is a special case where the test was expecting the integer value + value_type = object_option["properties"]["value"].get("type", {}) + if isinstance(value_type, dict): + # For serialized enum, check _value_ and convert to enum to get integer + type_str = value_type.get("_value_") + if type_str == "NUMBER": + assert True, "Expected 'value' to be NUMBER." + elif type_str == "INTEGER": + assert True, "Expected 'value' to be INTEGER." + else: + assert False, f"Expected 'value' to be NUMBER or INTEGER, got {type_str}" + else: + assert ( + value_type == 3 + ), f"Expected 'value' to be NUMBER or INTEGER (3), got {value_type}" def test_tool_with_nested_union_types() -> None: @@ -1070,16 +1081,11 @@ class Person(BaseModel): assert len(location_any_of) == 2, "Expected 'location.any_of' to have 2 options." # One option should be a string - string_option = next( - (opt for opt in location_any_of if opt.get("type_") == glm.Type.STRING), None - ) - assert string_option is not None, "Expected one location option to be a STRING." + # Just verify string option exists + _ = find_any_of_option_by_type(location_any_of, Type.STRING) # One option should be an object (Address) - address_option = next( - (opt for opt in location_any_of if opt.get("type_") == glm.Type.OBJECT), None - ) - assert address_option is not None, "Expected one location option to be an OBJECT." + address_option = find_any_of_option_by_type(location_any_of, Type.OBJECT) assert "properties" in address_option, "Expected address option to have properties" assert ( "city" in address_option["properties"] @@ -1115,6 +1121,8 @@ def configure_service(service_name: str, config: Union[str, Configuration]) -> s genai_tool = convert_to_genai_function_declarations([oai_tool]) # Get function declaration + assert genai_tool.function_declarations is not None + assert len(genai_tool.function_declarations) > 0 function_declaration = genai_tool.function_declarations[0] # Check parameters @@ -1124,6 +1132,7 @@ def configure_service(service_name: str, config: Union[str, Configuration]) -> s # Check for config property config_property = None + assert parameters.properties is not None, "Expected properties to exist" for prop_name, prop in parameters.properties.items(): if prop_name == "config": config_property = prop @@ -1131,22 +1140,24 @@ def configure_service(service_name: str, config: Union[str, Configuration]) -> s assert config_property is not None, "Expected 'config' property to exist" assert hasattr(config_property, "any_of"), "Expected any_of attribute on config" + assert config_property.any_of is not None, "Expected any_of to not be None" assert len(config_property.any_of) == 2, "Expected config.any_of to have 2 options" # Check both variants of the Union type type_variants = [option.type for option in config_property.any_of] - assert glm.Type.STRING in type_variants, "Expected STRING to be one of the variants" - assert glm.Type.OBJECT in type_variants, "Expected OBJECT to be one of the variants" + assert Type.STRING in type_variants, "Expected STRING to be one of the variants" + assert Type.OBJECT in type_variants, "Expected OBJECT to be one of the variants" # Find the object variant object_variant = None for option in config_property.any_of: - if option.type == glm.Type.OBJECT: + if option.type == Type.OBJECT: object_variant = option break assert object_variant is not None, "Expected to find an object variant" assert hasattr(object_variant, "properties"), "Expected object to have properties" + assert object_variant.properties is not None, "Expected properties to not be None" # Check for settings property has_settings = False @@ -1195,15 +1206,16 @@ class GetWeather(BaseModel): function_declarations = genai_tool_dict.get("function_declarations", []) assert len(function_declarations) > 0, "Expected at least one function declaration" fn_decl = function_declarations[0] + assert isinstance(fn_decl, dict), "Expected function declaration to be a dict" # Check the name and description - assert fn_decl.get("name") == "GetWeather", "Expected name to be 'GetWeather'" # type: ignore - assert "Get weather information" in fn_decl.get("description", ""), ( # type: ignore - "Expected description to include weather information" - ) + assert fn_decl.get("name") == "GetWeather", "Expected name to be 'GetWeather'" + assert "Get weather information" in fn_decl.get( + "description", "" + ), "Expected description to include weather information" # Check parameters - parameters = fn_decl.get("parameters", {}) # type: ignore + parameters = fn_decl.get("parameters", {}) properties = parameters.get("properties", {}) # Check location property diff --git a/libs/genai/tests/unit_tests/test_standard.py b/libs/genai/tests/unit_tests/test_standard.py index 26f354b19..9f49849d9 100644 --- a/libs/genai/tests/unit_tests/test_standard.py +++ b/libs/genai/tests/unit_tests/test_standard.py @@ -13,14 +13,14 @@ def chat_model_class(self) -> Type[BaseChatModel]: @property def chat_model_params(self) -> dict: - return {"model": "models/gemini-1.0-pro-001"} + return {"model": "models/gemini-1.0-pro-001", "google_api_key": "test_api_key"} @property def init_from_env_params(self) -> Tuple[dict, dict, dict]: return ( - {"GOOGLE_API_KEY": "api_key"}, + {"GOOGLE_API_KEY": "test_api_key"}, self.chat_model_params, - {"google_api_key": "api_key"}, + {"google_api_key": "test_api_key"}, ) @@ -31,12 +31,12 @@ def chat_model_class(self) -> Type[BaseChatModel]: @property def chat_model_params(self) -> dict: - return {"model": "models/gemini-1.5-pro-001"} + return {"model": "models/gemini-1.5-pro-001", "google_api_key": "test_api_key"} @property def init_from_env_params(self) -> Tuple[dict, dict, dict]: return ( - {"GOOGLE_API_KEY": "api_key"}, + {"GOOGLE_API_KEY": "test_api_key"}, self.chat_model_params, - {"google_api_key": "api_key"}, + {"google_api_key": "test_api_key"}, )