From 27d8422866edb81d734992d60767572257273da9 Mon Sep 17 00:00:00 2001 From: n4ze3m Date: Sat, 7 Sep 2024 10:19:44 +0530 Subject: [PATCH] feat: Allow specifying knowledge bases for chat --- .../handlers/api/v1/bot/bot/chat.handler.ts | 21 +++++++++++++++---- server/src/handlers/api/v1/bot/bot/types.ts | 1 + server/src/routes/api/v1/bot/root.ts | 4 ++++ server/src/schema/api/v1/openai/index.ts | 3 ++- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/server/src/handlers/api/v1/bot/bot/chat.handler.ts b/server/src/handlers/api/v1/bot/bot/chat.handler.ts index 709b8ba6..098fe7cf 100644 --- a/server/src/handlers/api/v1/bot/bot/chat.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/chat.handler.ts @@ -10,7 +10,6 @@ import { createChain, groupMessagesByConversation } from "../../../../../chain"; import { getModelInfo } from "../../../../../utils/get-model-info"; import { nextTick } from "../../../../../utils/nextTick"; - async function getBotAndEmbedding(request: FastifyRequest) { const bot_id = request.params.id; const user_id = request.user.user_id; @@ -43,7 +42,7 @@ async function getBotAndEmbedding(request: FastifyRequest) { return { bot, embeddingModel }; } -async function getRetriever(bot, embeddingModel) { +async function getRetriever(bot, embeddingModel, knowledge_base_ids) { let resolveWithDocuments: (value: Document[]) => void; const documentPromise = new Promise((resolve) => { resolveWithDocuments = resolve; @@ -63,11 +62,16 @@ async function getRetriever(bot, embeddingModel) { botId: bot.id, sourceId: null, callbacks, + knowledge_base_ids, }); } else { const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( embeddingModel, - { botId: bot.id, sourceId: null } + { + botId: bot.id, + sourceId: null, + knowledge_base_ids, + } ); retriever = vectorstore.asRetriever({ callbacks }); } @@ -105,9 +109,18 @@ async function handleChatRequest( try { const { message, history } = request.body; const { bot, embeddingModel } = await getBotAndEmbedding(request); + let knowledge_base_ids: string[] = []; + + if ( + request.body.knowledge_base_ids && + request.body.knowledge_base_ids.length > 0 + ) { + knowledge_base_ids = request.body.knowledge_base_ids; + } const { retriever, documentPromise } = await getRetriever( bot, - embeddingModel + embeddingModel, + knowledge_base_ids ); const model = await getModel(bot, request.server.prisma); diff --git a/server/src/handlers/api/v1/bot/bot/types.ts b/server/src/handlers/api/v1/bot/bot/types.ts index d4ea5dcf..84f21c20 100644 --- a/server/src/handlers/api/v1/bot/bot/types.ts +++ b/server/src/handlers/api/v1/bot/bot/types.ts @@ -130,6 +130,7 @@ export interface ChatAPIRequest { role: string; text: string; }[]; + knowledge_base_ids?: string[] }; } diff --git a/server/src/routes/api/v1/bot/root.ts b/server/src/routes/api/v1/bot/root.ts index 0f3ad911..b1e1caa8 100644 --- a/server/src/routes/api/v1/bot/root.ts +++ b/server/src/routes/api/v1/bot/root.ts @@ -302,6 +302,10 @@ const root: FastifyPluginAsync = async (fastify, _): Promise => { stream: { type: "boolean", }, + knowledge_base_ids: { + type: "array", + default: [], + } }, }, }, diff --git a/server/src/schema/api/v1/openai/index.ts b/server/src/schema/api/v1/openai/index.ts index 76e57809..4d829a23 100644 --- a/server/src/schema/api/v1/openai/index.ts +++ b/server/src/schema/api/v1/openai/index.ts @@ -47,7 +47,8 @@ export const createChatCompletionSchema: FastifySchema = { } } } - } + }, + default: [] } } }