Skip to content

Commit 6ba1e1b

Browse files
authored
Merge branch 'main' into feat/upgrade-secretstore
2 parents fa75eba + 9f0aec6 commit 6ba1e1b

File tree

3 files changed

+93
-46
lines changed

3 files changed

+93
-46
lines changed

conversation/echo/echo.go

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,31 @@ func (e *Echo) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
5959
return
6060
}
6161

62-
// Converse returns the last message's content directly.
62+
// Converse returns one output per input message.
6363
func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conversation.Response, err error) {
64-
var content string
65-
var toolCalls []llms.ToolCall
64+
if r.Message == nil {
65+
return &conversation.Response{
66+
ConversationContext: r.ConversationContext,
67+
Outputs: []conversation.Result{},
68+
}, nil
69+
}
70+
71+
outputs := make([]conversation.Result, 0, len(*r.Message))
6672

67-
if r.Message != nil && len(*r.Message) > 0 {
68-
lastMessage := (*r.Message)[len(*r.Message)-1]
73+
for _, message := range *r.Message {
74+
var content string
75+
var toolCalls []llms.ToolCall
6976

70-
for _, part := range lastMessage.Parts {
77+
for i, part := range message.Parts {
7178
switch p := part.(type) {
7279
case llms.TextContent:
80+
// end with space if not the first part
81+
if i > 0 && content != "" {
82+
content += " "
83+
}
7384
content += p.Text
74-
case *llms.ToolCall:
75-
toolCalls = append(toolCalls, *p)
85+
case llms.ToolCall:
86+
toolCalls = append(toolCalls, p)
7687
case llms.ToolCallResponse:
7788
content = p.Content
7889
toolCalls = append(toolCalls, llms.ToolCall{
@@ -87,25 +98,25 @@ func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conv
8798
return nil, fmt.Errorf("found invalid content type as input for %v", p)
8899
}
89100
}
90-
}
91101

92-
choice := conversation.Choice{
93-
FinishReason: "stop",
94-
Index: 0,
95-
Message: conversation.Message{
96-
Content: content,
97-
},
98-
}
102+
choice := conversation.Choice{
103+
FinishReason: "stop",
104+
Index: 0,
105+
Message: conversation.Message{
106+
Content: content,
107+
},
108+
}
99109

100-
if len(toolCalls) > 0 {
101-
choice.Message.ToolCallRequest = &toolCalls
102-
}
110+
if len(toolCalls) > 0 {
111+
choice.Message.ToolCallRequest = &toolCalls
112+
}
103113

104-
outputs := []conversation.Result{
105-
{
114+
output := conversation.Result{
106115
StopReason: "stop",
107116
Choices: []conversation.Choice{choice},
108-
},
117+
}
118+
119+
outputs = append(outputs, output)
109120
}
110121

111122
res = &conversation.Response{

conversation/echo/echo_test.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ func TestConverse(t *testing.T) {
9090
},
9191
expected: &conversation.Response{
9292
Outputs: []conversation.Result{
93+
{
94+
StopReason: "stop",
95+
Choices: []conversation.Choice{
96+
{
97+
FinishReason: "stop",
98+
Index: 0,
99+
Message: conversation.Message{
100+
Content: "first message second message",
101+
},
102+
},
103+
},
104+
},
93105
{
94106
StopReason: "stop",
95107
Choices: []conversation.Choice{
@@ -133,7 +145,7 @@ func TestConverseAlpha2(t *testing.T) {
133145
{
134146
Role: llms.ChatMessageTypeAI,
135147
Parts: []llms.ContentPart{
136-
&llms.ToolCall{
148+
llms.ToolCall{
137149
ID: "myid",
138150
Type: "function",
139151
FunctionCall: &llms.FunctionCall{
@@ -218,7 +230,7 @@ func TestConverseAlpha2(t *testing.T) {
218230
Role: llms.ChatMessageTypeAI,
219231
Parts: []llms.ContentPart{
220232
llms.TextContent{Text: "text msg"},
221-
&llms.ToolCall{
233+
llms.ToolCall{
222234
ID: "myid",
223235
Type: "function",
224236
FunctionCall: &llms.FunctionCall{

tests/conformance/conversation/conversation.go

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,20 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
133133
resp, err := conv.Converse(ctx, req)
134134

135135
require.NoError(t, err)
136-
assert.Len(t, resp.Outputs, 1)
137-
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
138-
// anthropic responds with end_turn but other llm providers return with stop
139-
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason))
140-
assert.Empty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest)
136+
// Echo component returns one output per message, other components return one output
137+
if component == "echo" {
138+
assert.Len(t, resp.Outputs, 2)
139+
// Check the last output - system message
140+
assert.NotEmpty(t, resp.Outputs[1].Choices[0].Message.Content)
141+
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[1].StopReason))
142+
assert.Empty(t, resp.Outputs[1].Choices[0].Message.ToolCallRequest)
143+
} else {
144+
assert.Len(t, resp.Outputs, 1)
145+
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
146+
// anthropic responds with end_turn but other llm providers return with stop
147+
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason))
148+
assert.Empty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest)
149+
}
141150
})
142151
t.Run("test assistant message type", func(t *testing.T) {
143152
ctx, cancel := context.WithTimeout(t.Context(), 25*time.Second)
@@ -224,13 +233,26 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
224233
resp, err := conv.Converse(ctx, req)
225234

226235
require.NoError(t, err)
227-
assert.Len(t, resp.Outputs, 1)
228-
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
229-
// anthropic responds with end_turn but other llm providers return with stop
230-
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason))
231-
if resp.Outputs[0].Choices[0].Message.ToolCallRequest != nil && len(*resp.Outputs[0].Choices[0].Message.ToolCallRequest) > 0 {
232-
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest)
233-
require.JSONEq(t, `{"test": "value"}`, (*resp.Outputs[0].Choices[0].Message.ToolCallRequest)[0].FunctionCall.Arguments)
236+
// Echo component returns one output per message, other components return one output
237+
if component == "echo" {
238+
assert.Len(t, resp.Outputs, 4)
239+
// Check the last output - human message
240+
assert.NotEmpty(t, resp.Outputs[3].Choices[0].Message.Content)
241+
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[3].StopReason))
242+
// Check the tool call output - second output
243+
if resp.Outputs[1].Choices[0].Message.ToolCallRequest != nil && len(*resp.Outputs[1].Choices[0].Message.ToolCallRequest) > 0 {
244+
assert.NotEmpty(t, resp.Outputs[1].Choices[0].Message.ToolCallRequest)
245+
require.JSONEq(t, `{"test": "value"}`, (*resp.Outputs[1].Choices[0].Message.ToolCallRequest)[0].FunctionCall.Arguments)
246+
}
247+
} else {
248+
assert.Len(t, resp.Outputs, 1)
249+
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.Content)
250+
// anthropic responds with end_turn but other llm providers return with stop
251+
assert.True(t, slices.Contains([]string{"stop", "end_turn"}, resp.Outputs[0].StopReason))
252+
if resp.Outputs[0].Choices[0].Message.ToolCallRequest != nil && len(*resp.Outputs[0].Choices[0].Message.ToolCallRequest) > 0 {
253+
assert.NotEmpty(t, resp.Outputs[0].Choices[0].Message.ToolCallRequest)
254+
require.JSONEq(t, `{"test": "value"}`, (*resp.Outputs[0].Choices[0].Message.ToolCallRequest)[0].FunctionCall.Arguments)
255+
}
234256
}
235257
})
236258

@@ -306,14 +328,14 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
306328
if resp.Outputs[0].Choices[0].Message.ToolCallRequest != nil {
307329
assert.GreaterOrEqual(t, len(*resp.Outputs[0].Choices[0].Message.ToolCallRequest), 1)
308330

309-
var toolCall *llms.ToolCall
331+
var toolCall llms.ToolCall
310332
for i := range *resp.Outputs[0].Choices[0].Message.ToolCallRequest {
311333
if (*resp.Outputs[0].Choices[0].Message.ToolCallRequest)[i].FunctionCall.Name == "get_project_name" {
312-
toolCall = &(*resp.Outputs[0].Choices[0].Message.ToolCallRequest)[i]
334+
toolCall = (*resp.Outputs[0].Choices[0].Message.ToolCallRequest)[i]
313335
break
314336
}
315337
}
316-
require.NotNil(t, toolCall)
338+
require.NotEmpty(t, toolCall.ID)
317339
assert.Equal(t, "get_project_name", toolCall.FunctionCall.Name)
318340
assert.Contains(t, toolCall.FunctionCall.Arguments, "repo_link")
319341

@@ -348,7 +370,7 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
348370
responseMessages = append(responseMessages,
349371
llms.MessageContent{
350372
Role: llms.ChatMessageTypeAI,
351-
Parts: []llms.ContentPart{mistral.CreateToolCallPart(toolCall)},
373+
Parts: []llms.ContentPart{mistral.CreateToolCallPart(&toolCall)},
352374
},
353375
mistral.CreateToolResponseMessage(toolResponse),
354376
)
@@ -413,24 +435,26 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
413435
require.NoError(t, err)
414436

415437
// handle potentially multiple outputs from different llm providers
416-
var toolCall *llms.ToolCall
438+
var toolCall llms.ToolCall
439+
found := false
417440
for _, output := range resp1.Outputs {
418441
if output.Choices[0].Message.ToolCallRequest != nil {
419442
// find the tool call with the expected function name
420443
for i := range *output.Choices[0].Message.ToolCallRequest {
421444
if (*output.Choices[0].Message.ToolCallRequest)[i].FunctionCall.Name == "retrieve_payment_status" {
422-
toolCall = &(*output.Choices[0].Message.ToolCallRequest)[i]
445+
toolCall = (*output.Choices[0].Message.ToolCallRequest)[i]
446+
found = true
423447
break
424448
}
425449
}
426-
if toolCall != nil {
450+
if found {
427451
break
428452
}
429453
}
430454
}
431455

432456
// check if we got a tool call request
433-
if toolCall != nil {
457+
if found {
434458
assert.Equal(t, "retrieve_payment_status", toolCall.FunctionCall.Name)
435459
assert.Contains(t, toolCall.FunctionCall.Arguments, "T1001")
436460

@@ -451,7 +475,7 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
451475
},
452476
{
453477
Role: llms.ChatMessageTypeAI,
454-
Parts: []llms.ContentPart{*toolCall},
478+
Parts: []llms.ContentPart{toolCall},
455479
},
456480
{
457481
Role: llms.ChatMessageTypeTool,
@@ -470,7 +494,7 @@ func ConformanceTests(t *testing.T, props map[string]string, conv conversation.C
470494
},
471495
{
472496
Role: llms.ChatMessageTypeAI,
473-
Parts: []llms.ContentPart{mistral.CreateToolCallPart(toolCall)},
497+
Parts: []llms.ContentPart{mistral.CreateToolCallPart(&toolCall)},
474498
},
475499
mistral.CreateToolResponseMessage(toolResponse),
476500
}

0 commit comments

Comments
 (0)