diff --git a/frontend/src/components/workspace/messages/message-group.tsx b/frontend/src/components/workspace/messages/message-group.tsx index 0b2d2519c0..67d9f9519e 100644 --- a/frontend/src/components/workspace/messages/message-group.tsx +++ b/frontend/src/components/workspace/messages/message-group.tsx @@ -24,11 +24,11 @@ import { import { CodeBlock } from "@/components/ai-elements/code-block"; import { Button } from "@/components/ui/button"; import { useI18n } from "@/core/i18n/hooks"; -import { - extractReasoningContentFromMessage, - findToolCallResult, -} from "@/core/messages/utils"; import { useRehypeSplitWordsIntoSpans } from "@/core/rehype"; +import { + convertToToolCallSteps, + partitionStepsForDisplay, +} from "@/core/tools/utils"; import { extractTitleFromMarkdown } from "@/core/utils/markdown"; import { env } from "@/env"; import { cn } from "@/lib/utils"; @@ -55,18 +55,15 @@ export function MessageGroup({ const [showLastThinking, setShowLastThinking] = useState( env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true", ); - const steps = useMemo(() => convertToSteps(messages), [messages]); - const lastToolCallStep = useMemo(() => { - const filteredSteps = steps.filter((step) => step.type === "toolCall"); - return filteredSteps[filteredSteps.length - 1]; - }, [steps]); - const aboveLastToolCallSteps = useMemo(() => { - if (lastToolCallStep) { - const index = steps.indexOf(lastToolCallStep); - return steps.slice(0, index); - } - return []; - }, [lastToolCallStep, steps]); + const steps = useMemo(() => convertToToolCallSteps(messages), [messages]); + const { + aboveSteps: aboveLastToolCallSteps, + activeSteps: activeToolCallSteps, + } = useMemo(() => partitionStepsForDisplay(steps), [steps]); + const lastToolCallStep = useMemo( + () => activeToolCallSteps[activeToolCallSteps.length - 1], + [activeToolCallSteps], + ); const lastReasoningStep = useMemo(() => { if (lastToolCallStep) { const index = steps.indexOf(lastToolCallStep); @@ -127,16 +124,19 @@ export function MessageGroup({ ), )} - {lastToolCallStep && ( - - - - )} + {activeToolCallSteps.map((step, index) => { + const isLast = index === activeToolCallSteps.length - 1; + return ( + + + + ); + })} )} {lastReasoningStep && ( @@ -422,65 +422,3 @@ function ToolCall({ ); } } - -interface GenericCoTStep { - id?: string; - messageId?: string; - type: T; -} - -interface CoTReasoningStep extends GenericCoTStep<"reasoning"> { - reasoning: string | null; -} - -interface CoTToolCallStep extends GenericCoTStep<"toolCall"> { - name: string; - args: Record; - result?: string; -} - -type CoTStep = CoTReasoningStep | CoTToolCallStep; - -function convertToSteps(messages: Message[]): CoTStep[] { - const steps: CoTStep[] = []; - for (const message of messages) { - if (message.type === "ai") { - const reasoning = extractReasoningContentFromMessage(message); - if (reasoning) { - const step: CoTReasoningStep = { - id: message.id, - messageId: message.id, - type: "reasoning", - reasoning, - }; - steps.push(step); - } - for (const tool_call of message.tool_calls ?? []) { - if (tool_call.name === "task") { - continue; - } - const step: CoTToolCallStep = { - id: tool_call.id, - messageId: message.id, - type: "toolCall", - name: tool_call.name, - args: tool_call.args, - }; - const toolCallId = tool_call.id; - if (toolCallId) { - const toolCallResult = findToolCallResult(toolCallId, messages); - if (toolCallResult) { - try { - const json = JSON.parse(toolCallResult); - step.result = json; - } catch { - step.result = toolCallResult; - } - } - } - steps.push(step); - } - } - } - return steps; -} diff --git a/frontend/src/core/tools/utils.ts b/frontend/src/core/tools/utils.ts index 10f8c6ffa8..c0c69608dd 100644 --- a/frontend/src/core/tools/utils.ts +++ b/frontend/src/core/tools/utils.ts @@ -1,8 +1,12 @@ import type { ToolCall } from "@langchain/core/messages"; -import type { AIMessage } from "@langchain/langgraph-sdk"; +import type { AIMessage, Message } from "@langchain/langgraph-sdk"; import type { Translations } from "../i18n"; -import { hasToolCalls } from "../messages/utils"; +import { + extractReasoningContentFromMessage, + findToolCallResult, + hasToolCalls, +} from "../messages/utils"; export function explainLastToolCall(message: AIMessage, t: Translations) { if (hasToolCalls(message)) { @@ -27,3 +31,90 @@ export function explainToolCall(toolCall: ToolCall, t: Translations) { return t.toolCalls.useTool(toolCall.name); } } + +interface GenericCoTStep { + id?: string; + messageId?: string; + type: T; +} + +export interface CoTReasoningStep extends GenericCoTStep<"reasoning"> { + reasoning: string | null; +} + +export interface CoTToolCallStep extends GenericCoTStep<"toolCall"> { + name: string; + args: Record; + result?: string | Record; +} + +export type CoTStep = CoTReasoningStep | CoTToolCallStep; + +export function convertToToolCallSteps(messages: Message[]): CoTStep[] { + const steps: CoTStep[] = []; + for (const message of messages) { + if (message.type !== "ai") { + continue; + } + const reasoning = extractReasoningContentFromMessage(message); + if (reasoning) { + steps.push({ + id: message.id, + messageId: message.id, + type: "reasoning", + reasoning, + }); + } + for (const tool_call of message.tool_calls ?? []) { + if (tool_call.name === "task") { + continue; + } + const step: CoTToolCallStep = { + id: tool_call.id, + messageId: message.id, + type: "toolCall", + name: tool_call.name, + args: tool_call.args, + }; + const toolCallId = tool_call.id; + if (toolCallId) { + const toolCallResult = findToolCallResult(toolCallId, messages); + if (toolCallResult) { + try { + step.result = JSON.parse(toolCallResult); + } catch { + step.result = toolCallResult; + } + } + } + steps.push(step); + } + } + return steps; +} + +export function partitionStepsForDisplay(steps: CoTStep[]): { + aboveSteps: CoTStep[]; + activeSteps: CoTToolCallStep[]; +} { + const toolCallSteps = steps.filter( + (step): step is CoTToolCallStep => step.type === "toolCall", + ); + if (toolCallSteps.length === 0) { + return { aboveSteps: [], activeSteps: [] }; + } + + const lastAIMessageId = toolCallSteps[toolCallSteps.length - 1]!.messageId; + const activeSteps = toolCallSteps.filter( + (step) => + step.messageId !== undefined && step.messageId === lastAIMessageId, + ); + if (activeSteps.length === 0) { + return { aboveSteps: [], activeSteps: [] }; + } + + const firstActiveIndex = steps.indexOf(activeSteps[0]!); + const aboveSteps = steps.slice(0, firstActiveIndex); + + return { aboveSteps, activeSteps }; +} diff --git a/frontend/tests/unit/core/tools/parallel-tool-calls.test.ts b/frontend/tests/unit/core/tools/parallel-tool-calls.test.ts new file mode 100644 index 0000000000..b61d8f42eb --- /dev/null +++ b/frontend/tests/unit/core/tools/parallel-tool-calls.test.ts @@ -0,0 +1,117 @@ +import type { Message } from "@langchain/langgraph-sdk"; +import { expect, test } from "vitest"; + +import { + convertToToolCallSteps, + partitionStepsForDisplay, +} from "@/core/tools/utils"; + +function aiMessage( + id: string, + toolCalls: { id: string; name: string; args?: Record }[], + reasoning?: string, +): Message { + return { + type: "ai", + id, + content: "", + tool_calls: toolCalls.map((tc) => ({ + id: tc.id, + name: tc.name, + args: tc.args ?? {}, + })), + additional_kwargs: reasoning ? { reasoning_content: reasoning } : {}, + } as Message; +} + +function toolMessage(toolCallId: string, content: string): Message { + return { + type: "tool", + id: `tool-msg-${toolCallId}`, + tool_call_id: toolCallId, + content, + } as Message; +} + +test("a single tool call in the latest AI message stays the only active step", () => { + const messages: Message[] = [ + aiMessage("ai-1", [ + { id: "call-1", name: "web_search", args: { query: "x" } }, + ]), + toolMessage("call-1", "[]"), + aiMessage("ai-2", [ + { id: "call-2", name: "web_fetch", args: { url: "u" } }, + ]), + ]; + const steps = convertToToolCallSteps(messages); + const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps); + + expect(activeSteps.map((s) => s.id)).toEqual(["call-2"]); + expect(aboveSteps.map((s) => s.id)).toEqual(["call-1"]); +}); + +test("all parallel siblings stay active until each one completes", () => { + const messages: Message[] = [ + aiMessage("ai-1", [ + { id: "call-1", name: "web_search", args: { query: "a" } }, + { id: "call-2", name: "web_search", args: { query: "b" } }, + { id: "call-3", name: "web_search", args: { query: "c" } }, + ]), + toolMessage("call-2", "[]"), + ]; + const steps = convertToToolCallSteps(messages); + const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps); + + expect(activeSteps.map((s) => s.id)).toEqual(["call-1", "call-2", "call-3"]); + expect(aboveSteps).toEqual([]); +}); + +test("parallel tool results pair by tool_call_id regardless of arrival order", () => { + const messages: Message[] = [ + aiMessage("ai-1", [ + { id: "call-1", name: "web_search", args: { query: "a" } }, + { id: "call-2", name: "web_search", args: { query: "b" } }, + ]), + toolMessage("call-2", '[{"url":"u2","title":"t2"}]'), + toolMessage("call-1", '[{"url":"u1","title":"t1"}]'), + ]; + const steps = convertToToolCallSteps(messages); + const toolSteps = steps.filter((s) => s.type === "toolCall"); + const byId = new Map(toolSteps.map((s) => [s.id, s])); + expect(byId.get("call-1")?.result).toEqual([{ url: "u1", title: "t1" }]); + expect(byId.get("call-2")?.result).toEqual([{ url: "u2", title: "t2" }]); +}); + +test("reasoning emitted with parallel tool calls stays visible above the active batch", () => { + const messages: Message[] = [ + aiMessage( + "ai-1", + [ + { id: "call-1", name: "web_search", args: { query: "a" } }, + { id: "call-2", name: "web_search", args: { query: "b" } }, + ], + "considering both queries in parallel", + ), + ]; + const steps = convertToToolCallSteps(messages); + const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps); + + expect(aboveSteps.map((s) => s.type)).toEqual(["reasoning"]); + expect(activeSteps.map((s) => s.id)).toEqual(["call-1", "call-2"]); +}); + +test("earlier serial tool calls collapse above a fresh parallel batch", () => { + const messages: Message[] = [ + aiMessage("ai-1", [{ id: "call-0", name: "ls", args: { path: "/" } }]), + toolMessage("call-0", "ok"), + aiMessage("ai-2", [ + { id: "call-1", name: "web_search", args: { query: "a" } }, + { id: "call-2", name: "web_search", args: { query: "b" } }, + ]), + ]; + const steps = convertToToolCallSteps(messages); + const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps); + + expect(aboveSteps.map((s) => s.id)).toEqual(["call-0"]); + expect(activeSteps.map((s) => s.id)).toEqual(["call-1", "call-2"]); +});