Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions src/acp-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import {
query,
Settings,
SDKAssistantMessageError,
SDKMessageOrigin,
SDKPartialAssistantMessage,
SDKUserMessage,
SlashCommand,
Expand Down Expand Up @@ -179,6 +180,7 @@ type BackgroundTerminal =
export type SDKMessageFilter = {
type: string;
subtype?: string;
origin?: SDKMessageOrigin["kind"];
};

/**
Expand Down Expand Up @@ -873,6 +875,12 @@ export class ClaudeAcpAgent implements Agent {
session.contextWindowSize = matchingModelUsage.contextWindow;
}

// Task-notification followups are autonomous work triggered by a
// task-notification system message, not by the user's prompt.
// They should not influence the user-turn lifecycle (stop reason,
// slash-command output forwarding) but their cost is real.
const isTaskNotification = message.origin?.kind === "task-notification";

// Send usage_update notification
if (lastAssistantTotalUsage !== null) {
await this.client.sessionUpdate({
Expand All @@ -885,12 +893,17 @@ export class ClaudeAcpAgent implements Agent {
amount: message.total_cost_usd,
currency: "USD",
},
...(message.origin && {
_meta: { "_claude/origin": message.origin },
}),
},
});
}

if (session.cancelled) {
stopReason = "cancelled";
if (!isTaskNotification) {
stopReason = "cancelled";
}
break;
}

Expand All @@ -900,7 +913,9 @@ export class ClaudeAcpAgent implements Agent {
throw RequestError.authRequired();
}
if (message.stop_reason === "max_tokens") {
stopReason = "max_tokens";
if (!isTaskNotification) {
stopReason = "max_tokens";
}
break;
}
if (message.is_error) {
Expand All @@ -911,7 +926,9 @@ export class ClaudeAcpAgent implements Agent {
}
// For local-only commands (no model invocation), the result
// text is the command output — forward it to the client.
if (isLocalOnlyCommand) {
// Task-notification followups never originate from a user
// slash command, so skip the forwarding for them.
if (isLocalOnlyCommand && !isTaskNotification) {
for (const notification of toAcpNotifications(
message.result,
"assistant",
Expand All @@ -927,7 +944,9 @@ export class ClaudeAcpAgent implements Agent {
}
case "error_during_execution": {
if (message.stop_reason === "max_tokens") {
stopReason = "max_tokens";
if (!isTaskNotification) {
stopReason = "max_tokens";
}
break;
}
if (message.is_error) {
Expand All @@ -936,7 +955,9 @@ export class ClaudeAcpAgent implements Agent {
message.errors.join(", ") || message.subtype,
);
}
stopReason = "end_turn";
if (!isTaskNotification) {
stopReason = "end_turn";
}
break;
}
case "error_max_budget_usd":
Expand All @@ -948,7 +969,9 @@ export class ClaudeAcpAgent implements Agent {
message.errors.join(", ") || message.subtype,
);
}
stopReason = "max_turn_requests";
if (!isTaskNotification) {
stopReason = "max_turn_requests";
}
break;
default:
unreachable(message, this.logger);
Expand Down Expand Up @@ -2034,12 +2057,15 @@ export class ClaudeAcpAgent implements Agent {

function shouldEmitRawMessage(
config: boolean | SDKMessageFilter[],
message: { type: string; subtype?: string },
message: { type: string; subtype?: string; origin?: SDKMessageOrigin },
): boolean {
if (config === true) return true;
if (config === false) return false;
return config.some(
(f) => f.type === message.type && (f.subtype === undefined || f.subtype === message.subtype),
(f) =>
f.type === message.type &&
(f.subtype === undefined || f.subtype === message.subtype) &&
(f.origin === undefined || f.origin === message.origin?.kind),
);
}

Expand Down
214 changes: 213 additions & 1 deletion src/tests/acp-agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
ClaudeAcpAgent,
claudeCliPath,
describeAlwaysAllow,
type SDKMessageFilter,
} from "../acp-agent.js";
import { Pushable } from "../utils.js";
import { query, SDKAssistantMessage } from "@anthropic-ai/claude-agent-sdk";
Expand Down Expand Up @@ -3021,7 +3022,7 @@ describe("emitRawSDKMessages", () => {
function injectSession(
agent: ClaudeAcpAgent,
messages: any[],
emitRawSDKMessages: boolean | { type: string; subtype?: string }[],
emitRawSDKMessages: boolean | SDKMessageFilter[],
) {
const input = new Pushable<any>();
async function* messageGenerator() {
Expand Down Expand Up @@ -3198,4 +3199,215 @@ describe("emitRawSDKMessages", () => {
expect(sdkMessages[0].params.message.subtype).toBe("compact_boundary");
expect(sdkMessages[1].params.message.type).toBe("result");
});

it("filter by origin kind only emits matching results", async () => {
const { agent, extNotifications } = createMockAgentWithExtNotification();
injectSession(
agent,
[
{ ...createResultMessage(), origin: { kind: "channel", server: "acp" } },
{ ...createResultMessage(), origin: { kind: "task-notification" } },
{ type: "system", subtype: "session_state_changed", state: "idle" },
],
[{ type: "result", origin: "task-notification" }],
);

await agent.prompt({ sessionId: "test-session", prompt: [{ type: "text", text: "test" }] });

const sdkMessages = extNotifications.filter((n) => n.method === "_claude/sdkMessage");
expect(sdkMessages).toHaveLength(1);
expect(sdkMessages[0].params.message.origin.kind).toBe("task-notification");
});

it("filter without origin matches results regardless of origin", async () => {
const { agent, extNotifications } = createMockAgentWithExtNotification();
injectSession(
agent,
[
{ ...createResultMessage(), origin: { kind: "channel", server: "acp" } },
{ ...createResultMessage(), origin: { kind: "task-notification" } },
{ type: "system", subtype: "session_state_changed", state: "idle" },
],
[{ type: "result" }],
);

await agent.prompt({ sessionId: "test-session", prompt: [{ type: "text", text: "test" }] });

const sdkMessages = extNotifications.filter((n) => n.method === "_claude/sdkMessage");
expect(sdkMessages).toHaveLength(2);
});
});

describe("result origin handling", () => {
function createMockAgentWithCapture() {
const updates: any[] = [];
const mockClient = {
sessionUpdate: async (notification: any) => {
updates.push(notification);
},
} as unknown as AgentSideConnection;
const agent = new ClaudeAcpAgent(mockClient, { log: () => {}, error: () => {} });
return { agent, updates };
}

function injectSession(agent: ClaudeAcpAgent, messages: any[]) {
const input = new Pushable<any>();
async function* messageGenerator() {
const iter = input[Symbol.asyncIterator]();
const { value: userMessage, done } = await iter.next();
if (!done && userMessage) {
yield {
type: "user",
message: userMessage.message,
parent_tool_use_id: null,
uuid: userMessage.uuid,
session_id: "test-session",
isReplay: true,
};
}
yield* messages;
}
agent.sessions["test-session"] = {
query: messageGenerator() as any,
input,
cancelled: false,
cwd: "/test",
sessionFingerprint: JSON.stringify({ cwd: "/test", mcpServers: [] }),
modes: { currentModeId: "default", availableModes: [] },
models: { currentModelId: "default", availableModels: [] },
modelInfos: [],
settingsManager: { dispose: vi.fn() } as any,
accumulatedUsage: {
inputTokens: 0,
outputTokens: 0,
cachedReadTokens: 0,
cachedWriteTokens: 0,
},
configOptions: [],
promptRunning: false,
pendingMessages: new Map(),
nextPendingOrder: 0,
abortController: new AbortController(),
emitRawSDKMessages: false,
contextWindowSize: 200000,
};
}

function createAssistantMessage() {
return {
type: "assistant" as const,
parent_tool_use_id: null,
uuid: randomUUID(),
session_id: "test-session",
message: {
model: "claude-sonnet-4-6",
content: [{ type: "text", text: "hello" }],
usage: {
input_tokens: 10,
output_tokens: 5,
cache_read_input_tokens: 0,
cache_creation_input_tokens: 0,
},
},
};
}

function createResult(overrides: Record<string, unknown> = {}) {
return {
type: "result" as const,
subtype: "success" as const,
stop_reason: "end_turn",
is_error: false,
result: "",
errors: [],
duration_ms: 0,
duration_api_ms: 0,
num_turns: 1,
total_cost_usd: 0.01,
usage: {
input_tokens: 10,
output_tokens: 5,
cache_read_input_tokens: 0,
cache_creation_input_tokens: 0,
},
modelUsage: {},
permission_denials: [],
uuid: randomUUID(),
session_id: "test-session",
...overrides,
};
}

it("forwards origin in usage_update _meta", async () => {
const { agent, updates } = createMockAgentWithCapture();
injectSession(agent, [
createAssistantMessage(),
createResult({ origin: { kind: "channel", server: "acp" } }),
{ type: "system", subtype: "session_state_changed", state: "idle" },
]);

await agent.prompt({ sessionId: "test-session", prompt: [{ type: "text", text: "test" }] });

const usageUpdate = updates.find((u: any) => u.update?.sessionUpdate === "usage_update");
expect(usageUpdate).toBeDefined();
expect(usageUpdate.update._meta).toEqual({
"_claude/origin": { kind: "channel", server: "acp" },
});
});

it("omits _meta when origin is absent", async () => {
const { agent, updates } = createMockAgentWithCapture();
injectSession(agent, [
createAssistantMessage(),
createResult(),
{ type: "system", subtype: "session_state_changed", state: "idle" },
]);

await agent.prompt({ sessionId: "test-session", prompt: [{ type: "text", text: "test" }] });

const usageUpdate = updates.find((u: any) => u.update?.sessionUpdate === "usage_update");
expect(usageUpdate).toBeDefined();
expect(usageUpdate.update._meta).toBeUndefined();
});

it("task-notification result with max_tokens does not override the user-turn stopReason", async () => {
const { agent } = createMockAgentWithCapture();
injectSession(agent, [
createAssistantMessage(),
// User-turn result completes normally
createResult({ origin: { kind: "channel", server: "acp" } }),
// Task-notification followup hits max_tokens — must not bleed into the user's stopReason
createResult({
stop_reason: "max_tokens",
origin: { kind: "task-notification" },
}),
{ type: "system", subtype: "session_state_changed", state: "idle" },
]);

const response = await agent.prompt({
sessionId: "test-session",
prompt: [{ type: "text", text: "test" }],
});

expect(response.stopReason).toBe("end_turn");
});

it("user-prompted result with max_tokens still sets stopReason", async () => {
const { agent } = createMockAgentWithCapture();
injectSession(agent, [
createAssistantMessage(),
createResult({
stop_reason: "max_tokens",
origin: { kind: "channel", server: "acp" },
}),
{ type: "system", subtype: "session_state_changed", state: "idle" },
]);

const response = await agent.prompt({
sessionId: "test-session",
prompt: [{ type: "text", text: "test" }],
});

expect(response.stopReason).toBe("max_tokens");
});
});
Loading