Skip to content

Commit b7f8096

Browse files
authored
Merge branch 'main' into feature/ollama-llm
2 parents e9944c7 + 585ebfd commit b7f8096

6 files changed

Lines changed: 2386 additions & 2146 deletions

File tree

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 74 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,9 @@ def get_first_token_time(span_id: str) -> Optional[float]:
660660
# ==============================================================================
661661
# HELPER: BATCH PROCESSOR
662662
# ==============================================================================
663+
_SHUTDOWN_SENTINEL = object()
664+
665+
663666
class BatchProcessor:
664667
"""Handles asynchronous batching and writing of events to BigQuery."""
665668

@@ -809,11 +812,18 @@ async def _batch_writer(self) -> None:
809812
self._queue.get(), timeout=self.flush_interval
810813
)
811814

815+
if first_item is _SHUTDOWN_SENTINEL:
816+
self._queue.task_done()
817+
continue
818+
812819
batch.append(first_item)
813820

814821
while len(batch) < self.batch_size:
815822
try:
816823
item = self._queue.get_nowait()
824+
if item is _SHUTDOWN_SENTINEL:
825+
self._queue.task_done()
826+
continue
817827
batch.append(item)
818828
except asyncio.QueueEmpty:
819829
break
@@ -831,6 +841,13 @@ async def _batch_writer(self) -> None:
831841
except asyncio.CancelledError:
832842
logger.info("Batch writer task cancelled.")
833843
break
844+
except RuntimeError as e:
845+
if "Event loop is closed" in str(e):
846+
logger.info("Batch writer loop closed: %s", e)
847+
break
848+
# Re-raise other RuntimeErrors (or log them below)
849+
logger.error("RuntimeError in batch writer loop: %s", e, exc_info=True)
850+
await asyncio.sleep(1)
834851
except Exception as e:
835852
logger.error("Error in batch writer loop: %s", e, exc_info=True)
836853
await asyncio.sleep(1)
@@ -939,12 +956,24 @@ async def shutdown(self, timeout: float = 5.0) -> None:
939956
"""
940957
self._shutdown = True
941958
logger.info("BatchProcessor shutting down, draining queue...")
959+
960+
# Signal the writer to wake up and check shutdown status
961+
try:
962+
self._queue.put_nowait(_SHUTDOWN_SENTINEL)
963+
except asyncio.QueueFull:
964+
# If queue is full, the writer is active and will check _shutdown soon
965+
pass
966+
942967
if self._batch_processor_task:
943968
try:
944969
await asyncio.wait_for(self._batch_processor_task, timeout=timeout)
945970
except asyncio.TimeoutError:
946971
logger.warning("BatchProcessor shutdown timed out, cancelling worker.")
947972
self._batch_processor_task.cancel()
973+
try:
974+
await self._batch_processor_task
975+
except asyncio.CancelledError:
976+
pass
948977
except Exception as e:
949978
logger.error("Error during BatchProcessor shutdown: %s", e)
950979

@@ -1626,51 +1655,55 @@ def get_credentials():
16261655
def _atexit_cleanup(batch_processor: "BatchProcessor") -> None:
16271656
"""Clean up batch processor on script exit."""
16281657
# Check if the batch_processor object is still alive
1629-
if batch_processor and not batch_processor._shutdown:
1630-
# Emergency Flush: Rescue any logs remaining in the queue
1631-
remaining_items = []
1632-
try:
1633-
while True:
1634-
remaining_items.append(batch_processor._queue.get_nowait())
1635-
except (asyncio.QueueEmpty, AttributeError):
1636-
pass
1637-
1638-
if remaining_items:
1639-
# We need a new loop and client to flush these
1640-
async def rescue_flush():
1641-
try:
1642-
# Create a short-lived client just for this flush
1658+
try:
1659+
if batch_processor and not batch_processor._shutdown:
1660+
# Emergency Flush: Rescue any logs remaining in the queue
1661+
remaining_items = []
1662+
try:
1663+
while True:
1664+
remaining_items.append(batch_processor._queue.get_nowait())
1665+
except (asyncio.QueueEmpty, AttributeError):
1666+
pass
1667+
1668+
if remaining_items:
1669+
# We need a new loop and client to flush these
1670+
async def rescue_flush():
16431671
try:
1644-
# Note: This relies on google.auth.default() working in this context.
1645-
# pylint: disable=g-import-not-at-top
1646-
from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient
1647-
1648-
# pylint: enable=g-import-not-at-top
1649-
client = BigQueryWriteAsyncClient()
1672+
# Create a short-lived client just for this flush
1673+
try:
1674+
# Note: This relies on google.auth.default() working in this context.
1675+
# pylint: disable=g-import-not-at-top
1676+
from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient
1677+
1678+
# pylint: enable=g-import-not-at-top
1679+
client = BigQueryWriteAsyncClient()
1680+
except Exception as e:
1681+
logger.warning("Could not create rescue client: %s", e)
1682+
return
1683+
1684+
# Patch batch_processor.write_client temporarily
1685+
old_client = batch_processor.write_client
1686+
batch_processor.write_client = client
1687+
try:
1688+
# Force a write
1689+
await batch_processor._write_rows_with_retry(remaining_items)
1690+
logger.info("Rescued logs flushed successfully.")
1691+
except Exception as e:
1692+
logger.error("Failed to flush rescued logs: %s", e)
1693+
finally:
1694+
batch_processor.write_client = old_client
16501695
except Exception as e:
1651-
logger.warning("Could not create rescue client: %s", e)
1652-
return
1696+
logger.error("Rescue flush failed: %s", e)
16531697

1654-
# Patch batch_processor.write_client temporarily
1655-
old_client = batch_processor.write_client
1656-
batch_processor.write_client = client
1657-
try:
1658-
# Force a write
1659-
await batch_processor._write_rows_with_retry(remaining_items)
1660-
logger.info("Rescued logs flushed successfully.")
1661-
except Exception as e:
1662-
logger.error("Failed to flush rescued logs: %s", e)
1663-
finally:
1664-
batch_processor.write_client = old_client
1698+
try:
1699+
loop = asyncio.new_event_loop()
1700+
loop.run_until_complete(rescue_flush())
1701+
loop.close()
16651702
except Exception as e:
1666-
logger.error("Rescue flush failed: %s", e)
1667-
1668-
try:
1669-
loop = asyncio.new_event_loop()
1670-
loop.run_until_complete(rescue_flush())
1671-
loop.close()
1672-
except Exception as e:
1673-
logger.error("Failed to run rescue loop: %s", e)
1703+
logger.error("Failed to run rescue loop: %s", e)
1704+
except ReferenceError:
1705+
# batch_processor already GC'd, nothing to do
1706+
pass
16741707

16751708
def _ensure_schema_exists(self) -> None:
16761709
"""Ensures the BigQuery table exists with the correct schema."""

src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import logging
1718
import ssl
1819
from typing import Any
1920
from typing import Callable
@@ -48,6 +49,8 @@
4849
from .operation_parser import OperationParser
4950
from .tool_auth_handler import ToolAuthHandler
5051

52+
logger = logging.getLogger("google_adk." + __name__)
53+
5154

5255
def snake_to_lower_camel(snake_case_string: str):
5356
"""Converts a snake_case string to a lower_camel_case string.
@@ -158,6 +161,7 @@ def __init__(
158161
self._default_headers: Dict[str, str] = {}
159162
self._ssl_verify = ssl_verify
160163
self._header_provider = header_provider
164+
self._logger = logger
161165
if should_parse_operation:
162166
self._operation_parser = OperationParser(self.operation)
163167

@@ -493,14 +497,40 @@ async def call(
493497
if provider_headers:
494498
request_params.setdefault("headers", {}).update(provider_headers)
495499

500+
# Log the API request
501+
self._logger.debug(
502+
"API Request: %s %s",
503+
request_params.get("method", "").upper(),
504+
request_params.get("url", ""),
505+
)
506+
self._logger.debug("API Request params: %s", request_params.get("params"))
507+
if "json" in request_params:
508+
self._logger.debug("API Request body: %s", request_params.get("json"))
509+
496510
response = requests.request(**request_params)
497511

512+
# Log the API response
513+
self._logger.debug(
514+
"API Response: %s %s - Status: %d",
515+
request_params.get("method", "").upper(),
516+
request_params.get("url", ""),
517+
response.status_code,
518+
)
519+
498520
# Parse API response
499521
try:
500522
response.raise_for_status() # Raise HTTPError for bad responses
501-
return response.json() # Try to decode JSON
523+
result = response.json() # Try to decode JSON
524+
self._logger.debug("API Response body: %s", result)
525+
return result
502526
except requests.exceptions.HTTPError:
503527
error_details = response.content.decode("utf-8")
528+
self._logger.warning(
529+
"API call failed for tool %s: Status %d - %s",
530+
self.name,
531+
response.status_code,
532+
error_details,
533+
)
504534
return {
505535
"error": (
506536
f"Tool {self.name} execution failed. Analyze this execution error"
@@ -510,6 +540,7 @@ async def call(
510540
)
511541
}
512542
except ValueError:
543+
self._logger.debug("API Response (non-JSON): %s", response.text)
513544
return {"text": response.text} # Return text if not JSON
514545

515546
def __str__(self):

src/google/adk/tools/vertex_ai_search_tool.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.genai import types
2222
from typing_extensions import override
2323

24+
from ..agents.readonly_context import ReadonlyContext
2425
from ..utils.model_name_utils import is_gemini_1_model
2526
from ..utils.model_name_utils import is_gemini_model
2627
from .base_tool import BaseTool
@@ -38,6 +39,25 @@ class VertexAiSearchTool(BaseTool):
3839
Attributes:
3940
data_store_id: The Vertex AI search data store resource ID.
4041
search_engine_id: The Vertex AI search engine resource ID.
42+
43+
To dynamically customize the search configuration at runtime (e.g., set
44+
filter based on user context), subclass this tool and override the
45+
`_build_vertex_ai_search_config` method.
46+
47+
Example:
48+
```python
49+
class DynamicFilterSearchTool(VertexAiSearchTool):
50+
def _build_vertex_ai_search_config(
51+
self, ctx: ReadonlyContext
52+
) -> types.VertexAISearch:
53+
user_id = ctx.state.get('user_id')
54+
return types.VertexAISearch(
55+
datastore=self.data_store_id,
56+
engine=self.search_engine_id,
57+
filter=f"user_id = '{user_id}'",
58+
max_results=self.max_results,
59+
)
60+
```
4161
"""
4262

4363
def __init__(
@@ -90,6 +110,30 @@ def __init__(
90110
self.max_results = max_results
91111
self.bypass_multi_tools_limit = bypass_multi_tools_limit
92112

113+
def _build_vertex_ai_search_config(
114+
self, readonly_context: ReadonlyContext
115+
) -> types.VertexAISearch:
116+
"""Builds the VertexAISearch configuration.
117+
118+
Override this method in a subclass to dynamically customize the search
119+
configuration based on the context (e.g., set filter based on session
120+
state).
121+
122+
Args:
123+
readonly_context: The readonly context with access to state and session
124+
info.
125+
126+
Returns:
127+
The VertexAISearch configuration to use for this request.
128+
"""
129+
return types.VertexAISearch(
130+
datastore=self.data_store_id,
131+
data_store_specs=self.data_store_specs,
132+
engine=self.search_engine_id,
133+
filter=self.filter,
134+
max_results=self.max_results,
135+
)
136+
93137
@override
94138
async def process_llm_request(
95139
self,
@@ -106,14 +150,20 @@ async def process_llm_request(
106150
llm_request.config = llm_request.config or types.GenerateContentConfig()
107151
llm_request.config.tools = llm_request.config.tools or []
108152

153+
# Build the search config (can be overridden by subclasses)
154+
vertex_ai_search_config = self._build_vertex_ai_search_config(
155+
tool_context
156+
)
157+
109158
# Format data_store_specs concisely for logging
110-
if self.data_store_specs:
159+
if vertex_ai_search_config.data_store_specs:
111160
spec_ids = [
112161
spec.data_store.split('/')[-1] if spec.data_store else 'unnamed'
113-
for spec in self.data_store_specs
162+
for spec in vertex_ai_search_config.data_store_specs
114163
]
115164
specs_info = (
116-
f'{len(self.data_store_specs)} spec(s): [{", ".join(spec_ids)}]'
165+
f'{len(vertex_ai_search_config.data_store_specs)} spec(s):'
166+
f' [{", ".join(spec_ids)}]'
117167
)
118168
else:
119169
specs_info = None
@@ -122,23 +172,17 @@ async def process_llm_request(
122172
'Adding Vertex AI Search tool config to LLM request: '
123173
'datastore=%s, engine=%s, filter=%s, max_results=%s, '
124174
'data_store_specs=%s',
125-
self.data_store_id,
126-
self.search_engine_id,
127-
self.filter,
128-
self.max_results,
175+
vertex_ai_search_config.datastore,
176+
vertex_ai_search_config.engine,
177+
vertex_ai_search_config.filter,
178+
vertex_ai_search_config.max_results,
129179
specs_info,
130180
)
131181

132182
llm_request.config.tools.append(
133183
types.Tool(
134184
retrieval=types.Retrieval(
135-
vertex_ai_search=types.VertexAISearch(
136-
datastore=self.data_store_id,
137-
data_store_specs=self.data_store_specs,
138-
engine=self.search_engine_id,
139-
filter=self.filter,
140-
max_results=self.max_results,
141-
)
185+
vertex_ai_search=vertex_ai_search_config
142186
)
143187
)
144188
)

0 commit comments

Comments
 (0)