Skip to content

Commit d7cfd8f

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Decode image data from ComputerUse tool response into image blobs
PiperOrigin-RevId: 875292001
1 parent 35366f4 commit d7cfd8f

File tree

2 files changed

+118
-9
lines changed

2 files changed

+118
-9
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import annotations
1818

1919
import asyncio
20+
import base64
21+
import binascii
2022
from concurrent.futures import ThreadPoolExecutor
2123
import copy
2224
import functools
@@ -31,6 +33,7 @@
3133
from typing import TYPE_CHECKING
3234
import uuid
3335

36+
from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool
3437
from google.genai import types
3538

3639
from ...agents.active_streaming_tool import ActiveStreamingTool
@@ -991,6 +994,50 @@ def _get_tool_and_context(
991994
return (tool, tool_context)
992995

993996

997+
def _try_decode_computer_use_image(
998+
tool: BaseTool,
999+
function_result: dict[str, object],
1000+
) -> Optional[list[types.FunctionResponsePart]]:
1001+
"""Decodes the image from the function result for a computer use tool.
1002+
1003+
Args:
1004+
tool: The tool that produced the function result.
1005+
function_result: The dictionary containing the function's result. This
1006+
dictionary may be modified in-place to remove the 'image' key if an image
1007+
is successfully decoded.
1008+
1009+
Returns:
1010+
A list containing a `types.FunctionResponsePart` with the decoded image
1011+
data, or None if no image was found or decoding failed.
1012+
"""
1013+
1014+
if not isinstance(tool, ComputerUseTool) or not isinstance(
1015+
function_result, dict
1016+
):
1017+
return None
1018+
1019+
if (
1020+
'image' not in function_result
1021+
or 'data' not in function_result['image']
1022+
or 'mimetype' not in function_result['image']
1023+
):
1024+
return None
1025+
1026+
try:
1027+
image_data = base64.b64decode(function_result['image']['data'])
1028+
mime_type = function_result['image']['mimetype']
1029+
1030+
part = types.FunctionResponsePart.from_bytes(
1031+
data=image_data, mime_type=mime_type
1032+
)
1033+
1034+
del function_result['image']
1035+
return [part]
1036+
except (binascii.Error, ValueError):
1037+
logger.exception('Failed to decode image from computer use tool')
1038+
return None
1039+
1040+
9941041
async def __call_tool_live(
9951042
tool: BaseTool,
9961043
args: dict[str, object],
@@ -1028,8 +1075,16 @@ def __build_response_event(
10281075
if not isinstance(function_result, dict):
10291076
function_result = {'result': function_result}
10301077

1078+
function_response_parts = None
1079+
if isinstance(tool, ComputerUseTool):
1080+
function_response_parts = _try_decode_computer_use_image(
1081+
tool, function_result
1082+
)
1083+
10311084
part_function_response = types.Part.from_function_response(
1032-
name=tool.name, response=function_result
1085+
name=tool.name,
1086+
response=function_result,
1087+
parts=function_response_parts,
10331088
)
10341089
part_function_response.function_response.id = tool_context.function_call_id
10351090

tests/unittests/flows/llm_flows/test_functions_simple.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from google.adk.agents.llm_agent import Agent
2020
from google.adk.events.event import Event
2121
from google.adk.flows.llm_flows.functions import find_matching_function_call
22+
from google.adk.flows.llm_flows.functions import handle_function_calls_async
23+
from google.adk.flows.llm_flows.functions import handle_function_calls_live
2224
from google.adk.flows.llm_flows.functions import merge_parallel_function_response_events
25+
from google.adk.tools.computer_use.computer_use_tool import ComputerUseTool
2326
from google.adk.tools.function_tool import FunctionTool
2427
from google.adk.tools.tool_context import ToolContext
2528
from google.genai import types
@@ -397,8 +400,6 @@ def test_find_function_call_event_multiple_function_responses():
397400
@pytest.mark.asyncio
398401
async def test_function_call_args_not_modified():
399402
"""Test that function_call.args is not modified when making a copy."""
400-
from google.adk.flows.llm_flows.functions import handle_function_calls_async
401-
from google.adk.flows.llm_flows.functions import handle_function_calls_live
402403

403404
def simple_fn(**kwargs) -> dict:
404405
return {'result': 'test'}
@@ -455,8 +456,6 @@ def simple_fn(**kwargs) -> dict:
455456
@pytest.mark.asyncio
456457
async def test_function_call_args_none_handling():
457458
"""Test that function_call.args=None is handled correctly."""
458-
from google.adk.flows.llm_flows.functions import handle_function_calls_async
459-
from google.adk.flows.llm_flows.functions import handle_function_calls_live
460459

461460
def simple_fn(**kwargs) -> dict:
462461
return {'result': 'test'}
@@ -504,8 +503,6 @@ def simple_fn(**kwargs) -> dict:
504503
@pytest.mark.asyncio
505504
async def test_function_call_args_copy_behavior():
506505
"""Test that modifying the copied args doesn't affect the original."""
507-
from google.adk.flows.llm_flows.functions import handle_function_calls_async
508-
from google.adk.flows.llm_flows.functions import handle_function_calls_live
509506

510507
def simple_fn(test_param: str, other_param: int) -> dict:
511508
# Modify the args to test that the copy prevents affecting the original
@@ -565,8 +562,6 @@ def simple_fn(test_param: str, other_param: int) -> dict:
565562
@pytest.mark.asyncio
566563
async def test_function_call_args_deep_copy_behavior():
567564
"""Test that deep copy behavior works correctly with nested structures."""
568-
from google.adk.flows.llm_flows.functions import handle_function_calls_async
569-
from google.adk.flows.llm_flows.functions import handle_function_calls_live
570565

571566
def simple_fn(nested_dict: dict, list_param: list) -> dict:
572567
# Modify the nested structures to test deep copy
@@ -1141,3 +1136,62 @@ async def yielding_async() -> dict:
11411136
'yield_E',
11421137
'yield_F',
11431138
]
1139+
1140+
1141+
@pytest.mark.asyncio
1142+
@pytest.mark.parametrize(
1143+
'handle_function_calls',
1144+
[
1145+
(handle_function_calls_async),
1146+
(handle_function_calls_live),
1147+
],
1148+
)
1149+
async def test_computer_use_tool_decoding_behavior(handle_function_calls):
1150+
"""Tests that computer use tools automatically decode base64 images."""
1151+
valid_b64 = 'R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7'
1152+
1153+
# make the tool return a dictionary with the image
1154+
async def mock_run(*args, **kwargs):
1155+
return {
1156+
'image': {'data': valid_b64, 'mimetype': 'image/png'},
1157+
'url': 'https://example.com',
1158+
}
1159+
1160+
# create a ComputerUseTool
1161+
tool = ComputerUseTool(func=mock_run, screen_size=(1024, 768))
1162+
1163+
model = testing_utils.MockModel.create(responses=[])
1164+
agent = Agent(
1165+
name='test_agent',
1166+
model=model,
1167+
tools=[tool],
1168+
)
1169+
invocation_context = await testing_utils.create_invocation_context(
1170+
agent=agent, user_content=''
1171+
)
1172+
1173+
# Create function call
1174+
function_call = types.FunctionCall(name=tool.name, args={})
1175+
content = types.Content(parts=[types.Part(function_call=function_call)])
1176+
event = Event(
1177+
invocation_id=invocation_context.invocation_id,
1178+
author=agent.name,
1179+
content=content,
1180+
)
1181+
tools_dict = {tool.name: tool}
1182+
1183+
result = await handle_function_calls(
1184+
invocation_context,
1185+
event,
1186+
tools_dict,
1187+
)
1188+
1189+
assert result is not None
1190+
response_part = result.content.parts[0].function_response
1191+
1192+
# Verify original image data is removed from the dict response
1193+
assert 'image' not in response_part.response
1194+
assert 'url' in response_part.response
1195+
# Verify the image was converted to a blob
1196+
assert len(response_part.parts) == 1
1197+
assert response_part.parts[0].inline_data is not None

0 commit comments

Comments
 (0)