diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index e6847c50fa..30365fb7d9 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -68,6 +68,27 @@ class RunResponse(BaseModel): updated_at: str = "" +class ThreadTokenUsageModelBreakdown(BaseModel): + tokens: int = 0 + runs: int = 0 + + +class ThreadTokenUsageCallerBreakdown(BaseModel): + lead_agent: int = 0 + subagent: int = 0 + middleware: int = 0 + + +class ThreadTokenUsageResponse(BaseModel): + thread_id: str + total_tokens: int = 0 + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_runs: int = 0 + by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict) + by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -368,10 +389,10 @@ async def list_run_events( return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit) -@router.get("/{thread_id}/token-usage") +@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse) @require_permission("threads", "read", owner_check=True) -async def thread_token_usage(thread_id: str, request: Request) -> dict: +async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse: """Thread-level token usage aggregation.""" run_store = get_run_store(request) agg = await run_store.aggregate_tokens_by_thread(thread_id) - return {"thread_id": thread_id, **agg} + return ThreadTokenUsageResponse(thread_id=thread_id, **agg) diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 34ab9b492f..bff49206d9 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -166,6 +166,61 @@ async def test_update_run_completion_preserves_existing_fields(self, tmp_path): assert row["total_tokens"] == 100 await _cleanup() + @pytest.mark.anyio + async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("success-run", thread_id="t1", status="running") + await repo.update_run_completion( + "success-run", + status="success", + total_input_tokens=70, + total_output_tokens=30, + total_tokens=100, + lead_agent_tokens=80, + subagent_tokens=15, + middleware_tokens=5, + ) + await repo.put("error-run", thread_id="t1", status="running") + await repo.update_run_completion( + "error-run", + status="error", + total_input_tokens=20, + total_output_tokens=30, + total_tokens=50, + lead_agent_tokens=40, + subagent_tokens=10, + ) + await repo.put("running-run", thread_id="t1", status="running") + await repo.update_run_completion( + "running-run", + status="running", + total_input_tokens=900, + total_output_tokens=99, + total_tokens=999, + lead_agent_tokens=999, + ) + await repo.put("other-thread-run", thread_id="t2", status="running") + await repo.update_run_completion( + "other-thread-run", + status="success", + total_tokens=888, + lead_agent_tokens=888, + ) + + agg = await repo.aggregate_tokens_by_thread("t1") + + assert agg["total_tokens"] == 150 + assert agg["total_input_tokens"] == 90 + assert agg["total_output_tokens"] == 60 + assert agg["total_runs"] == 2 + assert agg["by_model"] == {"unknown": {"tokens": 150, "runs": 2}} + assert agg["by_caller"] == { + "lead_agent": 120, + "subagent": 25, + "middleware": 5, + } + await _cleanup() + @pytest.mark.anyio async def test_list_by_thread_ordered_desc(self, tmp_path): """list_by_thread returns newest first.""" diff --git a/backend/tests/test_thread_token_usage.py b/backend/tests/test_thread_token_usage.py new file mode 100644 index 0000000000..713f6aa5ff --- /dev/null +++ b/backend/tests/test_thread_token_usage.py @@ -0,0 +1,55 @@ +"""Tests for thread-level token usage aggregation API.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +from _router_auth_helpers import make_authed_test_app +from fastapi.testclient import TestClient + +from app.gateway.routers import thread_runs + + +def _make_app(run_store: MagicMock): + app = make_authed_test_app() + app.include_router(thread_runs.router) + app.state.run_store = run_store + return app + + +def test_thread_token_usage_returns_stable_shape(): + run_store = MagicMock() + run_store.aggregate_tokens_by_thread = AsyncMock( + return_value={ + "total_tokens": 150, + "total_input_tokens": 90, + "total_output_tokens": 60, + "total_runs": 2, + "by_model": {"unknown": {"tokens": 150, "runs": 2}}, + "by_caller": { + "lead_agent": 120, + "subagent": 25, + "middleware": 5, + }, + }, + ) + app = _make_app(run_store) + + with TestClient(app) as client: + response = client.get("/api/threads/thread-1/token-usage") + + assert response.status_code == 200 + assert response.json() == { + "thread_id": "thread-1", + "total_tokens": 150, + "total_input_tokens": 90, + "total_output_tokens": 60, + "total_runs": 2, + "by_model": {"unknown": {"tokens": 150, "runs": 2}}, + "by_caller": { + "lead_agent": 120, + "subagent": 25, + "middleware": 5, + }, + } + run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1") diff --git a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx index dbef7963b1..cfd444e4d2 100644 --- a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx @@ -26,7 +26,8 @@ import { useI18n } from "@/core/i18n/hooks"; import { useModels } from "@/core/models/hooks"; import { useNotification } from "@/core/notification/hooks"; import { useLocalSettings, useThreadSettings } from "@/core/settings"; -import { useThreadStream } from "@/core/threads/hooks"; +import { useThreadStream, useThreadTokenUsage } from "@/core/threads/hooks"; +import { threadTokenUsageToTokenUsage } from "@/core/threads/token-usage"; import { textOfMessage } from "@/core/threads/utils"; import { env } from "@/env"; import { cn } from "@/lib/utils"; @@ -42,15 +43,21 @@ export default function AgentChatPage() { const { agent } = useAgent(agent_name); - const { threadId, setThreadId, isNewThread, setIsNewThread } = + const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } = useThreadChat(); const [settings, setSettings] = useThreadSettings(threadId); const [localSettings, setLocalSettings] = useLocalSettings(); const { tokenUsageEnabled } = useModels(); + const threadTokenUsage = useThreadTokenUsage( + isNewThread || isMock ? undefined : threadId, + { enabled: tokenUsageEnabled && !isMock }, + ); + const backendTokenUsage = threadTokenUsageToTokenUsage(threadTokenUsage.data); const { showNotification } = useNotification(); const { thread, + pendingUsageMessages, sendMessage, isHistoryLoading, hasMoreHistory, @@ -58,6 +65,7 @@ export default function AgentChatPage() { } = useThreadStream({ threadId: isNewThread ? undefined : threadId, context: { ...settings.context, agent_name: agent_name }, + isMock, onStart: (createdThreadId) => { setThreadId(createdThreadId); setIsNewThread(false); @@ -141,8 +149,11 @@ export default function AgentChatPage() { setLocalSettings("tokenUsage", preferences) diff --git a/frontend/src/app/workspace/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/chats/[thread_id]/page.tsx index a97a6b2e24..a04fe994d3 100644 --- a/frontend/src/app/workspace/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/chats/[thread_id]/page.tsx @@ -25,7 +25,8 @@ import { useI18n } from "@/core/i18n/hooks"; import { useModels } from "@/core/models/hooks"; import { useNotification } from "@/core/notification/hooks"; import { useLocalSettings, useThreadSettings } from "@/core/settings"; -import { useThreadStream } from "@/core/threads/hooks"; +import { useThreadStream, useThreadTokenUsage } from "@/core/threads/hooks"; +import { threadTokenUsageToTokenUsage } from "@/core/threads/token-usage"; import { textOfMessage } from "@/core/threads/utils"; import { env } from "@/env"; import { cn } from "@/lib/utils"; @@ -44,6 +45,11 @@ export default function ChatPage() { const [settings, setSettings] = useThreadSettings(threadId); const [localSettings, setLocalSettings] = useLocalSettings(); const { tokenUsageEnabled } = useModels(); + const threadTokenUsage = useThreadTokenUsage( + isNewThread || isMock ? undefined : threadId, + { enabled: tokenUsageEnabled && !isMock }, + ); + const backendTokenUsage = threadTokenUsageToTokenUsage(threadTokenUsage.data); const mountedRef = useRef(false); useSpecificChatMode(); @@ -63,6 +69,7 @@ export default function ChatPage() { const { thread, + pendingUsageMessages, sendMessage, isUploading, isHistoryLoading, @@ -137,8 +144,11 @@ export default function ChatPage() {
setLocalSettings("tokenUsage", preferences) diff --git a/frontend/src/components/workspace/token-usage-indicator.tsx b/frontend/src/components/workspace/token-usage-indicator.tsx index 4e99c2f5b4..0fef98ab0e 100644 --- a/frontend/src/components/workspace/token-usage-indicator.tsx +++ b/frontend/src/components/workspace/token-usage-indicator.tsx @@ -15,7 +15,11 @@ import { DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { useI18n } from "@/core/i18n/hooks"; -import { accumulateUsage, formatTokenCount } from "@/core/messages/usage"; +import { + formatTokenCount, + selectHeaderTokenUsage, + type TokenUsage, +} from "@/core/messages/usage"; import { getTokenUsageViewPreset, tokenUsagePreferencesFromPreset, @@ -25,7 +29,10 @@ import { import { cn } from "@/lib/utils"; interface TokenUsageIndicatorProps { + threadId?: string; messages: Message[]; + pendingMessages?: Message[]; + backendUsage?: TokenUsage | null; enabled?: boolean; preferences: TokenUsagePreferences; onPreferencesChange: (preferences: TokenUsagePreferences) => void; @@ -33,7 +40,10 @@ interface TokenUsageIndicatorProps { } export function TokenUsageIndicator({ + threadId, messages, + pendingMessages, + backendUsage, enabled = false, preferences, onPreferencesChange, @@ -41,7 +51,15 @@ export function TokenUsageIndicator({ }: TokenUsageIndicatorProps) { const { t } = useI18n(); - const usage = useMemo(() => accumulateUsage(messages), [messages]); + const usage = useMemo( + () => + selectHeaderTokenUsage({ + backendUsage: threadId ? backendUsage : null, + messages, + pendingMessages, + }), + [backendUsage, messages, pendingMessages, threadId], + ); const preset = getTokenUsageViewPreset(preferences); if (!enabled) { diff --git a/frontend/src/core/i18n/locales/en-US.ts b/frontend/src/core/i18n/locales/en-US.ts index d2b9538b94..1e96fc3cb8 100644 --- a/frontend/src/core/i18n/locales/en-US.ts +++ b/frontend/src/core/i18n/locales/en-US.ts @@ -310,7 +310,7 @@ export const enUS: Translations = { unavailable: "No token usage yet. Usage appears only after a successful model response when the provider returns usage_metadata.", unavailableShort: "No usage returned", - note: "Shown from provider-returned usage_metadata. Totals are best-effort conversation totals and may differ from provider billing pages.", + note: "Header totals use persisted thread usage when available. Per-turn and debug usage come from visible messages. Totals may differ from provider billing pages.", presets: { off: "Off", summary: "Summary", diff --git a/frontend/src/core/i18n/locales/zh-CN.ts b/frontend/src/core/i18n/locales/zh-CN.ts index c4fc6b9457..634d89d47e 100644 --- a/frontend/src/core/i18n/locales/zh-CN.ts +++ b/frontend/src/core/i18n/locales/zh-CN.ts @@ -296,7 +296,7 @@ export const zhCN: Translations = { unavailable: "暂无 Token 用量。只有模型成功返回且供应商提供 usage_metadata 时才会显示。", unavailableShort: "未返回用量", - note: "基于供应商返回的 usage_metadata 展示。当前总量是 best-effort 的会话参考值,可能与平台账单页不完全一致。", + note: "顶部总量优先使用后端持久化的线程用量。每轮和调试用量来自当前可见消息,可能与平台账单页不完全一致。", presets: { off: "关闭", summary: "总览", diff --git a/frontend/src/core/messages/usage.ts b/frontend/src/core/messages/usage.ts index e9311b92c6..4679dffa5a 100644 --- a/frontend/src/core/messages/usage.ts +++ b/frontend/src/core/messages/usage.ts @@ -65,6 +65,40 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null { return hasUsage ? cumulative : null; } +function hasNonZeroUsage( + usage: TokenUsage | null | undefined, +): usage is TokenUsage { + return ( + usage !== null && + usage !== undefined && + (usage.inputTokens > 0 || usage.outputTokens > 0 || usage.totalTokens > 0) + ); +} + +function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { + return { + inputTokens: base.inputTokens + delta.inputTokens, + outputTokens: base.outputTokens + delta.outputTokens, + totalTokens: base.totalTokens + delta.totalTokens, + }; +} + +export function selectHeaderTokenUsage({ + backendUsage, + messages, + pendingMessages = [], +}: { + backendUsage?: TokenUsage | null; + messages: Message[]; + pendingMessages?: Message[]; +}): TokenUsage | null { + if (hasNonZeroUsage(backendUsage)) { + const pendingUsage = accumulateUsage(pendingMessages); + return pendingUsage ? addUsage(backendUsage, pendingUsage) : backendUsage; + } + return accumulateUsage(messages); +} + /** * Format a token count for display: 1234 -> "1,234", 12345 -> "12.3K" */ diff --git a/frontend/src/core/threads/api.ts b/frontend/src/core/threads/api.ts new file mode 100644 index 0000000000..1d1feb40f7 --- /dev/null +++ b/frontend/src/core/threads/api.ts @@ -0,0 +1,24 @@ +import { fetch as fetchWithAuth } from "@/core/api/fetcher"; +import { getBackendBaseURL } from "@/core/config"; + +import type { ThreadTokenUsageResponse } from "./types"; + +export async function fetchThreadTokenUsage( + threadId: string, +): Promise { + const response = await fetchWithAuth( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/token-usage`, + { + method: "GET", + }, + ); + + if (!response.ok) { + if (response.status === 403 || response.status === 404) { + return null; + } + throw new Error("Failed to load thread token usage."); + } + + return (await response.json()) as ThreadTokenUsageResponse; +} diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 249ea366b1..0ac790eb2b 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -17,7 +17,14 @@ import { useUpdateSubtask } from "../tasks/context"; import type { UploadedFileInfo } from "../uploads"; import { promptInputFilePartToFile, uploadFiles } from "../uploads"; -import type { AgentThread, AgentThreadState, RunMessage } from "./types"; +import { fetchThreadTokenUsage } from "./api"; +import { threadTokenUsageQueryKey } from "./token-usage"; +import type { + AgentThread, + AgentThreadState, + RunMessage, + ThreadTokenUsageResponse, +} from "./types"; export type ToolEndEvent = { name: string; @@ -75,6 +82,23 @@ function mergeMessages( ]; } +function messageIdentity(message: Message): string | undefined { + if ("tool_call_id" in message) { + return message.tool_call_id; + } + return message.id; +} + +function getMessagesAfterBaseline( + messages: Message[], + baselineMessageIds: ReadonlySet, +): Message[] { + return messages.filter((message) => { + const id = messageIdentity(message); + return !id || !baselineMessageIds.has(id); + }); +} + function getStreamErrorMessage(error: unknown): string { if (typeof error === "string" && error.trim()) { return error; @@ -114,6 +138,7 @@ export function useThreadStream({ // and to allow access to the current thread id in onUpdateEvent const threadIdRef = useRef(threadId ?? null); const startedRef = useRef(false); + const pendingUsageBaselineMessageIdsRef = useRef>(new Set()); const listeners = useRef({ onSend, onStart, @@ -271,29 +296,42 @@ export function useThreadStream({ onError(error) { setOptimisticMessages([]); toast.error(getStreamErrorMessage(error)); + pendingUsageBaselineMessageIdsRef.current = new Set(); + if (threadIdRef.current && !isMock) { + void queryClient.invalidateQueries({ + queryKey: threadTokenUsageQueryKey(threadIdRef.current), + }); + } }, onFinish(state) { listeners.current.onFinish?.(state.values); + pendingUsageBaselineMessageIdsRef.current = new Set(); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); + if (threadIdRef.current && !isMock) { + void queryClient.invalidateQueries({ + queryKey: threadTokenUsageQueryKey(threadIdRef.current), + }); + } }, }); // Optimistic messages shown before the server stream responds const [optimisticMessages, setOptimisticMessages] = useState([]); const [isUploading, setIsUploading] = useState(false); + const humanMessageCount = thread.messages.filter( + (m) => m.type === "human", + ).length; + const latestMessageCountsRef = useRef({ humanMessageCount }); const sendInFlightRef = useRef(false); const messagesRef = useRef([]); const summarizedRef = useRef>(null); - // Track message count before sending so we know when server has responded - const prevMsgCountRef = useRef(thread.messages.length); // Track human message count before sending to prevent clearing optimistic // messages before the server's human message arrives (e.g. when AI messages // from "messages-tuple" events arrive before the input human message from // "values" events). - const prevHumanMsgCountRef = useRef( - thread.messages.filter((m) => m.type === "human").length, - ); + const prevHumanMsgCountRef = useRef(humanMessageCount); + latestMessageCountsRef.current = { humanMessageCount }; summarizedRef.current ??= new Set(); // Reset thread-local pending UI state when switching between threads so @@ -301,31 +339,43 @@ export function useThreadStream({ useEffect(() => { startedRef.current = false; sendInFlightRef.current = false; - prevMsgCountRef.current = thread.messages.length; - prevHumanMsgCountRef.current = thread.messages.filter( - (m) => m.type === "human", - ).length; + pendingUsageBaselineMessageIdsRef.current = new Set(); + prevHumanMsgCountRef.current = + latestMessageCountsRef.current.humanMessageCount; }, [threadId]); + // When streaming starts without a baseline (e.g. reconnection, run started + // from another client, or page reload mid-stream), snapshot the current + // messages so only *new* messages are treated as "pending" for token usage. + useEffect(() => { + if ( + thread.isLoading && + pendingUsageBaselineMessageIdsRef.current.size === 0 + ) { + pendingUsageBaselineMessageIdsRef.current = new Set( + thread.messages + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); + } + }, [thread.isLoading, thread.messages]); + // Clear optimistic when server messages arrive. // For messages with a human optimistic message, wait until the server's // human message has arrived to avoid clearing before the input message // appears in the stream (the input message may arrive via "values" events // after individual "messages-tuple" events for AI messages). + const optimisticMessageCount = optimisticMessages.length; + const hasHumanOptimistic = optimisticMessages.some((m) => m.type === "human"); useEffect(() => { - if (optimisticMessages.length === 0) return; + if (optimisticMessageCount === 0) return; - const hasHumanOptimistic = optimisticMessages.some( - (m) => m.type === "human", - ); - const newHumanMsgArrived = - thread.messages.filter((m) => m.type === "human").length > - prevHumanMsgCountRef.current; + const newHumanMsgArrived = humanMessageCount > prevHumanMsgCountRef.current; if (!hasHumanOptimistic || newHumanMsgArrived) { setOptimisticMessages([]); } - }, [thread.messages.length, optimisticMessages.length]); + }, [hasHumanOptimistic, humanMessageCount, optimisticMessageCount]); const sendMessage = useCallback( async ( @@ -341,11 +391,14 @@ export function useThreadStream({ const text = message.text.trim(); - // Capture current count before showing optimistic messages - prevMsgCountRef.current = thread.messages.length; - prevHumanMsgCountRef.current = thread.messages.filter( - (m) => m.type === "human", - ).length; + // Capture the current human message count before showing optimistic + // messages so we can wait for the server's copy of the user input. + prevHumanMsgCountRef.current = humanMessageCount; + pendingUsageBaselineMessageIdsRef.current = new Set( + thread.messages + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); // Build optimistic files list with uploading status const optimisticFiles: FileInMessage[] = (message.files ?? []).map( @@ -517,7 +570,7 @@ export function useThreadStream({ sendInFlightRef.current = false; } }, - [thread, t.uploads.uploadingFiles, context, queryClient], + [thread, t.uploads.uploadingFiles, context, queryClient, humanMessageCount], ); // Cache the latest thread messages in a ref to compare against incoming history messages for deduplication, @@ -531,6 +584,12 @@ export function useThreadStream({ thread.messages, optimisticMessages, ); + const pendingUsageMessages = thread.isLoading + ? getMessagesAfterBaseline( + thread.messages, + pendingUsageBaselineMessageIdsRef.current, + ) + : []; // Merge history, live stream, and optimistic messages for display // History messages may overlap with thread.messages; thread.messages take precedence @@ -541,6 +600,7 @@ export function useThreadStream({ return { thread: mergedThread, + pendingUsageMessages, sendMessage, isUploading, isHistoryLoading, @@ -701,6 +761,24 @@ export function useThreadRuns(threadId?: string) { }); } +export function useThreadTokenUsage( + threadId?: string | null, + { enabled = true }: { enabled?: boolean } = {}, +) { + return useQuery({ + queryKey: threadTokenUsageQueryKey(threadId), + queryFn: async () => { + if (!threadId) { + return null; + } + return fetchThreadTokenUsage(threadId); + }, + enabled: enabled && Boolean(threadId), + retry: false, + refetchOnWindowFocus: false, + }); +} + export function useRunDetail(threadId: string, runId: string) { const apiClient = getAPIClient(); return useQuery({ diff --git a/frontend/src/core/threads/token-usage.ts b/frontend/src/core/threads/token-usage.ts new file mode 100644 index 0000000000..89455eef9e --- /dev/null +++ b/frontend/src/core/threads/token-usage.ts @@ -0,0 +1,20 @@ +import type { TokenUsage } from "@/core/messages/usage"; + +import type { ThreadTokenUsageResponse } from "./types"; + +export function threadTokenUsageQueryKey(threadId?: string | null) { + return ["thread-token-usage", threadId] as const; +} + +export function threadTokenUsageToTokenUsage( + usage: ThreadTokenUsageResponse | null | undefined, +): TokenUsage | null { + if (!usage) { + return null; + } + return { + inputTokens: usage.total_input_tokens ?? 0, + outputTokens: usage.total_output_tokens ?? 0, + totalTokens: usage.total_tokens ?? 0, + }; +} diff --git a/frontend/src/core/threads/types.ts b/frontend/src/core/threads/types.ts index 2c0263e53e..dafb073494 100644 --- a/frontend/src/core/threads/types.ts +++ b/frontend/src/core/threads/types.ts @@ -31,3 +31,17 @@ export interface RunMessage { }; created_at: string; } + +export interface ThreadTokenUsageResponse { + thread_id: string; + total_tokens: number; + total_input_tokens: number; + total_output_tokens: number; + total_runs: number; + by_model: Record; + by_caller: { + lead_agent: number; + subagent: number; + middleware: number; + }; +} diff --git a/frontend/tests/unit/core/messages/usage.test.ts b/frontend/tests/unit/core/messages/usage.test.ts index 1ec3756c45..6a4144f87d 100644 --- a/frontend/tests/unit/core/messages/usage.test.ts +++ b/frontend/tests/unit/core/messages/usage.test.ts @@ -1,7 +1,7 @@ import type { Message } from "@langchain/langgraph-sdk"; import { expect, test } from "vitest"; -import { accumulateUsage } from "@/core/messages/usage"; +import { accumulateUsage, selectHeaderTokenUsage } from "@/core/messages/usage"; import { getAssistantTurnUsageMessages, getMessageGroups, @@ -79,3 +79,86 @@ test("keeps header and per-turn aggregation consistent for duplicated UI groups" totalTokens: 27, }); }); + +test("prefers backend thread usage for header totals", () => { + const messages = [ + { + id: "ai-visible", + type: "ai", + content: "Visible answer", + usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 }, + }, + ] as Message[]; + + expect( + selectHeaderTokenUsage({ + backendUsage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + messages, + }), + ).toEqual({ + inputTokens: 100, + outputTokens: 50, + totalTokens: 150, + }); +}); + +test("adds current in-flight message usage to backend header totals", () => { + const completedMessages = [ + { + id: "ai-completed", + type: "ai", + content: "Completed answer", + usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 }, + }, + { + id: "ai-pending", + type: "ai", + content: "Streaming answer", + usage_metadata: { input_tokens: 4, output_tokens: 6, total_tokens: 10 }, + }, + ] as Message[]; + + expect( + selectHeaderTokenUsage({ + backendUsage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + messages: completedMessages, + pendingMessages: [completedMessages[1]!], + }), + ).toEqual({ + inputTokens: 104, + outputTokens: 56, + totalTokens: 160, + }); +}); + +test("falls back to visible messages when backend usage is unavailable or zero", () => { + const messages = [ + { + id: "ai-visible", + type: "ai", + content: "Visible answer", + usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 }, + }, + ] as Message[]; + + expect( + selectHeaderTokenUsage({ + backendUsage: null, + messages, + }), + ).toEqual({ + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }); + expect( + selectHeaderTokenUsage({ + backendUsage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + messages, + }), + ).toEqual({ + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }); +}); diff --git a/frontend/tests/unit/core/threads/api.test.ts b/frontend/tests/unit/core/threads/api.test.ts new file mode 100644 index 0000000000..d91a2bcdfc --- /dev/null +++ b/frontend/tests/unit/core/threads/api.test.ts @@ -0,0 +1,51 @@ +import { beforeEach, expect, test, vi } from "vitest"; + +const fetchWithAuth = vi.fn(); + +vi.mock("@/core/api/fetcher", () => ({ + fetch: fetchWithAuth, +})); + +beforeEach(() => { + fetchWithAuth.mockReset(); +}); + +test("fetchThreadTokenUsage uses shared auth fetch without JSON GET headers", async () => { + fetchWithAuth.mockResolvedValue({ + ok: true, + json: async () => ({ + thread_id: "thread-1", + total_input_tokens: 3, + total_output_tokens: 4, + total_tokens: 7, + total_runs: 1, + by_model: { unknown: { tokens: 7, runs: 1 } }, + by_caller: {}, + }), + }); + + const { fetchThreadTokenUsage } = await import("@/core/threads/api"); + + await expect(fetchThreadTokenUsage("thread-1")).resolves.toMatchObject({ + thread_id: "thread-1", + total_tokens: 7, + }); + + expect(fetchWithAuth).toHaveBeenCalledWith( + expect.stringContaining("/api/threads/thread-1/token-usage"), + { + method: "GET", + }, + ); +}); + +test("fetchThreadTokenUsage returns null for unavailable token usage", async () => { + fetchWithAuth.mockResolvedValue({ + ok: false, + status: 404, + }); + + const { fetchThreadTokenUsage } = await import("@/core/threads/api"); + + await expect(fetchThreadTokenUsage("thread-1")).resolves.toBeNull(); +}); diff --git a/frontend/tests/unit/core/threads/token-usage.test.ts b/frontend/tests/unit/core/threads/token-usage.test.ts new file mode 100644 index 0000000000..c6ba0978e1 --- /dev/null +++ b/frontend/tests/unit/core/threads/token-usage.test.ts @@ -0,0 +1,31 @@ +import { expect, test } from "vitest"; + +import { threadTokenUsageToTokenUsage } from "@/core/threads/token-usage"; +import type { ThreadTokenUsageResponse } from "@/core/threads/types"; + +test("maps backend thread token usage to UI token usage", () => { + const response: ThreadTokenUsageResponse = { + thread_id: "thread-1", + total_input_tokens: 90, + total_output_tokens: 60, + total_tokens: 150, + total_runs: 2, + by_model: { unknown: { tokens: 150, runs: 2 } }, + by_caller: { + lead_agent: 120, + subagent: 25, + middleware: 5, + }, + }; + + expect(threadTokenUsageToTokenUsage(response)).toEqual({ + inputTokens: 90, + outputTokens: 60, + totalTokens: 150, + }); +}); + +test("returns null when backend thread token usage is unavailable", () => { + expect(threadTokenUsageToTokenUsage(null)).toBeNull(); + expect(threadTokenUsageToTokenUsage(undefined)).toBeNull(); +});