From 66c5f62b6a0bbcb81f37e3d22d40a96b2857bf37 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 8 Apr 2026 13:03:49 +0000 Subject: [PATCH] feat: add async tool support, pause/resume, and MCP elicitation (#93, #104) Wire the existing FinishReasonInputRequired scaffolding into a working pause/resume system for human-in-the-loop and long-running tools. Breaking change: Tool interface gains IsAsynchronous() bool. All Tool implementations must add this method (return false for synchronous tools). Core changes: - tool.Tool: add IsAsynchronous() to signal tools that cannot complete in a single synchronous call (user input, CI/CD, batch jobs) - agent/llmagent: executeTools() detects async tools via IsAsynchronous() and returns inputRequiredIDs; executeSingleTurn() emits FinishReasonInputRequired with InputRequiredToolIDs - runner: add Resume() for providing tool results after pause; refactor Run()/Resume() to share runAgent() helper - tool/registry: annotate async tool descriptions with re-invocation warning for the LLM - tool/mcp: wire ElicitationHandler to MCP SDK ClientOptions for server-initiated user input requests - tool/builtin/require_input: mark as IsAsynchronous()=true; simplify Execute() to return pending status https://claude.ai/code/session_01WFwbKjni1JS6Rf17VNySb1 --- adapter/a2a/executor_test.go | 2 + agent/conformance/tools.go | 2 + agent/llmagent/llmagent.go | 90 +++++++--- agent/llmagent/llmagent_integration_test.go | 2 + agent/llmagent/llmagent_test.go | 185 ++++++++++++++++++++ examples/agent_interceptors/tools.go | 6 + runner/runner.go | 147 +++++++++++----- runner/runner_integration_test.go | 2 + tool/agenttool/agenttool.go | 3 + tool/builtin/artifact_emit.go | 3 + tool/builtin/plot/tool.go | 3 + tool/builtin/require_input.go | 20 +-- tool/builtin/todo/todo_update.go | 6 + tool/builtin/webfetch/tool.go | 3 + tool/mcp/client.go | 49 ++++-- tool/mcp/client_options.go | 50 ++++++ tool/mcp/tools.go | 3 + tool/registry.go | 11 +- tool/registry_integration_test.go | 2 + tool/tool.go | 15 ++ 20 files changed, 507 insertions(+), 97 deletions(-) diff --git a/adapter/a2a/executor_test.go b/adapter/a2a/executor_test.go index 5a1003d5..9e8fffb3 100644 --- a/adapter/a2a/executor_test.go +++ b/adapter/a2a/executor_test.go @@ -465,6 +465,8 @@ func (m *mockWeatherTool) Execute(_ context.Context, _ json.RawMessage) (json.Ra return json.RawMessage(`{"temperature": "72°F", "conditions": "sunny"}`), nil } +func (*mockWeatherTool) IsAsynchronous() bool { return false } + func TestExecutor_SessionPersistence_Mock(t *testing.T) { t.Parallel() diff --git a/agent/conformance/tools.go b/agent/conformance/tools.go index 8f51bad7..2f0b45da 100644 --- a/agent/conformance/tools.go +++ b/agent/conformance/tools.go @@ -51,6 +51,8 @@ func (*CalculatorTool) Definition() llm.ToolDefinition { } } +func (*CalculatorTool) IsAsynchronous() bool { return false } + func (*CalculatorTool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) { var params struct { A float64 `json:"a"` diff --git a/agent/llmagent/llmagent.go b/agent/llmagent/llmagent.go index 0767ed1e..5fafd72d 100644 --- a/agent/llmagent/llmagent.go +++ b/agent/llmagent/llmagent.go @@ -187,12 +187,17 @@ func (a *LLMAgent) Run(ctx context.Context, inv *agent.InvocationMetadata) iter. // Check if interceptor or turn logic wants to end execution if finishReason != "" { - // Emit terminal event - yield(agent.InvocationEndEvent{ + endEvent := agent.InvocationEndEvent{ Envelope: makeEnvelope(), FinishReason: finishReason, Usage: new(inv.TotalUsage()), - }, nil) + } + // Populate InputRequiredToolIDs when pausing for external input + if ids, ok := inv.GetMetadata("input_required_tool_ids").([]string); ok { + endEvent.InputRequiredToolIDs = ids + } + + yield(endEvent, nil) return } @@ -347,12 +352,27 @@ func (a *LLMAgent) executeSingleTurn( return "", agent.ErrToolRegistry } - toolParts := a.executeTools(ctx, inv, toolReqs, req.Tools, makeEnvelope, yield) + toolResult := a.executeTools(ctx, inv, toolReqs, req.Tools, makeEnvelope, yield) // Build single message with all tool response parts - toolMsg := llm.NewMessage(llm.RoleUser, toolParts...) + toolMsg := llm.NewMessage(llm.RoleUser, toolResult.parts...) sess.Messages = append(sess.Messages, toolMsg) + // Check: any tools requiring external input? + if len(toolResult.inputRequiredIDs) > 0 { + // Emit input required status + yield(agent.StatusEvent{ + Envelope: makeEnvelope(), + Stage: agent.StatusStageInputRequired, + Details: fmt.Sprintf("%d tools awaiting external input", len(toolResult.inputRequiredIDs)), + }, nil) + + // Store for InvocationEndEvent + inv.SetMetadata("input_required_tool_ids", toolResult.inputRequiredIDs) + + return agent.FinishReasonInputRequired, nil + } + // Emit turn completed if !yield(agent.StatusEvent{ Envelope: makeEnvelope(), @@ -479,6 +499,12 @@ func (a *LLMAgent) generateWithStreaming( return response, nil } +// executeToolsResult holds the outcome of tool execution. +type executeToolsResult struct { + parts []*llm.Part // All tool response parts + inputRequiredIDs []string // Tool call IDs requiring external input (asynchronous tools) +} + // executeTools runs tool calls concurrently. // // Tool execution is limited by toolConcurrency. Individual tool errors @@ -491,10 +517,9 @@ func (a *LLMAgent) generateWithStreaming( // // ToolResponseEvents are yielded as tools complete. // -// Returns tool response parts in the order they were requested. -// -// Future: Will also return list of tool IDs requiring input for -// StatusStageInputRequired / FinishReasonInputRequired support. +// Returns tool response parts and a list of tool call IDs that require +// external input (from asynchronous tools). When inputRequiredIDs is non-empty, +// the agent should pause and emit FinishReasonInputRequired. func (a *LLMAgent) executeTools( ctx context.Context, inv *agent.InvocationMetadata, @@ -502,18 +527,19 @@ func (a *LLMAgent) executeTools( toolDefs []llm.ToolDefinition, makeEnvelope func() agent.EventEnvelope, yield func(agent.Event, error) bool, -) []*llm.Part { +) executeToolsResult { // Execute tools concurrently with limited parallelism g, gctx := errgroup.WithContext(ctx) g.SetLimit(min(a.config.toolConcurrency, len(toolReqs))) // Results channel (buffered to avoid blocking) type toolResult struct { - idx int - requestID string - name string - response *llm.ToolResponse - err error + idx int + requestID string + name string + response *llm.ToolResponse + err error + isAsynchronous bool } results := make(chan toolResult, len(toolReqs)) @@ -535,19 +561,24 @@ func (a *LLMAgent) executeTools( // Launch tool executions for i, req := range toolReqs { g.Go(func() error { + // Check if tool is asynchronous + t, _ := a.config.tools.Get(req.Name) + async := t != nil && t.IsAsynchronous() + toolInfo := &agent.ToolCallInfo{ Inv: inv, Req: req, - Definition: toolDefMap[req.Name], // Add tool definition + Definition: toolDefMap[req.Name], } resp, err := executor(gctx, toolInfo) results <- toolResult{ - idx: i, - requestID: req.ID, - name: req.Name, - response: resp, - err: err, + idx: i, + requestID: req.ID, + name: req.Name, + response: resp, + err: err, + isAsynchronous: async, } return nil // Never return error to errgroup (we handle errors individually) @@ -557,6 +588,8 @@ func (a *LLMAgent) executeTools( // Collect tool response parts and yield events as they arrive parts := make([]*llm.Part, 0, len(toolReqs)) + var inputRequiredIDs []string + for range toolReqs { result := <-results @@ -574,7 +607,7 @@ func (a *LLMAgent) executeTools( Envelope: makeEnvelope(), Response: *errResp, }, nil) { - return parts // Consumer stopped listening + return executeToolsResult{parts: parts, inputRequiredIDs: inputRequiredIDs} } } else { // Tool execution succeeded @@ -585,12 +618,17 @@ func (a *LLMAgent) executeTools( Envelope: makeEnvelope(), Response: *result.response, }, nil) { - return parts // Consumer stopped listening + return executeToolsResult{parts: parts, inputRequiredIDs: inputRequiredIDs} + } + + // Track asynchronous tools that need external completion + if result.isAsynchronous { + inputRequiredIDs = append(inputRequiredIDs, result.requestID) } } } - return parts + return executeToolsResult{parts: parts, inputRequiredIDs: inputRequiredIDs} } // recoverIncompleteToolCalls detects and executes incomplete tool calls from a @@ -645,12 +683,12 @@ func (a *LLMAgent) recoverIncompleteToolCalls( // Execute the incomplete tools toolDefs := a.config.tools.List() - toolParts := a.executeTools(ctx, inv, incomplete, toolDefs, makeEnvelope, yield) + toolResult := a.executeTools(ctx, inv, incomplete, toolDefs, makeEnvelope, yield) // Insert tool response message BEFORE the last user message. // Current: [..., assistant(tool_req), user(text)] // After: [..., assistant(tool_req), user(tool_resp), user(text)] - toolMsg := llm.NewMessage(llm.RoleUser, toolParts...) + toolMsg := llm.NewMessage(llm.RoleUser, toolResult.parts...) lastIdx := len(sess.Messages) - 1 sess.Messages = append(sess.Messages[:lastIdx], toolMsg, sess.Messages[lastIdx]) diff --git a/agent/llmagent/llmagent_integration_test.go b/agent/llmagent/llmagent_integration_test.go index 10884983..c9bd0703 100644 --- a/agent/llmagent/llmagent_integration_test.go +++ b/agent/llmagent/llmagent_integration_test.go @@ -61,6 +61,8 @@ func (*calculatorTool) Definition() llm.ToolDefinition { } } +func (*calculatorTool) IsAsynchronous() bool { return false } + func (*calculatorTool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) { // Parse arguments var params struct { diff --git a/agent/llmagent/llmagent_test.go b/agent/llmagent/llmagent_test.go index a9489652..bd1ff1ce 100644 --- a/agent/llmagent/llmagent_test.go +++ b/agent/llmagent/llmagent_test.go @@ -972,3 +972,188 @@ func (m *mockTool) Execute(ctx context.Context, args json.RawMessage) (json.RawM return json.RawMessage(`{}`), nil } + +func (*mockTool) IsAsynchronous() bool { return false } + +// asyncMockTool is a mock tool that declares itself as asynchronous. +type asyncMockTool struct { + mockTool +} + +func (*asyncMockTool) IsAsynchronous() bool { return true } + +// TestRun_AsyncToolPausesExecution tests that an asynchronous tool +// causes the agent to pause with FinishReasonInputRequired. +func TestRun_AsyncToolPausesExecution(t *testing.T) { + t.Parallel() + + // Setup: Create an asynchronous deploy tool + deployTool := &asyncMockTool{ + mockTool: mockTool{ + name: "deploy", + definition: llm.ToolDefinition{ + Name: "deploy", + Description: "Deploy to staging", + Parameters: json.RawMessage(`{ + "type": "object", + "properties": { + "version": {"type": "string"} + }, + "required": ["version"] + }`), + }, + executeFn: func(_ context.Context, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`{"status": "pending", "task_id": "deploy-42"}`), nil + }, + }, + } + + registry := tool.NewRegistry(tool.RegistryConfig{}) + err := registry.Register(deployTool) + require.NoError(t, err) + + // Setup: Configure fake model to call the deploy tool + model := fakellm.NewFakeModel() + model.When(fakellm.FirstTurn()). + Times(1). + ThenRespondWithToolCall("deploy", map[string]any{"version": "v2.5.0"}) + + // Create agent + ag, err := llmagent.New( + "deploy-agent", + "You help deploy applications", + model, + llmagent.WithTools(registry), + ) + require.NoError(t, err) + + // Create session + sess := &session.State{ + ID: "test-session", + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("Deploy v2.5.0 to staging"))}, + } + inv := agent.NewInvocationMetadata(sess, agent.Info{}) + + // Execute + events := collectEvents(t, ag.Run(t.Context(), inv)) + + // Assert: Should end with FinishReasonInputRequired + endEvent := findInvocationEndEvent(events) + require.NotNil(t, endEvent) + assert.Equal(t, agent.FinishReasonInputRequired, endEvent.FinishReason) + + // Assert: Should have the deploy tool's call ID in InputRequiredToolIDs + require.Len(t, endEvent.InputRequiredToolIDs, 1) + assert.NotEmpty(t, endEvent.InputRequiredToolIDs[0]) + + // Assert: Should have emitted StatusStageInputRequired + statusEvents := filterEvents[agent.StatusEvent](events) + hasInputRequired := false + for _, se := range statusEvents { + if se.Stage == agent.StatusStageInputRequired { + hasInputRequired = true + assert.Contains(t, se.Details, "1 tools awaiting external input") + } + } + assert.True(t, hasInputRequired, "expected StatusStageInputRequired event") + + // Assert: Tool response event was emitted with the pending result + toolRespEvents := filterEvents[agent.ToolResponseEvent](events) + require.NotEmpty(t, toolRespEvents) + assert.Contains(t, string(toolRespEvents[0].Response.Result), "pending") + + // Assert: Tool result was added to session + // Session should have: [user_msg, assistant(tool_call), user(tool_resp)] + assert.Len(t, sess.Messages, 3) +} + +// TestRun_MixedSyncAndAsyncTools tests that a mix of normal and async tools +// correctly pauses only for the async tool IDs. +func TestRun_MixedSyncAndAsyncTools(t *testing.T) { + t.Parallel() + + // Setup: One normal tool, one async tool + normalTool := &mockTool{ + name: "get_status", + definition: llm.ToolDefinition{ + Name: "get_status", + Description: "Get current status", + Parameters: json.RawMessage(`{"type": "object", "properties": {}}`), + }, + executeFn: func(_ context.Context, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`{"status": "ok"}`), nil + }, + } + + asyncTool := &asyncMockTool{ + mockTool: mockTool{ + name: "long_task", + definition: llm.ToolDefinition{ + Name: "long_task", + Description: "Start a long running task", + Parameters: json.RawMessage(`{"type": "object", "properties": {}}`), + }, + executeFn: func(_ context.Context, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`{"status": "pending", "task_id": "task-99"}`), nil + }, + }, + } + + registry := tool.NewRegistry(tool.RegistryConfig{}) + require.NoError(t, registry.Register(normalTool)) + require.NoError(t, registry.Register(asyncTool)) + + // Setup: Model calls both tools simultaneously + model := fakellm.NewFakeModel() + model.When(fakellm.FirstTurn()). + Times(1). + ThenRespondWith(func(_ *llm.Request, cc *fakellm.CallContext) (*llm.Response, error) { + return &llm.Response{ + Message: llm.Message{ + Role: llm.RoleAssistant, + Content: []*llm.Part{ + llm.NewToolRequestPart(&llm.ToolRequest{ + ID: "call_status", + Name: "get_status", + Arguments: json.RawMessage(`{}`), + }), + llm.NewToolRequestPart(&llm.ToolRequest{ + ID: "call_task", + Name: "long_task", + Arguments: json.RawMessage(`{}`), + }), + }, + }, + FinishReason: llm.FinishReasonToolCalls, + ID: "resp-1", + }, nil + }) + + ag, err := llmagent.New( + "mixed-agent", + "You handle mixed tasks", + model, + llmagent.WithTools(registry), + ) + require.NoError(t, err) + + sess := &session.State{ + ID: "test-session", + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("Check status and start task"))}, + } + inv := agent.NewInvocationMetadata(sess, agent.Info{}) + + events := collectEvents(t, ag.Run(t.Context(), inv)) + + // Assert: Should pause with input required + endEvent := findInvocationEndEvent(events) + require.NotNil(t, endEvent) + assert.Equal(t, agent.FinishReasonInputRequired, endEvent.FinishReason) + + // Assert: Only the async tool's ID should be in InputRequiredToolIDs + require.Len(t, endEvent.InputRequiredToolIDs, 1) + + // Assert: Both tool responses were emitted + toolRespEvents := filterEvents[agent.ToolResponseEvent](events) + assert.Len(t, toolRespEvents, 2) +} diff --git a/examples/agent_interceptors/tools.go b/examples/agent_interceptors/tools.go index 979ce18e..ffee31ce 100644 --- a/examples/agent_interceptors/tools.go +++ b/examples/agent_interceptors/tools.go @@ -66,6 +66,9 @@ func (t *TemperatureSensorTool) Definition() llm.ToolDefinition { } } +// IsAsynchronous implements tool.Tool. +func (*TemperatureSensorTool) IsAsynchronous() bool { return false } + func (t *TemperatureSensorTool) Execute(ctx context.Context, args json.RawMessage) (json.RawMessage, error) { var input TemperatureSensorInput if err := json.Unmarshal(args, &input); err != nil { @@ -136,6 +139,9 @@ func (t *GetSecretValueTool) Definition() llm.ToolDefinition { } } +// IsAsynchronous implements tool.Tool. +func (*GetSecretValueTool) IsAsynchronous() bool { return false } + func (t *GetSecretValueTool) Execute(ctx context.Context, args json.RawMessage) (json.RawMessage, error) { var input GetSecretValueInput if err := json.Unmarshal(args, &input); err != nil { diff --git a/runner/runner.go b/runner/runner.go index c1e1db95..bffc2a6a 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -174,64 +174,119 @@ func (r *Runner) Run( // 2. Add user message to session sess.Messages = append(sess.Messages, userMessage) - // 3. Create invocation metadata with agent snapshot - // The snapshot captures agent identity for observability. - inv := agent.NewInvocationMetadata(sess, r.config.agent.Info()) + // 3. Execute agent and forward events + r.runAgent(ctx, sess, yield) + } +} - // Track whether the consumer stopped iteration (yield returned false). - // When yield returns false, we must not call it again or Go panics with - // "range function continued iteration after function for loop body returned false". - consumerStopped := false +// Resume continues a paused invocation by providing results for pending tools. +// +// This is used after an invocation ended with FinishReasonInputRequired. +// The caller provides tool results for the tools listed in +// InvocationEndEvent.InputRequiredToolIDs. +// +// The provided tool results are added to the session as a user message +// containing ToolResponse parts, then the agent resumes from where it +// left off. The LLM sees the full history: original tool call, the initial +// pending result, and the final result provided here. +// +// # Example +// +// // After receiving InvocationEndEvent with FinishReason "input_required" +// // and InputRequiredToolIDs: ["call_abc123"] +// for evt, err := range runner.Resume(ctx, userID, sessionID, +// []llm.ToolResponse{{ +// ID: "call_abc123", +// Name: "deploy", +// Result: json.RawMessage(`{"url": "https://staging.example.com"}`), +// }}, +// ) { +// // ... handle events as with Run() ... +// } +func (r *Runner) Resume( + ctx context.Context, + _ string, // UserID will be used in the future. + sessionID string, + toolResults []llm.ToolResponse, +) iter.Seq2[agent.Event, error] { + return func(yield func(agent.Event, error) bool) { + // 1. Load session (must exist — Resume is only valid for existing sessions) + sess, err := r.config.sessionStore.Load(ctx, sessionID) + if err != nil { + yield(nil, fmt.Errorf("%w: %w", agent.ErrSessionLoad, err)) + return + } - // 4. Save session on exit (handles normal completion, cancellation, errors) - defer func() { - if err := r.config.sessionStore.Save(ctx, sess); err != nil { - // Only yield error if consumer hasn't explicitly stopped iteration. - // If consumer broke out of their for loop (yield returned false), - // calling yield again would panic. - if !consumerStopped { - yield(nil, fmt.Errorf("%w: %w", agent.ErrSessionSave, err)) - } else { - r.config.logger.Error("session save failed after consumer stopped", - "sessionID", sess.ID, - "error", err) - } - } - }() - - // 5. Execute agent and forward events - for evt, err := range r.config.agent.Run(ctx, inv) { - if err != nil { - // Forward error - if !yield(nil, err) { - consumerStopped = true - return - } - - continue - } + // 2. Add tool results as user message with ToolResponse parts + parts := make([]*llm.Part, 0, len(toolResults)) + for i := range toolResults { + parts = append(parts, llm.NewToolResponsePart(&toolResults[i])) + } + + sess.Messages = append(sess.Messages, llm.NewMessage(llm.RoleUser, parts...)) + + // 3. Execute agent and forward events + r.runAgent(ctx, sess, yield) + } +} - // Save session after each assistant message (incremental persistence) - // Note: Agent already appended the message to sess.Messages, we just save it - if _, ok := evt.(agent.MessageEvent); ok { - if err := r.config.sessionStore.Save(ctx, sess); err != nil { - yield(nil, fmt.Errorf("%w: %w", agent.ErrSessionSave, err)) - return - } +// runAgent is the shared agent execution and event forwarding logic +// used by both Run() and Resume(). +func (r *Runner) runAgent( + ctx context.Context, + sess *session.State, + yield func(agent.Event, error) bool, +) { + // Create invocation metadata with agent snapshot + inv := agent.NewInvocationMetadata(sess, r.config.agent.Info()) + + // Track whether the consumer stopped iteration (yield returned false). + // When yield returns false, we must not call it again or Go panics with + // "range function continued iteration after function for loop body returned false". + consumerStopped := false + + // Save session on exit (handles normal completion, cancellation, errors) + defer func() { + if err := r.config.sessionStore.Save(ctx, sess); err != nil { + if !consumerStopped { + yield(nil, fmt.Errorf("%w: %w", agent.ErrSessionSave, err)) + } else { + r.config.logger.Error("session save failed after consumer stopped", + "sessionID", sess.ID, + "error", err) } + } + }() - // Forward event to caller - if !yield(evt, nil) { + // Execute agent and forward events + for evt, err := range r.config.agent.Run(ctx, inv) { + if err != nil { + if !yield(nil, err) { consumerStopped = true return } - // Exit after completion event - consumer is still active here, - // defer can still yield if needed - if _, ok := evt.(agent.InvocationEndEvent); ok { + continue + } + + // Save session after each assistant message (incremental persistence) + if _, ok := evt.(agent.MessageEvent); ok { + if err := r.config.sessionStore.Save(ctx, sess); err != nil { + yield(nil, fmt.Errorf("%w: %w", agent.ErrSessionSave, err)) return } } + + // Forward event to caller + if !yield(evt, nil) { + consumerStopped = true + return + } + + // Exit after completion event + if _, ok := evt.(agent.InvocationEndEvent); ok { + return + } } } diff --git a/runner/runner_integration_test.go b/runner/runner_integration_test.go index 40228964..849d86e1 100644 --- a/runner/runner_integration_test.go +++ b/runner/runner_integration_test.go @@ -62,6 +62,8 @@ func (*calculatorTool) Definition() llm.ToolDefinition { } } +func (*calculatorTool) IsAsynchronous() bool { return false } + func (*calculatorTool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) { // Parse arguments var params struct { diff --git a/tool/agenttool/agenttool.go b/tool/agenttool/agenttool.go index b4eecc09..9118aa47 100644 --- a/tool/agenttool/agenttool.go +++ b/tool/agenttool/agenttool.go @@ -100,6 +100,9 @@ type Result struct { // - Each invocation creates a fresh session (no context sharing) // - This prevents context pollution and keeps parent/child boundaries clear // - For context sharing, pass relevant information explicitly in args +// IsAsynchronous implements tool.Tool. +func (*AgentTool) IsAsynchronous() bool { return false } + func (at *AgentTool) Execute(ctx context.Context, args json.RawMessage) (json.RawMessage, error) { info := at.agent.Info() diff --git a/tool/builtin/artifact_emit.go b/tool/builtin/artifact_emit.go index 9c32fe67..37bba930 100644 --- a/tool/builtin/artifact_emit.go +++ b/tool/builtin/artifact_emit.go @@ -79,6 +79,9 @@ Append to existing: {"append_to_artifact_id": "artifact-123", "text": "Additiona } } +// IsAsynchronous implements tool.Tool. +func (*ArtifactEmitTool) IsAsynchronous() bool { return false } + // Execute processes the artifact emit request. func (*ArtifactEmitTool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) { var input EmitArtifactInput diff --git a/tool/builtin/plot/tool.go b/tool/builtin/plot/tool.go index decce648..d2e5ae4c 100644 --- a/tool/builtin/plot/tool.go +++ b/tool/builtin/plot/tool.go @@ -85,6 +85,9 @@ Histogram: {"name": "Response Time Distribution", "description": "API response t } } +// IsAsynchronous implements tool.Tool. +func (*Tool) IsAsynchronous() bool { return false } + // Execute performs the plot generation. func (*Tool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) { var input Input diff --git a/tool/builtin/require_input.go b/tool/builtin/require_input.go index 661cdf15..c6bb5e58 100644 --- a/tool/builtin/require_input.go +++ b/tool/builtin/require_input.go @@ -90,7 +90,15 @@ IMPORTANT: } } +// IsAsynchronous implements tool.Tool. RequireInputTool is asynchronous because +// it signals that the agent needs external input before continuing. The agent +// pauses and the caller provides the user's response via Runner.Resume(). +func (*RequireInputTool) IsAsynchronous() bool { return true } + // Execute processes the require input request. +// Returns a pending status describing what input is needed. Because +// IsAsynchronous() returns true, the framework will pause the agent +// and emit FinishReasonInputRequired. func (*RequireInputTool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) { var req RequireInputRequest @@ -120,17 +128,9 @@ func (*RequireInputTool) Execute(_ context.Context, args json.RawMessage) (json. return nil, fmt.Errorf("invalid type %q", req.Type) } - response := RequireInputResponse{ - Success: true, - Message: "Task marked as requiring user input: " + req.Message, - Status: "require_input", - } - - // Include the original request in the response for the reconciler to process responseWithDetails := map[string]any{ - "success": response.Success, - "message": response.Message, - "status": response.Status, + "status": "awaiting_input", + "message": "Awaiting user " + req.Type + ": " + req.Message, "input_message": req.Message, "input_type": req.Type, } diff --git a/tool/builtin/todo/todo_update.go b/tool/builtin/todo/todo_update.go index 243ea32f..241e6c8c 100644 --- a/tool/builtin/todo/todo_update.go +++ b/tool/builtin/todo/todo_update.go @@ -115,6 +115,9 @@ IMPORTANT RULES: } } +// IsAsynchronous implements tool.Tool. +func (*UpdateTodoStateTool) IsAsynchronous() bool { return false } + // Execute processes the update todo state request. func (t *UpdateTodoStateTool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) { var req UpdateTodoStateRequest @@ -246,6 +249,9 @@ IMPORTANT RULES: } } +// IsAsynchronous implements tool.Tool. +func (*AddTodoTool) IsAsynchronous() bool { return false } + // Execute processes the add todo request. func (t *AddTodoTool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) { var req AddTodoRequest diff --git a/tool/builtin/webfetch/tool.go b/tool/builtin/webfetch/tool.go index 745939fb..f9239272 100644 --- a/tool/builtin/webfetch/tool.go +++ b/tool/builtin/webfetch/tool.go @@ -85,6 +85,9 @@ func (t *Tool) Definition() llm.ToolDefinition { } // Execute performs the webfetch operation. +// IsAsynchronous implements tool.Tool. +func (*Tool) IsAsynchronous() bool { return false } + func (t *Tool) Execute(ctx context.Context, args json.RawMessage) (json.RawMessage, error) { var params struct { URL string `json:"url"` diff --git a/tool/mcp/client.go b/tool/mcp/client.go index 7c7cd34e..eb3563f8 100644 --- a/tool/mcp/client.go +++ b/tool/mcp/client.go @@ -125,14 +125,15 @@ var _ Client = (*clientImpl)(nil) // - All operations respect both client lifetime (bgCtx) and caller deadlines (via opContext). type clientImpl struct { // Configuration (immutable after construction) - serverID string - transportFactory TransportFactory - registry tool.Registry - autoSyncInterval time.Duration - shutdownTimeout time.Duration - toolTimeout time.Duration - logger *slog.Logger - toolFilter ToolFilterFunc + serverID string + transportFactory TransportFactory + registry tool.Registry + autoSyncInterval time.Duration + shutdownTimeout time.Duration + toolTimeout time.Duration + logger *slog.Logger + toolFilter ToolFilterFunc + elicitationHandler ElicitationHandler // MCP SDK components mcpClient *sdkmcp.Client @@ -419,11 +420,7 @@ func (c *clientImpl) isShutdown() bool { // connect creates and connects the MCP client. func (c *clientImpl) connect(ctx context.Context) (*sdkmcp.ClientSession, *sdkmcp.Client, error) { - mcpClient := sdkmcp.NewClient(&sdkmcp.Implementation{ - Name: "redpanda-ai-agent-sdk", - Title: "Redpanda AI Agent SDK", - Version: "v1.0.0", - }, &sdkmcp.ClientOptions{ + opts := &sdkmcp.ClientOptions{ KeepAlive: 30 * time.Second, ToolListChangedHandler: func(_ context.Context, _ *sdkmcp.ToolListChangedRequest) { select { @@ -431,7 +428,31 @@ func (c *clientImpl) connect(ctx context.Context) (*sdkmcp.ClientSession, *sdkmc default: } }, - }) + } + + // Wire elicitation handler if configured + if c.elicitationHandler != nil { + opts.ElicitationHandler = func(ctx context.Context, req *sdkmcp.ElicitRequest) (*sdkmcp.ElicitResult, error) { + resp, err := c.elicitationHandler(ctx, &ElicitationRequest{ + Message: req.Params.Message, + RequestedSchema: req.Params.RequestedSchema, + }) + if err != nil { + return nil, err + } + + return &sdkmcp.ElicitResult{ + Action: resp.Action, + Content: resp.Content, + }, nil + } + } + + mcpClient := sdkmcp.NewClient(&sdkmcp.Implementation{ + Name: "redpanda-ai-agent-sdk", + Title: "Redpanda AI Agent SDK", + Version: "v1.0.0", + }, opts) transport, err := c.transportFactory() if err != nil { diff --git a/tool/mcp/client_options.go b/tool/mcp/client_options.go index 1d51a17f..680b81cc 100644 --- a/tool/mcp/client_options.go +++ b/tool/mcp/client_options.go @@ -15,6 +15,7 @@ package mcp import ( + "context" "log/slog" "time" @@ -102,3 +103,52 @@ func WithToolTimeout(timeout time.Duration) ClientOption { c.toolTimeout = timeout } } + +// ElicitationHandler is called when an MCP server requests user input during +// tool execution via the MCP elicitation protocol. +// +// The handler receives the server's elicitation request (message + optional +// schema) and returns the user's response. For synchronous contexts (CLI apps), +// the handler can prompt on stdin. For async contexts (web apps), it can +// integrate with a UI framework. +// +// Returning an error causes the MCP tool call to fail with that error. +type ElicitationHandler func(ctx context.Context, req *ElicitationRequest) (*ElicitationResponse, error) + +// ElicitationRequest contains the MCP server's request for user input. +type ElicitationRequest struct { + // Message is the human-readable message explaining what input is needed. + Message string `json:"message"` + // RequestedSchema is an optional JSON schema defining the expected input structure. + // Only used for "form" elicitation mode. + RequestedSchema any `json:"requested_schema,omitempty"` +} + +// ElicitationResponse contains the user's response to an elicitation request. +type ElicitationResponse struct { + // Action is the user's decision: "accept", "decline", or "cancel". + Action string `json:"action"` + // Content contains the submitted form data when Action is "accept". + Content map[string]any `json:"content,omitempty"` +} + +// WithElicitationHandler sets a handler for MCP server elicitation requests. +// When an MCP server requests user input during tool execution, this handler +// is called to obtain the user's response. +// +// Setting this handler automatically advertises elicitation capability to +// the MCP server during connection. +// +// Example: +// +// handler := func(ctx context.Context, req *mcp.ElicitationRequest) (*mcp.ElicitationResponse, error) { +// fmt.Printf("Server asks: %s\n", req.Message) +// // ... prompt user and collect response ... +// return &mcp.ElicitationResponse{Action: "accept", Content: response}, nil +// } +// client, err := NewClient(serverID, transport, WithElicitationHandler(handler)) +func WithElicitationHandler(handler ElicitationHandler) ClientOption { + return func(c *clientImpl) { + c.elicitationHandler = handler + } +} diff --git a/tool/mcp/tools.go b/tool/mcp/tools.go index f26d9985..7d18a893 100644 --- a/tool/mcp/tools.go +++ b/tool/mcp/tools.go @@ -377,6 +377,9 @@ func (w *toolWrapper) Definition() llm.ToolDefinition { return w.definition } +// IsAsynchronous implements tool.Tool. +func (*toolWrapper) IsAsynchronous() bool { return false } + // Execute forwards the tool execution to the MCP client. // Uses the namespaced tool name from the definition. func (w *toolWrapper) Execute(ctx context.Context, args json.RawMessage) (json.RawMessage, error) { diff --git a/tool/registry.go b/tool/registry.go index 6a1ca5e9..4aa5c1c6 100644 --- a/tool/registry.go +++ b/tool/registry.go @@ -148,13 +148,22 @@ func (r *registry) Unregister(name string) error { } // List returns tool definitions for use in llm.Request.Tools. +// Asynchronous tools (IsAsynchronous() == true) have a note appended to their +// description instructing the LLM not to re-invoke them after a pending status. func (r *registry) List() []llm.ToolDefinition { r.mu.RLock() defer r.mu.RUnlock() definitions := make([]llm.ToolDefinition, 0, len(r.tools)) for _, registered := range r.tools { - definitions = append(definitions, registered.tool.Definition()) + def := registered.tool.Definition() + if registered.tool.IsAsynchronous() { + def.Description += "\n\nNOTE: This is an asynchronous operation. " + + "Do not call this tool again if it has already returned " + + "an intermediate or pending status." + } + + definitions = append(definitions, def) } return definitions diff --git a/tool/registry_integration_test.go b/tool/registry_integration_test.go index 5955a53a..156f66e9 100644 --- a/tool/registry_integration_test.go +++ b/tool/registry_integration_test.go @@ -520,6 +520,8 @@ func (m *mockTool) Definition() llm.ToolDefinition { } } +func (*mockTool) IsAsynchronous() bool { return false } + func (m *mockTool) Execute(ctx context.Context, args json.RawMessage) (json.RawMessage, error) { if m.delay > 0 { select { diff --git a/tool/tool.go b/tool/tool.go index 7a75af26..8f3fb5a8 100644 --- a/tool/tool.go +++ b/tool/tool.go @@ -36,4 +36,19 @@ type Tool interface { // Execute performs the tool's main operation synchronously // Input and output are JSON for maximum flexibility across tool types Execute(ctx context.Context, args json.RawMessage) (json.RawMessage, error) + + // IsAsynchronous indicates whether this tool cannot complete in a single + // synchronous call. Asynchronous tools return an initial/pending result + // from Execute() and require external completion (e.g., user input, + // CI/CD deployment finish, batch job result). + // + // When true, the agent pauses after executing this tool and emits + // FinishReasonInputRequired with the tool's call ID, allowing the + // caller to provide the final result later via Runner.Resume(). + // + // The tool's Execute() should return a normal result describing the + // pending state (e.g., {"status": "pending", "task_id": "ci-42"}). + // This result is stored in the session so the LLM has context when + // the final result arrives. + IsAsynchronous() bool }