Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/tool-runner-hooks/fastagent.config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
default_model: haiku

logger:
type: file
level: error
truncate_tools: true
56 changes: 56 additions & 0 deletions examples/tool-runner-hooks/tool_runner_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import asyncio

from fast_agent import FastAgent
from fast_agent.agents.agent_types import AgentConfig
from fast_agent.agents.tool_agent import ToolAgent
from fast_agent.agents.tool_runner import ToolRunnerHooks
from fast_agent.context import Context
from fast_agent.interfaces import ToolRunnerHookCapable
from fast_agent.types import PromptMessageExtended


def get_video_call_transcript(video_id: str) -> str:
return "Assistant: Hi, how can I assist you today?\n\nCustomer: Hi, I wanted to ask you about last invoice I received..."


class HookedToolAgent(ToolAgent, ToolRunnerHookCapable):
def __init__(
self,
config: AgentConfig,
context: Context | None = None,
):
tools = [get_video_call_transcript]
super().__init__(config, tools, context)
self._hooks = ToolRunnerHooks(
before_llm_call=self._add_style_hint,
after_tool_call=self._log_tool_result,
)

@property
def tool_runner_hooks(self) -> ToolRunnerHooks | None:
return self._hooks

async def _add_style_hint(self, runner, messages: list[PromptMessageExtended]) -> None:
if runner.iteration == 0:
runner.append_messages("Keep the answer to one short sentence.")

async def _log_tool_result(self, runner, message: PromptMessageExtended) -> None:
if message.tool_results:
tool_names = ", ".join(message.tool_results.keys())
print(f"[hook] tool results received: {tool_names}")


fast = FastAgent("Example Tool Use Application (Hooks)")


@fast.custom(HookedToolAgent)
async def main() -> None:
async with fast.run() as agent:
await agent.default.generate(
"What is the topic of the video call no.1234?",
)
await agent.interactive()


if __name__ == "__main__":
asyncio.run(main())
47 changes: 47 additions & 0 deletions examples/tool-runner-hooks/tool_runner_lowlevel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import asyncio

from mcp.types import TextContent

from fast_agent.agents.agent_types import AgentConfig
from fast_agent.agents.tool_agent import ToolAgent
from fast_agent.agents.tool_runner import ToolRunner
from fast_agent.core import Core
from fast_agent.llm.model_factory import ModelFactory
from fast_agent.types import PromptMessageExtended


def lookup_order_status(order_id: str) -> str:
return f"Order {order_id} is packed and ready to ship."


async def main() -> None:
core: Core = Core()
await core.initialize()

config = AgentConfig(name="order_bot")
agent = ToolAgent(config, tools=[lookup_order_status], context=core.context)
await agent.attach_llm(ModelFactory.create_factory("haiku"))

messages = [
PromptMessageExtended(
role="user",
content=[
TextContent(type="text", text="Check order 12345, then summarize in one line.")
],
)
]

runner = ToolRunner(
agent=agent,
messages=messages,
)

async for assistant_message in runner:
text = assistant_message.last_text() or "<no text>"
print(f"[assistant] {text}")

await core.cleanup()


if __name__ == "__main__":
asyncio.run(main())
16 changes: 14 additions & 2 deletions src/fast_agent/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,13 @@ async def generate_impl(
Messages are already normalized to List[PromptMessageExtended].
"""
if "user" == messages[-1].role:
self.show_user_message(message=messages[-1])
trailing_users: list[PromptMessageExtended] = []
for message in reversed(messages):
if message.role != "user":
break
trailing_users.append(message)
for message in reversed(trailing_users):
self.show_user_message(message=message)

# TODO - manage error catch, recovery, pause
summary_text: Text | None = None
Expand Down Expand Up @@ -341,7 +347,13 @@ async def structured_impl(
request_params: RequestParams | None = None,
) -> Tuple[ModelT | None, PromptMessageExtended]:
if "user" == messages[-1].role:
self.show_user_message(message=messages[-1])
trailing_users: list[PromptMessageExtended] = []
for message in reversed(messages):
if message.role != "user":
break
trailing_users.append(message)
for message in reversed(trailing_users):
self.show_user_message(message=message)

(result, message), summary = await self._structured_with_summary(
messages, model, request_params
Expand Down
102 changes: 89 additions & 13 deletions src/fast_agent/agents/mcp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import fnmatch
import time
from abc import ABC
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -36,7 +37,7 @@
from fast_agent.agents.agent_types import AgentConfig, AgentType
from fast_agent.agents.llm_agent import DEFAULT_CAPABILITIES
from fast_agent.agents.tool_agent import ToolAgent
from fast_agent.constants import HUMAN_INPUT_TOOL_NAME
from fast_agent.constants import FORCE_SEQUENTIAL_TOOL_CALLS, HUMAN_INPUT_TOOL_NAME
from fast_agent.core.exceptions import PromptExitError
from fast_agent.core.logging.logger import get_logger
from fast_agent.interfaces import FastAgentLLMProtocol
Expand All @@ -59,9 +60,9 @@
PromptMessageExtended,
RequestParams,
ToolTimingInfo,
ToolTimings,
)
from fast_agent.ui import console
from fast_agent.utils.async_utils import gather_with_cancel

# Define a TypeVar for models
ModelT = TypeVar("ModelT", bound=BaseModel)
Expand Down Expand Up @@ -804,14 +805,12 @@ async def with_resource(

async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtended:
"""Override ToolAgent's run_tools to use MCP tools via aggregator."""
import time

if not request.tool_calls:
self.logger.warning("No tool calls found in request", data=request)
return PromptMessageExtended(role="user", tool_results={})

tool_results: dict[str, CallToolResult] = {}
tool_timings: ToolTimings = {} # Track timing for each tool call
tool_timings: dict[str, ToolTimingInfo] = {}
tool_loop_error: str | None = None

# Cache available tool names exactly as advertised to the LLM for display/highlighting
Expand All @@ -832,8 +831,13 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
# Cache namespaced tools for routing/metadata
namespaced_tools = self._aggregator._namespaced_tool_map

# Process each tool call using our aggregator
for correlation_id, tool_request in request.tool_calls.items():
tool_call_items = list(request.tool_calls.items())
should_parallel = (not FORCE_SEQUENTIAL_TOOL_CALLS) and len(tool_call_items) > 1

planned_calls: list[dict[str, Any]] = []

# Plan each tool call using our aggregator
for correlation_id, tool_request in tool_call_items:
tool_name = tool_request.params.name
tool_args = tool_request.params.arguments or {}
# correlation_id is the tool_use_id from the LLM
Expand Down Expand Up @@ -926,21 +930,95 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
metadata=metadata,
)

