Skip to content

Commit 2580c9d

Browse files
committed
Fix streaming tools support for string annotations and serialization
1 parent 4aa4751 commit 2580c9d

File tree

4 files changed

+74
-6
lines changed

4 files changed

+74
-6
lines changed

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,10 @@ def from_function_with_options(
394394

395395
return_annotation = inspect.signature(func).return_annotation
396396

397+
# Resolve deferred type hints.
398+
if 'return' in annotation_under_future:
399+
return_annotation = annotation_under_future['return']
400+
397401
# Handle AsyncGenerator and Generator return types (streaming tools)
398402
# AsyncGenerator[YieldType, SendType] -> use YieldType as response schema
399403
# Generator[YieldType, SendType, ReturnType] -> use YieldType as response schema

src/google/adk/tools/function_tool.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ async def _invoke_callable(
222222
) -> Any:
223223
"""Invokes a callable, handling both sync and async cases."""
224224

225+
# Handle async generator functions (streaming tools)
226+
is_async_gen = inspect.isasyncgenfunction(target) or (
227+
hasattr(target, '__call__')
228+
and inspect.isasyncgenfunction(target.__call__)
229+
)
230+
if is_async_gen:
231+
results = []
232+
async for item in target(**args_to_call):
233+
results.append(item)
234+
return results
235+
225236
# Functions are callable objects, but not all callable objects are functions
226237
# checking coroutine function is not enough. We also need to check whether
227238
# Callable's __call__ function is a coroutine function

tests/unittests/tools/test_build_function_declaration.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import collections.abc
1516
from enum import Enum
1617

1718
from google.adk.features import FeatureName
@@ -661,3 +662,25 @@ def greet(name: str = 'World') -> str:
661662
schema = decl.parameters_json_schema
662663
assert schema['properties']['name']['default'] == 'World'
663664
assert 'name' not in schema.get('required', [])
665+
666+
667+
def test_schema_generation_for_streaming_tool_with_string_annotations():
668+
"""Test schema generation for AsyncGenerator with string annotations."""
669+
670+
# Simulate string annotation by using forward reference string
671+
# This mimics "from __future__ import annotations" behavior
672+
async def streaming_tool(
673+
param: str,
674+
) -> 'collections.abc.AsyncGenerator[str, None]':
675+
"""A streaming tool."""
676+
yield f'result {param}'
677+
678+
function_decl = _automatic_function_calling_util.build_function_declaration(
679+
func=streaming_tool, variant=GoogleLLMVariant.VERTEX_AI
680+
)
681+
682+
assert function_decl.name == 'streaming_tool'
683+
assert function_decl.parameters.type == 'OBJECT'
684+
# VERTEX_AI should have response schema for string return (yield type)
685+
assert function_decl.response is not None
686+
assert function_decl.response.type == types.Type.STRING

tests/unittests/tools/test_function_tool.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import collections.abc
16+
from typing import AsyncGenerator
1517
from unittest.mock import MagicMock
1618

1719
from google.adk.agents.invocation_context import InvocationContext
@@ -200,9 +202,11 @@ async def test_run_async_1_missing_arg_sync_func():
200202
args = {"arg1": "test_value_1"}
201203
result = await tool.run_async(args=args, tool_context=MagicMock())
202204
assert result == {
203-
"error": """Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
205+
"error": (
206+
"""Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
204207
arg2
205208
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
209+
)
206210
}
207211

208212

@@ -213,9 +217,11 @@ async def test_run_async_1_missing_arg_async_func():
213217
args = {"arg2": "test_value_1"}
214218
result = await tool.run_async(args=args, tool_context=MagicMock())
215219
assert result == {
216-
"error": """Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
220+
"error": (
221+
"""Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
217222
arg1
218223
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
224+
)
219225
}
220226

221227

@@ -226,11 +232,13 @@ async def test_run_async_3_missing_arg_sync_func():
226232
args = {"arg2": "test_value_1"}
227233
result = await tool.run_async(args=args, tool_context=MagicMock())
228234
assert result == {
229-
"error": """Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
235+
"error": (
236+
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
230237
arg1
231238
arg3
232239
arg4
233240
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
241+
)
234242
}
235243

236244

@@ -241,11 +249,13 @@ async def test_run_async_3_missing_arg_async_func():
241249
args = {"arg3": "test_value_1"}
242250
result = await tool.run_async(args=args, tool_context=MagicMock())
243251
assert result == {
244-
"error": """Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
252+
"error": (
253+
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
245254
arg1
246255
arg2
247256
arg4
248257
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
258+
)
249259
}
250260

251261

@@ -256,12 +266,14 @@ async def test_run_async_missing_all_arg_sync_func():
256266
args = {}
257267
result = await tool.run_async(args=args, tool_context=MagicMock())
258268
assert result == {
259-
"error": """Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
269+
"error": (
270+
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
260271
arg1
261272
arg2
262273
arg3
263274
arg4
264275
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
276+
)
265277
}
266278

267279

@@ -272,12 +284,14 @@ async def test_run_async_missing_all_arg_async_func():
272284
args = {}
273285
result = await tool.run_async(args=args, tool_context=MagicMock())
274286
assert result == {
275-
"error": """Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
287+
"error": (
288+
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
276289
arg1
277290
arg2
278291
arg3
279292
arg4
280293
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
294+
)
281295
}
282296

283297

@@ -428,3 +442,19 @@ def explicit_params_func(arg1: str, arg2: int):
428442
assert result == {"arg1": "test", "arg2": 42}
429443
# Explicitly verify that unexpected_param was filtered out and not passed to the function
430444
assert "unexpected_param" not in result
445+
446+
447+
@pytest.mark.asyncio
448+
async def test_run_async_streaming_generator():
449+
"""Test that run_async consumes the async generator and returns a list."""
450+
451+
async def streaming_tool(param: str) -> AsyncGenerator[str, None]:
452+
yield f"part 1 {param}"
453+
yield f"part 2 {param}"
454+
455+
tool = FunctionTool(streaming_tool)
456+
457+
result = await tool.run_async(args={"param": "test"}, tool_context=None)
458+
459+
assert isinstance(result, list)
460+
assert result == ["part 1 test", "part 2 test"]

0 commit comments

Comments
 (0)