Skip to content

Commit

Permalink
feat: add tool list option to tools node
Browse files Browse the repository at this point in the history
  • Loading branch information
meguminnnnnnnnn committed Feb 28, 2025
1 parent 93cb521 commit 7830492
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 22 deletions.
78 changes: 56 additions & 22 deletions compose/tool_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

type toolsNodeOptions struct {
ToolOptions []tool.Option
ToolList []tool.BaseTool
}

// ToolsNodeOption is the option func type for ToolsNode.
Expand All @@ -44,14 +45,19 @@ func WithToolOption(opts ...tool.Option) ToolsNodeOption {
}
}

// WithToolListOption sets the tool list for the ToolsNode.
func WithToolListOption(tool ...tool.BaseTool) ToolsNodeOption {
return func(o *toolsNodeOptions) {
o.ToolList = tool
}
}

// ToolsNode a node that can run tools in a graph. the interface in Graph Node as below:
//
// Invoke(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) ([]*schema.Message, error)
// Stream(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error)
type ToolsNode struct {
runners []*runnablePacker[string, string, tool.Option]
toolsMeta []*executorMeta
indexes map[string]int // toolName vs index in runners
tuple *toolsTuple
}

// ToolsNodeConfig is the config for ToolsNode. It requires a list of tools.
Expand All @@ -68,12 +74,29 @@ type ToolsNodeConfig struct {
// }
// toolsNode, err := NewToolNode(ctx, conf)
func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) {
rps := make([]*runnablePacker[string, string, tool.Option], len(conf.Tools))
toolsMeta := make([]*executorMeta, len(conf.Tools))
indexes := make(map[string]int)
tuple, err := convTools(ctx, conf.Tools)
if err != nil {
return nil, err
}

return &ToolsNode{
tuple: tuple,
}, nil
}

