Skip to content

Commit d3d01f7

Browse files
author
Gautier Masse
committed
fix(tools): add optional location override validation
1 parent b51a717 commit d3d01f7

2 files changed

Lines changed: 96 additions & 6 deletions

File tree

src/google/adk/tools/discovery_engine_search_tool.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,42 @@
3131
_LOCATION_PATTERN = re.compile(r"/locations/([^/]+)(?:/|$)")
3232

3333

34+
def _resolve_location(resource_id: str, location: Optional[str]) -> str:
35+
"""Resolves the Discovery Engine location to use for the endpoint."""
36+
location_match = _LOCATION_PATTERN.search(resource_id)
37+
inferred_location = (
38+
location_match.group(1).lower() if location_match else None
39+
)
40+
41+
if location is not None:
42+
normalized_location = location.strip().lower()
43+
if not normalized_location:
44+
raise ValueError("location must not be empty if specified.")
45+
if inferred_location and normalized_location != inferred_location:
46+
raise ValueError(
47+
"location must match the location in data_store_id or "
48+
"search_engine_id."
49+
)
50+
return normalized_location
51+
52+
if inferred_location:
53+
return inferred_location
54+
return _GLOBAL_LOCATION
55+
56+
3457
def _build_client_options(
35-
resource_id: str, quota_project_id: Optional[str]
58+
resource_id: str,
59+
quota_project_id: Optional[str],
60+
location: Optional[str],
3661
) -> Optional[client_options.ClientOptions]:
3762
"""Builds client options for Discovery Engine requests."""
3863
client_options_kwargs = {}
39-
location_match = _LOCATION_PATTERN.search(resource_id)
40-
location = location_match.group(1) if location_match else _GLOBAL_LOCATION
64+
resolved_location = _resolve_location(resource_id, location)
4165

42-
if location != _GLOBAL_LOCATION:
43-
client_options_kwargs["api_endpoint"] = f"{location}-{_DEFAULT_ENDPOINT}"
66+
if resolved_location != _GLOBAL_LOCATION:
67+
client_options_kwargs["api_endpoint"] = (
68+
f"{resolved_location}-{_DEFAULT_ENDPOINT}"
69+
)
4470
if quota_project_id:
4571
client_options_kwargs["quota_project_id"] = quota_project_id
4672

@@ -61,6 +87,7 @@ def __init__(
6187
search_engine_id: Optional[str] = None,
6288
filter: Optional[str] = None,
6389
max_results: Optional[int] = None,
90+
location: Optional[str] = None,
6491
):
6592
"""Initializes the DiscoveryEngineSearchTool.
6693
@@ -74,6 +101,9 @@ def __init__(
74101
"projects/{project}/locations/{location}/collections/{collection}/engines/{engine}".
75102
filter: The filter to be applied to the search request. Default is None.
76103
max_results: The maximum number of results to return. Default is None.
104+
location: Optional endpoint location override.
105+
Examples: "global", "us", "eu". If not specified, location is inferred
106+
from `data_store_id` or `search_engine_id` and defaults to "global".
77107
"""
78108
super().__init__(self.discovery_engine_search)
79109
if (data_store_id is None and search_engine_id is None) or (
@@ -94,11 +124,16 @@ def __init__(
94124
self._search_engine_id = search_engine_id
95125
self._filter = filter
96126
self._max_results = max_results
127+
self._location = location
97128

98129
credentials, _ = google.auth.default()
99130
quota_project_id = getattr(credentials, "quota_project_id", None)
100131
resource_id = data_store_id or search_engine_id or ""
101-
options = _build_client_options(resource_id, quota_project_id)
132+
options = _build_client_options(
133+
resource_id=resource_id,
134+
quota_project_id=quota_project_id,
135+
location=location,
136+
)
102137
self._discovery_engine_client = discoveryengine.SearchServiceClient(
103138
credentials=credentials, client_options=options
104139
)

tests/unittests/tools/test_discovery_engine_search_tool.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,61 @@ def test_init_with_regional_location_uses_regional_endpoint(
131131
client_options=mock_client_options.ClientOptions.return_value,
132132
)
133133

134+
@mock.patch.object(discovery_engine_search_tool, "client_options")
135+
@mock.patch.object(discoveryengine, "SearchServiceClient")
136+
def test_init_with_explicit_location_override_uses_input_location(
137+
self, mock_search_client, mock_client_options
138+
):
139+
"""Test initialization uses explicit location when resource has none."""
140+
DiscoveryEngineSearchTool(
141+
data_store_id="test_data_store",
142+
location="eu",
143+
)
144+
145+
mock_client_options.ClientOptions.assert_called_once_with(
146+
api_endpoint="eu-discoveryengine.googleapis.com"
147+
)
148+
mock_search_client.assert_called_once_with(
149+
credentials="credentials",
150+
client_options=mock_client_options.ClientOptions.return_value,
151+
)
152+
153+
@mock.patch.object(discoveryengine, "SearchServiceClient")
154+
def test_init_with_mismatched_location_raises_error(self, mock_search_client):
155+
"""Test initialization rejects mismatched location overrides."""
156+
with pytest.raises(
157+
ValueError,
158+
match=(
159+
"location must match the location in data_store_id or "
160+
"search_engine_id."
161+
),
162+
):
163+
DiscoveryEngineSearchTool(
164+
data_store_id=(
165+
"projects/test/locations/us/collections/default_collection/"
166+
"dataStores/test_data_store"
167+
),
168+
location="eu",
169+
)
170+
171+
mock_search_client.assert_not_called()
172+
173+
@mock.patch.object(discoveryengine, "SearchServiceClient")
174+
def test_init_with_empty_location_raises_error(self, mock_search_client):
175+
"""Test initialization rejects an empty location override."""
176+
with pytest.raises(
177+
ValueError, match="location must not be empty if specified."
178+
):
179+
DiscoveryEngineSearchTool(
180+
data_store_id=(
181+
"projects/test/locations/us/collections/default_collection/"
182+
"dataStores/test_data_store"
183+
),
184+
location=" ",
185+
)
186+
187+
mock_search_client.assert_not_called()
188+
134189
@mock.patch.object(discovery_engine_search_tool, "client_options")
135190
@mock.patch.object(discoveryengine, "SearchServiceClient")
136191
def test_init_with_global_location_keeps_default_endpoint(

0 commit comments

Comments
 (0)