Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions adapter/a2a/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ func (m *mockWeatherTool) Execute(_ context.Context, _ json.RawMessage) (json.Ra
return json.RawMessage(`{"temperature": "72°F", "conditions": "sunny"}`), nil
}

func (*mockWeatherTool) IsAsynchronous() bool { return false }

func TestExecutor_SessionPersistence_Mock(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 2 additions & 0 deletions agent/conformance/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ func (*CalculatorTool) Definition() llm.ToolDefinition {
}
}

func (*CalculatorTool) IsAsynchronous() bool { return false }

func (*CalculatorTool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) {
var params struct {
A float64 `json:"a"`
Expand Down
90 changes: 64 additions & 26 deletions agent/llmagent/llmagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,17 @@ func (a *LLMAgent) Run(ctx context.Context, inv *agent.InvocationMetadata) iter.

// Check if interceptor or turn logic wants to end execution
if finishReason != "" {
// Emit terminal event
yield(agent.InvocationEndEvent{
endEvent := agent.InvocationEndEvent{
Envelope: makeEnvelope(),
FinishReason: finishReason,
Usage: new(inv.TotalUsage()),
}, nil)
}
// Populate InputRequiredToolIDs when pausing for external input
if ids, ok := inv.GetMetadata("input_required_tool_ids").([]string); ok {
endEvent.InputRequiredToolIDs = ids
}

yield(endEvent, nil)

return
}
Expand Down Expand Up @@ -347,12 +352,27 @@ func (a *LLMAgent) executeSingleTurn(
return "", agent.ErrToolRegistry
}

toolParts := a.executeTools(ctx, inv, toolReqs, req.Tools, makeEnvelope, yield)
toolResult := a.executeTools(ctx, inv, toolReqs, req.Tools, makeEnvelope, yield)

// Build single message with all tool response parts
toolMsg := llm.NewMessage(llm.RoleUser, toolParts...)
toolMsg := llm.NewMessage(llm.RoleUser, toolResult.parts...)
sess.Messages = append(sess.Messages, toolMsg)

// Check: any tools requiring external input?
if len(toolResult.inputRequiredIDs) > 0 {
// Emit input required status
yield(agent.StatusEvent{
Envelope: makeEnvelope(),
Stage: agent.StatusStageInputRequired,
Details: fmt.Sprintf("%d tools awaiting external input", len(toolResult.inputRequiredIDs)),
}, nil)

// Store for InvocationEndEvent
inv.SetMetadata("input_required_tool_ids", toolResult.inputRequiredIDs)

return agent.FinishReasonInputRequired, nil
}

// Emit turn completed
if !yield(agent.StatusEvent{
Envelope: makeEnvelope(),
Expand Down Expand Up @@ -479,6 +499,12 @@ func (a *LLMAgent) generateWithStreaming(
return response, nil
}

// executeToolsResult holds the outcome of tool execution.
type executeToolsResult struct {
parts []*llm.Part // All tool response parts
inputRequiredIDs []string // Tool call IDs requiring external input (asynchronous tools)
}

// executeTools runs tool calls concurrently.
//
// Tool execution is limited by toolConcurrency. Individual tool errors
Expand All @@ -491,29 +517,29 @@ func (a *LLMAgent) generateWithStreaming(
//
// ToolResponseEvents are yielded as tools complete.
//
// Returns tool response parts in the order they were requested.
//
// Future: Will also return list of tool IDs requiring input for
// StatusStageInputRequired / FinishReasonInputRequired support.
// Returns tool response parts and a list of tool call IDs that require
// external input (from asynchronous tools). When inputRequiredIDs is non-empty,
// the agent should pause and emit FinishReasonInputRequired.
func (a *LLMAgent) executeTools(
ctx context.Context,
inv *agent.InvocationMetadata,
toolReqs []*llm.ToolRequest,
toolDefs []llm.ToolDefinition,
makeEnvelope func() agent.EventEnvelope,
yield func(agent.Event, error) bool,
) []*llm.Part {
) executeToolsResult {
// Execute tools concurrently with limited parallelism
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(min(a.config.toolConcurrency, len(toolReqs)))

// Results channel (buffered to avoid blocking)
type toolResult struct {
idx int
requestID string
name string
response *llm.ToolResponse
err error
idx int
requestID string
name string
response *llm.ToolResponse
err error
isAsynchronous bool
}

results := make(chan toolResult, len(toolReqs))
Expand All @@ -535,19 +561,24 @@ func (a *LLMAgent) executeTools(
// Launch tool executions
for i, req := range toolReqs {
g.Go(func() error {
// Check if tool is asynchronous
t, _ := a.config.tools.Get(req.Name)
async := t != nil && t.IsAsynchronous()

toolInfo := &agent.ToolCallInfo{
Inv: inv,
Req: req,
Definition: toolDefMap[req.Name], // Add tool definition
Definition: toolDefMap[req.Name],
}

resp, err := executor(gctx, toolInfo)
results <- toolResult{
idx: i,
requestID: req.ID,
name: req.Name,
response: resp,
err: err,
idx: i,
requestID: req.ID,
name: req.Name,
response: resp,
err: err,
isAsynchronous: async,
}

return nil // Never return error to errgroup (we handle errors individually)
Expand All @@ -557,6 +588,8 @@ func (a *LLMAgent) executeTools(
// Collect tool response parts and yield events as they arrive
parts := make([]*llm.Part, 0, len(toolReqs))

var inputRequiredIDs []string

for range toolReqs {
result := <-results

Expand All @@ -574,7 +607,7 @@ func (a *LLMAgent) executeTools(
Envelope: makeEnvelope(),
Response: *errResp,
}, nil) {
return parts // Consumer stopped listening
return executeToolsResult{parts: parts, inputRequiredIDs: inputRequiredIDs}
}
} else {
// Tool execution succeeded
Expand All @@ -585,12 +618,17 @@ func (a *LLMAgent) executeTools(
Envelope: makeEnvelope(),
Response: *result.response,
}, nil) {
return parts // Consumer stopped listening
return executeToolsResult{parts: parts, inputRequiredIDs: inputRequiredIDs}
}

// Track asynchronous tools that need external completion
if result.isAsynchronous {
inputRequiredIDs = append(inputRequiredIDs, result.requestID)
}
}
}

return parts
return executeToolsResult{parts: parts, inputRequiredIDs: inputRequiredIDs}
}

// recoverIncompleteToolCalls detects and executes incomplete tool calls from a
Expand Down Expand Up @@ -645,12 +683,12 @@ func (a *LLMAgent) recoverIncompleteToolCalls(

// Execute the incomplete tools
toolDefs := a.config.tools.List()
toolParts := a.executeTools(ctx, inv, incomplete, toolDefs, makeEnvelope, yield)
toolResult := a.executeTools(ctx, inv, incomplete, toolDefs, makeEnvelope, yield)

// Insert tool response message BEFORE the last user message.
// Current: [..., assistant(tool_req), user(text)]
// After: [..., assistant(tool_req), user(tool_resp), user(text)]
toolMsg := llm.NewMessage(llm.RoleUser, toolParts...)
toolMsg := llm.NewMessage(llm.RoleUser, toolResult.parts...)
lastIdx := len(sess.Messages) - 1
sess.Messages = append(sess.Messages[:lastIdx], toolMsg, sess.Messages[lastIdx])

Expand Down
2 changes: 2 additions & 0 deletions agent/llmagent/llmagent_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ func (*calculatorTool) Definition() llm.ToolDefinition {
}
}

func (*calculatorTool) IsAsynchronous() bool { return false }

func (*calculatorTool) Execute(_ context.Context, args json.RawMessage) (json.RawMessage, error) {
// Parse arguments
var params struct {
Expand Down
Loading
Loading