Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 26 additions & 88 deletions frontend/src/components/workspace/messages/message-group.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ import {
import { CodeBlock } from "@/components/ai-elements/code-block";
import { Button } from "@/components/ui/button";
import { useI18n } from "@/core/i18n/hooks";
import {
extractReasoningContentFromMessage,
findToolCallResult,
} from "@/core/messages/utils";
import { useRehypeSplitWordsIntoSpans } from "@/core/rehype";
import {
convertToToolCallSteps,
partitionStepsForDisplay,
} from "@/core/tools/utils";
import { extractTitleFromMarkdown } from "@/core/utils/markdown";
import { env } from "@/env";
import { cn } from "@/lib/utils";
Expand All @@ -55,18 +55,15 @@ export function MessageGroup({
const [showLastThinking, setShowLastThinking] = useState(
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true",
);
const steps = useMemo(() => convertToSteps(messages), [messages]);
const lastToolCallStep = useMemo(() => {
const filteredSteps = steps.filter((step) => step.type === "toolCall");
return filteredSteps[filteredSteps.length - 1];
}, [steps]);
const aboveLastToolCallSteps = useMemo(() => {
if (lastToolCallStep) {
const index = steps.indexOf(lastToolCallStep);
return steps.slice(0, index);
}
return [];
}, [lastToolCallStep, steps]);
const steps = useMemo(() => convertToToolCallSteps(messages), [messages]);
const {
aboveSteps: aboveLastToolCallSteps,
activeSteps: activeToolCallSteps,
} = useMemo(() => partitionStepsForDisplay(steps), [steps]);
const lastToolCallStep = useMemo(
() => activeToolCallSteps[activeToolCallSteps.length - 1],
[activeToolCallSteps],
);
const lastReasoningStep = useMemo(() => {
if (lastToolCallStep) {
const index = steps.indexOf(lastToolCallStep);
Expand Down Expand Up @@ -127,16 +124,19 @@ export function MessageGroup({
<ToolCall key={step.id} {...step} isLoading={isLoading} />
),
)}
{lastToolCallStep && (
<FlipDisplay uniqueKey={lastToolCallStep.id ?? ""}>
<ToolCall
key={lastToolCallStep.id}
{...lastToolCallStep}
isLast={true}
isLoading={isLoading}
/>
</FlipDisplay>
)}
{activeToolCallSteps.map((step, index) => {
const isLast = index === activeToolCallSteps.length - 1;
return (
<FlipDisplay key={step.id} uniqueKey={step.id ?? ""}>
<ToolCall
key={step.id}
{...step}
isLast={isLast}
isLoading={isLoading}
/>
</FlipDisplay>
);
})}
</ChainOfThoughtContent>
)}
{lastReasoningStep && (
Expand Down Expand Up @@ -422,65 +422,3 @@ function ToolCall({
);
}
}

interface GenericCoTStep<T extends string = string> {
id?: string;
messageId?: string;
type: T;
}

interface CoTReasoningStep extends GenericCoTStep<"reasoning"> {
reasoning: string | null;
}

interface CoTToolCallStep extends GenericCoTStep<"toolCall"> {
name: string;
args: Record<string, unknown>;
result?: string;
}

type CoTStep = CoTReasoningStep | CoTToolCallStep;

function convertToSteps(messages: Message[]): CoTStep[] {
const steps: CoTStep[] = [];
for (const message of messages) {
if (message.type === "ai") {
const reasoning = extractReasoningContentFromMessage(message);
if (reasoning) {
const step: CoTReasoningStep = {
id: message.id,
messageId: message.id,
type: "reasoning",
reasoning,
};
steps.push(step);
}
for (const tool_call of message.tool_calls ?? []) {
if (tool_call.name === "task") {
continue;
}
const step: CoTToolCallStep = {
id: tool_call.id,
messageId: message.id,
type: "toolCall",
name: tool_call.name,
args: tool_call.args,
};
const toolCallId = tool_call.id;
if (toolCallId) {
const toolCallResult = findToolCallResult(toolCallId, messages);
if (toolCallResult) {
try {
const json = JSON.parse(toolCallResult);
step.result = json;
} catch {
step.result = toolCallResult;
}
}
}
steps.push(step);
}
}
}
return steps;
}
95 changes: 93 additions & 2 deletions frontend/src/core/tools/utils.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import type { ToolCall } from "@langchain/core/messages";
import type { AIMessage } from "@langchain/langgraph-sdk";
import type { AIMessage, Message } from "@langchain/langgraph-sdk";

import type { Translations } from "../i18n";
import { hasToolCalls } from "../messages/utils";
import {
extractReasoningContentFromMessage,
findToolCallResult,
hasToolCalls,
} from "../messages/utils";

export function explainLastToolCall(message: AIMessage, t: Translations) {
if (hasToolCalls(message)) {
Expand All @@ -27,3 +31,90 @@ export function explainToolCall(toolCall: ToolCall, t: Translations) {
return t.toolCalls.useTool(toolCall.name);
}
}

interface GenericCoTStep<T extends string = string> {
id?: string;
messageId?: string;
type: T;
}

export interface CoTReasoningStep extends GenericCoTStep<"reasoning"> {
reasoning: string | null;
}

export interface CoTToolCallStep extends GenericCoTStep<"toolCall"> {
name: string;
args: Record<string, unknown>;
result?: string | Record<string, unknown>;
}

export type CoTStep = CoTReasoningStep | CoTToolCallStep;

export function convertToToolCallSteps(messages: Message[]): CoTStep[] {
const steps: CoTStep[] = [];
for (const message of messages) {
if (message.type !== "ai") {
continue;
}
const reasoning = extractReasoningContentFromMessage(message);
if (reasoning) {
steps.push({
id: message.id,
messageId: message.id,
type: "reasoning",
reasoning,
});
}
for (const tool_call of message.tool_calls ?? []) {
if (tool_call.name === "task") {
continue;
}
const step: CoTToolCallStep = {
id: tool_call.id,
messageId: message.id,
type: "toolCall",
name: tool_call.name,
args: tool_call.args,
};
const toolCallId = tool_call.id;
if (toolCallId) {
const toolCallResult = findToolCallResult(toolCallId, messages);
if (toolCallResult) {
try {
step.result = JSON.parse(toolCallResult);
} catch {
step.result = toolCallResult;
}
}
}
steps.push(step);
}
}
return steps;
}

export function partitionStepsForDisplay(steps: CoTStep[]): {
aboveSteps: CoTStep[];
activeSteps: CoTToolCallStep[];
} {
const toolCallSteps = steps.filter(
(step): step is CoTToolCallStep => step.type === "toolCall",
);
if (toolCallSteps.length === 0) {
return { aboveSteps: [], activeSteps: [] };
}

const lastAIMessageId = toolCallSteps[toolCallSteps.length - 1]!.messageId;
const activeSteps = toolCallSteps.filter(
(step) =>
step.messageId !== undefined && step.messageId === lastAIMessageId,
);
if (activeSteps.length === 0) {
return { aboveSteps: [], activeSteps: [] };
}

const firstActiveIndex = steps.indexOf(activeSteps[0]!);
const aboveSteps = steps.slice(0, firstActiveIndex);

return { aboveSteps, activeSteps };
}
117 changes: 117 additions & 0 deletions frontend/tests/unit/core/tools/parallel-tool-calls.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";

import {
convertToToolCallSteps,
partitionStepsForDisplay,
} from "@/core/tools/utils";

function aiMessage(
id: string,
toolCalls: { id: string; name: string; args?: Record<string, unknown> }[],
reasoning?: string,
): Message {
return {
type: "ai",
id,
content: "",
tool_calls: toolCalls.map((tc) => ({
id: tc.id,
name: tc.name,
args: tc.args ?? {},
})),
additional_kwargs: reasoning ? { reasoning_content: reasoning } : {},
} as Message;
}

function toolMessage(toolCallId: string, content: string): Message {
return {
type: "tool",
id: `tool-msg-${toolCallId}`,
tool_call_id: toolCallId,
content,
} as Message;
}

test("a single tool call in the latest AI message stays the only active step", () => {
const messages: Message[] = [
aiMessage("ai-1", [
{ id: "call-1", name: "web_search", args: { query: "x" } },
]),
toolMessage("call-1", "[]"),
aiMessage("ai-2", [
{ id: "call-2", name: "web_fetch", args: { url: "u" } },
]),
];
const steps = convertToToolCallSteps(messages);
const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps);

expect(activeSteps.map((s) => s.id)).toEqual(["call-2"]);
expect(aboveSteps.map((s) => s.id)).toEqual(["call-1"]);
});

test("all parallel siblings stay active until each one completes", () => {
const messages: Message[] = [
aiMessage("ai-1", [
{ id: "call-1", name: "web_search", args: { query: "a" } },
{ id: "call-2", name: "web_search", args: { query: "b" } },
{ id: "call-3", name: "web_search", args: { query: "c" } },
]),
toolMessage("call-2", "[]"),
];
const steps = convertToToolCallSteps(messages);
const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps);

expect(activeSteps.map((s) => s.id)).toEqual(["call-1", "call-2", "call-3"]);
expect(aboveSteps).toEqual([]);
});

test("parallel tool results pair by tool_call_id regardless of arrival order", () => {
const messages: Message[] = [
aiMessage("ai-1", [
{ id: "call-1", name: "web_search", args: { query: "a" } },
{ id: "call-2", name: "web_search", args: { query: "b" } },
]),
toolMessage("call-2", '[{"url":"u2","title":"t2"}]'),
toolMessage("call-1", '[{"url":"u1","title":"t1"}]'),
];
const steps = convertToToolCallSteps(messages);
const toolSteps = steps.filter((s) => s.type === "toolCall");
const byId = new Map(toolSteps.map((s) => [s.id, s]));
expect(byId.get("call-1")?.result).toEqual([{ url: "u1", title: "t1" }]);
expect(byId.get("call-2")?.result).toEqual([{ url: "u2", title: "t2" }]);
});

test("reasoning emitted with parallel tool calls stays visible above the active batch", () => {
const messages: Message[] = [
aiMessage(
"ai-1",
[
{ id: "call-1", name: "web_search", args: { query: "a" } },
{ id: "call-2", name: "web_search", args: { query: "b" } },
],
"considering both queries in parallel",
),
];
const steps = convertToToolCallSteps(messages);
const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps);

expect(aboveSteps.map((s) => s.type)).toEqual(["reasoning"]);
expect(activeSteps.map((s) => s.id)).toEqual(["call-1", "call-2"]);
});

test("earlier serial tool calls collapse above a fresh parallel batch", () => {
const messages: Message[] = [
aiMessage("ai-1", [{ id: "call-0", name: "ls", args: { path: "/" } }]),
toolMessage("call-0", "ok"),
aiMessage("ai-2", [
{ id: "call-1", name: "web_search", args: { query: "a" } },
{ id: "call-2", name: "web_search", args: { query: "b" } },
]),
];
const steps = convertToToolCallSteps(messages);
const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps);

expect(aboveSteps.map((s) => s.id)).toEqual(["call-0"]);
expect(activeSteps.map((s) => s.id)).toEqual(["call-1", "call-2"]);
});