Skip to content

Commit 2060ea1

Browse files
committed
add capability and example
1 parent 4447b21 commit 2060ea1

File tree

7 files changed

+93
-6
lines changed

7 files changed

+93
-6
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import asyncio
2+
3+
from fast_agent import FastAgent
4+
from fast_agent.agents.agent_types import AgentConfig
5+
from fast_agent.agents.tool_agent import ToolAgent
6+
from fast_agent.agents.tool_runner import ToolRunnerHooks
7+
from fast_agent.context import Context
8+
from fast_agent.interfaces import ToolRunnerHookCapable
9+
from fast_agent.types import PromptMessageExtended
10+
11+
12+
def get_video_call_transcript(video_id: str) -> str:
13+
return "Assistant: Hi, how can I assist you today?\n\nCustomer: Hi, I wanted to ask you about last invoice I received..."
14+
15+
16+
class HookedToolAgent(ToolAgent, ToolRunnerHookCapable):
17+
def __init__(
18+
self,
19+
config: AgentConfig,
20+
context: Context | None = None,
21+
):
22+
tools = [get_video_call_transcript]
23+
super().__init__(config, tools, context)
24+
self._hooks = ToolRunnerHooks(
25+
before_llm_call=self._add_style_hint,
26+
after_tool_call=self._log_tool_result,
27+
)
28+
29+
@property
30+
def tool_runner_hooks(self) -> ToolRunnerHooks | None:
31+
return self._hooks
32+
33+
async def _add_style_hint(self, runner, messages: list[PromptMessageExtended]) -> None:
34+
if runner.iteration == 0:
35+
runner.append_messages("Keep the answer to one short sentence.")
36+
37+
async def _log_tool_result(self, runner, message: PromptMessageExtended) -> None:
38+
if message.tool_results:
39+
tool_names = ", ".join(message.tool_results.keys())
40+
print(f"[hook] tool results received: {tool_names}")
41+
42+
43+
fast = FastAgent("Example Tool Use Application (Hooks)")
44+
45+
46+
@fast.custom(HookedToolAgent)
47+
async def main() -> None:
48+
async with fast.run() as agent:
49+
await agent.default.generate(
50+
"What is the topic of the video call no.1234?",
51+
)
52+
await agent.interactive()
53+
54+
55+
if __name__ == "__main__":
56+
asyncio.run(main())
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
default_model: haiku
2+
3+
logger:
4+
type: file
5+
level: error
6+
truncate_tools: true

src/fast_agent/agents/llm_agent.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,13 @@ async def generate_impl(
280280
Messages are already normalized to List[PromptMessageExtended].
281281
"""
282282
if "user" == messages[-1].role:
283-
self.show_user_message(message=messages[-1])
283+
trailing_users: list[PromptMessageExtended] = []
284+
for message in reversed(messages):
285+
if message.role != "user":
286+
break
287+
trailing_users.append(message)
288+
for message in reversed(trailing_users):
289+
self.show_user_message(message=message)
284290

285291
# TODO - manage error catch, recovery, pause
286292
summary_text: Text | None = None
@@ -341,7 +347,13 @@ async def structured_impl(
341347
request_params: RequestParams | None = None,
342348
) -> Tuple[ModelT | None, PromptMessageExtended]:
343349
if "user" == messages[-1].role:
344-
self.show_user_message(message=messages[-1])
350+
trailing_users: list[PromptMessageExtended] = []
351+
for message in reversed(messages):
352+
if message.role != "user":
353+
break
354+
trailing_users.append(message)
355+
for message in reversed(trailing_users):
356+
self.show_user_message(message=message)
345357

346358
(result, message), summary = await self._structured_with_summary(
347359
messages, model, request_params

src/fast_agent/agents/tool_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from fast_agent.context import Context
1616
from fast_agent.core.logging.logger import get_logger
17+
from fast_agent.interfaces import ToolRunnerHookCapable
1718
from fast_agent.mcp.helpers.content_helpers import text_content
1819
from fast_agent.tools.elicitation import get_elicitation_fastmcp_tool
1920
from fast_agent.types import PromptMessageExtended, RequestParams, ToolTimingInfo
@@ -57,7 +58,6 @@ def __init__(
5758
logger.warning(f"Failed to initialize human-input tool: {e}")
5859

5960
for tool in working_tools:
60-
(tool)
6161
if isinstance(tool, FastMCPTool):
6262
fast_tool = tool
6363
elif callable(tool):
@@ -99,6 +99,8 @@ async def generate_impl(
9999
return await runner.until_done()
100100

101101
def _tool_runner_hooks(self) -> ToolRunnerHooks | None:
102+
if isinstance(self, ToolRunnerHookCapable):
103+
return self.tool_runner_hooks
102104
return None
103105

104106
async def _tool_runner_llm_step(

src/fast_agent/agents/workflow/agents_as_tools_agent.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,6 @@ async def _run_child_tools(
639639

640640
call_descriptors: list[dict[str, Any]] = []
641641
descriptor_by_id: dict[str, dict[str, Any]] = {}
642-
tasks: list[asyncio.Task[CallToolResult]] = []
643642
id_list: list[str] = []
644643

645644
for correlation_id, tool_request in (request.tool_calls or {}).items():

src/fast_agent/interfaces.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from fast_agent.acp.acp_aware_mixin import ACPCommand, ACPModeInfo
3434
from fast_agent.acp.acp_context import ACPContext
3535
from fast_agent.agents.agent_types import AgentConfig, AgentType
36+
from fast_agent.agents.tool_runner import ToolRunnerHooks
3637
from fast_agent.context import Context
3738
from fast_agent.llm.model_info import ModelInfo
3839

@@ -41,6 +42,7 @@
4142
"StreamingAgentProtocol",
4243
"LlmAgentProtocol",
4344
"AgentProtocol",
45+
"ToolRunnerHookCapable",
4446
"ACPAwareProtocol",
4547
"LLMFactoryProtocol",
4648
"ModelFactoryFunctionProtocol",
@@ -265,6 +267,14 @@ def initialized(self) -> bool: ...
265267
def set_instruction(self, instruction: str) -> None: ...
266268

267269

270+
@runtime_checkable
271+
class ToolRunnerHookCapable(Protocol):
272+
"""Optional capability for agents to expose ToolRunner hooks."""
273+
274+
@property
275+
def tool_runner_hooks(self) -> "ToolRunnerHooks | None": ...
276+
277+
268278
@runtime_checkable
269279
class StreamingAgentProtocol(AgentProtocol, Protocol):
270280
"""Optional extension for agents that expose LLM streaming callbacks."""

src/fast_agent/utils/async_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from collections.abc import Awaitable, Iterable
5-
from typing import TypeVar
4+
from typing import TYPE_CHECKING, TypeVar
5+
6+
if TYPE_CHECKING:
7+
from collections.abc import Awaitable, Iterable
68

79
T = TypeVar("T")
810

0 commit comments

Comments
 (0)