diff --git a/src/lib/agents/search/api.ts b/src/lib/agents/search/api.ts index 94b70ae9b..27e0d5ed2 100644 --- a/src/lib/agents/search/api.ts +++ b/src/lib/agents/search/api.ts @@ -7,95 +7,102 @@ import { WidgetExecutor } from './widgets'; class APISearchAgent { async searchAsync(session: SessionManager, input: SearchAgentInput) { - const classification = await classify({ - chatHistory: input.chatHistory, - enabledSources: input.config.sources, - query: input.followUp, - llm: input.config.llm, - }); - - const widgetPromise = WidgetExecutor.executeAll({ - classification, - chatHistory: input.chatHistory, - followUp: input.followUp, - llm: input.config.llm, - }).catch((err) => { - console.error(`Error executing widgets: ${err}`); - return []; - }); - - let searchPromise: Promise | null = null; + try { + const classification = await classify({ + chatHistory: input.chatHistory, + enabledSources: input.config.sources, + query: input.followUp, + llm: input.config.llm, + }); - if (!classification.classification.skipSearch) { - const researcher = new Researcher(); - searchPromise = researcher.research(SessionManager.createSession(), { + const widgetPromise = WidgetExecutor.executeAll({ + classification, chatHistory: input.chatHistory, followUp: input.followUp, - classification: classification, - config: input.config, + llm: input.config.llm, + }).catch((err) => { + console.error(`Error executing widgets: ${err}`); + return []; }); - } - const [widgetOutputs, searchResults] = await Promise.all([ - widgetPromise, - searchPromise, - ]); + let searchPromise: Promise | null = null; + + if (!classification.classification.skipSearch) { + const researcher = new Researcher(); + searchPromise = researcher.research(SessionManager.createSession(), { + chatHistory: input.chatHistory, + followUp: input.followUp, + classification: classification, + config: input.config, + }); + } + + const [widgetOutputs, searchResults] = await Promise.all([ + widgetPromise, + searchPromise, + ]); + + if (searchResults) { + session.emit('data', { + type: 'searchResults', + data: searchResults.searchFindings, + }); + } - if (searchResults) { session.emit('data', { - type: 'searchResults', - data: searchResults.searchFindings, + type: 'researchComplete', }); - } - session.emit('data', { - type: 'researchComplete', - }); + const finalContext = + searchResults?.searchFindings + .map( + (f, index) => + `${f.content}`, + ) + .join('\n') || ''; - const finalContext = - searchResults?.searchFindings - .map( - (f, index) => - `${f.content}`, - ) - .join('\n') || ''; + const widgetContext = widgetOutputs + .map((o) => { + return `${o.llmContext}`; + }) + .join('\n-------------\n'); - const widgetContext = widgetOutputs - .map((o) => { - return `${o.llmContext}`; - }) - .join('\n-------------\n'); + const finalContextWithWidgets = `\n${finalContext}\n\n\n${widgetContext}\n`; - const finalContextWithWidgets = `\n${finalContext}\n\n\n${widgetContext}\n`; + const writerPrompt = getWriterPrompt( + finalContextWithWidgets, + input.config.systemInstructions, + input.config.mode, + ); - const writerPrompt = getWriterPrompt( - finalContextWithWidgets, - input.config.systemInstructions, - input.config.mode, - ); + const answerStream = input.config.llm.streamText({ + messages: [ + { + role: 'system', + content: writerPrompt, + }, + ...input.chatHistory, + { + role: 'user', + content: input.followUp, + }, + ], + }); - const answerStream = input.config.llm.streamText({ - messages: [ - { - role: 'system', - content: writerPrompt, - }, - ...input.chatHistory, - { - role: 'user', - content: input.followUp, - }, - ], - }); + for await (const chunk of answerStream) { + session.emit('data', { + type: 'response', + data: chunk.contentChunk, + }); + } - for await (const chunk of answerStream) { - session.emit('data', { - type: 'response', - data: chunk.contentChunk, + session.emit('end', {}); + } catch (error) { + console.error('Error while running API search:', error); + session.emit('error', { + data: error instanceof Error ? error.message : 'Search failed', }); } - - session.emit('end', {}); } } diff --git a/src/lib/agents/search/index.ts b/src/lib/agents/search/index.ts index 859183293..f2e43ebfc 100644 --- a/src/lib/agents/search/index.ts +++ b/src/lib/agents/search/index.ts @@ -11,175 +11,200 @@ import { TextBlock } from '@/lib/types'; class SearchAgent { async searchAsync(session: SessionManager, input: SearchAgentInput) { - const exists = await db.query.messages.findFirst({ - where: and( - eq(messages.chatId, input.chatId), - eq(messages.messageId, input.messageId), - ), - }); - - if (!exists) { - await db.insert(messages).values({ - chatId: input.chatId, - messageId: input.messageId, - backendId: session.id, - query: input.followUp, - createdAt: new Date().toISOString(), - status: 'answering', - responseBlocks: [], + try { + const exists = await db.query.messages.findFirst({ + where: and( + eq(messages.chatId, input.chatId), + eq(messages.messageId, input.messageId), + ), }); - } else { - await db - .delete(messages) - .where( - and(eq(messages.chatId, input.chatId), gt(messages.id, exists.id)), - ) - .execute(); - await db - .update(messages) - .set({ - status: 'answering', + + if (!exists) { + await db.insert(messages).values({ + chatId: input.chatId, + messageId: input.messageId, backendId: session.id, + query: input.followUp, + createdAt: new Date().toISOString(), + status: 'answering', responseBlocks: [], - }) - .where( - and( - eq(messages.chatId, input.chatId), - eq(messages.messageId, input.messageId), - ), - ) - .execute(); - } - - const classification = await classify({ - chatHistory: input.chatHistory, - enabledSources: input.config.sources, - query: input.followUp, - llm: input.config.llm, - }); - - const widgetPromise = WidgetExecutor.executeAll({ - classification, - chatHistory: input.chatHistory, - followUp: input.followUp, - llm: input.config.llm, - }).then((widgetOutputs) => { - widgetOutputs.forEach((o) => { - session.emitBlock({ - id: crypto.randomUUID(), - type: 'widget', - data: { - widgetType: o.type, - params: o.data, - }, }); - }); - return widgetOutputs; - }); + } else { + await db + .delete(messages) + .where( + and(eq(messages.chatId, input.chatId), gt(messages.id, exists.id)), + ) + .execute(); + await db + .update(messages) + .set({ + status: 'answering', + backendId: session.id, + responseBlocks: [], + }) + .where( + and( + eq(messages.chatId, input.chatId), + eq(messages.messageId, input.messageId), + ), + ) + .execute(); + } - let searchPromise: Promise | null = null; + const classification = await classify({ + chatHistory: input.chatHistory, + enabledSources: input.config.sources, + query: input.followUp, + llm: input.config.llm, + }); - if (!classification.classification.skipSearch) { - const researcher = new Researcher(); - searchPromise = researcher.research(session, { + const widgetPromise = WidgetExecutor.executeAll({ + classification, chatHistory: input.chatHistory, followUp: input.followUp, - classification: classification, - config: input.config, + llm: input.config.llm, + }).then((widgetOutputs) => { + widgetOutputs.forEach((o) => { + session.emitBlock({ + id: crypto.randomUUID(), + type: 'widget', + data: { + widgetType: o.type, + params: o.data, + }, + }); + }); + return widgetOutputs; }); - } - const [widgetOutputs, searchResults] = await Promise.all([ - widgetPromise, - searchPromise, - ]); + let searchPromise: Promise | null = null; - session.emit('data', { - type: 'researchComplete', - }); + if (!classification.classification.skipSearch) { + const researcher = new Researcher(); + searchPromise = researcher.research(session, { + chatHistory: input.chatHistory, + followUp: input.followUp, + classification: classification, + config: input.config, + }); + } - const finalContext = - searchResults?.searchFindings - .map( - (f, index) => - `${f.content}`, - ) - .join('\n') || ''; - - const widgetContext = widgetOutputs - .map((o) => { - return `${o.llmContext}`; - }) - .join('\n-------------\n'); - - const finalContextWithWidgets = `\n${finalContext}\n\n\n${widgetContext}\n`; - - const writerPrompt = getWriterPrompt( - finalContextWithWidgets, - input.config.systemInstructions, - input.config.mode, - ); - const answerStream = input.config.llm.streamText({ - messages: [ - { - role: 'system', - content: writerPrompt, - }, - ...input.chatHistory, - { - role: 'user', - content: input.followUp, - }, - ], - }); - - let responseBlockId = ''; - - for await (const chunk of answerStream) { - if (!responseBlockId) { - const block: TextBlock = { - id: crypto.randomUUID(), - type: 'text', - data: chunk.contentChunk, - }; - - session.emitBlock(block); - - responseBlockId = block.id; - } else { - const block = session.getBlock(responseBlockId) as TextBlock | null; + const [widgetOutputs, searchResults] = await Promise.all([ + widgetPromise, + searchPromise, + ]); - if (!block) { - continue; - } + session.emit('data', { + type: 'researchComplete', + }); + + const finalContext = + searchResults?.searchFindings + .map( + (f, index) => + `${f.content}`, + ) + .join('\n') || ''; + + const widgetContext = widgetOutputs + .map((o) => { + return `${o.llmContext}`; + }) + .join('\n-------------\n'); - block.data += chunk.contentChunk; + const finalContextWithWidgets = `\n${finalContext}\n\n\n${widgetContext}\n`; - session.updateBlock(block.id, [ + const writerPrompt = getWriterPrompt( + finalContextWithWidgets, + input.config.systemInstructions, + input.config.mode, + ); + const answerStream = input.config.llm.streamText({ + messages: [ { - op: 'replace', - path: '/data', - value: block.data, + role: 'system', + content: writerPrompt, }, - ]); + ...input.chatHistory, + { + role: 'user', + content: input.followUp, + }, + ], + }); + + let responseBlockId = ''; + + for await (const chunk of answerStream) { + if (!responseBlockId) { + const block: TextBlock = { + id: crypto.randomUUID(), + type: 'text', + data: chunk.contentChunk, + }; + + session.emitBlock(block); + + responseBlockId = block.id; + } else { + const block = session.getBlock(responseBlockId) as TextBlock | null; + + if (!block) { + continue; + } + + block.data += chunk.contentChunk; + + session.updateBlock(block.id, [ + { + op: 'replace', + path: '/data', + value: block.data, + }, + ]); + } } - } - session.emit('end', {}); + session.emit('end', {}); - await db - .update(messages) - .set({ - status: 'completed', - responseBlocks: session.getAllBlocks(), - }) - .where( - and( - eq(messages.chatId, input.chatId), - eq(messages.messageId, input.messageId), - ), - ) - .execute(); + await db + .update(messages) + .set({ + status: 'completed', + responseBlocks: session.getAllBlocks(), + }) + .where( + and( + eq(messages.chatId, input.chatId), + eq(messages.messageId, input.messageId), + ), + ) + .execute(); + } catch (error) { + console.error('Error while running search:', error); + session.emit('error', { + data: error instanceof Error ? error.message : 'Search failed', + }); + + try { + await db + .update(messages) + .set({ + status: 'error', + responseBlocks: session.getAllBlocks(), + }) + .where( + and( + eq(messages.chatId, input.chatId), + eq(messages.messageId, input.messageId), + ), + ) + .execute(); + } catch (dbError) { + console.error('Failed to persist errored search state:', dbError); + } + } } }