3131_LOCATION_PATTERN = re .compile (r"/locations/([^/]+)(?:/|$)" )
3232
3333
34+ def _resolve_location (resource_id : str , location : Optional [str ]) -> str :
35+ """Resolves the Discovery Engine location to use for the endpoint."""
36+ location_match = _LOCATION_PATTERN .search (resource_id )
37+ inferred_location = (
38+ location_match .group (1 ).lower () if location_match else None
39+ )
40+
41+ if location is not None :
42+ normalized_location = location .strip ().lower ()
43+ if not normalized_location :
44+ raise ValueError ("location must not be empty if specified." )
45+ if inferred_location and normalized_location != inferred_location :
46+ raise ValueError (
47+ "location must match the location in data_store_id or "
48+ "search_engine_id."
49+ )
50+ return normalized_location
51+
52+ if inferred_location :
53+ return inferred_location
54+ return _GLOBAL_LOCATION
55+
56+
3457def _build_client_options (
35- resource_id : str , quota_project_id : Optional [str ]
58+ resource_id : str ,
59+ quota_project_id : Optional [str ],
60+ location : Optional [str ],
3661) -> Optional [client_options .ClientOptions ]:
3762 """Builds client options for Discovery Engine requests."""
3863 client_options_kwargs = {}
39- location_match = _LOCATION_PATTERN .search (resource_id )
40- location = location_match .group (1 ) if location_match else _GLOBAL_LOCATION
64+ resolved_location = _resolve_location (resource_id , location )
4165
42- if location != _GLOBAL_LOCATION :
43- client_options_kwargs ["api_endpoint" ] = f"{ location } -{ _DEFAULT_ENDPOINT } "
66+ if resolved_location != _GLOBAL_LOCATION :
67+ client_options_kwargs ["api_endpoint" ] = (
68+ f"{ resolved_location } -{ _DEFAULT_ENDPOINT } "
69+ )
4470 if quota_project_id :
4571 client_options_kwargs ["quota_project_id" ] = quota_project_id
4672
@@ -61,6 +87,7 @@ def __init__(
6187 search_engine_id : Optional [str ] = None ,
6288 filter : Optional [str ] = None ,
6389 max_results : Optional [int ] = None ,
90+ location : Optional [str ] = None ,
6491 ):
6592 """Initializes the DiscoveryEngineSearchTool.
6693
@@ -74,6 +101,9 @@ def __init__(
74101 "projects/{project}/locations/{location}/collections/{collection}/engines/{engine}".
75102 filter: The filter to be applied to the search request. Default is None.
76103 max_results: The maximum number of results to return. Default is None.
104+ location: Optional endpoint location override.
105+ Examples: "global", "us", "eu". If not specified, location is inferred
106+ from `data_store_id` or `search_engine_id` and defaults to "global".
77107 """
78108 super ().__init__ (self .discovery_engine_search )
79109 if (data_store_id is None and search_engine_id is None ) or (
@@ -94,11 +124,16 @@ def __init__(
94124 self ._search_engine_id = search_engine_id
95125 self ._filter = filter
96126 self ._max_results = max_results
127+ self ._location = location
97128
98129 credentials , _ = google .auth .default ()
99130 quota_project_id = getattr (credentials , "quota_project_id" , None )
100131 resource_id = data_store_id or search_engine_id or ""
101- options = _build_client_options (resource_id , quota_project_id )
132+ options = _build_client_options (
133+ resource_id = resource_id ,
134+ quota_project_id = quota_project_id ,
135+ location = location ,
136+ )
102137 self ._discovery_engine_client = discoveryengine .SearchServiceClient (
103138 credentials = credentials , client_options = options
104139 )
0 commit comments