From 08ecccd7c29b1e03dbd6021a60ea5ae93aea0ad4 Mon Sep 17 00:00:00 2001 From: Rohan Patra Date: Tue, 10 Mar 2026 15:08:43 -0400 Subject: [PATCH] feat: add conversation forking support with message ID tracking --- src/acp-agent.ts | 91 ++++++++++++++++++++++++++++++++----- src/tests/acp-agent.test.ts | 1 + 2 files changed, 80 insertions(+), 12 deletions(-) diff --git a/src/acp-agent.ts b/src/acp-agent.ts index 04badb0d..1a33f79d 100644 --- a/src/acp-agent.ts +++ b/src/acp-agent.ts @@ -117,6 +117,8 @@ type Session = { promptRunning: boolean; pendingMessages: Map void; order: number }>; nextPendingOrder: number; + messageIdMap: Map; // ACP user messageId -> preceding SDK assistant UUID + lastAssistantUuid?: string; }; type BackgroundTerminal = @@ -375,6 +377,18 @@ export class ClaudeAcpAgent implements Agent { } async unstable_forkSession(params: ForkSessionRequest): Promise { + // Support forking at a specific message for "edit previous message" UX. + // Zed passes atMessageId via _meta to indicate where to fork. + const atMessageId = (params._meta as any)?.atMessageId as string | undefined; + let resumeSessionAt: string | undefined; + if (atMessageId) { + const sourceSession = this.sessions[params.sessionId]; + const precedingAssistantUuid = sourceSession?.messageIdMap.get(atMessageId); + if (precedingAssistantUuid) { + resumeSessionAt = precedingAssistantUuid; + } + } + const response = await this.createSession( { cwd: params.cwd, @@ -384,6 +398,7 @@ export class ClaudeAcpAgent implements Agent { { resume: params.sessionId, forkSession: true, + ...(resumeSessionAt && { resumeSessionAt }), }, ); // Needs to happen after we return the session @@ -478,11 +493,18 @@ export class ClaudeAcpAgent implements Agent { }; let lastAssistantTotalUsage: number | null = null; + const userMessageId = params.messageId ?? undefined; const userMessage = promptToClaude(params); + // Track message ID mapping: this user message -> preceding assistant UUID + if (params.messageId && session.lastAssistantUuid) { + session.messageIdMap.set(params.messageId, session.lastAssistantUuid); + } + if (session.promptRunning) { - const uuid = randomUUID(); + const uuid = + (params.messageId as `${string}-${string}-${string}-${string}-${string}`) ?? randomUUID(); userMessage.uuid = uuid; session.input.push(userMessage); const order = session.nextPendingOrder++; @@ -490,7 +512,7 @@ export class ClaudeAcpAgent implements Agent { session.pendingMessages.set(uuid, { resolve, order }); }); if (cancelled) { - return { stopReason: "cancelled" }; + return { stopReason: "cancelled", userMessageId }; } } else { session.input.push(userMessage); @@ -505,7 +527,7 @@ export class ClaudeAcpAgent implements Agent { if (done || !message) { if (session.cancelled) { - return { stopReason: "cancelled" }; + return { stopReason: "cancelled", userMessageId }; } break; } @@ -567,7 +589,7 @@ export class ClaudeAcpAgent implements Agent { break; case "result": { if (session.cancelled) { - return { stopReason: "cancelled" }; + return { stopReason: "cancelled", userMessageId }; } // Accumulate usage from this result @@ -616,16 +638,28 @@ export class ClaudeAcpAgent implements Agent { throw RequestError.authRequired(); } if (message.stop_reason === "max_tokens") { - return { stopReason: "max_tokens", usage }; + return { + stopReason: "max_tokens", + usage, + userMessageId, + }; } if (message.is_error) { throw RequestError.internalError(undefined, message.result); } - return { stopReason: "end_turn", usage }; + return { + stopReason: "end_turn", + usage, + userMessageId, + }; } case "error_during_execution": if (message.stop_reason === "max_tokens") { - return { stopReason: "max_tokens", usage }; + return { + stopReason: "max_tokens", + usage, + userMessageId, + }; } if (message.is_error) { throw RequestError.internalError( @@ -633,7 +667,11 @@ export class ClaudeAcpAgent implements Agent { message.errors.join(", ") || message.subtype, ); } - return { stopReason: "end_turn", usage }; + return { + stopReason: "end_turn", + usage, + userMessageId, + }; case "error_max_budget_usd": case "error_max_turns": case "error_max_structured_output_retries": @@ -643,7 +681,11 @@ export class ClaudeAcpAgent implements Agent { message.errors.join(", ") || message.subtype, ); } - return { stopReason: "max_turn_requests", usage }; + return { + stopReason: "max_turn_requests", + usage, + userMessageId, + }; default: unreachable(message, this.logger); break; @@ -660,6 +702,7 @@ export class ClaudeAcpAgent implements Agent { { clientCapabilities: this.clientCapabilities, cwd: session.cwd, + messageId: session.lastAssistantUuid, }, )) { await this.client.sessionUpdate(notification); @@ -681,7 +724,7 @@ export class ClaudeAcpAgent implements Agent { handedOff = true; // the current loop stops with end_turn, // the loop of the next prompt continues running - return { stopReason: "end_turn" }; + return { stopReason: "end_turn", userMessageId }; } if ("isReplay" in message && message.isReplay) { // not pending or unrelated replay message @@ -689,6 +732,16 @@ export class ClaudeAcpAgent implements Agent { } } + // Track top-level assistant message UUIDs for message editing (fork-at support) + if ( + message.type === "assistant" && + message.parent_tool_use_id === null && + "uuid" in message && + message.uuid + ) { + session.lastAssistantUuid = message.uuid as string; + } + // Store latest assistant usage (excluding subagents) if ((message.message as any).usage && message.parent_tool_use_id === null) { const messageWithUsage = message.message as unknown as SDKResultMessage; @@ -773,6 +826,7 @@ export class ClaudeAcpAgent implements Agent { clientCapabilities: this.clientCapabilities, parentToolUseId: message.parent_tool_use_id, cwd: session.cwd, + messageId: "uuid" in message ? (message.uuid as string) : undefined, }, )) { await this.client.sessionUpdate(notification); @@ -1139,7 +1193,7 @@ export class ClaudeAcpAgent implements Agent { private async createSession( params: NewSessionRequest, - creationOpts: { resume?: string; forkSession?: boolean } = {}, + creationOpts: { resume?: string; forkSession?: boolean; resumeSessionAt?: string } = {}, ): Promise { // We want to create a new session id unless it is resume, // but not resume + forkSession. @@ -1376,6 +1430,7 @@ export class ClaudeAcpAgent implements Agent { promptRunning: false, pendingMessages: new Map(), nextPendingOrder: 0, + messageIdMap: new Map(), }; return { @@ -1653,7 +1708,7 @@ export function promptToClaude(prompt: PromptRequest): SDKUserMessage { content.push(...context); - return { + const msg: SDKUserMessage = { type: "user", message: { role: "user", @@ -1662,6 +1717,10 @@ export function promptToClaude(prompt: PromptRequest): SDKUserMessage { session_id: prompt.sessionId, parent_tool_use_id: null, }; + if (prompt.messageId) { + msg.uuid = prompt.messageId as `${string}-${string}-${string}-${string}-${string}`; + } + return msg; } /** @@ -1680,6 +1739,7 @@ export function toAcpNotifications( clientCapabilities?: ClientCapabilities; parentToolUseId?: string | null; cwd?: string; + messageId?: string; }, ): SessionNotification[] { const registerHooks = options?.registerHooks !== false; @@ -1691,6 +1751,7 @@ export function toAcpNotifications( type: "text", text: content, }, + ...(options?.messageId && { messageId: options.messageId }), }; if (options?.parentToolUseId) { @@ -1719,6 +1780,7 @@ export function toAcpNotifications( type: "text", text: chunk.text, }, + ...(options?.messageId && { messageId: options.messageId }), }; break; case "image": @@ -1730,6 +1792,7 @@ export function toAcpNotifications( mimeType: chunk.source.type === "base64" ? chunk.source.media_type : "", uri: chunk.source.type === "url" ? chunk.source.url : undefined, }, + ...(options?.messageId && { messageId: options.messageId }), }; break; case "thinking": @@ -1740,6 +1803,7 @@ export function toAcpNotifications( type: "text", text: chunk.thinking, }, + ...(options?.messageId && { messageId: options.messageId }), }; break; case "tool_use": @@ -1935,6 +1999,7 @@ export function streamEventToAcpNotifications( options?: { clientCapabilities?: ClientCapabilities; cwd?: string; + messageId?: string; }, ): SessionNotification[] { const event = message.event; @@ -1951,6 +2016,7 @@ export function streamEventToAcpNotifications( clientCapabilities: options?.clientCapabilities, parentToolUseId: message.parent_tool_use_id, cwd: options?.cwd, + messageId: options?.messageId, }, ); case "content_block_delta": @@ -1965,6 +2031,7 @@ export function streamEventToAcpNotifications( clientCapabilities: options?.clientCapabilities, parentToolUseId: message.parent_tool_use_id, cwd: options?.cwd, + messageId: options?.messageId, }, ); // No content diff --git a/src/tests/acp-agent.test.ts b/src/tests/acp-agent.test.ts index d665b739..769f8779 100644 --- a/src/tests/acp-agent.test.ts +++ b/src/tests/acp-agent.test.ts @@ -1328,6 +1328,7 @@ describe("stop reason propagation", () => { promptRunning: false, pendingMessages: new Map(), nextPendingOrder: 0, + messageIdMap: new Map(), }; }