diff --git a/contributing/samples/streaming_tools_non_live_agent/README.md b/contributing/samples/streaming_tools_non_live_agent/README.md new file mode 100644 index 0000000000..801073a783 --- /dev/null +++ b/contributing/samples/streaming_tools_non_live_agent/README.md @@ -0,0 +1,39 @@ +# Streaming Tools Non-Live Agent + +This agent demonstrates streaming tools in non-live mode (run_async/SSE). + +## Features + +- **monitor_stock_price**: Monitors stock prices with real-time updates +- **process_large_dataset**: Processes datasets with progress updates +- **monitor_system_health**: Monitors system health metrics continuously + +## Testing + +### With ADK Web UI + +```bash +cd contributing/samples +adk web . +``` + +Then try: +- "Monitor the stock price for AAPL" +- "Process a large dataset at /tmp/data.csv" +- "Monitor system health" + +### With ADK CLI + +```bash +cd contributing/samples/streaming_tools_non_live_agent +adk run . +``` + +### With API Server (SSE) + +```bash +cd contributing/samples +adk api_server . +``` + +Then send a POST request to `/run_sse` with `streaming: true` to see intermediate Events. diff --git a/contributing/samples/streaming_tools_non_live_agent/__init__.py b/contributing/samples/streaming_tools_non_live_agent/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/streaming_tools_non_live_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/streaming_tools_non_live_agent/agent.py b/contributing/samples/streaming_tools_non_live_agent/agent.py new file mode 100644 index 0000000000..95deaf673c --- /dev/null +++ b/contributing/samples/streaming_tools_non_live_agent/agent.py @@ -0,0 +1,128 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example agent demonstrating streaming tools in non-live mode (run_async/SSE). + +This agent shows how to use streaming tools that yield intermediate results +in non-live mode. Streaming tools work with both run_async and SSE endpoints. +""" + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator + +from google.adk.agents import Agent + + +async def monitor_stock_price(symbol: str) -> AsyncGenerator[dict, None]: + """Monitor stock price with real-time updates. + + This is a streaming tool that yields intermediate results as the stock + price changes. The agent can react to these intermediate results. + + Args: + symbol: The stock symbol to monitor (e.g., 'AAPL', 'GOOGL'). + + Yields: + Dictionary containing stock price updates with status indicators. + """ + # Simulate stock price changes + prices = [100, 105, 110, 108, 112, 115] + for i, price in enumerate(prices): + await asyncio.sleep(1) # Simulate real-time updates + yield { + 'symbol': symbol, + 'price': price, + 'update': i + 1, + 'status': 'streaming' if i < len(prices) - 1 else 'complete', + } + + +async def process_large_dataset(file_path: str) -> AsyncGenerator[dict, None]: + """Process dataset with progress updates. + + This streaming tool demonstrates how to provide progress feedback + for long-running operations. + + Args: + file_path: Path to the dataset file to process. + + Yields: + Dictionary containing progress information and final result. + """ + total_rows = 100 + processed = 0 + + # Simulate processing in batches + for batch in range(10): + await asyncio.sleep(0.5) # Simulate processing time + processed += 10 + yield { + 'progress': processed / total_rows, + 'processed': processed, + 'total': total_rows, + 'status': 'streaming', + 'message': f'Processed {processed}/{total_rows} rows', + } + + # Final result + yield { + 'result': 'Processing complete', + 'status': 'complete', + 'file_path': file_path, + 'total_processed': total_rows, + } + + +async def monitor_system_health() -> AsyncGenerator[dict, None]: + """Monitor system health metrics with continuous updates. + + This streaming tool demonstrates continuous monitoring that can be + stopped by the agent when thresholds are reached. + + Yields: + Dictionary containing system health metrics. + """ + metrics = [ + {'cpu': 45, 'memory': 60, 'disk': 70}, + {'cpu': 50, 'memory': 65, 'disk': 72}, + {'cpu': 55, 'memory': 70, 'disk': 75}, + {'cpu': 60, 'memory': 75, 'disk': 78}, + ] + + for i, metric in enumerate(metrics): + await asyncio.sleep(2) # Check every 2 seconds + yield { + 'metrics': metric, + 'timestamp': i + 1, + 'status': 'streaming' if i < len(metrics) - 1 else 'complete', + 'alert': 'high' if metric['cpu'] > 55 else 'normal', + } + + +root_agent = Agent( + name='streaming_tools_agent', + model='gemini-2.5-flash-lite', + instruction=( + 'You are a helpful assistant that can monitor stock prices, process' + ' datasets, and monitor system health using streaming tools. When' + ' using streaming tools, you will receive intermediate results that' + ' you can react to. For example, if monitoring stock prices, you can' + ' alert the user when prices change significantly. If processing a' + ' dataset, you can provide progress updates. If monitoring system' + ' health, you can alert when metrics exceed thresholds.' + ), + tools=[monitor_stock_price, process_large_dataset, monitor_system_health], +) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 824cd26be1..afc64e8bd5 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -672,48 +672,111 @@ async def _postprocess_run_processors_async( async for event in agen: yield event - async def _postprocess_handle_function_calls_async( + async def _yield_function_response_events( self, invocation_context: InvocationContext, function_call_event: Event, - llm_request: LlmRequest, + function_response_event: Event, ) -> AsyncGenerator[Event, None]: - if function_response_event := await functions.handle_function_calls_async( - invocation_context, function_call_event, llm_request.tools_dict + """Yields auth, confirmation, and set_model_response events for a function response. + + Args: + invocation_context: The invocation context. + function_call_event: The original function call event. + function_response_event: The function response event. + + Yields: + Auth events, confirmation events, the function response event, and + set_model_response events if applicable. + """ + auth_event = functions.generate_auth_event( + invocation_context, function_response_event + ) + if auth_event: + yield auth_event + + tool_confirmation_event = functions.generate_request_confirmation_event( + invocation_context, function_call_event, function_response_event + ) + if tool_confirmation_event: + yield tool_confirmation_event + + yield function_response_event + + # Check if this is a set_model_response function response + if json_response := ( + _output_schema_processor.get_structured_model_response( + function_response_event + ) ): - auth_event = functions.generate_auth_event( - invocation_context, function_response_event + final_event = _output_schema_processor.create_final_model_response_event( + invocation_context, json_response ) - if auth_event: - yield auth_event + yield final_event - tool_confirmation_event = functions.generate_request_confirmation_event( - invocation_context, function_call_event, function_response_event - ) - if tool_confirmation_event: - yield tool_confirmation_event + async def _postprocess_handle_function_calls_async( + self, + invocation_context: InvocationContext, + function_call_event: Event, + llm_request: LlmRequest, + ) -> AsyncGenerator[Event, None]: + function_calls = function_call_event.get_function_calls() + if not function_calls: + return - # Always yield the function response event first - yield function_response_event + # Check if any tools are streaming tools + has_streaming_tools = any( + functions._is_streaming_tool(tool) + for call in function_calls + if (tool := llm_request.tools_dict.get(call.name)) + ) - # Check if this is a set_model_response function response - if json_response := _output_schema_processor.get_structured_model_response( - function_response_event + if has_streaming_tools: + # Use streaming handler + tool_confirmation_dict = getattr( + invocation_context, 'tool_confirmation_dict', None + ) + async for event in functions.handle_function_calls_async_with_streaming( + invocation_context, + function_calls, + llm_request.tools_dict, + tool_confirmation_dict, ): - # Create and yield a final model response event - final_event = ( - _output_schema_processor.create_final_model_response_event( - invocation_context, json_response + async with Aclosing( + self._yield_function_response_events( + invocation_context, function_call_event, event ) - ) - yield final_event - transfer_to_agent = function_response_event.actions.transfer_to_agent - if transfer_to_agent: - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent - ) - async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: - async for event in agen: + ) as agen: + async for secondary_event in agen: + yield secondary_event + + # Check for agent transfer after each streaming event + transfer_to_agent = event.actions.transfer_to_agent + if transfer_to_agent: + async for transfer_event in self._handle_agent_transfer( + invocation_context, transfer_to_agent + ): + yield transfer_event + # Agent transfer handled, exit the streaming loop + return + else: + # Use regular handler + if function_response_event := await functions.handle_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ): + async with Aclosing( + self._yield_function_response_events( + invocation_context, function_call_event, function_response_event + ) + ) as agen: + async for secondary_event in agen: + yield secondary_event + + transfer_to_agent = function_response_event.actions.transfer_to_agent + if transfer_to_agent: + async for event in self._handle_agent_transfer( + invocation_context, transfer_to_agent + ): yield event def _get_agent_to_run( @@ -725,6 +788,23 @@ def _get_agent_to_run( raise ValueError(f'Agent {agent_name} not found in the agent tree.') return agent_to_run + async def _handle_agent_transfer( + self, invocation_context: InvocationContext, agent_name: str + ) -> AsyncGenerator[Event, None]: + """Handles agent transfer by running the specified agent and yielding its events. + + Args: + invocation_context: The invocation context. + agent_name: The name of the agent to transfer to. + + Yields: + Events from the transferred agent's execution. + """ + agent_to_run = self._get_agent_to_run(invocation_context, agent_name) + async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: + async for event in agen: + yield event + async def _call_llm_async( self, invocation_context: InvocationContext, diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index ffe1657be1..253df52712 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -39,6 +39,7 @@ from ...telemetry.tracing import trace_tool_call from ...telemetry.tracing import tracer from ...tools.base_tool import BaseTool +from ...tools.function_tool import FunctionTool from ...tools.tool_confirmation import ToolConfirmation from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing @@ -788,6 +789,270 @@ async def __call_tool_async( return await tool.run_async(args=args, tool_context=tool_context) +def _is_streaming_tool(tool: BaseTool) -> bool: + """Checks if a tool is a streaming tool (async generator function). + + Args: + tool: The tool to check. + + Returns: + True if the tool's function is an async generator, False otherwise. + """ + return hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func) + + +async def _execute_streaming_tool_async( + tool: BaseTool, + function_args: dict[str, Any], + tool_context: ToolContext, + invocation_context: InvocationContext, +) -> AsyncGenerator[Event, None]: + """Executes a streaming tool and yields Events for each intermediate result. + + Args: + tool: The streaming tool to execute. + function_args: The function call arguments. + tool_context: The tool context. + invocation_context: The invocation context. + + Yields: + Events for each intermediate result yielded by the streaming tool. + """ + task = None + try: + # Run before_tool_callbacks + function_response = ( + await invocation_context.plugin_manager.run_before_tool_callback( + tool=tool, tool_args=function_args, tool_context=tool_context + ) + ) + if function_response is not None: + # Plugin overrode the function response, yield it and return + event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + yield event + return + + # For FunctionTool, prepare arguments using the same logic as run_async + # This ensures argument preprocessing, tool_context injection, confirmation + # handling, and mandatory args validation are applied consistently. + if isinstance(tool, FunctionTool): + prepared_args, error_response = ( + await tool._prepare_args_and_check_confirmation( + args=function_args, tool_context=tool_context + ) + ) + if error_response is not None: + # Confirmation required/rejected or missing mandatory args + event = __build_response_event( + tool, error_response, tool_context, invocation_context + ) + yield event + return + function_args = prepared_args + + # Create a queue to buffer results from the async generator + # Queue can contain: dict results, non-dict types (e.g., strings), Exception objects, or None + result_queue: asyncio.Queue[Optional[Any]] = asyncio.Queue() + + # Background task to run the generator and put results in queue + async def _run_generator(): + try: + # Get the async generator from the tool + # For FunctionTool, function_args now includes preprocessed args and tool_context + # For other tools, use the original function_args + agen = tool.func(**function_args) + async with Aclosing(agen) as gen: + async for result in gen: + await result_queue.put(result) + await result_queue.put(None) # Signal completion + except Exception as e: + await result_queue.put(e) # Signal error + + task = asyncio.create_task(_run_generator()) + + # Track the task for cancellation + # Initialize if None (for direct calls, e.g., in tests) + # Note: When called from handle_function_calls_async_with_streaming, + # this should already be initialized before concurrent tasks start + if invocation_context.active_streaming_tools is None: + invocation_context.active_streaming_tools = {} + invocation_context.active_streaming_tools[tool.name] = ActiveStreamingTool( + task=task + ) + + # Yield Events as results come in + while True: + result = await result_queue.get() + if result is None: + # Generator completed normally + break + if isinstance(result, Exception): + # Generator raised an exception + raise result + + # Ensure result is a dict + if not isinstance(result, dict): + result = {'result': result} + + # Create and yield Event for this intermediate result + event = __build_response_event( + tool, result, tool_context, invocation_context + ) + yield event + + except asyncio.CancelledError: + # Clean up on cancellation + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + raise + except Exception as e: + # Handle errors + + # Run error callbacks + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=e, + ) + ) + if error_response is not None: + event = __build_response_event( + tool, error_response, tool_context, invocation_context + ) + yield event + return + + # Re-raise if no error callback handled it + raise + finally: + # Clean up active_streaming_tools tracking + if ( + invocation_context.active_streaming_tools + and tool.name in invocation_context.active_streaming_tools + ): + del invocation_context.active_streaming_tools[tool.name] + + +async def handle_function_calls_async_with_streaming( + invocation_context: InvocationContext, + function_calls: list[types.FunctionCall], + tools_dict: dict[str, BaseTool], + tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None, +) -> AsyncGenerator[Event, None]: + """Handles function calls with support for streaming tools. + + Executes all tools (both regular and streaming) concurrently and yields + Events as they become available. This maximizes asyncio performance by + allowing all tools to run in parallel. + + Args: + invocation_context: The invocation context. + function_calls: List of function calls to handle. + tools_dict: Dictionary of available tools. + tool_confirmation_dict: Optional dictionary of tool confirmations. + + Yields: + Events for function responses, including intermediate results from + streaming tools, as they become available from any tool. + """ + # Initialize active_streaming_tools before starting any concurrent tasks + # to avoid race conditions + if invocation_context.active_streaming_tools is None: + invocation_context.active_streaming_tools = {} + + # Separate streaming and non-streaming tools + streaming_calls = [] + regular_calls = [] + + for function_call in function_calls: + tool = tools_dict.get(function_call.name) + if tool and _is_streaming_tool(tool): + streaming_calls.append(function_call) + else: + regular_calls.append(function_call) + + # Queue to merge events from all concurrent tasks + event_queue: asyncio.Queue[Optional[Event]] = asyncio.Queue() + active_tasks: list[asyncio.Task] = [] + + # Task to handle regular tools + async def _handle_regular_tools(): + if regular_calls: + regular_event = await handle_function_call_list_async( + invocation_context, + regular_calls, + tools_dict, + filters=None, + tool_confirmation_dict=tool_confirmation_dict, + ) + if regular_event: + await event_queue.put(regular_event) + await event_queue.put(None) # Signal completion + + # Task to handle a single streaming tool + async def _handle_streaming_tool(function_call: types.FunctionCall): + try: + tool = tools_dict[function_call.name] + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) + + tool_context = _create_tool_context( + invocation_context, + function_call, + tool_confirmation_dict.get(function_call.id) + if tool_confirmation_dict + else None, + ) + + async for event in _execute_streaming_tool_async( + tool, function_args, tool_context, invocation_context + ): + await event_queue.put(event) + finally: + await event_queue.put(None) # Signal completion + + # Start all tasks concurrently + if regular_calls: + active_tasks.append(asyncio.create_task(_handle_regular_tools())) + + for function_call in streaming_calls: + active_tasks.append( + asyncio.create_task(_handle_streaming_tool(function_call)) + ) + + # If no tasks, return early + if not active_tasks: + return + + # Yield events as they arrive from any task + completed_tasks = 0 + total_tasks = len(active_tasks) + + while completed_tasks < total_tasks: + event = await event_queue.get() + if event is None: + # Task completed + completed_tasks += 1 + else: + yield event + + # Wait for all tasks to complete and check for exceptions + results = await asyncio.gather(*active_tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + # Re-raise the first exception encountered to signal failure + raise result + + def __build_response_event( tool: BaseTool, function_result: dict[str, object], diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index d957d1c16b..beac21a490 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -153,6 +153,87 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]: return converted_args + async def _prepare_args_and_check_confirmation( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> tuple[dict[str, Any], Optional[dict[str, Any]]]: + """Prepares arguments and checks confirmation for function invocation. + + This method extracts the argument preparation and confirmation logic from + run_async so it can be reused by streaming tools and other callers. + + Args: + args: Raw arguments from the LLM tool call. + tool_context: The tool context. + + Returns: + A tuple of (prepared_args, error_response). If error_response is not None, + it indicates that the tool call should not proceed (e.g., missing args or + confirmation required/rejected). Otherwise, prepared_args contains the + processed arguments ready for function invocation. + """ + # Preprocess arguments (includes Pydantic model conversion) + args_to_call = self._preprocess_args(args) + + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context + + # Filter args_to_call to only include valid parameters for the function + args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params} + + # Before invoking the function, we check for if the list of args passed in + # has all the mandatory arguments or not. + # If the check fails, then we don't invoke the tool and let the Agent know + # that there was a missing input parameter. This will basically help + # the underlying model fix the issue and retry. + mandatory_args = self._get_mandatory_args() + missing_mandatory_args = [ + arg for arg in mandatory_args if arg not in args_to_call + ] + + if missing_mandatory_args: + missing_mandatory_args_str = '\n'.join(missing_mandatory_args) + error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present: +{missing_mandatory_args_str} +You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters.""" + return (args_to_call, {'error': error_str}) + + if isinstance(self._require_confirmation, Callable): + require_confirmation = await self._invoke_callable( + self._require_confirmation, args_to_call + ) + else: + require_confirmation = bool(self._require_confirmation) + + if require_confirmation: + if not tool_context.tool_confirmation: + args_to_show = args_to_call.copy() + if 'tool_context' in args_to_show: + args_to_show.pop('tool_context') + + tool_context.request_confirmation( + hint=( + f'Please approve or reject the tool call {self.name}() by' + ' responding with a FunctionResponse with an expected' + ' ToolConfirmation payload.' + ), + ) + tool_context.actions.skip_summarization = True + return ( + args_to_call, + { + 'error': ( + 'This tool call requires confirmation, please approve or' + ' reject.' + ) + }, + ) + elif not tool_context.tool_confirmation.confirmed: + return (args_to_call, {'error': 'This tool call is rejected.'}) + + return (args_to_call, None) + @override async def run_async( self, *, args: dict[str, Any], tool_context: ToolContext diff --git a/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py b/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py new file mode 100644 index 0000000000..2b3a7edd4c --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_streaming_tools_non_live.py @@ -0,0 +1,263 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for streaming tools in non-live mode.""" + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator + +from google.adk.agents.llm_agent import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.functions import _execute_streaming_tool_async +from google.adk.flows.llm_flows.functions import _is_streaming_tool +from google.adk.flows.llm_flows.functions import handle_function_calls_async_with_streaming +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + +from ... import testing_utils + + +def test_is_streaming_tool_detects_async_generator(): + """Test that _is_streaming_tool correctly identifies async generators.""" + + async def regular_function(x: int) -> int: + return x + 1 + + async def streaming_function(x: int) -> AsyncGenerator[dict, None]: + yield {'result': x + 1} + yield {'result': x + 2} + + regular_tool = FunctionTool(func=regular_function) + streaming_tool = FunctionTool(func=streaming_function) + + assert not _is_streaming_tool(regular_tool) + assert _is_streaming_tool(streaming_tool) + + +@pytest.mark.asyncio +async def test_streaming_tool_yields_multiple_events(): + """Test that a streaming tool yields multiple Events.""" + + async def monitor_stock(symbol: str) -> AsyncGenerator[dict, None]: + prices = [100, 105, 110] + for price in prices: + await asyncio.sleep(0.01) # Small delay for testing + yield {'symbol': symbol, 'price': price} + + tool = FunctionTool(func=monitor_stock) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-1', + ) + + events = [] + async for event in _execute_streaming_tool_async( + tool, + {'symbol': 'AAPL'}, + tool_context, + invocation_context, + ): + events.append(event) + + # Should have 3 events (one per price) + assert len(events) == 3 + assert all(isinstance(e, Event) for e in events) + assert all( + e.content and e.content.parts and e.content.parts[0].function_response + for e in events + ) + + +@pytest.mark.asyncio +async def test_streaming_tool_with_string_results(): + """Test streaming tool that yields string results.""" + + async def process_data(data: str) -> AsyncGenerator[str, None]: + for i, char in enumerate(data): + await asyncio.sleep(0.01) + yield f'Processed {i}: {char}' + + tool = FunctionTool(func=process_data) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-2', + ) + + events = [] + async for event in _execute_streaming_tool_async( + tool, + {'data': 'ABC'}, + tool_context, + invocation_context, + ): + events.append(event) + + assert len(events) == 3 + + +@pytest.mark.asyncio +async def test_streaming_tool_tracks_task_for_cancellation(): + """Test that streaming tool tasks are tracked for cancellation and cleaned up.""" + + async def long_running_task() -> AsyncGenerator[dict, None]: + for i in range(10): + await asyncio.sleep(0.1) + yield {'progress': i} + + tool = FunctionTool(func=long_running_task) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-3', + ) + + # Start the streaming tool and consume it + async def consume_generator(): + async for event in _execute_streaming_tool_async( + tool, {}, tool_context, invocation_context + ): + # Consume events until cancelled + pass + + consume_task = asyncio.create_task(consume_generator()) + + # Wait a bit for the tool to start + await asyncio.sleep(0.05) + + # Check that task is tracked + assert invocation_context.active_streaming_tools is not None + assert tool.name in invocation_context.active_streaming_tools + assert invocation_context.active_streaming_tools[tool.name].task is not None + + # Cancel the generator consumption task + consume_task.cancel() + try: + await consume_task + except asyncio.CancelledError: + pass + + # Wait a bit for cleanup to complete + await asyncio.sleep(0.1) + + # Verify cleanup: task should be removed from active_streaming_tools + # The finally block in _execute_streaming_tool_async should have cleaned it up + assert tool.name not in invocation_context.active_streaming_tools + + +@pytest.mark.asyncio +async def test_streaming_tool_handles_errors(): + """Test error handling in streaming tools.""" + + async def failing_tool() -> AsyncGenerator[dict, None]: + yield {'status': 'started'} + raise ValueError('Test error') + yield {'status': 'never reached'} # type: ignore[unreachable] + + tool = FunctionTool(func=failing_tool) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-4', + ) + + events = [] + with pytest.raises(ValueError, match='Test error'): + async for event in _execute_streaming_tool_async( + tool, {}, tool_context, invocation_context + ): + events.append(event) + + # Should have yielded at least one event before error + assert len(events) >= 1 + + +@pytest.mark.asyncio +async def test_handle_function_calls_async_with_streaming_separates_tools(): + """Test that handler correctly separates streaming and non-streaming tools.""" + + async def regular_tool(x: int) -> int: + return x * 2 + + async def streaming_tool(x: int) -> AsyncGenerator[dict, None]: + yield {'value': x} + yield {'value': x * 2} + + regular = FunctionTool(func=regular_tool) + streaming = FunctionTool(func=streaming_tool) + + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + function_calls = [ + types.FunctionCall(name='regular_tool', args={'x': 5}, id='call-1'), + types.FunctionCall(name='streaming_tool', args={'x': 3}, id='call-2'), + ] + tools_dict = { + 'regular_tool': regular, + 'streaming_tool': streaming, + } + + events = [] + async for event in handle_function_calls_async_with_streaming( + invocation_context, function_calls, tools_dict + ): + events.append(event) + + # Should have 1 event from regular tool + 2 events from streaming tool + assert len(events) == 3 + + +@pytest.mark.asyncio +async def test_streaming_tool_with_tool_context(): + """Test that tool_context is correctly passed to streaming tools.""" + + async def context_aware_tool(x: int) -> AsyncGenerator[dict, None]: + # Note: tool_context is injected by FunctionTool, not passed as arg + yield { + 'value': x, + 'status': 'processed', + } + + tool = FunctionTool(func=context_aware_tool) + agent = Agent(name='test_agent') + invocation_context = await testing_utils.create_invocation_context(agent) + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id='test-call-5', + ) + + events = [] + async for event in _execute_streaming_tool_async( + tool, + {'x': 42}, + tool_context, + invocation_context, + ): + events.append(event) + + assert len(events) == 1 + response = events[0].content.parts[0].function_response.response + assert response['value'] == 42 + assert response['status'] == 'processed' + # Verify the function_call_id is set correctly in the event + assert events[0].content.parts[0].function_response.id == 'test-call-5'