File tree Expand file tree Collapse file tree 16 files changed +285
-11
lines changed
Expand file tree Collapse file tree 16 files changed +285
-11
lines changed Original file line number Diff line number Diff line change 2020from ..agents .invocation_context import InvocationContext
2121from ..models import LlmRequest
2222from ..utils .model_name_utils import is_gemini_2_or_above
23+ from ..utils .model_name_utils import is_gemini_model_id_check_disabled
2324from .base_code_executor import BaseCodeExecutor
2425from .code_execution_utils import CodeExecutionInput
2526from .code_execution_utils import CodeExecutionResult
@@ -42,7 +43,8 @@ def execute_code(
4243
4344 def process_llm_request (self , llm_request : LlmRequest ) -> None :
4445 """Pre-process the LLM request for Gemini 2.0+ models to use the code execution tool."""
45- if is_gemini_2_or_above (llm_request .model ):
46+ model_check_disabled = is_gemini_model_id_check_disabled ()
47+ if is_gemini_2_or_above (llm_request .model ) or model_check_disabled :
4648 llm_request .config = llm_request .config or types .GenerateContentConfig ()
4749 llm_request .config .tools = llm_request .config .tools or []
4850 llm_request .config .tools .append (
Original file line number Diff line number Diff line change 2121
2222from ..utils .model_name_utils import is_gemini_1_model
2323from ..utils .model_name_utils import is_gemini_model
24+ from ..utils .model_name_utils import is_gemini_model_id_check_disabled
2425from .base_tool import BaseTool
2526from .tool_context import ToolContext
2627
@@ -54,14 +55,16 @@ async def process_llm_request(
5455 tool_context : ToolContext ,
5556 llm_request : LlmRequest ,
5657 ) -> None :
57- if is_gemini_model (llm_request .model ):
58+ model_check_disabled = is_gemini_model_id_check_disabled ()
59+ llm_request .config = llm_request .config or types .GenerateContentConfig ()
60+ llm_request .config .tools = llm_request .config .tools or []
61+
62+ if is_gemini_model (llm_request .model ) or model_check_disabled :
5863 if is_gemini_1_model (llm_request .model ) and llm_request .config .tools :
5964 raise ValueError (
6065 'Enterprise Web Search tool cannot be used with other tools in'
6166 ' Gemini 1.x.'
6267 )
63- llm_request .config = llm_request .config or types .GenerateContentConfig ()
64- llm_request .config .tools = llm_request .config .tools or []
6568 llm_request .config .tools .append (
6669 types .Tool (enterprise_web_search = types .EnterpriseWebSearch ())
6770 )
Original file line number Diff line number Diff line change 2121
2222from ..utils .model_name_utils import is_gemini_1_model
2323from ..utils .model_name_utils import is_gemini_model
24+ from ..utils .model_name_utils import is_gemini_model_id_check_disabled
2425from .base_tool import BaseTool
2526from .tool_context import ToolContext
2627
@@ -49,13 +50,14 @@ async def process_llm_request(
4950 tool_context : ToolContext ,
5051 llm_request : LlmRequest ,
5152 ) -> None :
53+ model_check_disabled = is_gemini_model_id_check_disabled ()
5254 llm_request .config = llm_request .config or types .GenerateContentConfig ()
5355 llm_request .config .tools = llm_request .config .tools or []
5456 if is_gemini_1_model (llm_request .model ):
5557 raise ValueError (
5658 'Google Maps grounding tool cannot be used with Gemini 1.x models.'
5759 )
58- elif is_gemini_model (llm_request .model ):
60+ elif is_gemini_model (llm_request .model ) or model_check_disabled :
5961 llm_request .config .tools .append (
6062 types .Tool (google_maps = types .GoogleMaps ())
6163 )
Original file line number Diff line number Diff line change 2121
2222from ..utils .model_name_utils import is_gemini_1_model
2323from ..utils .model_name_utils import is_gemini_model
24+ from ..utils .model_name_utils import is_gemini_model_id_check_disabled
2425from .base_tool import BaseTool
2526from .tool_context import ToolContext
2627
@@ -67,6 +68,7 @@ async def process_llm_request(
6768 if self .model is not None :
6869 llm_request .model = self .model
6970
71+ model_check_disabled = is_gemini_model_id_check_disabled ()
7072 llm_request .config = llm_request .config or types .GenerateContentConfig ()
7173 llm_request .config .tools = llm_request .config .tools or []
7274 if is_gemini_1_model (llm_request .model ):
@@ -77,7 +79,7 @@ async def process_llm_request(
7779 llm_request .config .tools .append (
7880 types .Tool (google_search_retrieval = types .GoogleSearchRetrieval ())
7981 )
80- elif is_gemini_model (llm_request .model ):
82+ elif is_gemini_model (llm_request .model ) or model_check_disabled :
8183 llm_request .config .tools .append (
8284 types .Tool (google_search = types .GoogleSearch ())
8385 )
Original file line number Diff line number Diff line change 2424from typing_extensions import override
2525
2626from ...utils .model_name_utils import is_gemini_2_or_above
27+ from ...utils .model_name_utils import is_gemini_model_id_check_disabled
2728from ..tool_context import ToolContext
2829from .base_retrieval_tool import BaseRetrievalTool
2930
@@ -63,7 +64,8 @@ async def process_llm_request(
6364 llm_request : LlmRequest ,
6465 ) -> None :
6566 # Use Gemini built-in Vertex AI RAG tool for Gemini 2 models.
66- if is_gemini_2_or_above (llm_request .model ):
67+ model_check_disabled = is_gemini_model_id_check_disabled ()
68+ if is_gemini_2_or_above (llm_request .model ) or model_check_disabled :
6769 llm_request .config = (
6870 types .GenerateContentConfig ()
6971 if not llm_request .config
Original file line number Diff line number Diff line change 2121
2222from ..utils .model_name_utils import is_gemini_1_model
2323from ..utils .model_name_utils import is_gemini_2_or_above
24+ from ..utils .model_name_utils import is_gemini_model_id_check_disabled
2425from .base_tool import BaseTool
2526from .tool_context import ToolContext
2627
@@ -46,11 +47,12 @@ async def process_llm_request(
4647 tool_context : ToolContext ,
4748 llm_request : LlmRequest ,
4849 ) -> None :
50+ model_check_disabled = is_gemini_model_id_check_disabled ()
4951 llm_request .config = llm_request .config or types .GenerateContentConfig ()
5052 llm_request .config .tools = llm_request .config .tools or []
5153 if is_gemini_1_model (llm_request .model ):
5254 raise ValueError ('Url context tool cannot be used in Gemini 1.x.' )
53- elif is_gemini_2_or_above (llm_request .model ):
55+ elif is_gemini_2_or_above (llm_request .model ) or model_check_disabled :
5456 llm_request .config .tools .append (
5557 types .Tool (url_context = types .UrlContext ())
5658 )
Original file line number Diff line number Diff line change 2424from ..agents .readonly_context import ReadonlyContext
2525from ..utils .model_name_utils import is_gemini_1_model
2626from ..utils .model_name_utils import is_gemini_model
27+ from ..utils .model_name_utils import is_gemini_model_id_check_disabled
2728from .base_tool import BaseTool
2829from .tool_context import ToolContext
2930
@@ -141,14 +142,16 @@ async def process_llm_request(
141142 tool_context : ToolContext ,
142143 llm_request : LlmRequest ,
143144 ) -> None :
144- if is_gemini_model (llm_request .model ):
145+ model_check_disabled = is_gemini_model_id_check_disabled ()
146+ llm_request .config = llm_request .config or types .GenerateContentConfig ()
147+ llm_request .config .tools = llm_request .config .tools or []
148+
149+ if is_gemini_model (llm_request .model ) or model_check_disabled :
145150 if is_gemini_1_model (llm_request .model ) and llm_request .config .tools :
146151 raise ValueError (
147152 'Vertex AI search tool cannot be used with other tools in Gemini'
148153 ' 1.x.'
149154 )
150- llm_request .config = llm_request .config or types .GenerateContentConfig ()
151- llm_request .config .tools = llm_request .config .tools or []
152155
153156 # Build the search config (can be overridden by subclasses)
154157 vertex_ai_search_config = self ._build_vertex_ai_search_config (
Original file line number Diff line number Diff line change 2222from packaging .version import InvalidVersion
2323from packaging .version import Version
2424
25+ from .env_utils import is_env_enabled
26+
27+ _DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR = 'ADK_DISABLE_GEMINI_MODEL_ID_CHECK'
28+
29+
30+ def is_gemini_model_id_check_disabled () -> bool :
31+ """Returns True when Gemini model-id validation should be bypassed.
32+
33+ This opt-in environment variable is intended for internal usage where model
34+ ids may not follow the public ``gemini-*`` naming convention.
35+ """
36+ return is_env_enabled (_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR )
37+
2538
2639def extract_model_name (model_string : str ) -> str :
2740 """Extract the actual model name from either simple or path-based format.
Original file line number Diff line number Diff line change @@ -97,6 +97,22 @@ def test_process_llm_request_non_gemini_2_model(
9797 )
9898
9999
100+ def test_process_llm_request_non_gemini_2_model_with_disabled_check (
101+ built_in_executor : BuiltInCodeExecutor ,
102+ monkeypatch ,
103+ ):
104+ """Tests non-Gemini models pass when model-id check is disabled."""
105+ monkeypatch .setenv ("ADK_DISABLE_GEMINI_MODEL_ID_CHECK" , "true" )
106+ llm_request = LlmRequest (model = "internal-model-v1" )
107+
108+ built_in_executor .process_llm_request (llm_request )
109+
110+ assert llm_request .config is not None
111+ assert llm_request .config .tools == [
112+ types .Tool (code_execution = types .ToolCodeExecution ())
113+ ]
114+
115+
100116def test_process_llm_request_no_model_name (
101117 built_in_executor : BuiltInCodeExecutor ,
102118):
Original file line number Diff line number Diff line change @@ -145,3 +145,43 @@ def test_vertex_rag_retrieval_for_gemini_2_x():
145145 )
146146 ]
147147 assert 'rag_retrieval' not in mockModel .requests [0 ].tools_dict
148+
149+
150+ def test_vertex_rag_retrieval_for_non_gemini_with_disabled_check (monkeypatch ):
151+ monkeypatch .setenv ('ADK_DISABLE_GEMINI_MODEL_ID_CHECK' , 'true' )
152+ responses = [
153+ 'response1' ,
154+ ]
155+ mockModel = testing_utils .MockModel .create (responses = responses )
156+ mockModel .model = 'internal-model-v1'
157+
158+ agent = Agent (
159+ name = 'root_agent' ,
160+ model = mockModel ,
161+ tools = [
162+ VertexAiRagRetrieval (
163+ name = 'rag_retrieval' ,
164+ description = 'rag_retrieval' ,
165+ rag_corpora = [
166+ 'projects/123456789/locations/us-central1/ragCorpora/1234567890'
167+ ],
168+ )
169+ ],
170+ )
171+ runner = testing_utils .InMemoryRunner (agent )
172+ runner .run ('test1' )
173+
174+ assert len (mockModel .requests ) == 1
175+ assert len (mockModel .requests [0 ].config .tools ) == 1
176+ assert mockModel .requests [0 ].config .tools == [
177+ types .Tool (
178+ retrieval = types .Retrieval (
179+ vertex_rag_store = types .VertexRagStore (
180+ rag_corpora = [
181+ 'projects/123456789/locations/us-central1/ragCorpora/1234567890'
182+ ]
183+ )
184+ )
185+ )
186+ ]
187+ assert 'rag_retrieval' not in mockModel .requests [0 ].tools_dict
You can’t perform that action at this time.
0 commit comments