planned_calls.append(
{
"correlation_id": correlation_id,
"tool_name": tool_name,
"tool_args": tool_args,
"display_tool_name": display_tool_name,
"namespaced_tool": namespaced_tool,
"candidate_namespaced_tool": candidate_namespaced_tool,
}
)

if should_parallel and planned_calls:

async def run_one(call: dict[str, Any]) -> tuple[str, CallToolResult, float]:
start_time = time.perf_counter()
result = await self.call_tool(
call["tool_name"], call["tool_args"], call["correlation_id"]
)
end_time = time.perf_counter()
return call["correlation_id"], result, round((end_time - start_time) * 1000, 2)

results = await gather_with_cancel(run_one(call) for call in planned_calls)

for i, item in enumerate(results):
call = planned_calls[i]
correlation_id = call["correlation_id"]
display_tool_name = call["display_tool_name"]
namespaced_tool = call["namespaced_tool"]
candidate_namespaced_tool = call["candidate_namespaced_tool"]

if isinstance(item, BaseException):
self.logger.error(f"MCP tool {display_tool_name} failed: {item}")
result = CallToolResult(
content=[TextContent(type="text", text=f"Error: {str(item)}")],
isError=True,
)
duration_ms = 0.0
else:
_, result, duration_ms = item

tool_results[correlation_id] = result
tool_timings[correlation_id] = ToolTimingInfo(
timing_ms=duration_ms,
transport_channel=getattr(result, "transport_channel", None),
)

skybridge_config = None
skybridge_tool = namespaced_tool or candidate_namespaced_tool
if skybridge_tool:
try:
skybridge_config = await self._aggregator.get_skybridge_config(
skybridge_tool.server_name
)
except Exception:
skybridge_config = None

if not getattr(result, "_suppress_display", False):
self.display.show_tool_result(
name=self._name,
result=result,
tool_name=display_tool_name,
skybridge_config=skybridge_config,
timing_ms=duration_ms,
)

return self._finalize_tool_results(
tool_results, tool_timings=tool_timings, tool_loop_error=tool_loop_error
)

for call in planned_calls:
correlation_id = call["correlation_id"]
tool_name = call["tool_name"]
tool_args = call["tool_args"]
display_tool_name = call["display_tool_name"]
namespaced_tool = call["namespaced_tool"]
candidate_namespaced_tool = call["candidate_namespaced_tool"]

try:
# Track timing for tool execution
start_time = time.perf_counter()
result = await self.call_tool(tool_name, tool_args, correlation_id)
end_time = time.perf_counter()
duration_ms = round((end_time - start_time) * 1000, 2)

tool_results[correlation_id] = result
# Store timing and transport channel info
tool_timings[correlation_id] = ToolTimingInfo(
timing_ms=duration_ms,
transport_channel=getattr(result, "transport_channel", None),
)

# Show tool result (like ToolAgent does)
skybridge_config = None
skybridge_tool = namespaced_tool or candidate_namespaced_tool
if skybridge_tool:
Expand All @@ -954,7 +1032,7 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
result=result,
tool_name=display_tool_name,
skybridge_config=skybridge_config,
timing_ms=duration_ms, # Use local duration_ms variable for display
timing_ms=duration_ms,
)

self.logger.debug(f"MCP tool {display_tool_name} executed successfully")
Expand All @@ -965,8 +1043,6 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
isError=True,
)
tool_results[correlation_id] = error_result

# Show error result too (no need for skybridge config on errors)
self.display.show_tool_result(name=self._name, result=error_result)

return self._finalize_tool_results(
Expand Down
Loading
Loading