diff --git a/src/lib/agents/search/api.ts b/src/lib/agents/search/api.ts index 94b70ae9b..775d474d2 100644 --- a/src/lib/agents/search/api.ts +++ b/src/lib/agents/search/api.ts @@ -4,6 +4,7 @@ import { classify } from './classifier'; import Researcher from './researcher'; import { getWriterPrompt } from '@/lib/prompts/search/writer'; import { WidgetExecutor } from './widgets'; +import { buildSearchResultsContext } from './context'; class APISearchAgent { async searchAsync(session: SessionManager, input: SearchAgentInput) { @@ -52,13 +53,9 @@ class APISearchAgent { type: 'researchComplete', }); - const finalContext = - searchResults?.searchFindings - .map( - (f, index) => - `${f.content}`, - ) - .join('\n') || ''; + const finalContext = buildSearchResultsContext( + searchResults?.searchFindings || [], + ); const widgetContext = widgetOutputs .map((o) => { diff --git a/src/lib/agents/search/context.ts b/src/lib/agents/search/context.ts new file mode 100644 index 000000000..32a9bbf7a --- /dev/null +++ b/src/lib/agents/search/context.ts @@ -0,0 +1,60 @@ +import { Chunk } from '@/lib/types'; +import { getTokenCount, truncateTextByTokens } from '@/lib/utils/splitText'; + +const MAX_TOTAL_SEARCH_CONTEXT_TOKENS = 20000; +const MAX_RESULT_CONTEXT_TOKENS = 2500; +const TRUNCATION_NOTE = + '\n[Result content truncated to fit the model context window.]'; + +const escapeAttribute = (value: string) => + value.replace(/[<>]/g, '').replace(/"/g, '"'); + +export const buildSearchResultsContext = (searchFindings: Chunk[] = []) => { + let remainingTokens = MAX_TOTAL_SEARCH_CONTEXT_TOKENS; + const contextParts: string[] = []; + + for (const [index, finding] of searchFindings.entries()) { + if (remainingTokens <= 0) { + break; + } + + const title = escapeAttribute( + String(finding.metadata?.title || `Result ${index + 1}`), + ); + const prefix = ``; + const suffix = ``; + const wrapperTokens = getTokenCount(prefix) + getTokenCount(suffix); + const availableContentTokens = Math.min( + MAX_RESULT_CONTEXT_TOKENS, + remainingTokens - wrapperTokens, + ); + + if (availableContentTokens <= 0) { + break; + } + + const fullContent = String(finding.content || ''); + const fullContentTokens = getTokenCount(fullContent); + let content = truncateTextByTokens(fullContent, availableContentTokens); + + if (fullContentTokens > availableContentTokens) { + const noteBudget = Math.max( + 0, + availableContentTokens - getTokenCount(TRUNCATION_NOTE), + ); + content = `${truncateTextByTokens(fullContent, noteBudget)}${TRUNCATION_NOTE}`; + } + + const entry = `${prefix}${content}${suffix}`; + const entryTokens = getTokenCount(entry); + + if (entryTokens > remainingTokens) { + break; + } + + contextParts.push(entry); + remainingTokens -= entryTokens; + } + + return contextParts.join('\n'); +}; diff --git a/src/lib/agents/search/index.ts b/src/lib/agents/search/index.ts index 859183293..2671047d8 100644 --- a/src/lib/agents/search/index.ts +++ b/src/lib/agents/search/index.ts @@ -8,6 +8,7 @@ import db from '@/lib/db'; import { chats, messages } from '@/lib/db/schema'; import { and, eq, gt } from 'drizzle-orm'; import { TextBlock } from '@/lib/types'; +import { buildSearchResultsContext } from './context'; class SearchAgent { async searchAsync(session: SessionManager, input: SearchAgentInput) { @@ -98,13 +99,9 @@ class SearchAgent { type: 'researchComplete', }); - const finalContext = - searchResults?.searchFindings - .map( - (f, index) => - `${f.content}`, - ) - .join('\n') || ''; + const finalContext = buildSearchResultsContext( + searchResults?.searchFindings || [], + ); const widgetContext = widgetOutputs .map((o) => { diff --git a/src/lib/utils/splitText.ts b/src/lib/utils/splitText.ts index 796bf4b4c..c6cae9cfc 100644 --- a/src/lib/utils/splitText.ts +++ b/src/lib/utils/splitText.ts @@ -4,7 +4,7 @@ const splitRegex = /(?<=\. |\n|! |\? |; |:\s|\d+\.\s|- |\* )/g; const enc = getEncoding('cl100k_base'); -const getTokenCount = (text: string): number => { +export const getTokenCount = (text: string): number => { try { return enc.encode(text).length; } catch { @@ -12,6 +12,37 @@ const getTokenCount = (text: string): number => { } }; +export const truncateTextByTokens = ( + text: string, + maxTokens: number, +): string => { + if (maxTokens <= 0 || text.length === 0) { + return ''; + } + + if (getTokenCount(text) <= maxTokens) { + return text; + } + + let low = 0; + let high = text.length; + let best = ''; + + while (low <= high) { + const mid = Math.floor((low + high) / 2); + const candidate = text.slice(0, mid); + + if (getTokenCount(candidate) <= maxTokens) { + best = candidate; + low = mid + 1; + } else { + high = mid - 1; + } + } + + return best.trimEnd(); +}; + export const splitText = ( text: string, maxTokens = 512,