Skip to content

Commit e6b601a

Browse files
GWealecopybara-github
authored andcommitted
fix: Invoke on_tool_error_callback for missing tools in live mode
In live mode, when the model calls an unregistered tool, ADK now runs on_tool_error_callback before failing. If the callback returns a response, ADK emits that function response and continues; otherwise it keeps the old ValueError Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 872996178
1 parent 7478bda commit e6b601a

File tree

2 files changed

+131
-3
lines changed

2 files changed

+131
-3
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -660,14 +660,65 @@ async def _execute_single_function_call_live(
660660
streaming_lock: asyncio.Lock,
661661
) -> Optional[Event]:
662662
"""Execute a single function call for live mode with thread safety."""
663-
tool, tool_context = _get_tool_and_context(
664-
invocation_context, function_call, tools_dict
665-
)
666663

664+
async def _run_on_tool_error_callbacks(
665+
*,
666+
tool: BaseTool,
667+
tool_args: dict[str, Any],
668+
tool_context: ToolContext,
669+
error: Exception,
670+
) -> Optional[dict[str, Any]]:
671+
"""Runs the on_tool_error_callbacks for the given tool."""
672+
error_response = (
673+
await invocation_context.plugin_manager.run_on_tool_error_callback(
674+
tool=tool,
675+
tool_args=tool_args,
676+
tool_context=tool_context,
677+
error=error,
678+
)
679+
)
680+
if error_response is not None:
681+
return error_response
682+
683+
for callback in agent.canonical_on_tool_error_callbacks:
684+
error_response = callback(
685+
tool=tool,
686+
args=tool_args,
687+
tool_context=tool_context,
688+
error=error,
689+
)
690+
if inspect.isawaitable(error_response):
691+
error_response = await error_response
692+
if error_response is not None:
693+
return error_response
694+
695+
return None
696+
697+
# Do not use "args" as the variable name, because it is a reserved keyword
698+
# in python debugger.
699+
# Make a deep copy to avoid being modified.
667700
function_args = (
668701
copy.deepcopy(function_call.args) if function_call.args else {}
669702
)
670703

704+
tool_context = _create_tool_context(invocation_context, function_call)
705+
706+
try:
707+
tool = _get_tool(function_call, tools_dict)
708+
except ValueError as tool_error:
709+
tool = BaseTool(name=function_call.name, description='Tool not found')
710+
error_response = await _run_on_tool_error_callbacks(
711+
tool=tool,
712+
tool_args=function_args,
713+
tool_context=tool_context,
714+
error=tool_error,
715+
)
716+
if error_response is not None:
717+
return __build_response_event(
718+
tool, error_response, tool_context, invocation_context
719+
)
720+
raise tool_error
721+
671722
async def _run_with_trace():
672723
nonlocal function_args
673724

tests/unittests/flows/llm_flows/test_live_tool_callbacks.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,80 @@ def simple_fn(**kwargs) -> Dict[str, Any]:
386386
async_response = async_result.content.parts[0].function_response.response
387387
live_response = live_result.content.parts[0].function_response.response
388388
assert async_response == live_response == {"bypassed": "by_before_callback"}
389+
390+
391+
@pytest.mark.asyncio
392+
async def test_live_on_tool_error_callback_tool_not_found_noop():
393+
"""Test that on_tool_error_callback is a no-op when the tool is not found."""
394+
395+
def noop_on_tool_error_callback(tool, args, tool_context, error):
396+
return None
397+
398+
def simple_fn(**kwargs) -> Dict[str, Any]:
399+
return {"initial": "response"}
400+
401+
tool = FunctionTool(simple_fn)
402+
model = testing_utils.MockModel.create(responses=[])
403+
agent = Agent(
404+
name="agent",
405+
model=model,
406+
tools=[tool],
407+
on_tool_error_callback=noop_on_tool_error_callback,
408+
)
409+
invocation_context = await testing_utils.create_invocation_context(
410+
agent=agent, user_content=""
411+
)
412+
function_call = types.FunctionCall(name="nonexistent_function", args={})
413+
content = types.Content(parts=[types.Part(function_call=function_call)])
414+
event = Event(
415+
invocation_id=invocation_context.invocation_id,
416+
author=agent.name,
417+
content=content,
418+
)
419+
tools_dict = {tool.name: tool}
420+
421+
with pytest.raises(ValueError):
422+
await handle_function_calls_live(invocation_context, event, tools_dict)
423+
424+
425+
@pytest.mark.asyncio
426+
async def test_live_on_tool_error_callback_tool_not_found_modify_tool_response():
427+
"""Test that on_tool_error_callback modifies tool response when tool is not found."""
428+
429+
def mock_on_tool_error_callback(tool, args, tool_context, error):
430+
return {"result": "on_tool_error_callback_response"}
431+
432+
def simple_fn(**kwargs) -> Dict[str, Any]:
433+
return {"initial": "response"}
434+
435+
tool = FunctionTool(simple_fn)
436+
model = testing_utils.MockModel.create(responses=[])
437+
agent = Agent(
438+
name="agent",
439+
model=model,
440+
tools=[tool],
441+
on_tool_error_callback=mock_on_tool_error_callback,
442+
)
443+
invocation_context = await testing_utils.create_invocation_context(
444+
agent=agent, user_content=""
445+
)
446+
function_call = types.FunctionCall(name="nonexistent_function", args={})
447+
content = types.Content(parts=[types.Part(function_call=function_call)])
448+
event = Event(
449+
invocation_id=invocation_context.invocation_id,
450+
author=agent.name,
451+
content=content,
452+
)
453+
tools_dict = {tool.name: tool}
454+
455+
result_event = await handle_function_calls_live(
456+
invocation_context,
457+
event,
458+
tools_dict,
459+
)
460+
461+
assert result_event is not None
462+
part = result_event.content.parts[0]
463+
assert part.function_response.response == {
464+
"result": "on_tool_error_callback_response"
465+
}

0 commit comments

Comments
 (0)