for idx, bt := range conf.Tools {
type toolsTuple struct {
indexes map[string]int
meta []*executorMeta
rps []*runnablePacker[string, string, tool.Option]
}

func convTools(ctx context.Context, tools []tool.BaseTool) (*toolsTuple, error) {
ret := &toolsTuple{
indexes: make(map[string]int),
meta: make([]*executorMeta, len(tools)),
rps: make([]*runnablePacker[string, string, tool.Option], len(tools)),
}
for idx, bt := range tools {
tl, err := bt.Info(ctx)
if err != nil {
return nil, fmt.Errorf("(NewToolNode) failed to get tool info at idx= %d: %w", idx, err)
Expand Down Expand Up @@ -110,17 +133,12 @@ func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error)
meta = parseExecutorInfoFromComponent(components.ComponentOfTool, it)
}

toolsMeta[idx] = meta
rps[idx] = newRunnablePacker(invokable, streamable,
ret.indexes[toolName] = idx
ret.meta[idx] = meta
ret.rps[idx] = newRunnablePacker(invokable, streamable,
nil, nil, !meta.isComponentCallbackEnabled)
indexes[toolName] = idx
}

return &ToolsNode{
runners: rps,
toolsMeta: toolsMeta,
indexes: indexes,
}, nil
return ret, nil
}

type toolCallTask struct {
Expand All @@ -137,7 +155,7 @@ type toolCallTask struct {
err error
}

func (tn *ToolsNode) genToolCallTasks(input *schema.Message) ([]toolCallTask, error) {
func genToolCallTasks(tuple *toolsTuple, input *schema.Message) ([]toolCallTask, error) {
if input.Role != schema.Assistant {
return nil, fmt.Errorf("expected message role is Assistant, got %s", input.Role)
}
Expand All @@ -151,13 +169,13 @@ func (tn *ToolsNode) genToolCallTasks(input *schema.Message) ([]toolCallTask, er

for i := 0; i < n; i++ {
toolCall := input.ToolCalls[i]
index, ok := tn.indexes[toolCall.Function.Name]
index, ok := tuple.indexes[toolCall.Function.Name]
if !ok {
return nil, fmt.Errorf("tool %s not found in toolsNode indexes", toolCall.Function.Name)
}

toolCallTasks[i].r = tn.runners[index]
toolCallTasks[i].meta = tn.toolsMeta[index]
toolCallTasks[i].r = tuple.rps[index]
toolCallTasks[i].meta = tuple.meta[index]
toolCallTasks[i].name = toolCall.Function.Name
toolCallTasks[i].arg = toolCall.Function.Arguments
toolCallTasks[i].callID = toolCall.ID
Expand Down Expand Up @@ -217,8 +235,16 @@ func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message,
opts ...ToolsNodeOption) ([]*schema.Message, error) {

opt := getToolsNodeOptions(opts...)
tuple := tn.tuple
if opt.ToolList != nil {
var err error
tuple, err = convTools(ctx, opt.ToolList)
if err != nil {
return nil, fmt.Errorf("failed to convert tool list from call option: %w", err)
}
}

tasks, err := tn.genToolCallTasks(input)
tasks, err := genToolCallTasks(tuple, input)
if err != nil {
return nil, err
}
Expand All @@ -244,8 +270,16 @@ func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message,
opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error) {

opt := getToolsNodeOptions(opts...)
tuple := tn.tuple
if opt.ToolList != nil {
var err error
tuple, err = convTools(ctx, opt.ToolList)
if err != nil {
return nil, fmt.Errorf("failed to convert tool list from call option: %w", err)
}
}

tasks, err := tn.genToolCallTasks(input)
tasks, err := genToolCallTasks(tuple, input)
if err != nil {
return nil, err
}
Expand Down
70 changes: 70 additions & 0 deletions compose/tool_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,76 @@ func TestToolsNodeOptions(t *testing.T) {
assert.Len(t, msgs, 1)
assert.JSONEq(t, `{"echo":"jack: 10"}`, msgs[0].Content)
})
t.Run("tool_list", func(t *testing.T) {

g := NewGraph[*schema.Message, []*schema.Message]()

mt := &mockTool{}

tn, err := NewToolNode(ctx, &ToolsNodeConfig{
Tools: []tool.BaseTool{},
})
assert.NoError(t, err)

err = g.AddToolsNode("tools", tn)
assert.NoError(t, err)

err = g.AddEdge(START, "tools")
assert.NoError(t, err)
err = g.AddEdge("tools", END)
assert.NoError(t, err)

r, err := g.Compile(ctx)
assert.NoError(t, err)

out, err := r.Invoke(ctx, &schema.Message{
Role: schema.Assistant,
ToolCalls: []schema.ToolCall{
{
ID: toolIDOfUserCompany,
Function: schema.FunctionCall{
Name: "mock_tool",
Arguments: `{"name": "jack"}`,
},
},
},
}, WithToolsNodeOption(WithToolListOption(mt), WithToolOption(WithAge(10))))
assert.NoError(t, err)
assert.Len(t, out, 1)
assert.JSONEq(t, `{"echo": "jack: 10"}`, out[0].Content)

outMessages := make([][]*schema.Message, 0)
outStream, err := r.Stream(ctx, &schema.Message{
Role: schema.Assistant,
ToolCalls: []schema.ToolCall{
{
ID: toolIDOfUserCompany,
Function: schema.FunctionCall{
Name: "mock_tool",
Arguments: `{"name": "jack"}`,
},
},
},
}, WithToolsNodeOption(WithToolListOption(mt), WithToolOption(WithAge(10))))

assert.NoError(t, err)

for {
msgs, err := outStream.Recv()
if err == io.EOF {
break
}
assert.NoError(t, err)
outMessages = append(outMessages, msgs)
}
outStream.Close()

msgs, err := internal.ConcatItems(outMessages)
assert.NoError(t, err)

assert.Len(t, msgs, 1)
assert.JSONEq(t, `{"echo":"jack: 10"}`, msgs[0].Content)
})

}

Expand Down

0 comments on commit 7830492

Please sign in to comment.