From 7830492ecf2dd0bd28a644b7119b836e8321ba92 Mon Sep 17 00:00:00 2001 From: Megumin Date: Fri, 28 Feb 2025 12:14:35 +0800 Subject: [PATCH] feat: add tool list option to tools node --- compose/tool_node.go | 78 ++++++++++++++++++++++++++++----------- compose/tool_node_test.go | 70 +++++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 22 deletions(-) diff --git a/compose/tool_node.go b/compose/tool_node.go index a4f858e..a2000b5 100644 --- a/compose/tool_node.go +++ b/compose/tool_node.go @@ -32,6 +32,7 @@ import ( type toolsNodeOptions struct { ToolOptions []tool.Option + ToolList []tool.BaseTool } // ToolsNodeOption is the option func type for ToolsNode. @@ -44,14 +45,19 @@ func WithToolOption(opts ...tool.Option) ToolsNodeOption { } } +// WithToolListOption sets the tool list for the ToolsNode. +func WithToolListOption(tool ...tool.BaseTool) ToolsNodeOption { + return func(o *toolsNodeOptions) { + o.ToolList = tool + } +} + // ToolsNode a node that can run tools in a graph. the interface in Graph Node as below: // // Invoke(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) ([]*schema.Message, error) // Stream(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error) type ToolsNode struct { - runners []*runnablePacker[string, string, tool.Option] - toolsMeta []*executorMeta - indexes map[string]int // toolName vs index in runners + tuple *toolsTuple } // ToolsNodeConfig is the config for ToolsNode. It requires a list of tools. @@ -68,12 +74,29 @@ type ToolsNodeConfig struct { // } // toolsNode, err := NewToolNode(ctx, conf) func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) { - rps := make([]*runnablePacker[string, string, tool.Option], len(conf.Tools)) - toolsMeta := make([]*executorMeta, len(conf.Tools)) - indexes := make(map[string]int) + tuple, err := convTools(ctx, conf.Tools) + if err != nil { + return nil, err + } + + return &ToolsNode{ + tuple: tuple, + }, nil +} - for idx, bt := range conf.Tools { +type toolsTuple struct { + indexes map[string]int + meta []*executorMeta + rps []*runnablePacker[string, string, tool.Option] +} +func convTools(ctx context.Context, tools []tool.BaseTool) (*toolsTuple, error) { + ret := &toolsTuple{ + indexes: make(map[string]int), + meta: make([]*executorMeta, len(tools)), + rps: make([]*runnablePacker[string, string, tool.Option], len(tools)), + } + for idx, bt := range tools { tl, err := bt.Info(ctx) if err != nil { return nil, fmt.Errorf("(NewToolNode) failed to get tool info at idx= %d: %w", idx, err) @@ -110,17 +133,12 @@ func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) meta = parseExecutorInfoFromComponent(components.ComponentOfTool, it) } - toolsMeta[idx] = meta - rps[idx] = newRunnablePacker(invokable, streamable, + ret.indexes[toolName] = idx + ret.meta[idx] = meta + ret.rps[idx] = newRunnablePacker(invokable, streamable, nil, nil, !meta.isComponentCallbackEnabled) - indexes[toolName] = idx } - - return &ToolsNode{ - runners: rps, - toolsMeta: toolsMeta, - indexes: indexes, - }, nil + return ret, nil } type toolCallTask struct { @@ -137,7 +155,7 @@ type toolCallTask struct { err error } -func (tn *ToolsNode) genToolCallTasks(input *schema.Message) ([]toolCallTask, error) { +func genToolCallTasks(tuple *toolsTuple, input *schema.Message) ([]toolCallTask, error) { if input.Role != schema.Assistant { return nil, fmt.Errorf("expected message role is Assistant, got %s", input.Role) } @@ -151,13 +169,13 @@ func (tn *ToolsNode) genToolCallTasks(input *schema.Message) ([]toolCallTask, er for i := 0; i < n; i++ { toolCall := input.ToolCalls[i] - index, ok := tn.indexes[toolCall.Function.Name] + index, ok := tuple.indexes[toolCall.Function.Name] if !ok { return nil, fmt.Errorf("tool %s not found in toolsNode indexes", toolCall.Function.Name) } - toolCallTasks[i].r = tn.runners[index] - toolCallTasks[i].meta = tn.toolsMeta[index] + toolCallTasks[i].r = tuple.rps[index] + toolCallTasks[i].meta = tuple.meta[index] toolCallTasks[i].name = toolCall.Function.Name toolCallTasks[i].arg = toolCall.Function.Arguments toolCallTasks[i].callID = toolCall.ID @@ -217,8 +235,16 @@ func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) ([]*schema.Message, error) { opt := getToolsNodeOptions(opts...) + tuple := tn.tuple + if opt.ToolList != nil { + var err error + tuple, err = convTools(ctx, opt.ToolList) + if err != nil { + return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + } + } - tasks, err := tn.genToolCallTasks(input) + tasks, err := genToolCallTasks(tuple, input) if err != nil { return nil, err } @@ -244,8 +270,16 @@ func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error) { opt := getToolsNodeOptions(opts...) + tuple := tn.tuple + if opt.ToolList != nil { + var err error + tuple, err = convTools(ctx, opt.ToolList) + if err != nil { + return nil, fmt.Errorf("failed to convert tool list from call option: %w", err) + } + } - tasks, err := tn.genToolCallTasks(input) + tasks, err := genToolCallTasks(tuple, input) if err != nil { return nil, err } diff --git a/compose/tool_node_test.go b/compose/tool_node_test.go index 47e79bd..ead47e6 100644 --- a/compose/tool_node_test.go +++ b/compose/tool_node_test.go @@ -417,6 +417,76 @@ func TestToolsNodeOptions(t *testing.T) { assert.Len(t, msgs, 1) assert.JSONEq(t, `{"echo":"jack: 10"}`, msgs[0].Content) }) + t.Run("tool_list", func(t *testing.T) { + + g := NewGraph[*schema.Message, []*schema.Message]() + + mt := &mockTool{} + + tn, err := NewToolNode(ctx, &ToolsNodeConfig{ + Tools: []tool.BaseTool{}, + }) + assert.NoError(t, err) + + err = g.AddToolsNode("tools", tn) + assert.NoError(t, err) + + err = g.AddEdge(START, "tools") + assert.NoError(t, err) + err = g.AddEdge("tools", END) + assert.NoError(t, err) + + r, err := g.Compile(ctx) + assert.NoError(t, err) + + out, err := r.Invoke(ctx, &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: toolIDOfUserCompany, + Function: schema.FunctionCall{ + Name: "mock_tool", + Arguments: `{"name": "jack"}`, + }, + }, + }, + }, WithToolsNodeOption(WithToolListOption(mt), WithToolOption(WithAge(10)))) + assert.NoError(t, err) + assert.Len(t, out, 1) + assert.JSONEq(t, `{"echo": "jack: 10"}`, out[0].Content) + + outMessages := make([][]*schema.Message, 0) + outStream, err := r.Stream(ctx, &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: toolIDOfUserCompany, + Function: schema.FunctionCall{ + Name: "mock_tool", + Arguments: `{"name": "jack"}`, + }, + }, + }, + }, WithToolsNodeOption(WithToolListOption(mt), WithToolOption(WithAge(10)))) + + assert.NoError(t, err) + + for { + msgs, err := outStream.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + outMessages = append(outMessages, msgs) + } + outStream.Close() + + msgs, err := internal.ConcatItems(outMessages) + assert.NoError(t, err) + + assert.Len(t, msgs, 1) + assert.JSONEq(t, `{"echo":"jack: 10"}`, msgs[0].Content) + }) }