diff --git a/server/src/handlers/api/v1/bot/bot/chat.handler.ts b/server/src/handlers/api/v1/bot/bot/chat.handler.ts index 709b8ba..098fe7c 100644 --- a/server/src/handlers/api/v1/bot/bot/chat.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/chat.handler.ts @@ -10,7 +10,6 @@ import { createChain, groupMessagesByConversation } from "../../../../../chain"; import { getModelInfo } from "../../../../../utils/get-model-info"; import { nextTick } from "../../../../../utils/nextTick"; - async function getBotAndEmbedding(request: FastifyRequest) { const bot_id = request.params.id; const user_id = request.user.user_id; @@ -43,7 +42,7 @@ async function getBotAndEmbedding(request: FastifyRequest) { return { bot, embeddingModel }; } -async function getRetriever(bot, embeddingModel) { +async function getRetriever(bot, embeddingModel, knowledge_base_ids) { let resolveWithDocuments: (value: Document[]) => void; const documentPromise = new Promise((resolve) => { resolveWithDocuments = resolve; @@ -63,11 +62,16 @@ async function getRetriever(bot, embeddingModel) { botId: bot.id, sourceId: null, callbacks, + knowledge_base_ids, }); } else { const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( embeddingModel, - { botId: bot.id, sourceId: null } + { + botId: bot.id, + sourceId: null, + knowledge_base_ids, + } ); retriever = vectorstore.asRetriever({ callbacks }); } @@ -105,9 +109,18 @@ async function handleChatRequest( try { const { message, history } = request.body; const { bot, embeddingModel } = await getBotAndEmbedding(request); + let knowledge_base_ids: string[] = []; + + if ( + request.body.knowledge_base_ids && + request.body.knowledge_base_ids.length > 0 + ) { + knowledge_base_ids = request.body.knowledge_base_ids; + } const { retriever, documentPromise } = await getRetriever( bot, - embeddingModel + embeddingModel, + knowledge_base_ids ); const model = await getModel(bot, request.server.prisma); diff --git a/server/src/handlers/api/v1/bot/bot/types.ts b/server/src/handlers/api/v1/bot/bot/types.ts index d4ea5dc..84f21c2 100644 --- a/server/src/handlers/api/v1/bot/bot/types.ts +++ b/server/src/handlers/api/v1/bot/bot/types.ts @@ -130,6 +130,7 @@ export interface ChatAPIRequest { role: string; text: string; }[]; + knowledge_base_ids?: string[] }; } diff --git a/server/src/routes/api/v1/bot/root.ts b/server/src/routes/api/v1/bot/root.ts index 0f3ad91..b1e1caa 100644 --- a/server/src/routes/api/v1/bot/root.ts +++ b/server/src/routes/api/v1/bot/root.ts @@ -302,6 +302,10 @@ const root: FastifyPluginAsync = async (fastify, _): Promise => { stream: { type: "boolean", }, + knowledge_base_ids: { + type: "array", + default: [], + } }, }, }, diff --git a/server/src/schema/api/v1/openai/index.ts b/server/src/schema/api/v1/openai/index.ts index 76e5780..4d829a2 100644 --- a/server/src/schema/api/v1/openai/index.ts +++ b/server/src/schema/api/v1/openai/index.ts @@ -47,7 +47,8 @@ export const createChatCompletionSchema: FastifySchema = { } } } - } + }, + default: [] } } }