-
Notifications
You must be signed in to change notification settings - Fork 3
Fix Vercel AI SDK adapter bugs from code review #132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
45709e6
73b4875
f6bacfd
f5b2290
e8429f9
67b2772
06f0a7a
b93c88f
be957b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,12 @@ | |
| // packages/ai/src/generate-text/stream-text.ts and the SSE framing from | ||
| // packages/ai/src/ui-message-stream/json-to-sse-transform-stream.ts. | ||
| // | ||
| // Known limitations: | ||
| // - Inbound tool call history from multi-turn conversations is not yet | ||
| // reconstructed; only text parts are forwarded to the model. | ||
| // - The handler calls model.GenerateEvents directly; interceptor plugins | ||
| // (retry, OTel) must be wired at the model level. | ||
| // | ||
| // Reference: https://github.com/vercel/ai | ||
| package uimessagestream | ||
|
|
||
|
|
@@ -36,7 +42,11 @@ | |
| // protocol. It accepts POST requests with a JSON body containing messages | ||
| // and streams back SSE events compatible with useChat. | ||
| func Handler(model llm.Model, opts ...Option) http.Handler { | ||
| cfg := &config{logger: slog.Default()} | ||
| cfg := &config{ | ||
| logger: slog.Default(), | ||
| maxBodyBytes: 1 << 20, // 1MB | ||
| maxTurns: 10, | ||
| } | ||
| for _, o := range opts { | ||
| o(cfg) | ||
| } | ||
|
|
@@ -52,10 +62,12 @@ | |
| type ToolExecutor func(ctx context.Context, name string, args json.RawMessage) (json.RawMessage, error) | ||
|
|
||
| type config struct { | ||
| system string | ||
| logger *slog.Logger | ||
| tools []llm.ToolDefinition | ||
| executor ToolExecutor | ||
| system string | ||
| logger *slog.Logger | ||
| tools []llm.ToolDefinition | ||
| executor ToolExecutor | ||
| maxBodyBytes int64 | ||
| maxTurns int | ||
| } | ||
|
|
||
| // WithSystem sets the system prompt prepended to every request. | ||
|
|
@@ -79,6 +91,16 @@ | |
| } | ||
| } | ||
|
|
||
| // WithMaxBodyBytes sets the maximum request body size in bytes. Default is 1MB. | ||
| func WithMaxBodyBytes(n int64) Option { | ||
| return func(c *config) { c.maxBodyBytes = n } | ||
| } | ||
|
|
||
| // WithMaxTurns sets the maximum number of agentic tool-calling turns. Default is 10. | ||
| func WithMaxTurns(n int) Option { | ||
| return func(c *config) { c.maxTurns = n } | ||
| } | ||
|
|
||
| type handler struct { | ||
| model llm.Model | ||
| cfg *config | ||
|
|
@@ -88,7 +110,6 @@ | |
| type chatRequest struct { | ||
| ID string `json:"id"` | ||
| Messages []chatMessage `json:"messages"` | ||
| Trigger string `json:"trigger"` | ||
| } | ||
|
|
||
| type chatMessage struct { | ||
|
|
@@ -127,8 +148,8 @@ | |
| return | ||
| } | ||
|
|
||
| // Limit request body to 1MB to prevent abuse. | ||
| r.Body = http.MaxBytesReader(w, r.Body, 1<<20) | ||
| // Limit request body size to prevent abuse. | ||
| r.Body = http.MaxBytesReader(w, r.Body, h.cfg.maxBodyBytes) | ||
|
|
||
| var body chatRequest | ||
| if err := json.NewDecoder(r.Body).Decode(&body); err != nil { | ||
|
|
@@ -158,14 +179,14 @@ | |
| setSSEHeaders(w) | ||
|
|
||
| ew := &EventWriter{w: w, f: flusher} | ||
| StreamModelWithTools(r.Context(), h.model, req, ew, h.cfg.logger, h.cfg.executor) | ||
| StreamModelWithTools(r.Context(), h.model, req, ew, h.cfg.logger, h.cfg.executor, h.cfg.maxTurns) | ||
| } | ||
|
|
||
| // generateMessageID creates a random 16-character hex ID for use as a messageId. | ||
| func generateMessageID() string { | ||
| b := make([]byte, 8) | ||
| if _, err := rand.Read(b); err != nil { | ||
| return "msg-0000000000000000" | ||
| return "0000000000000000" | ||
| } | ||
|
|
||
| return hex.EncodeToString(b) | ||
|
|
@@ -182,16 +203,19 @@ | |
| // complexity of StreamModel and StreamModelWithTools. | ||
| type streamWriter struct { | ||
| ew *EventWriter | ||
| logger *slog.Logger | ||
| textID string | ||
| reasoningID string | ||
| textStarted bool | ||
| reasoningStarted bool | ||
| textCounter int | ||
| reasoningCounter int | ||
| } | ||
|
|
||
| func newStreamWriter(ew *EventWriter) *streamWriter { | ||
| func newStreamWriter(ew *EventWriter, logger *slog.Logger) *streamWriter { | ||
| return &streamWriter{ | ||
| ew: ew, | ||
| logger: logger, | ||
| textID: "text-0", | ||
| reasoningID: "reasoning-0", | ||
| } | ||
|
|
@@ -207,20 +231,8 @@ | |
| } | ||
|
|
||
| sw.reasoningStarted = false | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| func (sw *streamWriter) endText() error { | ||
| if !sw.textStarted { | ||
| return nil | ||
| } | ||
|
|
||
| if err := sw.ew.WriteChunk(Chunk{"type": "text-end", "id": sw.textID}); err != nil { | ||
| return err | ||
| } | ||
|
|
||
| sw.textStarted = false | ||
| sw.reasoningCounter++ | ||
| sw.reasoningID = fmt.Sprintf("reasoning-%d", sw.reasoningCounter) | ||
|
|
||
| return nil | ||
| } | ||
|
|
@@ -258,6 +270,10 @@ | |
| } | ||
|
|
||
| func (sw *streamWriter) writeReasoningDelta(trace *llm.ReasoningTrace) error { | ||
| if err := sw.endTextAndAdvance(); err != nil { | ||
| return err | ||
| } | ||
|
|
||
| if !sw.reasoningStarted { | ||
| if err := sw.ew.WriteChunk(Chunk{"type": "reasoning-start", "id": sw.reasoningID}); err != nil { | ||
| return err | ||
|
|
@@ -286,7 +302,9 @@ | |
|
|
||
| var input any | ||
| if len(tr.Arguments) > 0 { | ||
| _ = json.Unmarshal(tr.Arguments, &input) | ||
| if err := json.Unmarshal(tr.Arguments, &input); err != nil { | ||
| sw.logger.Warn("failed to unmarshal tool input", "toolCallId", tr.ID, "error", err) | ||
| } | ||
| } | ||
|
|
||
| return sw.ew.WriteChunk(Chunk{ | ||
|
|
@@ -307,7 +325,9 @@ | |
|
|
||
| var output any | ||
| if len(tr.Result) > 0 { | ||
| _ = json.Unmarshal(tr.Result, &output) | ||
| if err := json.Unmarshal(tr.Result, &output); err != nil { | ||
| sw.logger.Warn("failed to unmarshal tool output", "toolCallId", tr.ID, "error", err) | ||
| } | ||
| } | ||
|
|
||
| return sw.ew.WriteChunk(Chunk{ | ||
|
|
@@ -321,7 +341,7 @@ | |
| case llm.PartText: | ||
| return sw.writeTextDelta(part.Text) | ||
| case llm.PartToolRequest: | ||
| if err := sw.endText(); err != nil { | ||
| if err := sw.endTextAndAdvance(); err != nil { | ||
|
Check failure on line 344 in adapter/vercelaisdk/uimessagestream/handler.go
|
||
| return err | ||
| } | ||
|
|
||
|
claude[bot] marked this conversation as resolved.
|
||
|
|
@@ -340,12 +360,15 @@ | |
| if sw.reasoningStarted { | ||
| _ = sw.ew.WriteChunk(Chunk{"type": "reasoning-end", "id": sw.reasoningID}) | ||
| sw.reasoningStarted = false | ||
| sw.reasoningCounter++ | ||
| sw.reasoningID = fmt.Sprintf("reasoning-%d", sw.reasoningCounter) | ||
| } | ||
|
|
||
| if sw.textStarted { | ||
| _ = sw.ew.WriteChunk(Chunk{"type": "text-end", "id": sw.textID}) | ||
| sw.textStarted = false | ||
| sw.textCounter++ | ||
| sw.textID = fmt.Sprintf("text-%d", sw.textCounter) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -388,7 +411,7 @@ | |
| return | ||
| } | ||
|
|
||
| sw := newStreamWriter(ew) | ||
| sw := newStreamWriter(ew, logger) | ||
|
|
||
| for event, err := range model.GenerateEvents(ctx, req) { | ||
| if err != nil { | ||
|
|
@@ -398,6 +421,8 @@ | |
|
|
||
| logger.Error("stream error", "error", err) | ||
|
|
||
| sw.closeSpans() | ||
|
|
||
| _ = ew.WriteChunk(Chunk{"type": "error", "errorText": "An error occurred"}) | ||
| _ = ew.WriteChunk(Chunk{"type": "finish-step"}) | ||
| _ = ew.WriteChunk(Chunk{"type": "finish", "finishReason": finishReasonError}) | ||
|
|
@@ -419,7 +444,7 @@ | |
| } | ||
|
|
||
| case llm.StreamResetEvent: | ||
| if err := sw.endText(); err != nil { | ||
| if err := sw.endTextAndAdvance(); err != nil { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟣 🟣 Pre-existing (from #131): Extended reasoning...What the bug is
When any of three terminal events fires after such chunks have been written:
All three leave the previously-emitted Why this PR is the natural momentThis PR's diff explicitly:
So the same hunks were touched, the same template applied to the executor sibling — but Step-by-step proof (StreamResetEvent path)Provider yields
Wire output:
ReachabilityTwo routes through the public API:
Addressing the refutationThe refutation argues the no-executor path operates under a different contract:
Reachability via Severity rationale
How to fixOption 1 (mirror streamToolTurn): track pending tool IDs in Option 2 (smaller diff): reject |
||
| return | ||
| } | ||
|
|
||
|
|
@@ -438,8 +463,9 @@ | |
| // StreamModelWithTools is like StreamModel but supports agentic tool calling. | ||
| // When the model returns tool calls, the executor is invoked for each, results | ||
| // are streamed to the client, and the model is called again with the results | ||
| // appended to the conversation. This loops until the model stops calling tools. | ||
| func StreamModelWithTools(ctx context.Context, model llm.Model, req *llm.Request, ew *EventWriter, logger *slog.Logger, executor ToolExecutor) { | ||
| // appended to the conversation. This loops until the model stops calling tools | ||
| // or maxTurns is reached. If maxTurns is 0, it defaults to 10. | ||
| func StreamModelWithTools(ctx context.Context, model llm.Model, req *llm.Request, ew *EventWriter, logger *slog.Logger, executor ToolExecutor, maxTurns int) { | ||
| if executor == nil { | ||
| StreamModel(ctx, model, req, ew, logger) | ||
| return | ||
|
|
@@ -449,15 +475,17 @@ | |
| logger = slog.Default() | ||
| } | ||
|
|
||
| if maxTurns <= 0 { | ||
| maxTurns = 10 | ||
| } | ||
|
|
||
| messageID := generateMessageID() | ||
| if err := ew.WriteChunk(Chunk{"type": "start", "messageId": messageID}); err != nil { | ||
| return | ||
| } | ||
|
|
||
| messages := slices.Clone(req.Messages) | ||
| sw := newStreamWriter(ew) | ||
|
|
||
| const maxTurns = 10 | ||
| sw := newStreamWriter(ew, logger) | ||
|
|
||
| for range maxTurns { | ||
| finishReason, toolRequests := streamToolTurn(ctx, model, req, messages, sw, ew, logger) | ||
|
|
@@ -476,11 +504,14 @@ | |
|
|
||
| messages = append(messages, llm.Message{Role: llm.RoleAssistant, Content: assistantParts}) | ||
|
|
||
| if err := executeTools(ctx, toolRequests, &messages, ew, executor); err != nil { | ||
| if err := executeTools(ctx, toolRequests, &messages, ew, logger, executor); err != nil { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟣 🟡 Pre-existing bug (from #131, not introduced by this PR): In Extended reasoning...What the bug isIn case llm.PartText:
if err := sw.writeTextDelta(e.Part.Text); err != nil { ... } // streamed only
case llm.PartReasoning:
if err := sw.writeReasoningDelta(e.Part.ReasoningTrace); err != nil { ... } // streamed only
case llm.PartToolRequest:
...
*toolRequests = append(*toolRequests, tr) // ← accumulated
if err := sw.writeToolRequest(tr); err != nil { ... }Only assistantParts := make([]*llm.Part, 0, len(toolRequests))
for _, tr := range toolRequests {
assistantParts = append(assistantParts, llm.NewToolRequestPart(tr))
}
messages = append(messages, llm.Message{Role: llm.RoleAssistant, Content: assistantParts})So the assistant turn the model sees on iteration N+1 contains only its tool_use blocks — its own preceding text and reasoning are gone. Specific code path that triggers itA real model commonly emits something like Why existing code doesn't prevent itThe new Step-by-step proofConsider a 2-turn agentic interaction with a model that emits text alongside its tool call. Initial
Impact
How to fixAccumulate text and reasoning parts in // In streamToolTurn:
var assistantContent []*llm.Part
// pass &assistantContent into handleToolTurnPart
// In handleToolTurnPart:
case llm.PartText:
if err := sw.writeTextDelta(e.Part.Text); err != nil { return true }
*assistantContent = append(*assistantContent, e.Part)
case llm.PartReasoning:
if err := sw.writeReasoningDelta(e.Part.ReasoningTrace); err != nil { return true }
*assistantContent = append(*assistantContent, e.Part)
case llm.PartToolRequest:
...
*toolRequests = append(*toolRequests, tr)
*assistantContent = append(*assistantContent, e.Part)
...
// In StreamModelWithTools, replace the assistantParts loop with:
messages = append(messages, llm.Message{Role: llm.RoleAssistant, Content: assistantContent})Note: text deltas from the same span should ideally be concatenated into a single Why this PR is the natural momentThis PR substantially modifies
claude[bot] marked this conversation as resolved.
|
||
| return | ||
| } | ||
| } | ||
|
|
||
| sw.closeSpans() | ||
|
|
||
| _ = ew.WriteChunk(Chunk{"type": "finish-step"}) | ||
|
Check failure on line 514 in adapter/vercelaisdk/uimessagestream/handler.go
|
||
|
claude[bot] marked this conversation as resolved.
Outdated
|
||
| _ = ew.WriteChunk(Chunk{"type": "finish", "finishReason": finishReasonOther}) | ||
| _ = ew.WriteDone() | ||
| } | ||
|
|
@@ -520,7 +551,9 @@ | |
| } | ||
|
|
||
| logger.Error("stream error", "error", err) | ||
|
|
||
| sw.closeSpans() | ||
|
Check failure on line 555 in adapter/vercelaisdk/uimessagestream/handler.go
|
||
|
claude[bot] marked this conversation as resolved.
|
||
|
|
||
| _ = ew.WriteChunk(Chunk{"type": "error", "errorText": "An error occurred"}) | ||
| _ = ew.WriteChunk(Chunk{"type": "finish-step"}) | ||
| _ = ew.WriteChunk(Chunk{"type": "finish", "finishReason": finishReasonError}) | ||
|
|
@@ -601,7 +634,7 @@ | |
| return reason | ||
| } | ||
|
|
||
| func executeTools(ctx context.Context, toolRequests []*llm.ToolRequest, messages *[]llm.Message, ew *EventWriter, executor ToolExecutor) error { | ||
| func executeTools(ctx context.Context, toolRequests []*llm.ToolRequest, messages *[]llm.Message, ew *EventWriter, logger *slog.Logger, executor ToolExecutor) error { | ||
| toolResponseParts := make([]*llm.Part, 0, len(toolRequests)) | ||
|
|
||
| for _, tr := range toolRequests { | ||
|
|
@@ -620,7 +653,9 @@ | |
|
|
||
| var output any | ||
| if len(result) > 0 { | ||
| _ = json.Unmarshal(result, &output) | ||
| if err := json.Unmarshal(result, &output); err != nil { | ||
| logger.Warn("failed to unmarshal tool result", "toolCallId", tr.ID, "error", err) | ||
| } | ||
| } | ||
|
|
||
| if err := ew.WriteChunk(Chunk{ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.