From ad8b67f94579976b3fe509d30b1d14f7948b1b3e Mon Sep 17 00:00:00 2001 From: sarojrout Date: Fri, 5 Dec 2025 17:24:48 -0800 Subject: [PATCH 01/13] feat(flows): Add streaming tools support for non-live mode Implements streaming tools support for non-live mode (run_async/SSE), allowing tools to yield intermediate results as Events. Changes: - Added _is_streaming_tool() helper to detect async generator functions - Added handle_function_calls_async_with_streaming() to handle streaming tools - Added _execute_streaming_tool_async() to execute and yield Events for streaming tools - Integrated streaming tool detection in _postprocess_handle_function_calls_async() - Added cancellation support via task tracking in active_streaming_tools - Added unit tests Fixes #3837 --- .../adk/flows/llm_flows/base_llm_flow.py | 96 +++++-- src/google/adk/flows/llm_flows/functions.py | 198 ++++++++++++++ .../test_streaming_tools_non_live.py | 253 ++++++++++++++++++ 3 files changed, 524 insertions(+), 23 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 824cd26be1..dd3f48b09c 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -678,35 +678,85 @@ async def _postprocess_handle_function_calls_async( function_call_event: Event, llm_request: LlmRequest, ) -> AsyncGenerator[Event, None]: - if function_response_event := await functions.handle_function_calls_async( - invocation_context, function_call_event, llm_request.tools_dict - ): - auth_event = functions.generate_auth_event( - invocation_context, function_response_event - ) - if auth_event: - yield auth_event + function_calls = function_call_event.get_function_calls() + if not function_calls: + return + + # Check if any tools are streaming tools + has_streaming_tools = any( + functions._is_streaming_tool(tool) + for call in function_calls + if (tool := llm_request.tools_dict.get(call.name)) + ) - tool_confirmation_event = functions.generate_request_confirmation_event( - invocation_context, function_call_event, function_response_event + if has_streaming_tools: + # Use streaming handler + tool_confirmation_dict = ( + invocation_context.tool_confirmation_dict + if hasattr(invocation_context, 'tool_confirmation_dict') + else None ) - if tool_confirmation_event: - yield tool_confirmation_event + async for event in functions.handle_function_calls_async_with_streaming( + invocation_context, + function_calls, + llm_request.tools_dict, + tool_confirmation_dict, + ): + auth_event = functions.generate_auth_event(invocation_context, event) + if auth_event: + yield auth_event - # Always yield the function response event first - yield function_response_event + tool_confirmation_event = functions.generate_request_confirmation_event( + invocation_context, function_call_event, event + ) + if tool_confirmation_event: + yield tool_confirmation_event - # Check if this is a set_model_response function response - if json_response := _output_schema_processor.get_structured_model_response( - function_response_event + yield event + + # Check if this is a set_model_response function response + if json_response := ( + _output_schema_processor.get_structured_model_response(event) + ): + final_event = ( + _output_schema_processor.create_final_model_response_event( + invocation_context, json_response + ) + ) + yield final_event + else: + # Use regular handler + if function_response_event := await functions.handle_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict ): - # Create and yield a final model response event - final_event = ( - _output_schema_processor.create_final_model_response_event( - invocation_context, json_response - ) + auth_event = functions.generate_auth_event( + invocation_context, function_response_event ) - yield final_event + if auth_event: + yield auth_event + + tool_confirmation_event = functions.generate_request_confirmation_event( + invocation_context, function_call_event, function_response_event + ) + if tool_confirmation_event: + yield tool_confirmation_event + + # Always yield the function response event first + yield function_response_event + + # Check if this is a set_model_response function response + if json_response := ( + _output_schema_processor.get_structured_model_response( + function_response_event + ) + ): + # Create and yield a final model response event + final_event = ( + _output_schema_processor.create_final_model_response_event( + invocation_context, json_response + ) + ) + yield final_event transfer_to_agent = function_response_event.actions.transfer_to_agent if transfer_to_agent: agent_to_run = self._get_agent_to_run( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index ffe1657be1..614cc390ff 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -788,6 +788,204 @@ async def __call_tool_async( return await tool.run_async(args=args, tool_context=tool_context) +def _is_streaming_tool(tool: BaseTool) -> bool: + """Checks if a tool is a streaming tool (async generator function). + + Args: + tool: The tool to check. + + Returns: + True if the tool's function is an async generator, False otherwise. + """ + return hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func) + + +async def _execute_streaming_tool_async( + tool: BaseTool, + function_args: dict[str, Any], + tool_context: ToolContext, + invocation_context: InvocationContext, +) -> AsyncGenerator[Event, None]: + """Executes a streaming tool and yields Events for each intermediate result. + + Args: + tool: The streaming tool to execute. + function_args: The function call arguments. + tool_context: The tool context. + invocation_context: The invocation context. + + Yields: + Events for each intermediate result yielded by the streaming tool. + """ + task = None + try: + # Run before_tool_callbacks + function_response = ( + await invocation_context.plugin_manager.run_before_tool_callback( + tool=tool, tool_args=function_args, tool_context=tool_context + ) + ) + if function_response is not None: + # Plugin overrode the function response, yield it and return + event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + yield event + return + + # Create a queue to buffer results from the async generator + result_queue: asyncio.Queue[Optional[dict[str, Any]]] = asyncio.Queue() + + # Background task to run the generator and put results in queue + async def _run_generator(): + try: + # Get the async generator from the tool + agen = tool.func(**function_args) + async with Aclosing(agen) as gen: + async for result in gen: + await result_queue.put(result) + await result_queue.put(None) # Signal completion + except Exception as e: + await result_queue.put(e) # Signal error + + task = asyncio.create_task(_run_generator()) + + # Track the task for cancellation + if invocation_context.active_streaming_tools is None: + invocation_context.active_streaming_tools = {} + invocation_context.active_streaming_tools[tool.name] = ActiveStreamingTool( + task=task + ) + + # Yield Events as results come in + while True: + result = await result_queue.get() + if result is None: + # Generator completed normally + break + if isinstance(result, Exception): + # Generator raised an exception + raise result + + # Ensure result is a dict + if not isinstance(result, dict): + result = {'result': result} + + # Create and yield Event for this intermediate result + event = __build_response_event( + tool, result, tool_context, invocation_context + ) + yield event + + # Clean up + if tool.name in invocation_context.active_streaming_tools: + del invocation_context.active_streaming_tools[tool.name] + + except asyncio.CancelledError: + # Clean up on cancellation + if task and not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + if tool.name in invocation_context.active_streaming_tools: + del invocation_context.active_streaming_tools[tool.name] + raise + except Exception as e: + # Handle errors + if tool.name in invocation_context.active_streaming_tools: + del invocation_context.active_streaming_tools[tool.name] + + # Run error callbacks + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=e, + ) + ) + if error_response is not None: + event = __build_response_event( + tool, error_response, tool_context, invocation_context + ) + yield event + return + + # Re-raise if no error callback handled it + raise + + +async def handle_function_calls_async_with_streaming( + invocation_context: InvocationContext, + function_calls: list[types.FunctionCall], + tools_dict: dict[str, BaseTool], + tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None, +) -> AsyncGenerator[Event, None]: + """Handles function calls with support for streaming tools. + + Separates streaming tools from regular tools, processes regular tools + normally, and yields Events for streaming tools as they produce results. + + Args: + invocation_context: The invocation context. + function_calls: List of function calls to handle. + tools_dict: Dictionary of available tools. + tool_confirmation_dict: Optional dictionary of tool confirmations. + + Yields: + Events for function responses, including intermediate results from + streaming tools. + """ + from ...agents.llm_agent import LlmAgent + + agent = invocation_context.agent + + # Separate streaming and non-streaming tools + streaming_calls = [] + regular_calls = [] + + for function_call in function_calls: + tool = tools_dict.get(function_call.name) + if tool and _is_streaming_tool(tool): + streaming_calls.append(function_call) + else: + regular_calls.append(function_call) + + # Handle regular tools using existing logic + if regular_calls: + regular_event = await handle_function_call_list_async( + invocation_context, + regular_calls, + tools_dict, + filters=None, + tool_confirmation_dict=tool_confirmation_dict, + ) + if regular_event: + yield regular_event + + # Handle streaming tools + for function_call in streaming_calls: + tool = tools_dict[function_call.name] + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) + + tool_context = _create_tool_context( + invocation_context, + function_call, + tool_confirmation_dict.get(function_call.id) + if tool_confirmation_dict + else None, + ) + + async for event in _execute_streaming_tool_async( + tool, function_args, tool_context, invocation_context + ): + yield event + + def __build_response_event( tool: BaseTool, function_result: dict[str, object], diff --git a/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py b/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py new file mode 100644 index 0000000000..7245518712 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py @@ -0,0 +1,253 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for streaming tools in non-live mode.""" + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator + +from google.adk.agents.llm_agent import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.functions import _execute_streaming_tool_async +from google.adk.flows.llm_flows.functions import _is_streaming_tool +from google.adk.flows.llm_flows.functions import handle_function_calls_async_with_streaming +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + +from ... import testing_utils + + +def test_is_streaming_tool_detects_async_generator(): + """Test that _is_streaming_tool correctly identifies async generators.""" + + async def regular_function(x: int) -> int: + return x + 1 + + async def streaming_function(x: int) -> AsyncGenerator[dict, None]: + yield {'result': x + 1} + yield {'result': x + 2} + + regular_tool = FunctionTool(func=regular_function) + streaming_tool = FunctionTool(func=streaming_function) + + assert not _is_streaming_tool(regular_tool) + assert _is_streaming_tool(streaming_tool) + + +@pytest.mark.asyncio +async def test_streaming_tool_yields_multiple_events(): + """Test that a streaming tool yields multiple Events.""" + + async def monitor_stock(symbol: str) -> AsyncGenerator[dict, None]: + prices = [100, 105, 110] + for price in prices: + await asyncio.sleep(0.01) # Small delay for testing + yield {'symbol': symbol, 'price': price} + + tool = FunctionTool(func=monitor_stock) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-1', + ) + + events = [] + async for event in _execute_streaming_tool_async( + tool, + {'symbol': 'AAPL'}, + tool_context, + invocation_context, + ): + events.append(event) + + # Should have 3 events (one per price) + assert len(events) == 3 + assert all(isinstance(e, Event) for e in events) + assert all( + e.content and e.content.parts and e.content.parts[0].function_response + for e in events + ) + + +@pytest.mark.asyncio +async def test_streaming_tool_with_string_results(): + """Test streaming tool that yields string results.""" + + async def process_data(data: str) -> AsyncGenerator[str, None]: + for i, char in enumerate(data): + await asyncio.sleep(0.01) + yield f'Processed {i}: {char}' + + tool = FunctionTool(func=process_data) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-2', + ) + + events = [] + async for event in _execute_streaming_tool_async( + tool, + {'data': 'ABC'}, + tool_context, + invocation_context, + ): + events.append(event) + + assert len(events) == 3 + + +@pytest.mark.asyncio +async def test_streaming_tool_tracks_task_for_cancellation(): + """Test that streaming tool tasks are tracked for cancellation.""" + + async def long_running_task() -> AsyncGenerator[dict, None]: + for i in range(10): + await asyncio.sleep(0.1) + yield {'progress': i} + + tool = FunctionTool(func=long_running_task) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-3', + ) + + # Start the streaming tool + task = asyncio.create_task( + _execute_streaming_tool_async( + tool, {}, tool_context, invocation_context + ).__anext__() + ) + + # Wait a bit + await asyncio.sleep(0.05) + + # Check that task is tracked + assert invocation_context.active_streaming_tools is not None + assert tool.name in invocation_context.active_streaming_tools + assert invocation_context.active_streaming_tools[tool.name].task is not None + + # Cancel and clean up + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_streaming_tool_handles_errors(): + """Test error handling in streaming tools.""" + + async def failing_tool() -> AsyncGenerator[dict, None]: + yield {'status': 'started'} + raise ValueError('Test error') + yield {'status': 'never reached'} # type: ignore[unreachable] + + tool = FunctionTool(func=failing_tool) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-4', + ) + + events = [] + with pytest.raises(ValueError, match='Test error'): + async for event in _execute_streaming_tool_async( + tool, {}, tool_context, invocation_context + ): + events.append(event) + + # Should have yielded at least one event before error + assert len(events) >= 1 + + +@pytest.mark.asyncio +async def test_handle_function_calls_async_with_streaming_separates_tools(): + """Test that handler correctly separates streaming and non-streaming tools.""" + + async def regular_tool(x: int) -> int: + return x * 2 + + async def streaming_tool(x: int) -> AsyncGenerator[dict, None]: + yield {'value': x} + yield {'value': x * 2} + + regular = FunctionTool(func=regular_tool) + streaming = FunctionTool(func=streaming_tool) + + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + function_calls = [ + types.FunctionCall(name='regular_tool', args={'x': 5}, id='call-1'), + types.FunctionCall(name='streaming_tool', args={'x': 3}, id='call-2'), + ] + tools_dict = { + 'regular_tool': regular, + 'streaming_tool': streaming, + } + + events = [] + async for event in handle_function_calls_async_with_streaming( + invocation_context, function_calls, tools_dict + ): + events.append(event) + + # Should have 1 event from regular tool + 2 events from streaming tool + assert len(events) >= 3 + + +@pytest.mark.asyncio +async def test_streaming_tool_with_tool_context(): + """Test that tool_context is correctly passed to streaming tools.""" + + async def context_aware_tool(x: int) -> AsyncGenerator[dict, None]: + # Note: tool_context is injected by FunctionTool, not passed as arg + yield { + 'value': x, + 'status': 'processed', + } + + tool = FunctionTool(func=context_aware_tool) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-5', + ) + + events = [] + async for event in _execute_streaming_tool_async( + tool, + {'x': 42}, + tool_context, + invocation_context, + ): + events.append(event) + + assert len(events) == 1 + response = events[0].content.parts[0].function_response.response + assert response['value'] == 42 + assert response['status'] == 'processed' + # Verify the function_call_id is set correctly in the event + assert events[0].content.parts[0].function_response.id == 'test-call-5' From 5e9a352374d6243ebaa318a5b1d24b879e67f81a Mon Sep 17 00:00:00 2001 From: sarojrout Date: Fri, 5 Dec 2025 17:29:59 -0800 Subject: [PATCH 02/13] feat(samples): Add streaming tools non-live agent sample Adds a sample agent demonstrating streaming tools in non-live mode --- .../streaming_tools_non_live_agent/README.md | 42 ++++++ .../__init__.py | 15 ++ .../streaming_tools_non_live_agent/agent.py | 128 ++++++++++++++++++ 3 files changed, 185 insertions(+) create mode 100644 contributing/samples/streaming_tools_non_live_agent/README.md create mode 100644 contributing/samples/streaming_tools_non_live_agent/__init__.py create mode 100644 contributing/samples/streaming_tools_non_live_agent/agent.py diff --git a/contributing/samples/streaming_tools_non_live_agent/README.md b/contributing/samples/streaming_tools_non_live_agent/README.md new file mode 100644 index 0000000000..4396cb8d49 --- /dev/null +++ b/contributing/samples/streaming_tools_non_live_agent/README.md @@ -0,0 +1,42 @@ +# Streaming Tools Non-Live Agent + +This agent demonstrates streaming tools in non-live mode (run_async/SSE). + +## Features + +- **monitor_stock_price**: Monitors stock prices with real-time updates +- **process_large_dataset**: Processes datasets with progress updates +- **monitor_system_health**: Monitors system health metrics continuously + +## Testing + +### With ADK Web UI + +```bash +cd /Users/usharout/projects/adk-python +source .venv/bin/activate +adk web contributing/samples/streaming_tools_non_live_agent +``` + +Then try: +- "Monitor the stock price for AAPL" +- "Process a large dataset at /tmp/data.csv" +- "Monitor system health" + +### With ADK CLI + +```bash +cd /Users/usharout/projects/adk-python +source .venv/bin/activate +adk run contributing/samples/streaming_tools_non_live_agent +``` + +### With API Server (SSE) + +```bash +cd /Users/usharout/projects/adk-python +source .venv/bin/activate +adk api_server contributing/samples/streaming_tools_non_live_agent +``` + +Then send a POST request to `/run_sse` with `streaming: true` to see intermediate Events. diff --git a/contributing/samples/streaming_tools_non_live_agent/__init__.py b/contributing/samples/streaming_tools_non_live_agent/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/streaming_tools_non_live_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/streaming_tools_non_live_agent/agent.py b/contributing/samples/streaming_tools_non_live_agent/agent.py new file mode 100644 index 0000000000..95deaf673c --- /dev/null +++ b/contributing/samples/streaming_tools_non_live_agent/agent.py @@ -0,0 +1,128 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example agent demonstrating streaming tools in non-live mode (run_async/SSE). + +This agent shows how to use streaming tools that yield intermediate results +in non-live mode. Streaming tools work with both run_async and SSE endpoints. +""" + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator + +from google.adk.agents import Agent + + +async def monitor_stock_price(symbol: str) -> AsyncGenerator[dict, None]: + """Monitor stock price with real-time updates. + + This is a streaming tool that yields intermediate results as the stock + price changes. The agent can react to these intermediate results. + + Args: + symbol: The stock symbol to monitor (e.g., 'AAPL', 'GOOGL'). + + Yields: + Dictionary containing stock price updates with status indicators. + """ + # Simulate stock price changes + prices = [100, 105, 110, 108, 112, 115] + for i, price in enumerate(prices): + await asyncio.sleep(1) # Simulate real-time updates + yield { + 'symbol': symbol, + 'price': price, + 'update': i + 1, + 'status': 'streaming' if i < len(prices) - 1 else 'complete', + } + + +async def process_large_dataset(file_path: str) -> AsyncGenerator[dict, None]: + """Process dataset with progress updates. + + This streaming tool demonstrates how to provide progress feedback + for long-running operations. + + Args: + file_path: Path to the dataset file to process. + + Yields: + Dictionary containing progress information and final result. + """ + total_rows = 100 + processed = 0 + + # Simulate processing in batches + for batch in range(10): + await asyncio.sleep(0.5) # Simulate processing time + processed += 10 + yield { + 'progress': processed / total_rows, + 'processed': processed, + 'total': total_rows, + 'status': 'streaming', + 'message': f'Processed {processed}/{total_rows} rows', + } + + # Final result + yield { + 'result': 'Processing complete', + 'status': 'complete', + 'file_path': file_path, + 'total_processed': total_rows, + } + + +async def monitor_system_health() -> AsyncGenerator[dict, None]: + """Monitor system health metrics with continuous updates. + + This streaming tool demonstrates continuous monitoring that can be + stopped by the agent when thresholds are reached. + + Yields: + Dictionary containing system health metrics. + """ + metrics = [ + {'cpu': 45, 'memory': 60, 'disk': 70}, + {'cpu': 50, 'memory': 65, 'disk': 72}, + {'cpu': 55, 'memory': 70, 'disk': 75}, + {'cpu': 60, 'memory': 75, 'disk': 78}, + ] + + for i, metric in enumerate(metrics): + await asyncio.sleep(2) # Check every 2 seconds + yield { + 'metrics': metric, + 'timestamp': i + 1, + 'status': 'streaming' if i < len(metrics) - 1 else 'complete', + 'alert': 'high' if metric['cpu'] > 55 else 'normal', + } + + +root_agent = Agent( + name='streaming_tools_agent', + model='gemini-2.5-flash-lite', + instruction=( + 'You are a helpful assistant that can monitor stock prices, process' + ' datasets, and monitor system health using streaming tools. When' + ' using streaming tools, you will receive intermediate results that' + ' you can react to. For example, if monitoring stock prices, you can' + ' alert the user when prices change significantly. If processing a' + ' dataset, you can provide progress updates. If monitoring system' + ' health, you can alert when metrics exceed thresholds.' + ), + tools=[monitor_stock_price, process_large_dataset, monitor_system_health], +) From 9c6e11515c92962c93aaefb8a94e3ed00bee3fae Mon Sep 17 00:00:00 2001 From: sarojrout Date: Fri, 5 Dec 2025 17:37:47 -0800 Subject: [PATCH 03/13] docs(samples): Update README with instructions --- .../streaming_tools_non_live_agent/README.md | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/contributing/samples/streaming_tools_non_live_agent/README.md b/contributing/samples/streaming_tools_non_live_agent/README.md index 4396cb8d49..801073a783 100644 --- a/contributing/samples/streaming_tools_non_live_agent/README.md +++ b/contributing/samples/streaming_tools_non_live_agent/README.md @@ -13,9 +13,8 @@ This agent demonstrates streaming tools in non-live mode (run_async/SSE). ### With ADK Web UI ```bash -cd /Users/usharout/projects/adk-python -source .venv/bin/activate -adk web contributing/samples/streaming_tools_non_live_agent +cd contributing/samples +adk web . ``` Then try: @@ -26,17 +25,15 @@ Then try: ### With ADK CLI ```bash -cd /Users/usharout/projects/adk-python -source .venv/bin/activate -adk run contributing/samples/streaming_tools_non_live_agent +cd contributing/samples/streaming_tools_non_live_agent +adk run . ``` ### With API Server (SSE) ```bash -cd /Users/usharout/projects/adk-python -source .venv/bin/activate -adk api_server contributing/samples/streaming_tools_non_live_agent +cd contributing/samples +adk api_server . ``` Then send a POST request to `/run_sse` with `streaming: true` to see intermediate Events. From 69385a21c5a093e2f2f4545e3caceb5f372990e8 Mon Sep 17 00:00:00 2001 From: sarojrout Date: Fri, 5 Dec 2025 18:46:51 -0800 Subject: [PATCH 04/13] refactor(flows): review comments incorporated for streaming tools MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added agent transfer handling for streaming tools * Check transfer_to_agent after each streaming event * Transfer agent and exit streaming loop when detected * Enables agent reaction to intermediate results - Refactored duplicated code into _yield_function_response_events helper * Both streaming and regular paths use the same helper * Improves maintainability - Moved cleanup logic to finally block * Ensures cleanup always happens (3 places → 1 place) - Removed unused imports and variables * Removed unused LlmAgent import from handle_function_calls_async_with_streaming * Removed unused agent variable --- .../adk/flows/llm_flows/base_llm_flow.py | 121 ++++++++++-------- src/google/adk/flows/llm_flows/functions.py | 16 +-- 2 files changed, 73 insertions(+), 64 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index dd3f48b09c..0b69c238b8 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -672,6 +672,48 @@ async def _postprocess_run_processors_async( async for event in agen: yield event + async def _yield_function_response_events( + self, + invocation_context: InvocationContext, + function_call_event: Event, + function_response_event: Event, + ) -> AsyncGenerator[Event, None]: + """Yields auth, confirmation, and set_model_response events for a function response. + + Args: + invocation_context: The invocation context. + function_call_event: The original function call event. + function_response_event: The function response event. + + Yields: + Auth events, confirmation events, the function response event, and + set_model_response events if applicable. + """ + auth_event = functions.generate_auth_event( + invocation_context, function_response_event + ) + if auth_event: + yield auth_event + + tool_confirmation_event = functions.generate_request_confirmation_event( + invocation_context, function_call_event, function_response_event + ) + if tool_confirmation_event: + yield tool_confirmation_event + + yield function_response_event + + # Check if this is a set_model_response function response + if json_response := ( + _output_schema_processor.get_structured_model_response( + function_response_event + ) + ): + final_event = _output_schema_processor.create_final_model_response_event( + invocation_context, json_response + ) + yield final_event + async def _postprocess_handle_function_calls_async( self, invocation_context: InvocationContext, @@ -702,69 +744,44 @@ async def _postprocess_handle_function_calls_async( llm_request.tools_dict, tool_confirmation_dict, ): - auth_event = functions.generate_auth_event(invocation_context, event) - if auth_event: - yield auth_event - - tool_confirmation_event = functions.generate_request_confirmation_event( + async for secondary_event in self._yield_function_response_events( invocation_context, function_call_event, event - ) - if tool_confirmation_event: - yield tool_confirmation_event - - yield event - - # Check if this is a set_model_response function response - if json_response := ( - _output_schema_processor.get_structured_model_response(event) ): - final_event = ( - _output_schema_processor.create_final_model_response_event( - invocation_context, json_response - ) + yield secondary_event + + # Check for agent transfer after each streaming event + transfer_to_agent = event.actions.transfer_to_agent + if transfer_to_agent: + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent ) - yield final_event + async with Aclosing( + agent_to_run.run_async(invocation_context) + ) as agen: + async for transfer_event in agen: + yield transfer_event + # Agent transfer handled, exit the streaming loop + return else: # Use regular handler if function_response_event := await functions.handle_function_calls_async( invocation_context, function_call_event, llm_request.tools_dict ): - auth_event = functions.generate_auth_event( - invocation_context, function_response_event - ) - if auth_event: - yield auth_event - - tool_confirmation_event = functions.generate_request_confirmation_event( + async for secondary_event in self._yield_function_response_events( invocation_context, function_call_event, function_response_event - ) - if tool_confirmation_event: - yield tool_confirmation_event - - # Always yield the function response event first - yield function_response_event - - # Check if this is a set_model_response function response - if json_response := ( - _output_schema_processor.get_structured_model_response( - function_response_event - ) ): - # Create and yield a final model response event - final_event = ( - _output_schema_processor.create_final_model_response_event( - invocation_context, json_response - ) + yield secondary_event + + transfer_to_agent = function_response_event.actions.transfer_to_agent + if transfer_to_agent: + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent ) - yield final_event - transfer_to_agent = function_response_event.actions.transfer_to_agent - if transfer_to_agent: - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent - ) - async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: - async for event in agen: - yield event + async with Aclosing( + agent_to_run.run_async(invocation_context) + ) as agen: + async for event in agen: + yield event def _get_agent_to_run( self, invocation_context: InvocationContext, agent_name: str diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 614cc390ff..f615ba3206 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -877,10 +877,6 @@ async def _run_generator(): ) yield event - # Clean up - if tool.name in invocation_context.active_streaming_tools: - del invocation_context.active_streaming_tools[tool.name] - except asyncio.CancelledError: # Clean up on cancellation if task and not task.done(): @@ -889,13 +885,9 @@ async def _run_generator(): await task except (asyncio.CancelledError, Exception): pass - if tool.name in invocation_context.active_streaming_tools: - del invocation_context.active_streaming_tools[tool.name] raise except Exception as e: # Handle errors - if tool.name in invocation_context.active_streaming_tools: - del invocation_context.active_streaming_tools[tool.name] # Run error callbacks error_response = ( @@ -915,6 +907,10 @@ async def _run_generator(): # Re-raise if no error callback handled it raise + finally: + # Clean up active_streaming_tools tracking + if tool.name in invocation_context.active_streaming_tools: + del invocation_context.active_streaming_tools[tool.name] async def handle_function_calls_async_with_streaming( @@ -938,10 +934,6 @@ async def handle_function_calls_async_with_streaming( Events for function responses, including intermediate results from streaming tools. """ - from ...agents.llm_agent import LlmAgent - - agent = invocation_context.agent - # Separate streaming and non-streaming tools streaming_calls = [] regular_calls = [] From 1fda8097788854b590a38f5229182d0b475978eb Mon Sep 17 00:00:00 2001 From: sarojrout Date: Fri, 5 Dec 2025 19:15:26 -0800 Subject: [PATCH 05/13] perf(flows): Execute all tools concurrently and fix race condition in streaming tools --- src/google/adk/flows/llm_flows/functions.py | 105 +++++++++++++----- .../test_streaming_tools_non_live.py | 32 ++++-- 2 files changed, 98 insertions(+), 39 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index f615ba3206..183f77bc5f 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -851,6 +851,9 @@ async def _run_generator(): task = asyncio.create_task(_run_generator()) # Track the task for cancellation + # Initialize if None (for direct calls, e.g., in tests) + # Note: When called from handle_function_calls_async_with_streaming, + # this should already be initialized before concurrent tasks start if invocation_context.active_streaming_tools is None: invocation_context.active_streaming_tools = {} invocation_context.active_streaming_tools[tool.name] = ActiveStreamingTool( @@ -909,7 +912,10 @@ async def _run_generator(): raise finally: # Clean up active_streaming_tools tracking - if tool.name in invocation_context.active_streaming_tools: + if ( + invocation_context.active_streaming_tools + and tool.name in invocation_context.active_streaming_tools + ): del invocation_context.active_streaming_tools[tool.name] @@ -921,8 +927,9 @@ async def handle_function_calls_async_with_streaming( ) -> AsyncGenerator[Event, None]: """Handles function calls with support for streaming tools. - Separates streaming tools from regular tools, processes regular tools - normally, and yields Events for streaming tools as they produce results. + Executes all tools (both regular and streaming) concurrently and yields + Events as they become available. This maximizes asyncio performance by + allowing all tools to run in parallel. Args: invocation_context: The invocation context. @@ -932,8 +939,13 @@ async def handle_function_calls_async_with_streaming( Yields: Events for function responses, including intermediate results from - streaming tools. + streaming tools, as they become available from any tool. """ + # Initialize active_streaming_tools before starting any concurrent tasks + # to avoid race conditions + if invocation_context.active_streaming_tools is None: + invocation_context.active_streaming_tools = {} + # Separate streaming and non-streaming tools streaming_calls = [] regular_calls = [] @@ -945,38 +957,75 @@ async def handle_function_calls_async_with_streaming( else: regular_calls.append(function_call) - # Handle regular tools using existing logic + # Queue to merge events from all concurrent tasks + event_queue: asyncio.Queue[Optional[Event]] = asyncio.Queue() + active_tasks: list[asyncio.Task] = [] + + # Task to handle regular tools + async def _handle_regular_tools(): + if regular_calls: + regular_event = await handle_function_call_list_async( + invocation_context, + regular_calls, + tools_dict, + filters=None, + tool_confirmation_dict=tool_confirmation_dict, + ) + if regular_event: + await event_queue.put(regular_event) + await event_queue.put(None) # Signal completion + + # Task to handle a single streaming tool + async def _handle_streaming_tool(function_call: types.FunctionCall): + try: + tool = tools_dict[function_call.name] + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) + + tool_context = _create_tool_context( + invocation_context, + function_call, + tool_confirmation_dict.get(function_call.id) + if tool_confirmation_dict + else None, + ) + + async for event in _execute_streaming_tool_async( + tool, function_args, tool_context, invocation_context + ): + await event_queue.put(event) + finally: + await event_queue.put(None) # Signal completion + + # Start all tasks concurrently if regular_calls: - regular_event = await handle_function_call_list_async( - invocation_context, - regular_calls, - tools_dict, - filters=None, - tool_confirmation_dict=tool_confirmation_dict, - ) - if regular_event: - yield regular_event + active_tasks.append(asyncio.create_task(_handle_regular_tools())) - # Handle streaming tools for function_call in streaming_calls: - tool = tools_dict[function_call.name] - function_args = ( - copy.deepcopy(function_call.args) if function_call.args else {} + active_tasks.append( + asyncio.create_task(_handle_streaming_tool(function_call)) ) - tool_context = _create_tool_context( - invocation_context, - function_call, - tool_confirmation_dict.get(function_call.id) - if tool_confirmation_dict - else None, - ) + # If no tasks, return early + if not active_tasks: + return - async for event in _execute_streaming_tool_async( - tool, function_args, tool_context, invocation_context - ): + # Yield events as they arrive from any task + completed_tasks = 0 + total_tasks = len(active_tasks) + + while completed_tasks < total_tasks: + event = await event_queue.get() + if event is None: + # Task completed + completed_tasks += 1 + else: yield event + # Wait for all tasks to complete (in case of errors) + await asyncio.gather(*active_tasks, return_exceptions=True) + def __build_response_event( tool: BaseTool, diff --git a/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py b/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py index 7245518712..78309d3b9c 100644 --- a/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py +++ b/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py @@ -116,7 +116,7 @@ async def process_data(data: str) -> AsyncGenerator[str, None]: @pytest.mark.asyncio async def test_streaming_tool_tracks_task_for_cancellation(): - """Test that streaming tool tasks are tracked for cancellation.""" + """Test that streaming tool tasks are tracked for cancellation and cleaned up.""" async def long_running_task() -> AsyncGenerator[dict, None]: for i in range(10): @@ -131,14 +131,17 @@ async def long_running_task() -> AsyncGenerator[dict, None]: function_call_id='test-call-3', ) - # Start the streaming tool - task = asyncio.create_task( - _execute_streaming_tool_async( - tool, {}, tool_context, invocation_context - ).__anext__() - ) + # Start the streaming tool and consume it + async def consume_generator(): + async for event in _execute_streaming_tool_async( + tool, {}, tool_context, invocation_context + ): + # Consume events until cancelled + pass - # Wait a bit + consume_task = asyncio.create_task(consume_generator()) + + # Wait a bit for the tool to start await asyncio.sleep(0.05) # Check that task is tracked @@ -146,13 +149,20 @@ async def long_running_task() -> AsyncGenerator[dict, None]: assert tool.name in invocation_context.active_streaming_tools assert invocation_context.active_streaming_tools[tool.name].task is not None - # Cancel and clean up - task.cancel() + # Cancel the generator consumption task + consume_task.cancel() try: - await task + await consume_task except asyncio.CancelledError: pass + # Wait a bit for cleanup to complete + await asyncio.sleep(0.1) + + # Verify cleanup: task should be removed from active_streaming_tools + # The finally block in _execute_streaming_tool_async should have cleaned it up + assert tool.name not in invocation_context.active_streaming_tools + @pytest.mark.asyncio async def test_streaming_tool_handles_errors(): From 93da43ff405f577204fb12725f0401a6715861c0 Mon Sep 17 00:00:00 2001 From: sarojrout Date: Mon, 8 Dec 2025 22:20:04 -0800 Subject: [PATCH 06/13] minor review coment incorporated #3848 --- src/google/adk/flows/llm_flows/functions.py | 2 +- .../unittests/flows/llm_flows/test_streaming_tools_non_live.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 183f77bc5f..e1fa23199f 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -886,7 +886,7 @@ async def _run_generator(): task.cancel() try: await task - except (asyncio.CancelledError, Exception): + except asyncio.CancelledError: pass raise except Exception as e: diff --git a/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py b/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py index 78309d3b9c..2b3a7edd4c 100644 --- a/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py +++ b/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py @@ -224,7 +224,7 @@ async def streaming_tool(x: int) -> AsyncGenerator[dict, None]: events.append(event) # Should have 1 event from regular tool + 2 events from streaming tool - assert len(events) >= 3 + assert len(events) == 3 @pytest.mark.asyncio From bc155e459700d32d793865950a372953b7ae448d Mon Sep 17 00:00:00 2001 From: sarojrout Date: Tue, 9 Dec 2025 10:57:03 -0800 Subject: [PATCH 07/13] fix(flows): Checked and propagated exceptions from concurrent tool execution as per the review comments Fixed exception handling in handle_function_calls_async_with_streaming to prevent silent failures when tools raise exceptions. Changes are: - Stored results from asyncio.gather instead of discarding them - Checked each result for exceptions - Re-raised first exception encountered to signal failure This will ensure exceptions from failed tools are properly propagated instead of being silently swallowed. --- src/google/adk/flows/llm_flows/functions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index e1fa23199f..a19d8ccf40 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -1023,8 +1023,12 @@ async def _handle_streaming_tool(function_call: types.FunctionCall): else: yield event - # Wait for all tasks to complete (in case of errors) - await asyncio.gather(*active_tasks, return_exceptions=True) + # Wait for all tasks to complete and check for exceptions + results = await asyncio.gather(*active_tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + # Re-raise the first exception encountered to signal failure + raise result def __build_response_event( From f142af88fc3dc049c8f8eb272422d6296912c886 Mon Sep 17 00:00:00 2001 From: sarojrout Date: Tue, 9 Dec 2025 11:00:17 -0800 Subject: [PATCH 08/13] fix(flows): Correct result_queue type hint to Optional[Any] --- src/google/adk/flows/llm_flows/functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index a19d8ccf40..359325291e 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -834,7 +834,8 @@ async def _execute_streaming_tool_async( return # Create a queue to buffer results from the async generator - result_queue: asyncio.Queue[Optional[dict[str, Any]]] = asyncio.Queue() + # Queue can contain: dict results, non-dict types (e.g., strings), Exception objects, or None + result_queue: asyncio.Queue[Optional[Any]] = asyncio.Queue() # Background task to run the generator and put results in queue async def _run_generator(): From ab4df98150f3c72713e4b2375c151ed5e151c9f3 Mon Sep 17 00:00:00 2001 From: sarojrout Date: Tue, 9 Dec 2025 11:09:34 -0800 Subject: [PATCH 09/13] review comments incorporated and refactored by using getattr and made it more idiomatic and cleaner --- src/google/adk/flows/llm_flows/base_llm_flow.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 0b69c238b8..94e5aad03b 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -733,10 +733,8 @@ async def _postprocess_handle_function_calls_async( if has_streaming_tools: # Use streaming handler - tool_confirmation_dict = ( - invocation_context.tool_confirmation_dict - if hasattr(invocation_context, 'tool_confirmation_dict') - else None + tool_confirmation_dict = getattr( + invocation_context, 'tool_confirmation_dict', None ) async for event in functions.handle_function_calls_async_with_streaming( invocation_context, From 66e7046c378d912fe8a28efae59f73634be0e8e8 Mon Sep 17 00:00:00 2001 From: sarojrout Date: Wed, 10 Dec 2025 16:31:16 -0800 Subject: [PATCH 10/13] fixed the unit tests by wrapping _yield_function_response_events with Aclosing --- .../adk/flows/llm_flows/base_llm_flow.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 94e5aad03b..cdf8685dc9 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -677,7 +677,7 @@ async def _yield_function_response_events( invocation_context: InvocationContext, function_call_event: Event, function_response_event: Event, - ) -> AsyncGenerator[Event, None]: + ) -> AsyncGenerator[Event, None]: """Yields auth, confirmation, and set_model_response events for a function response. Args: @@ -742,10 +742,13 @@ async def _postprocess_handle_function_calls_async( llm_request.tools_dict, tool_confirmation_dict, ): - async for secondary_event in self._yield_function_response_events( - invocation_context, function_call_event, event - ): - yield secondary_event + async with Aclosing( + self._yield_function_response_events( + invocation_context, function_call_event, event + ) + ) as agen: + async for secondary_event in agen: + yield secondary_event # Check for agent transfer after each streaming event transfer_to_agent = event.actions.transfer_to_agent @@ -765,10 +768,13 @@ async def _postprocess_handle_function_calls_async( if function_response_event := await functions.handle_function_calls_async( invocation_context, function_call_event, llm_request.tools_dict ): - async for secondary_event in self._yield_function_response_events( - invocation_context, function_call_event, function_response_event - ): - yield secondary_event + async with Aclosing( + self._yield_function_response_events( + invocation_context, function_call_event, function_response_event + ) + ) as agen: + async for secondary_event in agen: + yield secondary_event transfer_to_agent = function_response_event.actions.transfer_to_agent if transfer_to_agent: From b3e98db789b06ff67cbd52cd5596680d1eab283e Mon Sep 17 00:00:00 2001 From: sarojrout Date: Wed, 10 Dec 2025 21:17:26 -0800 Subject: [PATCH 11/13] fix(flows): Apply FunctionTool argument preprocessing and confirmation to streaming tools as per the review comments #3848 --- src/google/adk/flows/llm_flows/functions.py | 21 ++++++ src/google/adk/tools/function_tool.py | 81 +++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 359325291e..253df52712 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -39,6 +39,7 @@ from ...telemetry.tracing import trace_tool_call from ...telemetry.tracing import tracer from ...tools.base_tool import BaseTool +from ...tools.function_tool import FunctionTool from ...tools.tool_confirmation import ToolConfirmation from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing @@ -833,6 +834,24 @@ async def _execute_streaming_tool_async( yield event return + # For FunctionTool, prepare arguments using the same logic as run_async + # This ensures argument preprocessing, tool_context injection, confirmation + # handling, and mandatory args validation are applied consistently. + if isinstance(tool, FunctionTool): + prepared_args, error_response = ( + await tool._prepare_args_and_check_confirmation( + args=function_args, tool_context=tool_context + ) + ) + if error_response is not None: + # Confirmation required/rejected or missing mandatory args + event = __build_response_event( + tool, error_response, tool_context, invocation_context + ) + yield event + return + function_args = prepared_args + # Create a queue to buffer results from the async generator # Queue can contain: dict results, non-dict types (e.g., strings), Exception objects, or None result_queue: asyncio.Queue[Optional[Any]] = asyncio.Queue() @@ -841,6 +860,8 @@ async def _execute_streaming_tool_async( async def _run_generator(): try: # Get the async generator from the tool + # For FunctionTool, function_args now includes preprocessed args and tool_context + # For other tools, use the original function_args agen = tool.func(**function_args) async with Aclosing(agen) as gen: async for result in gen: diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index d957d1c16b..beac21a490 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -153,6 +153,87 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]: return converted_args + async def _prepare_args_and_check_confirmation( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> tuple[dict[str, Any], Optional[dict[str, Any]]]: + """Prepares arguments and checks confirmation for function invocation. + + This method extracts the argument preparation and confirmation logic from + run_async so it can be reused by streaming tools and other callers. + + Args: + args: Raw arguments from the LLM tool call. + tool_context: The tool context. + + Returns: + A tuple of (prepared_args, error_response). If error_response is not None, + it indicates that the tool call should not proceed (e.g., missing args or + confirmation required/rejected). Otherwise, prepared_args contains the + processed arguments ready for function invocation. + """ + # Preprocess arguments (includes Pydantic model conversion) + args_to_call = self._preprocess_args(args) + + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context + + # Filter args_to_call to only include valid parameters for the function + args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params} + + # Before invoking the function, we check for if the list of args passed in + # has all the mandatory arguments or not. + # If the check fails, then we don't invoke the tool and let the Agent know + # that there was a missing input parameter. This will basically help + # the underlying model fix the issue and retry. + mandatory_args = self._get_mandatory_args() + missing_mandatory_args = [ + arg for arg in mandatory_args if arg not in args_to_call + ] + + if missing_mandatory_args: + missing_mandatory_args_str = '\n'.join(missing_mandatory_args) + error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present: +{missing_mandatory_args_str} +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + return (args_to_call, {'error': error_str}) + + if isinstance(self._require_confirmation, Callable): + require_confirmation = await self._invoke_callable( + self._require_confirmation, args_to_call + ) + else: + require_confirmation = bool(self._require_confirmation) + + if require_confirmation: + if not tool_context.tool_confirmation: + args_to_show = args_to_call.copy() + if 'tool_context' in args_to_show: + args_to_show.pop('tool_context') + + tool_context.request_confirmation( + hint=( + f'Please approve or reject the tool call {self.name}() by' + ' responding with a FunctionResponse with an expected' + ' ToolConfirmation payload.' + ), + ) + tool_context.actions.skip_summarization = True + return ( + args_to_call, + { + 'error': ( + 'This tool call requires confirmation, please approve or' + ' reject.' + ) + }, + ) + elif not tool_context.tool_confirmation.confirmed: + return (args_to_call, {'error': 'This tool call is rejected.'}) + + return (args_to_call, None) + @override async def run_async( self, *, args: dict[str, Any], tool_context: ToolContext From d1fae7bbf9e0343fbfaaf31f765d7b16fefea764 Mon Sep 17 00:00:00 2001 From: sarojrout Date: Wed, 10 Dec 2025 22:32:17 -0800 Subject: [PATCH 12/13] refactor(flows): Extracted duplicate agent transfer logic into helper method --- .../adk/flows/llm_flows/base_llm_flow.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index cdf8685dc9..2090a3dedf 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -753,14 +753,10 @@ async def _postprocess_handle_function_calls_async( # Check for agent transfer after each streaming event transfer_to_agent = event.actions.transfer_to_agent if transfer_to_agent: - agent_to_run = self._get_agent_to_run( + async for transfer_event in self._handle_agent_transfer( invocation_context, transfer_to_agent - ) - async with Aclosing( - agent_to_run.run_async(invocation_context) - ) as agen: - async for transfer_event in agen: - yield transfer_event + ): + yield transfer_event # Agent transfer handled, exit the streaming loop return else: @@ -778,14 +774,10 @@ async def _postprocess_handle_function_calls_async( transfer_to_agent = function_response_event.actions.transfer_to_agent if transfer_to_agent: - agent_to_run = self._get_agent_to_run( + async for event in self._handle_agent_transfer( invocation_context, transfer_to_agent - ) - async with Aclosing( - agent_to_run.run_async(invocation_context) - ) as agen: - async for event in agen: - yield event + ): + yield event def _get_agent_to_run( self, invocation_context: InvocationContext, agent_name: str @@ -796,6 +788,23 @@ def _get_agent_to_run( raise ValueError(f'Agent {agent_name} not found in the agent tree.') return agent_to_run + async def _handle_agent_transfer( + self, invocation_context: InvocationContext, agent_name: str + ) -> AsyncGenerator[Event, None]: + """Handles agent transfer by running the specified agent and yielding its events. + + Args: + invocation_context: The invocation context. + agent_name: The name of the agent to transfer to. + + Yields: + Events from the transferred agent's execution. + """ + agent_to_run = self._get_agent_to_run(invocation_context, agent_name) + async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: + async for event in agen: + yield event + async def _call_llm_async( self, invocation_context: InvocationContext, From 1f47d4dc5fcd1d6f1282ec191fe6c2bb79af8182 Mon Sep 17 00:00:00 2001 From: sarojrout Date: Thu, 11 Dec 2025 11:28:09 -0800 Subject: [PATCH 13/13] small fix to reduce duplication --- src/google/adk/flows/llm_flows/base_llm_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 2090a3dedf..afc64e8bd5 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -677,7 +677,7 @@ async def _yield_function_response_events( invocation_context: InvocationContext, function_call_event: Event, function_response_event: Event, - ) -> AsyncGenerator[Event, None]: + ) -> AsyncGenerator[Event, None]: """Yields auth, confirmation, and set_model_response events for a function response. Args: