|
19 | 19 | from google.adk.agents.llm_agent import Agent |
20 | 20 | from google.adk.events.event import Event |
21 | 21 | from google.adk.flows.llm_flows.functions import find_matching_function_call |
| 22 | +from google.adk.flows.llm_flows.functions import handle_function_calls_async |
| 23 | +from google.adk.flows.llm_flows.functions import handle_function_calls_live |
22 | 24 | from google.adk.flows.llm_flows.functions import merge_parallel_function_response_events |
| 25 | +from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool |
23 | 26 | from google.adk.tools.function_tool import FunctionTool |
24 | 27 | from google.adk.tools.tool_context import ToolContext |
25 | 28 | from google.genai import types |
@@ -397,8 +400,6 @@ def test_find_function_call_event_multiple_function_responses(): |
397 | 400 | @pytest.mark.asyncio |
398 | 401 | async def test_function_call_args_not_modified(): |
399 | 402 | """Test that function_call.args is not modified when making a copy.""" |
400 | | - from google.adk.flows.llm_flows.functions import handle_function_calls_async |
401 | | - from google.adk.flows.llm_flows.functions import handle_function_calls_live |
402 | 403 |
|
403 | 404 | def simple_fn(**kwargs) -> dict: |
404 | 405 | return {'result': 'test'} |
@@ -455,8 +456,6 @@ def simple_fn(**kwargs) -> dict: |
455 | 456 | @pytest.mark.asyncio |
456 | 457 | async def test_function_call_args_none_handling(): |
457 | 458 | """Test that function_call.args=None is handled correctly.""" |
458 | | - from google.adk.flows.llm_flows.functions import handle_function_calls_async |
459 | | - from google.adk.flows.llm_flows.functions import handle_function_calls_live |
460 | 459 |
|
461 | 460 | def simple_fn(**kwargs) -> dict: |
462 | 461 | return {'result': 'test'} |
@@ -504,8 +503,6 @@ def simple_fn(**kwargs) -> dict: |
504 | 503 | @pytest.mark.asyncio |
505 | 504 | async def test_function_call_args_copy_behavior(): |
506 | 505 | """Test that modifying the copied args doesn't affect the original.""" |
507 | | - from google.adk.flows.llm_flows.functions import handle_function_calls_async |
508 | | - from google.adk.flows.llm_flows.functions import handle_function_calls_live |
509 | 506 |
|
510 | 507 | def simple_fn(test_param: str, other_param: int) -> dict: |
511 | 508 | # Modify the args to test that the copy prevents affecting the original |
@@ -565,8 +562,6 @@ def simple_fn(test_param: str, other_param: int) -> dict: |
565 | 562 | @pytest.mark.asyncio |
566 | 563 | async def test_function_call_args_deep_copy_behavior(): |
567 | 564 | """Test that deep copy behavior works correctly with nested structures.""" |
568 | | - from google.adk.flows.llm_flows.functions import handle_function_calls_async |
569 | | - from google.adk.flows.llm_flows.functions import handle_function_calls_live |
570 | 565 |
|
571 | 566 | def simple_fn(nested_dict: dict, list_param: list) -> dict: |
572 | 567 | # Modify the nested structures to test deep copy |
@@ -1141,3 +1136,62 @@ async def yielding_async() -> dict: |
1141 | 1136 | 'yield_E', |
1142 | 1137 | 'yield_F', |
1143 | 1138 | ] |
| 1139 | + |
| 1140 | + |
| 1141 | +@pytest.mark.asyncio |
| 1142 | +@pytest.mark.parametrize( |
| 1143 | + 'handle_function_calls', |
| 1144 | + [ |
| 1145 | + (handle_function_calls_async), |
| 1146 | + (handle_function_calls_live), |
| 1147 | + ], |
| 1148 | +) |
| 1149 | +async def test_computer_use_tool_decoding_behavior(handle_function_calls): |
| 1150 | + """Tests that computer use tools automatically decode base64 images.""" |
| 1151 | + valid_b64 = 'R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7' |
| 1152 | + |
| 1153 | + # make the tool return a dictionary with the image |
| 1154 | + async def mock_run(*args, **kwargs): |
| 1155 | + return { |
| 1156 | + 'image': {'data': valid_b64, 'mimetype': 'image/png'}, |
| 1157 | + 'url': 'https://example.com', |
| 1158 | + } |
| 1159 | + |
| 1160 | + # create a ComputerUseTool |
| 1161 | + tool = ComputerUseTool(func=mock_run, screen_size=(1024, 768)) |
| 1162 | + |
| 1163 | + model = testing_utils.MockModel.create(responses=[]) |
| 1164 | + agent = Agent( |
| 1165 | + name='test_agent', |
| 1166 | + model=model, |
| 1167 | + tools=[tool], |
| 1168 | + ) |
| 1169 | + invocation_context = await testing_utils.create_invocation_context( |
| 1170 | + agent=agent, user_content='' |
| 1171 | + ) |
| 1172 | + |
| 1173 | + # Create function call |
| 1174 | + function_call = types.FunctionCall(name=tool.name, args={}) |
| 1175 | + content = types.Content(parts=[types.Part(function_call=function_call)]) |
| 1176 | + event = Event( |
| 1177 | + invocation_id=invocation_context.invocation_id, |
| 1178 | + author=agent.name, |
| 1179 | + content=content, |
| 1180 | + ) |
| 1181 | + tools_dict = {tool.name: tool} |
| 1182 | + |
| 1183 | + result = await handle_function_calls( |
| 1184 | + invocation_context, |
| 1185 | + event, |
| 1186 | + tools_dict, |
| 1187 | + ) |
| 1188 | + |
| 1189 | + assert result is not None |
| 1190 | + response_part = result.content.parts[0].function_response |
| 1191 | + |
| 1192 | + # Verify original image data is removed from the dict response |
| 1193 | + assert 'image' not in response_part.response |
| 1194 | + assert 'url' in response_part.response |
| 1195 | + # Verify the image was converted to a blob |
| 1196 | + assert len(response_part.parts) == 1 |
| 1197 | + assert response_part.parts[0].inline_data is not None |
0 commit comments