Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/google/adk/tools/_automatic_function_calling_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,10 @@ def from_function_with_options(

return_annotation = inspect.signature(func).return_annotation

# Resolve deferred type hints.
if 'return' in annotation_under_future:
return_annotation = annotation_under_future['return']

# Handle AsyncGenerator and Generator return types (streaming tools)
# AsyncGenerator[YieldType, SendType] -> use YieldType as response schema
# Generator[YieldType, SendType, ReturnType] -> use YieldType as response schema
Expand Down
8 changes: 8 additions & 0 deletions src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ async def _invoke_callable(
) -> Any:
"""Invokes a callable, handling both sync and async cases."""

# Handle async generator functions (streaming tools)
is_async_gen = inspect.isasyncgenfunction(target) or (
hasattr(target, '__call__')
and inspect.isasyncgenfunction(target.__call__)
)
if is_async_gen:
return [item async for item in target(**args_to_call)]

# Functions are callable objects, but not all callable objects are functions
# checking coroutine function is not enough. We also need to check whether
# Callable's __call__ function is a coroutine function
Expand Down
23 changes: 23 additions & 0 deletions tests/unittests/tools/test_build_function_declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections.abc
from enum import Enum

from google.adk.features import FeatureName
Expand Down Expand Up @@ -661,3 +662,25 @@ def greet(name: str = 'World') -> str:
schema = decl.parameters_json_schema
assert schema['properties']['name']['default'] == 'World'
assert 'name' not in schema.get('required', [])


def test_schema_generation_for_streaming_tool_with_string_annotations():
"""Test schema generation for AsyncGenerator with string annotations."""

# Simulate string annotation by using forward reference string
# This mimics "from __future__ import annotations" behavior
async def streaming_tool(
param: str,
) -> 'collections.abc.AsyncGenerator[str, None]':
"""A streaming tool."""
yield f'result {param}'

function_decl = _automatic_function_calling_util.build_function_declaration(
func=streaming_tool, variant=GoogleLLMVariant.VERTEX_AI
)

assert function_decl.name == 'streaming_tool'
assert function_decl.parameters.type == 'OBJECT'
# VERTEX_AI should have response schema for string return (yield type)
assert function_decl.response is not None
assert function_decl.response.type == types.Type.STRING
42 changes: 36 additions & 6 deletions tests/unittests/tools/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections.abc
from typing import AsyncGenerator
from unittest.mock import MagicMock

from google.adk.agents.invocation_context import InvocationContext
Expand Down Expand Up @@ -200,9 +202,11 @@ async def test_run_async_1_missing_arg_sync_func():
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": """Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"error": (
"""Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg2
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}


Expand All @@ -213,9 +217,11 @@ async def test_run_async_1_missing_arg_async_func():
args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": """Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"error": (
"""Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}


Expand All @@ -226,11 +232,13 @@ async def test_run_async_3_missing_arg_sync_func():
args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": """Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"error": (
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}


Expand All @@ -241,11 +249,13 @@ async def test_run_async_3_missing_arg_async_func():
args = {"arg3": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": """Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"error": (
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}


Expand All @@ -256,12 +266,14 @@ async def test_run_async_missing_all_arg_sync_func():
args = {}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": """Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"error": (
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}


Expand All @@ -272,12 +284,14 @@ async def test_run_async_missing_all_arg_async_func():
args = {}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": """Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
"error": (
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}


Expand Down Expand Up @@ -428,3 +442,19 @@ def explicit_params_func(arg1: str, arg2: int):
assert result == {"arg1": "test", "arg2": 42}
# Explicitly verify that unexpected_param was filtered out and not passed to the function
assert "unexpected_param" not in result


@pytest.mark.asyncio
async def test_run_async_streaming_generator():
"""Test that run_async consumes the async generator and returns a list."""

async def streaming_tool(param: str) -> AsyncGenerator[str, None]:
yield f"part 1 {param}"
yield f"part 2 {param}"

tool = FunctionTool(streaming_tool)

result = await tool.run_async(args={"param": "test"}, tool_context=MagicMock())

assert isinstance(result, list)
assert result == ["part 1 test", "part 2 test"]