diff --git a/adapter/vercelaisdk/uimessagestream/handler.go b/adapter/vercelaisdk/uimessagestream/handler.go index b207c81..6cdfc00 100644 --- a/adapter/vercelaisdk/uimessagestream/handler.go +++ b/adapter/vercelaisdk/uimessagestream/handler.go @@ -14,6 +14,12 @@ // packages/ai/src/generate-text/stream-text.ts and the SSE framing from // packages/ai/src/ui-message-stream/json-to-sse-transform-stream.ts. // +// Known limitations: +// - Inbound tool call history from multi-turn conversations is not yet +// reconstructed; only text parts are forwarded to the model. +// - The handler calls model.GenerateEvents directly; interceptor plugins +// (retry, OTel) must be wired at the model level. +// // Reference: https://github.com/vercel/ai package uimessagestream @@ -36,7 +42,11 @@ import ( // protocol. It accepts POST requests with a JSON body containing messages // and streams back SSE events compatible with useChat. func Handler(model llm.Model, opts ...Option) http.Handler { - cfg := &config{logger: slog.Default()} + cfg := &config{ + logger: slog.Default(), + maxBodyBytes: 1 << 20, // 1MB + maxTurns: 10, + } for _, o := range opts { o(cfg) } @@ -52,10 +62,12 @@ type Option func(*config) type ToolExecutor func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) type config struct { - system string - logger *slog.Logger - tools []llm.ToolDefinition - executor ToolExecutor + system string + logger *slog.Logger + tools []llm.ToolDefinition + executor ToolExecutor + maxBodyBytes int64 + maxTurns int } // WithSystem sets the system prompt prepended to every request. @@ -79,6 +91,16 @@ func WithTools(tools []llm.ToolDefinition, executor ToolExecutor) Option { } } +// WithMaxBodyBytes sets the maximum request body size in bytes. Default is 1MB. +func WithMaxBodyBytes(n int64) Option { + return func(c *config) { c.maxBodyBytes = n } +} + +// WithMaxTurns sets the maximum number of agentic tool-calling turns. Default is 10. +func WithMaxTurns(n int) Option { + return func(c *config) { c.maxTurns = n } +} + type handler struct { model llm.Model cfg *config @@ -88,7 +110,6 @@ type handler struct { type chatRequest struct { ID string `json:"id"` Messages []chatMessage `json:"messages"` - Trigger string `json:"trigger"` } type chatMessage struct { @@ -127,8 +148,8 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Limit request body to 1MB to prevent abuse. - r.Body = http.MaxBytesReader(w, r.Body, 1<<20) + // Limit request body size to prevent abuse. + r.Body = http.MaxBytesReader(w, r.Body, h.cfg.maxBodyBytes) var body chatRequest if err := json.NewDecoder(r.Body).Decode(&body); err != nil { @@ -158,14 +179,14 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { setSSEHeaders(w) ew := &EventWriter{w: w, f: flusher} - StreamModelWithTools(r.Context(), h.model, req, ew, h.cfg.logger, h.cfg.executor) + StreamModelWithTools(r.Context(), h.model, req, ew, h.cfg.logger, h.cfg.executor, h.cfg.maxTurns) } // generateMessageID creates a random 16-character hex ID for use as a messageId. func generateMessageID() string { b := make([]byte, 8) if _, err := rand.Read(b); err != nil { - return "msg-0000000000000000" + return "0000000000000000" } return hex.EncodeToString(b) @@ -182,16 +203,19 @@ const ( // complexity of StreamModel and StreamModelWithTools. type streamWriter struct { ew *EventWriter + logger *slog.Logger textID string reasoningID string textStarted bool reasoningStarted bool textCounter int + reasoningCounter int } -func newStreamWriter(ew *EventWriter) *streamWriter { +func newStreamWriter(ew *EventWriter, logger *slog.Logger) *streamWriter { return &streamWriter{ ew: ew, + logger: logger, textID: "text-0", reasoningID: "reasoning-0", } @@ -207,20 +231,8 @@ func (sw *streamWriter) endReasoning() error { } sw.reasoningStarted = false - - return nil -} - -func (sw *streamWriter) endText() error { - if !sw.textStarted { - return nil - } - - if err := sw.ew.WriteChunk(Chunk{"type": "text-end", "id": sw.textID}); err != nil { - return err - } - - sw.textStarted = false + sw.reasoningCounter++ + sw.reasoningID = fmt.Sprintf("reasoning-%d", sw.reasoningCounter) return nil } @@ -258,6 +270,10 @@ func (sw *streamWriter) writeTextDelta(text string) error { } func (sw *streamWriter) writeReasoningDelta(trace *llm.ReasoningTrace) error { + if err := sw.endTextAndAdvance(); err != nil { + return err + } + if !sw.reasoningStarted { if err := sw.ew.WriteChunk(Chunk{"type": "reasoning-start", "id": sw.reasoningID}); err != nil { return err @@ -286,7 +302,9 @@ func (sw *streamWriter) writeToolRequest(tr *llm.ToolRequest) error { var input any if len(tr.Arguments) > 0 { - _ = json.Unmarshal(tr.Arguments, &input) + if err := json.Unmarshal(tr.Arguments, &input); err != nil { + sw.logger.Warn("failed to unmarshal tool input", "toolCallId", tr.ID, "error", err) + } } return sw.ew.WriteChunk(Chunk{ @@ -307,7 +325,9 @@ func (sw *streamWriter) writeToolResponse(tr *llm.ToolResponse) error { var output any if len(tr.Result) > 0 { - _ = json.Unmarshal(tr.Result, &output) + if err := json.Unmarshal(tr.Result, &output); err != nil { + sw.logger.Warn("failed to unmarshal tool output", "toolCallId", tr.ID, "error", err) + } } return sw.ew.WriteChunk(Chunk{ @@ -321,12 +341,24 @@ func (sw *streamWriter) handleContentPart(part *llm.Part) error { case llm.PartText: return sw.writeTextDelta(part.Text) case llm.PartToolRequest: - if err := sw.endText(); err != nil { + if err := sw.endReasoning(); err != nil { + return err + } + + if err := sw.endTextAndAdvance(); err != nil { return err } return sw.writeToolRequest(part.ToolRequest) case llm.PartToolResponse: + if err := sw.endReasoning(); err != nil { + return err + } + + if err := sw.endTextAndAdvance(); err != nil { + return err + } + return sw.writeToolResponse(part.ToolResponse) case llm.PartReasoning: return sw.writeReasoningDelta(part.ReasoningTrace) @@ -340,12 +372,15 @@ func (sw *streamWriter) closeSpans() { if sw.reasoningStarted { _ = sw.ew.WriteChunk(Chunk{"type": "reasoning-end", "id": sw.reasoningID}) sw.reasoningStarted = false + sw.reasoningCounter++ + sw.reasoningID = fmt.Sprintf("reasoning-%d", sw.reasoningCounter) } if sw.textStarted { _ = sw.ew.WriteChunk(Chunk{"type": "text-end", "id": sw.textID}) sw.textStarted = false sw.textCounter++ + sw.textID = fmt.Sprintf("text-%d", sw.textCounter) } } @@ -388,7 +423,7 @@ func StreamModel(ctx context.Context, model llm.Model, req *llm.Request, ew *Eve return } - sw := newStreamWriter(ew) + sw := newStreamWriter(ew, logger) for event, err := range model.GenerateEvents(ctx, req) { if err != nil { @@ -398,6 +433,8 @@ func StreamModel(ctx context.Context, model llm.Model, req *llm.Request, ew *Eve logger.Error("stream error", "error", err) + sw.closeSpans() + _ = ew.WriteChunk(Chunk{"type": "error", "errorText": "An error occurred"}) _ = ew.WriteChunk(Chunk{"type": "finish-step"}) _ = ew.WriteChunk(Chunk{"type": "finish", "finishReason": finishReasonError}) @@ -419,7 +456,7 @@ func StreamModel(ctx context.Context, model llm.Model, req *llm.Request, ew *Eve } case llm.StreamResetEvent: - if err := sw.endText(); err != nil { + if err := sw.endTextAndAdvance(); err != nil { return } @@ -438,8 +475,9 @@ func StreamModel(ctx context.Context, model llm.Model, req *llm.Request, ew *Eve // StreamModelWithTools is like StreamModel but supports agentic tool calling. // When the model returns tool calls, the executor is invoked for each, results // are streamed to the client, and the model is called again with the results -// appended to the conversation. This loops until the model stops calling tools. -func StreamModelWithTools(ctx context.Context, model llm.Model, req *llm.Request, ew *EventWriter, logger *slog.Logger, executor ToolExecutor) { +// appended to the conversation. This loops until the model stops calling tools +// or maxTurns is reached. If maxTurns is 0, it defaults to 10. +func StreamModelWithTools(ctx context.Context, model llm.Model, req *llm.Request, ew *EventWriter, logger *slog.Logger, executor ToolExecutor, maxTurns int) { if executor == nil { StreamModel(ctx, model, req, ew, logger) return @@ -449,19 +487,26 @@ func StreamModelWithTools(ctx context.Context, model llm.Model, req *llm.Request logger = slog.Default() } + if maxTurns <= 0 { + maxTurns = 10 + } + messageID := generateMessageID() if err := ew.WriteChunk(Chunk{"type": "start", "messageId": messageID}); err != nil { return } messages := slices.Clone(req.Messages) - sw := newStreamWriter(ew) - - const maxTurns = 10 + sw := newStreamWriter(ew, logger) for range maxTurns { finishReason, toolRequests := streamToolTurn(ctx, model, req, messages, sw, ew, logger) + // Empty finishReason means the stream was aborted (ctx cancel or write failure). + if finishReason == "" { + return + } + if len(toolRequests) == 0 || finishReason != "tool-calls" { _ = ew.WriteChunk(Chunk{"type": "finish", "finishReason": finishReason}) _ = ew.WriteDone() @@ -476,7 +521,7 @@ func StreamModelWithTools(ctx context.Context, model llm.Model, req *llm.Request messages = append(messages, llm.Message{Role: llm.RoleAssistant, Content: assistantParts}) - if err := executeTools(ctx, toolRequests, &messages, ew, executor); err != nil { + if err := executeTools(ctx, toolRequests, &messages, ew, logger, executor); err != nil { return } } @@ -499,21 +544,14 @@ func streamToolTurn( return "", nil } - sw.textID = fmt.Sprintf("text-%d", sw.textCounter) - sw.textStarted = false - sw.reasoningStarted = false - var toolRequests []*llm.ToolRequest - iterReq := &llm.Request{ - Messages: messages, - Tools: req.Tools, - ToolChoice: req.ToolChoice, - } + iterReq := *req + iterReq.Messages = messages var finishReason string - for event, err := range model.GenerateEvents(ctx, iterReq) { + for event, err := range model.GenerateEvents(ctx, &iterReq) { if err != nil { if ctx.Err() != nil { return "", nil @@ -521,10 +559,17 @@ func streamToolTurn( logger.Error("stream error", "error", err) + sw.closeSpans() + + for _, tr := range toolRequests { + _ = ew.WriteChunk(Chunk{ + "type": "tool-output-error", "toolCallId": tr.ID, + "errorText": "stream error; tool call discarded", + }) + } + _ = ew.WriteChunk(Chunk{"type": "error", "errorText": "An error occurred"}) _ = ew.WriteChunk(Chunk{"type": "finish-step"}) - _ = ew.WriteChunk(Chunk{"type": "finish", "finishReason": finishReasonError}) - _ = ew.WriteDone() return finishReasonError, nil } @@ -535,8 +580,44 @@ func streamToolTurn( return "", nil } + case llm.ErrorEvent: + logger.Warn("recoverable LLM error", "message", e.Message) + + if err := ew.WriteChunk(Chunk{"type": "error", "errorText": "An error occurred"}); err != nil { + return "", nil + } + + case llm.StreamResetEvent: + if err := sw.endTextAndAdvance(); err != nil { + return "", nil + } + + if err := sw.endReasoning(); err != nil { + return "", nil + } + + for _, tr := range toolRequests { + _ = ew.WriteChunk(Chunk{ + "type": "tool-output-error", "toolCallId": tr.ID, + "errorText": "stream reset; tool call discarded", + }) + } + + toolRequests = nil + case llm.StreamEndEvent: finishReason = writeToolTurnEnd(e, sw, ew, logger) + + if e.Error != nil { + for _, tr := range toolRequests { + _ = ew.WriteChunk(Chunk{ + "type": "tool-output-error", "toolCallId": tr.ID, + "errorText": "stream error; tool call discarded", + }) + } + + toolRequests = nil + } } } @@ -601,7 +682,7 @@ func writeToolTurnEnd(e llm.StreamEndEvent, sw *streamWriter, ew *EventWriter, l return reason } -func executeTools(ctx context.Context, toolRequests []*llm.ToolRequest, messages *[]llm.Message, ew *EventWriter, executor ToolExecutor) error { +func executeTools(ctx context.Context, toolRequests []*llm.ToolRequest, messages *[]llm.Message, ew *EventWriter, logger *slog.Logger, executor ToolExecutor) error { toolResponseParts := make([]*llm.Part, 0, len(toolRequests)) for _, tr := range toolRequests { @@ -620,7 +701,9 @@ func executeTools(ctx context.Context, toolRequests []*llm.ToolRequest, messages var output any if len(result) > 0 { - _ = json.Unmarshal(result, &output) + if err := json.Unmarshal(result, &output); err != nil { + logger.Warn("failed to unmarshal tool result", "toolCallId", tr.ID, "error", err) + } } if err := ew.WriteChunk(Chunk{ diff --git a/adapter/vercelaisdk/uimessagestream/handler_test.go b/adapter/vercelaisdk/uimessagestream/handler_test.go index 19c9ca9..8dd9327 100644 --- a/adapter/vercelaisdk/uimessagestream/handler_test.go +++ b/adapter/vercelaisdk/uimessagestream/handler_test.go @@ -13,22 +13,26 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/redpanda-data/ai-sdk-go/llm" "github.com/redpanda-data/ai-sdk-go/llm/fakellm" ) const ( - typeStart = "start" - typeStartStep = "start-step" - typeTextStart = "text-start" - typeTextDelta = "text-delta" - typeTextEnd = "text-end" - typeFinishStep = "finish-step" - typeFinish = "finish" - typeError = "error" - typeReasoningStart = "reasoning-start" - typeReasoningDelta = "reasoning-delta" - typeReasoningEnd = "reasoning-end" + typeStart = "start" + typeStartStep = "start-step" + typeTextStart = "text-start" + typeTextDelta = "text-delta" + typeTextEnd = "text-end" + typeFinishStep = "finish-step" + typeFinish = "finish" + typeError = "error" + typeReasoningStart = "reasoning-start" + typeReasoningDelta = "reasoning-delta" + typeReasoningEnd = "reasoning-end" + typeToolOutputError = "tool-output-error" sanitizedError = "An error occurred" ) @@ -116,44 +120,30 @@ func TestHandler_SimpleTextResponse(t *testing.T) { h.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) - } - - if ct := rec.Header().Get("Content-Type"); ct != "text/event-stream" { - t.Errorf("Content-Type = %q, want text/event-stream", ct) - } - - if v := rec.Header().Get("X-Vercel-Ai-Ui-Message-Stream"); v != "v1" { - t.Errorf("X-Vercel-Ai-Ui-Message-Stream = %q, want v1", v) - } + require.Equal(t, http.StatusOK, rec.Code, "response body: %s", rec.Body.String()) + assert.Equal(t, "text/event-stream", rec.Header().Get("Content-Type")) + assert.Equal(t, "v1", rec.Header().Get("X-Vercel-Ai-Ui-Message-Stream")) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) expectedPrefix := []string{typeStart, typeStartStep, typeTextStart} for i, exp := range expectedPrefix { - if i >= len(types) || types[i] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall types: %v", i, types[i], exp, types) - } + require.Greater(t, len(types), i, "not enough chunks, all types: %v", types) + require.Equal(t, exp, types[i], "chunk[%d] type mismatch, all types: %v", i, types) } for i := 3; i < len(types)-3; i++ { - if types[i] != typeTextDelta { - t.Errorf("chunk[%d] type = %q, want text-delta", i, types[i]) - } + assert.Equal(t, typeTextDelta, types[i], "chunk[%d] type mismatch", i) } expectedSuffix := []string{typeTextEnd, typeFinishStep, typeFinish} for i, exp := range expectedSuffix { idx := len(types) - 3 + i - if idx < 0 || idx >= len(types) || types[idx] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall types: %v", idx, types[idx], exp, types) - } + require.True(t, idx >= 0 && idx < len(types), "suffix index %d out of range, all types: %v", idx, types) + require.Equal(t, exp, types[idx], "chunk[%d] type mismatch, all types: %v", idx, types) } var text strings.Builder @@ -164,9 +154,7 @@ func TestHandler_SimpleTextResponse(t *testing.T) { } } - if got := text.String(); got != "Hello, world!" { - t.Errorf("assembled text = %q, want %q", got, "Hello, world!") - } + assert.Equal(t, "Hello, world!", text.String(), "assembled text mismatch") var startID, endID string @@ -180,15 +168,11 @@ func TestHandler_SimpleTextResponse(t *testing.T) { } } - if startID != endID { - t.Errorf("text-start id = %q, text-end id = %q, want match", startID, endID) - } + assert.Equal(t, startID, endID, "text-start id and text-end id should match") for _, c := range chunks { if c["type"] == typeFinish { - if reason := chunkStr(t, c, "finishReason"); reason != finishReasonStop { - t.Errorf("finish.finishReason = %v, want 'stop'", c["finishReason"]) - } + assert.Equal(t, finishReasonStop, chunkStr(t, c, "finishReason")) } } } @@ -210,9 +194,7 @@ func TestHandler_StreamingTextResponse(t *testing.T) { h.ServeHTTP(rec, req) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") var text strings.Builder @@ -226,13 +208,8 @@ func TestHandler_StreamingTextResponse(t *testing.T) { } } - if got := text.String(); got != "Streaming works!" { - t.Errorf("assembled text = %q, want %q", got, "Streaming works!") - } - - if deltaCount < 2 { - t.Errorf("expected multiple text-delta chunks for streaming, got %d", deltaCount) - } + assert.Equal(t, "Streaming works!", text.String(), "assembled text mismatch") + assert.GreaterOrEqual(t, deltaCount, 2, "expected multiple text-delta chunks for streaming") } func TestHandler_ErrorResponse(t *testing.T) { @@ -252,30 +229,15 @@ func TestHandler_ErrorResponse(t *testing.T) { h.ServeHTTP(rec, req) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) expected := []string{typeStart, typeStartStep, typeError, typeFinishStep, typeFinish} - if len(types) != len(expected) { - t.Fatalf("chunk types = %v, want %v", types, expected) - } - - for i, exp := range expected { - if types[i] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall types: %v", i, types[i], exp, types) - } - } - - if et := chunkStr(t, chunks[2], "errorText"); et != sanitizedError { - t.Errorf("error chunk errorText = %q, want sanitized %q", et, sanitizedError) - } + require.Equal(t, expected, types, "chunk types mismatch") - if reason := chunkStr(t, chunks[4], "finishReason"); reason != finishReasonError { - t.Errorf("finish.finishReason = %v, want %q", chunks[4]["finishReason"], finishReasonError) - } + assert.Equal(t, sanitizedError, chunkStr(t, chunks[2], "errorText")) + assert.Equal(t, finishReasonError, chunkStr(t, chunks[4], "finishReason")) } func TestHandler_SystemPrompt(t *testing.T) { @@ -303,17 +265,9 @@ func TestHandler_SystemPrompt(t *testing.T) { h.ServeHTTP(rec, req) - if len(capturedMessages) != 2 { - t.Fatalf("expected 2 messages (system + user), got %d", len(capturedMessages)) - } - - if capturedMessages[0].Role != llm.RoleSystem { - t.Errorf("first message role = %q, want system", capturedMessages[0].Role) - } - - if capturedMessages[0].Content[0].Text != "Be concise." { - t.Errorf("system prompt = %q, want 'Be concise.'", capturedMessages[0].Content[0].Text) - } + require.Len(t, capturedMessages, 2, "expected system + user messages") + assert.Equal(t, llm.RoleSystem, capturedMessages[0].Role) + assert.Equal(t, "Be concise.", capturedMessages[0].Content[0].Text) } func TestHandler_MultiTurnConversation(t *testing.T) { @@ -345,21 +299,10 @@ func TestHandler_MultiTurnConversation(t *testing.T) { h.ServeHTTP(rec, req) - if len(capturedMessages) != 3 { - t.Fatalf("expected 3 messages, got %d", len(capturedMessages)) - } - - if capturedMessages[0].Role != llm.RoleUser { - t.Errorf("msg[0] role = %q, want user", capturedMessages[0].Role) - } - - if capturedMessages[1].Role != llm.RoleAssistant { - t.Errorf("msg[1] role = %q, want assistant", capturedMessages[1].Role) - } - - if capturedMessages[2].Role != llm.RoleUser { - t.Errorf("msg[2] role = %q, want user", capturedMessages[2].Role) - } + require.Len(t, capturedMessages, 3) + assert.Equal(t, llm.RoleUser, capturedMessages[0].Role) + assert.Equal(t, llm.RoleAssistant, capturedMessages[1].Role) + assert.Equal(t, llm.RoleUser, capturedMessages[2].Role) } func TestHandler_MethodNotAllowed(t *testing.T) { @@ -373,9 +316,7 @@ func TestHandler_MethodNotAllowed(t *testing.T) { rec := httptest.NewRecorder() h.ServeHTTP(rec, req) - if rec.Code != http.StatusMethodNotAllowed { - t.Errorf("expected 405, got %d", rec.Code) - } + assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) } func TestHandler_InvalidBody(t *testing.T) { @@ -390,9 +331,7 @@ func TestHandler_InvalidBody(t *testing.T) { rec := httptest.NewRecorder() h.ServeHTTP(rec, req) - if rec.Code != http.StatusBadRequest { - t.Errorf("expected 400, got %d", rec.Code) - } + assert.Equal(t, http.StatusBadRequest, rec.Code) } func TestHandler_V6PartsFormat(t *testing.T) { @@ -440,21 +379,10 @@ func TestHandler_V6PartsFormat(t *testing.T) { h.ServeHTTP(rec, req) - if len(capturedMessages) != 3 { - t.Fatalf("expected 3 messages, got %d", len(capturedMessages)) - } - - if capturedMessages[0].Content[0].Text != "hello from v6" { - t.Errorf("msg[0] text = %q, want 'hello from v6'", capturedMessages[0].Content[0].Text) - } - - if capturedMessages[1].Content[0].Text != "hi there" { - t.Errorf("msg[1] text = %q, want 'hi there'", capturedMessages[1].Content[0].Text) - } - - if capturedMessages[2].Content[0].Text != "follow up" { - t.Errorf("msg[2] text = %q, want 'follow up'", capturedMessages[2].Content[0].Text) - } + require.Len(t, capturedMessages, 3) + assert.Equal(t, "hello from v6", capturedMessages[0].Content[0].Text) + assert.Equal(t, "hi there", capturedMessages[1].Content[0].Text) + assert.Equal(t, "follow up", capturedMessages[2].Content[0].Text) } // errorStreamModel is a minimal llm.Model that yields specific event sequences @@ -504,35 +432,21 @@ func TestStreamModel_StreamEndEventWithError(t *testing.T) { }, ew, nil) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) expected := []string{typeStart, typeStartStep, typeTextStart, typeTextDelta, typeError, typeTextEnd, typeFinishStep, typeFinish} - if len(types) != len(expected) { - t.Fatalf("chunk types = %v, want %v", types, expected) - } - - for i, exp := range expected { - if types[i] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall types: %v", i, types[i], exp, types) - } - } + require.Equal(t, expected, types, "chunk types mismatch") for _, c := range chunks { if c["type"] == typeError { - if et := chunkStr(t, c, "errorText"); et != sanitizedError { - t.Errorf("error.errorText = %v, want %q", c["errorText"], sanitizedError) - } + assert.Equal(t, sanitizedError, chunkStr(t, c, "errorText")) } } finishChunk := chunks[len(chunks)-1] - if reason := chunkStr(t, finishChunk, "finishReason"); reason != finishReasonError { - t.Errorf("finish.finishReason = %v, want %q", finishChunk["finishReason"], finishReasonError) - } + assert.Equal(t, finishReasonError, chunkStr(t, finishChunk, "finishReason")) } func TestHandler_FinishReasonMapping(t *testing.T) { @@ -555,9 +469,7 @@ func TestHandler_FinishReasonMapping(t *testing.T) { t.Run(tt.want, func(t *testing.T) { t.Parallel() - if got := mapFinishReason(tt.reason); got != tt.want { - t.Errorf("mapFinishReason(%q) = %q, want %q", tt.reason, got, tt.want) - } + assert.Equal(t, tt.want, mapFinishReason(tt.reason)) }) } } @@ -579,9 +491,7 @@ func TestHandler_WithLogger(t *testing.T) { h.ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", rec.Code) - } + require.Equal(t, http.StatusOK, rec.Code) } func TestHandler_SystemRoleMessage(t *testing.T) { @@ -609,13 +519,8 @@ func TestHandler_SystemRoleMessage(t *testing.T) { h.ServeHTTP(rec, req) - if len(capturedMessages) != 2 { - t.Fatalf("expected 2 messages, got %d", len(capturedMessages)) - } - - if capturedMessages[0].Role != llm.RoleSystem { - t.Errorf("msg[0] role = %q, want system", capturedMessages[0].Role) - } + require.Len(t, capturedMessages, 2) + assert.Equal(t, llm.RoleSystem, capturedMessages[0].Role) } func TestHandler_EmptyMessagesSkipped(t *testing.T) { @@ -646,13 +551,8 @@ func TestHandler_EmptyMessagesSkipped(t *testing.T) { h.ServeHTTP(rec, req) - if len(capturedMessages) != 1 { - t.Fatalf("expected 1 message (empty skipped), got %d", len(capturedMessages)) - } - - if capturedMessages[0].Content[0].Text != "real message" { - t.Errorf("msg text = %q, want 'real message'", capturedMessages[0].Content[0].Text) - } + require.Len(t, capturedMessages, 1, "empty message should be skipped") + assert.Equal(t, "real message", capturedMessages[0].Content[0].Text) } func TestHandler_ContextCancellation(t *testing.T) { @@ -673,14 +573,12 @@ func TestHandler_ContextCancellation(t *testing.T) { chunks, _ := parseSSE(t, rec.Body) types := chunkTypes(chunks) - if len(types) < 2 || types[0] != typeStart || types[1] != typeStartStep { - t.Fatalf("expected at least start + start-step, got %v", types) - } + require.GreaterOrEqual(t, len(types), 2, "expected at least start + start-step, got %v", types) + require.Equal(t, typeStart, types[0]) + require.Equal(t, typeStartStep, types[1]) for _, tp := range types { - if tp == typeFinish { - t.Error("should not have finish chunk when context is cancelled") - } + assert.NotEqual(t, typeFinish, tp, "should not have finish chunk when context is cancelled") } } @@ -704,28 +602,16 @@ func TestHandler_ErrorEventNonTerminal(t *testing.T) { }, ew, nil) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) expected := []string{typeStart, typeStartStep, typeTextStart, typeTextDelta, typeError, typeTextDelta, typeTextEnd, typeFinishStep, typeFinish} - if len(types) != len(expected) { - t.Fatalf("chunk types = %v, want %v", types, expected) - } - - for i, exp := range expected { - if types[i] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall: %v", i, types[i], exp, types) - } - } + require.Equal(t, expected, types, "chunk types mismatch") for _, c := range chunks { if c["type"] == typeError { - if et := chunkStr(t, c, "errorText"); et != sanitizedError { - t.Errorf("errorText = %q, want %q (sanitized)", et, sanitizedError) - } + assert.Equal(t, sanitizedError, chunkStr(t, c, "errorText"), "error text should be sanitized") } } @@ -737,9 +623,7 @@ func TestHandler_ErrorEventNonTerminal(t *testing.T) { } } - if got := text.String(); got != "before after" { - t.Errorf("assembled text = %q, want 'before after'", got) - } + assert.Equal(t, "before after", text.String(), "assembled text mismatch") } func TestHandler_TextContentPrefersPartsOverContent(t *testing.T) { @@ -751,9 +635,7 @@ func TestHandler_TextContentPrefersPartsOverContent(t *testing.T) { Parts: []messagePart{{Type: "text", Text: "parts content"}}, } - if got := msg.textContent(); got != "parts content" { - t.Errorf("textContent() = %q, want 'parts content'", got) - } + assert.Equal(t, "parts content", msg.textContent()) } func TestHandler_TextContentFallsBackToContent(t *testing.T) { @@ -764,9 +646,7 @@ func TestHandler_TextContentFallsBackToContent(t *testing.T) { Content: "legacy content", } - if got := msg.textContent(); got != "legacy content" { - t.Errorf("textContent() = %q, want 'legacy content'", got) - } + assert.Equal(t, "legacy content", msg.textContent()) } func TestStreamModel_ReasoningTrace(t *testing.T) { @@ -788,28 +668,16 @@ func TestStreamModel_ReasoningTrace(t *testing.T) { }, ew, nil) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) expected := []string{typeStart, typeStartStep, typeReasoningStart, typeReasoningDelta, typeReasoningEnd, typeTextStart, typeTextDelta, typeTextEnd, typeFinishStep, typeFinish} - if len(types) != len(expected) { - t.Fatalf("chunk types = %v, want %v", types, expected) - } - - for i, exp := range expected { - if types[i] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall: %v", i, types[i], exp, types) - } - } + require.Equal(t, expected, types, "chunk types mismatch") for _, c := range chunks { if c["type"] == typeReasoningDelta { - if d := chunkStr(t, c, "delta"); d != "thinking..." { - t.Errorf("reasoning-delta = %q, want 'thinking...'", d) - } + assert.Equal(t, "thinking...", chunkStr(t, c, "delta")) } } } @@ -835,9 +703,7 @@ func TestStreamModel_ReasoningStatefulTracking(t *testing.T) { }, ew, nil) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) @@ -847,15 +713,7 @@ func TestStreamModel_ReasoningStatefulTracking(t *testing.T) { typeTextStart, typeTextDelta, typeTextEnd, typeFinishStep, typeFinish, } - if len(types) != len(expected) { - t.Fatalf("chunk types = %v, want %v", types, expected) - } - - for i, exp := range expected { - if types[i] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall: %v", i, types[i], exp, types) - } - } + require.Equal(t, expected, types, "chunk types mismatch") var rStartID, rEndID string @@ -872,14 +730,10 @@ func TestStreamModel_ReasoningStatefulTracking(t *testing.T) { } } - if rStartID != rEndID { - t.Errorf("reasoning-start id = %q, reasoning-end id = %q, want match", rStartID, rEndID) - } + assert.Equal(t, rStartID, rEndID, "reasoning-start id and reasoning-end id should match") for i, id := range deltaIDs { - if id != rStartID { - t.Errorf("reasoning-delta[%d] id = %q, want %q", i, id, rStartID) - } + assert.Equal(t, rStartID, id, "reasoning-delta[%d] id mismatch", i) } var reasoning strings.Builder @@ -890,9 +744,7 @@ func TestStreamModel_ReasoningStatefulTracking(t *testing.T) { } } - if got := reasoning.String(); got != "step 1 step 2 step 3" { - t.Errorf("assembled reasoning = %q, want 'step 1 step 2 step 3'", got) - } + assert.Equal(t, "step 1 step 2 step 3", reasoning.String(), "assembled reasoning mismatch") } func TestStreamModel_ReasoningNilTrace(t *testing.T) { @@ -913,16 +765,12 @@ func TestStreamModel_ReasoningNilTrace(t *testing.T) { }, ew, nil) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) expected := []string{typeStart, typeStartStep, typeReasoningStart, typeReasoningEnd, typeFinishStep, typeFinish} - if len(types) != len(expected) { - t.Fatalf("chunk types = %v, want %v", types, expected) - } + require.Equal(t, expected, types, "chunk types mismatch") } type noFlushResponseWriter struct { @@ -943,9 +791,7 @@ func TestHandler_NoFlusherSupport(t *testing.T) { h.ServeHTTP(rec, req) - if inner.Code != http.StatusInternalServerError { - t.Errorf("expected 500, got %d", inner.Code) - } + assert.Equal(t, http.StatusInternalServerError, inner.Code) } func TestStreamModel_IteratorErrorWithCancelledContext(t *testing.T) { @@ -966,14 +812,12 @@ func TestStreamModel_IteratorErrorWithCancelledContext(t *testing.T) { chunks, _ := parseSSE(t, rec.Body) types := chunkTypes(chunks) - if len(types) < 2 || types[0] != typeStart || types[1] != typeStartStep { - t.Fatalf("expected at least start + start-step, got %v", types) - } + require.GreaterOrEqual(t, len(types), 2, "expected at least start + start-step, got %v", types) + require.Equal(t, typeStart, types[0]) + require.Equal(t, typeStartStep, types[1]) for _, tp := range types { - if tp == typeFinish { - t.Error("should not have finish chunk when context is cancelled") - } + assert.NotEqual(t, typeFinish, tp, "should not have finish chunk when context is cancelled") } } @@ -1016,9 +860,7 @@ func TestHandler_AllRequiredHeaders(t *testing.T) { } for k, want := range headers { - if got := rec.Header().Get(k); got != want { - t.Errorf("header %q = %q, want %q", k, got, want) - } + assert.Equal(t, want, rec.Header().Get(k), "header %q mismatch", k) } } @@ -1039,27 +881,16 @@ func TestHandler_StartChunkHasMessageID(t *testing.T) { h.ServeHTTP(rec, req) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") - if len(chunks) == 0 { - t.Fatal("no chunks") - } + require.NotEmpty(t, chunks, "no chunks") startChunk := chunks[0] - if startChunk["type"] != typeStart { - t.Fatalf("first chunk type = %q, want 'start'", startChunk["type"]) - } + require.Equal(t, typeStart, startChunk["type"]) mid := chunkStr(t, startChunk, "messageId") - if mid == "" { - t.Fatalf("start chunk has no messageId") - } - - if len(mid) != 16 { - t.Errorf("messageId length = %d, want 16", len(mid)) - } + require.NotEmpty(t, mid, "start chunk has no messageId") + assert.Len(t, mid, 16) } func TestHandler_RequestBodySizeLimit(t *testing.T) { @@ -1076,9 +907,7 @@ func TestHandler_RequestBodySizeLimit(t *testing.T) { h.ServeHTTP(rec, req) - if rec.Code != http.StatusBadRequest { - t.Errorf("expected 400 for oversized body, got %d", rec.Code) - } + assert.Equal(t, http.StatusBadRequest, rec.Code, "oversized body should be rejected") } func TestHandler_EmptyMessagesArray(t *testing.T) { @@ -1094,9 +923,7 @@ func TestHandler_EmptyMessagesArray(t *testing.T) { h.ServeHTTP(rec, req) - if rec.Code != http.StatusBadRequest { - t.Errorf("expected 400 for empty messages, got %d", rec.Code) - } + assert.Equal(t, http.StatusBadRequest, rec.Code, "empty messages should be rejected") } func TestStreamModel_StreamResetEvent(t *testing.T) { @@ -1119,9 +946,7 @@ func TestStreamModel_StreamResetEvent(t *testing.T) { }, ew, nil) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) @@ -1131,15 +956,7 @@ func TestStreamModel_StreamResetEvent(t *testing.T) { typeTextStart, typeTextDelta, typeTextEnd, typeFinishStep, typeFinish, } - if len(types) != len(expected) { - t.Fatalf("chunk types = %v, want %v", types, expected) - } - - for i, exp := range expected { - if types[i] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall: %v", i, types[i], exp, types) - } - } + require.Equal(t, expected, types, "chunk types mismatch") } func TestStreamModel_StreamResetEventWithReasoning(t *testing.T) { @@ -1163,9 +980,7 @@ func TestStreamModel_StreamResetEventWithReasoning(t *testing.T) { }, ew, nil) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) @@ -1178,15 +993,7 @@ func TestStreamModel_StreamResetEventWithReasoning(t *testing.T) { typeTextStart, typeTextDelta, typeTextEnd, typeFinishStep, typeFinish, } - if len(types) != len(expected) { - t.Fatalf("chunk types = %v, want %v", types, expected) - } - - for i, exp := range expected { - if types[i] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall: %v", i, types[i], exp, types) - } - } + require.Equal(t, expected, types, "chunk types mismatch") } func TestHandler_TextContentConcatenatesAllParts(t *testing.T) { @@ -1201,9 +1008,7 @@ func TestHandler_TextContentConcatenatesAllParts(t *testing.T) { }, } - if got := msg.textContent(); got != "Hello World" { - t.Errorf("textContent() = %q, want 'Hello World'", got) - } + assert.Equal(t, "Hello World", msg.textContent()) } func TestHandler_ErrorResponseSanitized(t *testing.T) { @@ -1227,13 +1032,9 @@ func TestHandler_ErrorResponseSanitized(t *testing.T) { for _, c := range chunks { if c["type"] == typeError { et := chunkStr(t, c, "errorText") - if strings.Contains(et, "secret") || strings.Contains(et, "password") { - t.Errorf("error text leaks internals: %q", et) - } - - if et != sanitizedError { - t.Errorf("error text = %q, want %q", et, sanitizedError) - } + assert.NotContains(t, et, "secret", "error text leaks internals") + assert.NotContains(t, et, "password", "error text leaks internals") + assert.Equal(t, sanitizedError, et) } } } @@ -1267,9 +1068,7 @@ func TestStreamModel_ToolCall(t *testing.T) { }, ew, nil) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) @@ -1281,51 +1080,28 @@ func TestStreamModel_ToolCall(t *testing.T) { typeTextStart, typeTextDelta, typeTextEnd, typeFinishStep, typeFinish, } - if len(types) != len(expected) { - t.Fatalf("chunk types = %v\nwant = %v", types, expected) - } - - for i, exp := range expected { - if types[i] != exp { - t.Fatalf("chunk[%d] type = %q, want %q\nall: %v", i, types[i], exp, types) - } - } + require.Equal(t, expected, types, "chunk types mismatch") for _, c := range chunks { if c["type"] == "tool-input-start" { - if c["toolCallId"] != "call-1" { - t.Errorf("tool-input-start toolCallId = %v, want call-1", c["toolCallId"]) - } - - if c["toolName"] != "getWeather" { - t.Errorf("tool-input-start toolName = %v, want getWeather", c["toolName"]) - } + assert.Equal(t, "call-1", c["toolCallId"]) + assert.Equal(t, "getWeather", c["toolName"]) } } for _, c := range chunks { if c["type"] == "tool-input-available" { input, ok := c["input"].(map[string]any) - if !ok { - t.Fatalf("tool-input-available input is not map: %T", c["input"]) - } - - if input["location"] != "San Francisco" { - t.Errorf("input.location = %v, want San Francisco", input["location"]) - } + require.True(t, ok, "tool-input-available input is not map: %T", c["input"]) + assert.Equal(t, "San Francisco", input["location"]) } } for _, c := range chunks { if c["type"] == "tool-output-available" { output, ok := c["output"].(map[string]any) - if !ok { - t.Fatalf("tool-output-available output is not map: %T", c["output"]) - } - - if output["condition"] != "sunny" { - t.Errorf("output.condition = %v, want sunny", output["condition"]) - } + require.True(t, ok, "tool-output-available output is not map: %T", c["output"]) + assert.Equal(t, "sunny", output["condition"]) } } } @@ -1357,37 +1133,776 @@ func TestStreamModel_ToolCallError(t *testing.T) { }, ew, nil) chunks, done := parseSSE(t, rec.Body) - if !done { - t.Error("stream not terminated with [DONE]") - } + assert.True(t, done, "stream not terminated with [DONE]") types := chunkTypes(chunks) expected := []string{ typeStart, typeStartStep, "tool-input-start", "tool-input-available", - "tool-output-error", + typeToolOutputError, typeFinishStep, typeFinish, } - if len(types) != len(expected) { - t.Fatalf("chunk types = %v\nwant = %v", types, expected) + require.Equal(t, expected, types, "chunk types mismatch") + + for _, c := range chunks { + if c["type"] == typeToolOutputError { + assert.Equal(t, "tool execution failed", c["errorText"]) + assert.Equal(t, "call-2", c["toolCallId"]) + } + } +} + +func TestStreamModel_IteratorErrorClosesSpans(t *testing.T) { + t.Parallel() + + // When the iterator yields an error mid-stream, open text/reasoning spans + // must be closed before the error/finish-step/finish sequence. + model := &errorStreamModel{ + events: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("partial")}, + }, + } + + // Wrap with an iterator that yields the events then an error. + iterErrModel := &iteratorErrorModel{ + inner: model.events, + err: errors.New("network blip"), + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModel(context.Background(), iterErrModel, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("hi"))}, + }, ew, nil) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + types := chunkTypes(chunks) + + // text-end must appear before error to close the open span. + expected := []string{ + typeStart, typeStartStep, + typeTextStart, typeTextDelta, typeTextEnd, + typeError, typeFinishStep, typeFinish, } + require.Equal(t, expected, types, "chunk types mismatch") +} + +// iteratorErrorModel yields events then a terminal error from the iterator. +type iteratorErrorModel struct { + errorStreamModel + + inner []llm.Event + err error +} - for i, exp := range expected { - if types[i] != exp { - t.Fatalf("chunk[%d] = %q, want %q", i, types[i], exp) +func (m *iteratorErrorModel) GenerateEvents(_ context.Context, _ *llm.Request) iter.Seq2[llm.Event, error] { + return func(yield func(llm.Event, error) bool) { + for _, e := range m.inner { + if !yield(e, nil) { + return + } } + + yield(nil, m.err) + } +} + +func TestStreamModel_ReasoningIDsIncrement(t *testing.T) { + t.Parallel() + + // reasoning → text → reasoning should produce distinct reasoning IDs. + model := &errorStreamModel{ + events: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewReasoningPart(&llm.ReasoningTrace{Text: "think1"})}, + llm.ContentPartEvent{Index: 1, Part: llm.NewTextPart("middle")}, + llm.ContentPartEvent{Index: 2, Part: llm.NewReasoningPart(&llm.ReasoningTrace{Text: "think2"})}, + llm.ContentPartEvent{Index: 3, Part: llm.NewTextPart("final")}, + llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonStop}}, + }, + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModel(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("think"))}, + }, ew, nil) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + types := chunkTypes(chunks) + + expected := []string{ + typeStart, typeStartStep, + typeReasoningStart, typeReasoningDelta, typeReasoningEnd, + typeTextStart, typeTextDelta, typeTextEnd, + typeReasoningStart, typeReasoningDelta, typeReasoningEnd, + typeTextStart, typeTextDelta, typeTextEnd, + typeFinishStep, typeFinish, } + require.Equal(t, expected, types, "chunk types mismatch") + + // Collect reasoning-start IDs and verify they differ. + var reasoningIDs []string for _, c := range chunks { - if c["type"] == "tool-output-error" { - if c["errorText"] != "tool execution failed" { - t.Errorf("errorText = %v, want 'tool execution failed'", c["errorText"]) - } + if c["type"] == typeReasoningStart { + reasoningIDs = append(reasoningIDs, chunkStr(t, c, "id")) + } + } - if c["toolCallId"] != "call-2" { - t.Errorf("toolCallId = %v, want call-2", c["toolCallId"]) - } + require.Len(t, reasoningIDs, 2, "expected 2 reasoning-start chunks") + assert.NotEqual(t, reasoningIDs[0], reasoningIDs[1], "reasoning IDs should differ") +} + +func TestStreamModel_TextIDsIncrementAcrossToolCalls(t *testing.T) { + t.Parallel() + + // text → tool → text should produce distinct text IDs. + model := &errorStreamModel{ + events: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("before tool")}, + llm.ContentPartEvent{Index: 1, Part: llm.NewToolRequestPart(&llm.ToolRequest{ + ID: "call-1", Name: "lookup", Arguments: json.RawMessage(`{}`), + })}, + llm.ContentPartEvent{Index: 2, Part: llm.NewToolResponsePart(&llm.ToolResponse{ + ID: "call-1", Name: "lookup", Result: json.RawMessage(`"ok"`), + })}, + llm.ContentPartEvent{Index: 3, Part: llm.NewTextPart("after tool")}, + llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonStop}}, + }, + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModel(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + }, ew, nil) + + chunks, _ := parseSSE(t, rec.Body) + + // Collect text-start IDs and verify they differ. + var textIDs []string + + for _, c := range chunks { + if c["type"] == typeTextStart { + textIDs = append(textIDs, chunkStr(t, c, "id")) } } + + require.Len(t, textIDs, 2, "expected 2 text-start chunks") + assert.NotEqual(t, textIDs[0], textIDs[1], "text IDs should differ") +} + +func TestStreamModel_ReasoningThenToolRequest(t *testing.T) { + t.Parallel() + + // reasoning → tool-request must close the reasoning span before emitting tool events. + model := &errorStreamModel{ + events: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewReasoningPart(&llm.ReasoningTrace{Text: "thinking about tool"})}, + llm.ContentPartEvent{Index: 1, Part: llm.NewToolRequestPart(&llm.ToolRequest{ + ID: "call-1", Name: "lookup", Arguments: json.RawMessage(`{}`), + })}, + llm.ContentPartEvent{Index: 2, Part: llm.NewToolResponsePart(&llm.ToolResponse{ + ID: "call-1", Name: "lookup", Result: json.RawMessage(`"found"`), + })}, + llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonStop}}, + }, + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModel(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + }, ew, nil) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + types := chunkTypes(chunks) + + expected := []string{ + typeStart, typeStartStep, + typeReasoningStart, typeReasoningDelta, typeReasoningEnd, + "tool-input-start", "tool-input-available", + "tool-output-available", + typeFinishStep, typeFinish, + } + require.Equal(t, expected, types, "chunk types mismatch") +} + +func TestHandler_WithMaxBodyBytes(t *testing.T) { + t.Parallel() + + model := fakellm.NewFakeModel() + h := Handler(model, WithMaxBodyBytes(50)) + + // Valid JSON structure that exceeds the 50-byte limit. + body := `{"id":"x","messages":[{"role":"user","content":"this payload is definitely longer than fifty bytes"}]}` + req := httptest.NewRequestWithContext(context.Background(), http.MethodPost, "/api/chat", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + + h.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code, "body exceeding custom limit should be rejected") +} + +func TestGenerateMessageID_Length(t *testing.T) { + t.Parallel() + + id := generateMessageID() + assert.Len(t, id, 16, "generateMessageID() value = %q", id) +} + +// toolCallModel returns tool calls on the first N turns, then stops. +type toolCallModel struct { + errorStreamModel + + turnsWithTools int + callCount int +} + +func (m *toolCallModel) GenerateEvents(_ context.Context, _ *llm.Request) iter.Seq2[llm.Event, error] { + m.callCount++ + turn := m.callCount + + return func(yield func(llm.Event, error) bool) { + if turn <= m.turnsWithTools { + if !yield(llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("calling tool")}, nil) { + return + } + + if !yield(llm.ContentPartEvent{Index: 1, Part: llm.NewToolRequestPart(&llm.ToolRequest{ + ID: "call-1", Name: "lookup", Arguments: json.RawMessage(`{}`), + })}, nil) { + return + } + + yield(llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonToolCalls}}, nil) + + return + } + + if !yield(llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("done")}, nil) { + return + } + + yield(llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonStop}}, nil) + } +} + +func TestStreamModelWithTools_HappyPath(t *testing.T) { + t.Parallel() + + model := &toolCallModel{turnsWithTools: 1} + + executor := func(_ context.Context, _ string, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`"result"`), nil + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModelWithTools(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + Tools: []llm.ToolDefinition{{Name: "lookup"}}, + }, ew, nil, executor, 0) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + types := chunkTypes(chunks) + + expected := []string{ + typeStart, + // Turn 1: tool call + typeStartStep, typeTextStart, typeTextDelta, typeTextEnd, + "tool-input-start", "tool-input-available", + typeFinishStep, + "tool-output-available", + // Turn 2: final answer + typeStartStep, typeTextStart, typeTextDelta, typeTextEnd, + typeFinishStep, + typeFinish, + } + require.Equal(t, expected, types, "chunk types mismatch") + + finishChunk := chunks[len(chunks)-1] + assert.Equal(t, finishReasonStop, chunkStr(t, finishChunk, "finishReason")) +} + +func TestStreamModelWithTools_MaxTurnsExhaustion(t *testing.T) { + t.Parallel() + + // Model always returns tool calls — will exhaust maxTurns. + model := &toolCallModel{turnsWithTools: 100} + + executor := func(_ context.Context, _ string, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`"ok"`), nil + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModelWithTools(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + Tools: []llm.ToolDefinition{{Name: "lookup"}}, + }, ew, nil, executor, 2) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + types := chunkTypes(chunks) + + // Count paired start-step/finish-step. + startSteps := 0 + finishSteps := 0 + + for _, tp := range types { + switch tp { + case typeStartStep: + startSteps++ + case typeFinishStep: + finishSteps++ + } + } + + assert.Equal(t, 2, startSteps, "expected 2 start-step chunks") + assert.Equal(t, 2, finishSteps, "expected 2 finish-step chunks") + + // Exactly one finish chunk. + finishCount := 0 + + for _, tp := range types { + if tp == typeFinish { + finishCount++ + } + } + + assert.Equal(t, 1, finishCount, "expected 1 finish chunk") + + finishChunk := chunks[len(chunks)-1] + assert.Equal(t, finishReasonOther, chunkStr(t, finishChunk, "finishReason")) +} + +func TestStreamModelWithTools_IteratorErrorNoDuplicateTerminal(t *testing.T) { + t.Parallel() + + // Model yields text then iterator error — exercises the streamToolTurn error + // path and verifies the caller doesn't emit a second finish+[DONE]. + model := &iteratorErrorModel{ + inner: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("partial")}, + }, + err: errors.New("network blip"), + } + + executor := func(_ context.Context, _ string, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`"ok"`), nil + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModelWithTools(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + Tools: []llm.ToolDefinition{{Name: "lookup"}}, + }, ew, nil, executor, 0) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + // Exactly one finish chunk. + finishCount := 0 + + for _, c := range chunks { + if c["type"] == typeFinish { + finishCount++ + } + } + + require.Equal(t, 1, finishCount, "expected exactly 1 finish chunk") + + // Verify it's an error finish. + for _, c := range chunks { + if c["type"] == typeFinish { + assert.Equal(t, finishReasonError, chunkStr(t, c, "finishReason")) + } + } +} + +func TestStreamModelWithTools_IteratorErrorPairsToolChunks(t *testing.T) { + t.Parallel() + + // When a tool-input has been emitted and then the iterator errors, + // the orphaned tool-input must get a matching tool-output-error. + model := &iteratorErrorModel{ + inner: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewToolRequestPart(&llm.ToolRequest{ + ID: "call-orphan", Name: "lookup", Arguments: json.RawMessage(`{}`), + })}, + }, + err: errors.New("connection reset"), + } + + executor := func(_ context.Context, _ string, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`"ok"`), nil + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModelWithTools(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + Tools: []llm.ToolDefinition{{Name: "lookup"}}, + }, ew, nil, executor, 0) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + types := chunkTypes(chunks) + + expected := []string{ + typeStart, typeStartStep, + "tool-input-start", "tool-input-available", + typeToolOutputError, + typeError, typeFinishStep, typeFinish, + } + require.Equal(t, expected, types, "chunk types mismatch") + + // Verify the tool-output-error references the right call. + for _, c := range chunks { + if c["type"] == typeToolOutputError { + assert.Equal(t, "call-orphan", c["toolCallId"]) + } + } +} + +func TestStreamModelWithTools_StreamEndEventErrorPairsToolChunks(t *testing.T) { + t.Parallel() + + // StreamEndEvent{Error} after tool-input has been emitted must pair + // the orphaned tool-input with a tool-output-error. + model := &errorStreamModel{ + events: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewToolRequestPart(&llm.ToolRequest{ + ID: "call-orphan", Name: "lookup", Arguments: json.RawMessage(`{}`), + })}, + llm.StreamEndEvent{Error: errors.New("message too long")}, + }, + } + + executor := func(_ context.Context, _ string, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`"ok"`), nil + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModelWithTools(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + Tools: []llm.ToolDefinition{{Name: "lookup"}}, + }, ew, nil, executor, 0) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + types := chunkTypes(chunks) + + expected := []string{ + typeStart, typeStartStep, + "tool-input-start", "tool-input-available", + typeError, typeFinishStep, + typeToolOutputError, + typeFinish, + } + require.Equal(t, expected, types, "chunk types mismatch") + + for _, c := range chunks { + if c["type"] == typeToolOutputError { + assert.Equal(t, "call-orphan", c["toolCallId"]) + } + } +} + +func TestStreamModelWithTools_StreamEndEventWithError(t *testing.T) { + t.Parallel() + + // StreamEndEvent{Error: ...} (as opposed to iterator error) must also + // produce exactly one finish+[DONE]. This exercises the writeToolTurnEnd + // path where the caller emits finish, not streamToolTurn itself. + model := &errorStreamModel{ + events: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("partial")}, + llm.StreamEndEvent{Error: errors.New("provider exploded")}, + }, + } + + executor := func(_ context.Context, _ string, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`"ok"`), nil + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModelWithTools(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + Tools: []llm.ToolDefinition{{Name: "lookup"}}, + }, ew, nil, executor, 0) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + finishCount := 0 + + for _, c := range chunks { + if c["type"] == typeFinish { + finishCount++ + } + } + + require.Equal(t, 1, finishCount, "expected exactly 1 finish chunk") + + for _, c := range chunks { + if c["type"] == typeFinish { + assert.Equal(t, finishReasonError, chunkStr(t, c, "finishReason")) + } + } +} + +func TestStreamModelWithTools_StreamResetEvent(t *testing.T) { + t.Parallel() + + // StreamResetEvent during a tool-calling turn must close spans, reset + // collected tool requests, and not double-fire tools from the discarded attempt. + executorCalls := 0 + + model := &errorStreamModel{ + events: []llm.Event{ + // First attempt: text + tool request, then reset. + llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("attempt1")}, + llm.ContentPartEvent{Index: 1, Part: llm.NewToolRequestPart(&llm.ToolRequest{ + ID: "call-discarded", Name: "lookup", Arguments: json.RawMessage(`{}`), + })}, + llm.StreamResetEvent{Attempt: 1, Reason: "retrying"}, + // Second attempt: just text, no tool call. + llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("attempt2")}, + llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonStop}}, + }, + } + + executor := func(_ context.Context, _ string, _ json.RawMessage) (json.RawMessage, error) { + executorCalls++ + + return json.RawMessage(`"ok"`), nil + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModelWithTools(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + Tools: []llm.ToolDefinition{{Name: "lookup"}}, + }, ew, nil, executor, 0) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + // Tool executor must NOT have been called — the tool request was from the + // discarded attempt before the reset. + assert.Equal(t, 0, executorCalls, "executor should not be called for discarded tool requests") + + types := chunkTypes(chunks) + + // Verify text spans are properly closed across the reset, and the + // discarded tool request gets a tool-output-error to satisfy the + // protocol pairing invariant. + expected := []string{ + typeStart, + typeStartStep, + typeTextStart, typeTextDelta, typeTextEnd, + "tool-input-start", "tool-input-available", + typeToolOutputError, + typeTextStart, typeTextDelta, typeTextEnd, + typeFinishStep, typeFinish, + } + require.Equal(t, expected, types, "chunk types mismatch") + + assert.Equal(t, finishReasonStop, chunkStr(t, chunks[len(chunks)-1], "finishReason")) +} + +func TestStreamModel_ToolResponseClosesSpans(t *testing.T) { + t.Parallel() + + // PartToolResponse should close any open text/reasoning span. + model := &errorStreamModel{ + events: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("before")}, + llm.ContentPartEvent{Index: 1, Part: llm.NewToolResponsePart(&llm.ToolResponse{ + ID: "call-1", Name: "lookup", Result: json.RawMessage(`"found"`), + })}, + llm.ContentPartEvent{Index: 2, Part: llm.NewTextPart("after")}, + llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonStop}}, + }, + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModel(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + }, ew, nil) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + types := chunkTypes(chunks) + + expected := []string{ + typeStart, typeStartStep, + typeTextStart, typeTextDelta, typeTextEnd, + "tool-output-available", + typeTextStart, typeTextDelta, typeTextEnd, + typeFinishStep, typeFinish, + } + require.Equal(t, expected, types, "chunk types mismatch") + + // Verify text IDs differ across the tool response boundary. + var textIDs []string + + for _, c := range chunks { + if c["type"] == typeTextStart { + textIDs = append(textIDs, chunkStr(t, c, "id")) + } + } + + require.Len(t, textIDs, 2) + assert.NotEqual(t, textIDs[0], textIDs[1], "text IDs should differ across tool response") +} + +func TestStreamModelWithTools_ErrorEventNonTerminal(t *testing.T) { + t.Parallel() + + // ErrorEvent in tool-calling mode must be forwarded to the wire, + // not silently dropped. + model := &errorStreamModel{ + events: []llm.Event{ + llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("before")}, + llm.ErrorEvent{Message: "recoverable warning"}, + llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart(" after")}, + llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonStop}}, + }, + } + + executor := func(_ context.Context, _ string, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`"ok"`), nil + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + StreamModelWithTools(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + Tools: []llm.ToolDefinition{{Name: "lookup"}}, + }, ew, nil, executor, 0) + + chunks, done := parseSSE(t, rec.Body) + assert.True(t, done, "stream not terminated with [DONE]") + + types := chunkTypes(chunks) + + expected := []string{ + typeStart, typeStartStep, + typeTextStart, typeTextDelta, typeError, typeTextDelta, typeTextEnd, + typeFinishStep, typeFinish, + } + require.Equal(t, expected, types, "chunk types mismatch") + + for _, c := range chunks { + if c["type"] == typeError { + assert.Equal(t, sanitizedError, chunkStr(t, c, "errorText")) + } + } +} + +func TestStreamModelWithTools_RequestFieldsForwarded(t *testing.T) { + t.Parallel() + + // Verify that Options, ResponseFormat, and Metadata from the original + // request are forwarded to subsequent model turns, not dropped. + type testOptions struct { + MaxTokens int + } + + var capturedOptions []any + + // Model that returns tool call on first turn, stop on second. + callCount := 0 + model := &callbackModel{ + fn: func(_ context.Context, req *llm.Request) iter.Seq2[llm.Event, error] { + callCount++ + + capturedOptions = append(capturedOptions, req.Options) + + return func(yield func(llm.Event, error) bool) { + if callCount == 1 { + if !yield(llm.ContentPartEvent{Index: 0, Part: llm.NewToolRequestPart(&llm.ToolRequest{ + ID: "call-1", Name: "lookup", Arguments: json.RawMessage(`{}`), + })}, nil) { + return + } + + yield(llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonToolCalls}}, nil) + + return + } + + if !yield(llm.ContentPartEvent{Index: 0, Part: llm.NewTextPart("done")}, nil) { + return + } + + yield(llm.StreamEndEvent{Response: &llm.Response{FinishReason: llm.FinishReasonStop}}, nil) + } + }, + } + + executor := func(_ context.Context, _ string, _ json.RawMessage) (json.RawMessage, error) { + return json.RawMessage(`"ok"`), nil + } + + rec := httptest.NewRecorder() + ew := NewEventWriter(rec) + + opts := testOptions{MaxTokens: 1000} + + StreamModelWithTools(context.Background(), model, &llm.Request{ + Messages: []llm.Message{llm.NewMessage(llm.RoleUser, llm.NewTextPart("go"))}, + Tools: []llm.ToolDefinition{{Name: "lookup"}}, + Options: opts, + }, ew, nil, executor, 0) + + require.Len(t, capturedOptions, 2, "expected 2 model calls") + assert.Equal(t, opts, capturedOptions[0], "turn 1 should have options") + assert.Equal(t, opts, capturedOptions[1], "turn 2 should have options") +} + +// callbackModel delegates GenerateEvents to a callback function. +type callbackModel struct { + errorStreamModel + + fn func(context.Context, *llm.Request) iter.Seq2[llm.Event, error] +} + +func (m *callbackModel) GenerateEvents(ctx context.Context, req *llm.Request) iter.Seq2[llm.Event, error] { + return m.fn(ctx, req) }