diff --git a/app/ui/package.json b/app/ui/package.json index 2ca75cf6..73d2544c 100644 --- a/app/ui/package.json +++ b/app/ui/package.json @@ -1,7 +1,7 @@ { "name": "app", "private": true, - "version": "1.8.1", + "version": "1.8.2", "type": "module", "scripts": { "dev": "vite", diff --git a/app/ui/src/@types/bot.ts b/app/ui/src/@types/bot.ts index 3547bc40..3cc8b6b9 100644 --- a/app/ui/src/@types/bot.ts +++ b/app/ui/src/@types/bot.ts @@ -6,6 +6,7 @@ export type BotSettings = { public_id: string; temperature: number; embedding: string; + noOfDocumentsToRetrieve: number; qaPrompt: string; questionGeneratorPrompt: string; streaming: boolean; diff --git a/app/ui/src/components/Bot/Settings/SettingsCard.tsx b/app/ui/src/components/Bot/Settings/SettingsCard.tsx index 6a56de8a..10d717e6 100644 --- a/app/ui/src/components/Bot/Settings/SettingsCard.tsx +++ b/app/ui/src/components/Bot/Settings/SettingsCard.tsx @@ -1,4 +1,12 @@ -import { Form, Input, notification, Select, Slider, Switch } from "antd"; +import { + Form, + Input, + InputNumber, + notification, + Select, + Slider, + Switch, +} from "antd"; import { useNavigate, useParams } from "react-router-dom"; import api from "../../../services/api"; import { useMutation, useQueryClient } from "@tanstack/react-query"; @@ -155,6 +163,7 @@ export const SettingsCard: React.FC = ({ bot_protect: data.bot_protect, use_rag: data.use_rag, bot_model_api_key: data.bot_model_api_key, + noOfDocumentsToRetrieve: data.noOfDocumentsToRetrieve, }} form={form} requiredMark={false} @@ -252,14 +261,43 @@ export const SettingsCard: React.FC = ({ /> - + +

+ If you change the embedding method, make sure to + re-fetch the data source or choose a model with the same + dimensions +

+ + } + > +
+ + + + + + +
+ { try { @@ -130,21 +131,29 @@ export const getAllModelsHandler = async ( request: FastifyRequest, reply: FastifyReply ) => { - try { - const prisma = request.server.prisma; - const user = request.user; + const prisma = request.server.prisma; + const user = request.user; - if (!user.is_admin) { - return reply.status(403).send({ - message: "Forbidden", - }); - } - const allModels = await prisma.dialoqbaseModels.findMany({ - where: { - deleted: false, - }, + if (!user.is_admin) { + return reply.status(403).send({ + message: "Forbidden", }); + } + const settings = await getSettings(prisma); + + const not_to_hide_providers = settings?.hideDefaultModels + ? [ "Local", "local", "ollama", "transformer", "Transformer"] + : undefined; + const allModels = await prisma.dialoqbaseModels.findMany({ + where: { + deleted: false, + model_provider: { + in: not_to_hide_providers, + }, + }, + }); + try { return { data: allModels.filter((model) => model.model_type !== "embedding"), embedding: allModels.filter((model) => model.model_type === "embedding"), @@ -245,7 +254,7 @@ export const saveModelFromInputedUrlHandler = async ( }); } - let newModelId = model_id.trim() + `_custom_${new Date().getTime()}`; + let newModelId = model_id.trim() + `_dialoqbase_${new Date().getTime()}`; await prisma.dialoqbaseModels.create({ data: { name: isModelExist.name, diff --git a/server/src/handlers/api/v1/bot/bot/api.handler.ts b/server/src/handlers/api/v1/bot/bot/api.handler.ts index eb9dbb5a..bda42af7 100644 --- a/server/src/handlers/api/v1/bot/bot/api.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/api.handler.ts @@ -16,6 +16,7 @@ import { uniqueNamesGenerator, } from "unique-names-generator"; import { validateDataSource } from "../../../../../utils/datasource-validation"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const createBotAPIHandler = async ( request: FastifyRequest, @@ -55,19 +56,11 @@ export const createBotAPIHandler = async ( message: `Reach maximum limit of ${maxBotsAllowed} bots per user`, }); } - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - hide: false, - deleted: false, - OR: [ - { - model_id: model, - }, - { - model_id: `${model}-dbase`, - }, - ], - }, + + const modelInfo = await getModelInfo({ + model, + prisma, + type: "chat", }); if (!modelInfo) { @@ -76,19 +69,10 @@ export const createBotAPIHandler = async ( }); } - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - OR: [ - { - model_id: embedding, - }, - { - model_id: `dialoqbase_eb_${embedding}`, - }, - ], - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { diff --git a/server/src/handlers/api/v1/bot/bot/chat.handler.ts b/server/src/handlers/api/v1/bot/bot/chat.handler.ts index 6a371bf6..18e5b10f 100644 --- a/server/src/handlers/api/v1/bot/bot/chat.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/chat.handler.ts @@ -7,377 +7,368 @@ import { DialoqbaseHybridRetrival } from "../../../../../utils/hybrid"; import { DialoqbaseVectorStore } from "../../../../../utils/store"; import { chatModelProvider } from "../../../../../utils/models"; import { createChain, groupMessagesByConversation } from "../../../../../chain"; +import { getModelInfo } from "../../../../../utils/get-model-info"; function nextTick() { - return new Promise((resolve) => setTimeout(resolve, 0)); + return new Promise((resolve) => setTimeout(resolve, 0)); } export const chatRequestAPIHandler = async ( - request: FastifyRequest, - reply: FastifyReply + request: FastifyRequest, + reply: FastifyReply ) => { - const { message, history, stream } = request.body; - if (stream) { - try { - const bot_id = request.params.id; - const prisma = request.server.prisma; - const user_id = request.user.user_id; - - const bot = await prisma.bot.findFirst({ - where: { - id: bot_id, - user_id - }, - }); - - if (!bot) { - return reply.status(404).send({ - message: "Bot not found", - }); - } - - - const temperature = bot.temperature; - - const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, - }); - - if (!embeddingInfo) { - return reply.status(404).send({ - message: "Embedding not found", - }); - } - - const embeddingModel = embeddings( - embeddingInfo.model_provider!.toLowerCase(), - embeddingInfo.model_id, - embeddingInfo?.config - ); - - reply.raw.on("close", () => { - console.log("closed"); - }); - - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, - }); - - if (!modelinfo) { - return reply.status(404).send({ - message: "Model not found", - }); - } - - const botConfig = (modelinfo.config as {}) || {}; - let retriever: BaseRetriever; - let resolveWithDocuments: (value: Document[]) => void; - const documentPromise = new Promise((resolve) => { - resolveWithDocuments = resolve; - }); - if (bot.use_hybrid_search) { - retriever = new DialoqbaseHybridRetrival(embeddingModel, { - botId: bot.id, - sourceId: null, - callbacks: [ - { - handleRetrieverEnd(documents) { - resolveWithDocuments(documents); - }, - }, - ], - }); - } else { - const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( - embeddingModel, - { - botId: bot.id, - sourceId: null, - } - ); - - retriever = vectorstore.asRetriever({ - callbacks: [ - { - handleRetrieverEnd(documents) { - resolveWithDocuments(documents); - }, - }, - ], - }); - } - - let response: string = ""; - const streamedModel = chatModelProvider( - bot.provider, - bot.model, - temperature, - { - streaming: true, - ...botConfig, - } - ); - - const nonStreamingModel = chatModelProvider( - bot.provider, - bot.model, - temperature, - { - ...botConfig, - } - ); - - reply.raw.on("close", () => { - // close the model - }); - - const chain = createChain({ - llm: streamedModel, - question_llm: nonStreamingModel, - question_template: bot.questionGeneratorPrompt, - response_template: bot.qaPrompt, - retriever, - }); - - let stream = await chain.stream({ - question: sanitizedQuestion, - chat_history: groupMessagesByConversation( - history.map((message) => ({ - type: message.role, - content: message.text, - })) - ), - }); - - for await (const token of stream) { - reply.sse({ - id: "", - event: "chunk", - data: JSON.stringify({ - bot: { - text: token || "", - sourceDocuments: [], - }, - history: [ - ...history, - { - type: "human", - text: message, - }, - { - type: "ai", - text: token || "", - }, - ], - }), - }); - response += token; - } - - const documents = await documentPromise; - - - await prisma.botApiHistory.create({ - data: { - api_key: request.headers.authorization || "", - bot_id: bot.id, - human: message, - bot: response, - } - }) - - reply.sse({ - event: "result", - id: "", - data: JSON.stringify({ - bot: { - text: response, - sourceDocuments: documents, - }, - history: [ - ...history, - { - type: "human", - text: message, - }, - { - type: "ai", - text: response, - }, - ], - }), - }); - await nextTick(); - return reply.raw.end(); - } catch (e) { - return reply.status(500).send({ - message: "Internal Server Error", - }); + const { message, history, stream } = request.body; + if (stream) { + try { + const bot_id = request.params.id; + const prisma = request.server.prisma; + const user_id = request.user.user_id; + + const bot = await prisma.bot.findFirst({ + where: { + id: bot_id, + user_id, + }, + }); + + if (!bot) { + return reply.status(404).send({ + message: "Bot not found", + }); + } + + const temperature = bot.temperature; + + const sanitizedQuestion = message.trim().replaceAll("\n", " "); + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", + }); + + if (!embeddingInfo) { + return reply.status(404).send({ + message: "Embedding not found", + }); + } + + const embeddingModel = embeddings( + embeddingInfo.model_provider!.toLowerCase(), + embeddingInfo.model_id, + embeddingInfo?.config + ); + + reply.raw.on("close", () => { + console.log("closed"); + }); + + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", + }); + + if (!modelinfo) { + return reply.status(404).send({ + message: "Model not found", + }); + } + + const botConfig = (modelinfo.config as {}) || {}; + let retriever: BaseRetriever; + let resolveWithDocuments: (value: Document[]) => void; + const documentPromise = new Promise((resolve) => { + resolveWithDocuments = resolve; + }); + if (bot.use_hybrid_search) { + retriever = new DialoqbaseHybridRetrival(embeddingModel, { + botId: bot.id, + sourceId: null, + callbacks: [ + { + handleRetrieverEnd(documents) { + resolveWithDocuments(documents); + }, + }, + ], + }); + } else { + const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( + embeddingModel, + { + botId: bot.id, + sourceId: null, + } + ); + + retriever = vectorstore.asRetriever({ + callbacks: [ + { + handleRetrieverEnd(documents) { + resolveWithDocuments(documents); + }, + }, + ], + }); + } + + let response: string = ""; + const streamedModel = chatModelProvider( + bot.provider, + bot.model, + temperature, + { + streaming: true, + ...botConfig, } - } else { - try { - const bot_id = request.params.id; - const user_id = request.user.user_id; - - const prisma = request.server.prisma; - - const bot = await prisma.bot.findFirst({ - where: { - id: bot_id, - user_id - }, - }); - - if (!bot) { - return reply.status(404).send({ - message: "Bot not found", - }); - } - - const temperature = bot.temperature; - - const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, - }); - - if (!embeddingInfo) { - return reply.status(404).send({ - message: "Embedding not found", - }); - } - - const embeddingModel = embeddings( - embeddingInfo.model_provider!.toLowerCase(), - embeddingInfo.model_id, - embeddingInfo?.config - ); - - let retriever: BaseRetriever; - let resolveWithDocuments: (value: Document[]) => void; - const documentPromise = new Promise((resolve) => { - resolveWithDocuments = resolve; - }); - if (bot.use_hybrid_search) { - retriever = new DialoqbaseHybridRetrival(embeddingModel, { - botId: bot.id, - sourceId: null, - callbacks: [ - { - handleRetrieverEnd(documents) { - resolveWithDocuments(documents); - }, - }, - ], - }); - } else { - const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( - embeddingModel, - { - botId: bot.id, - sourceId: null, - } - ); - - retriever = vectorstore.asRetriever({ - callbacks: [ - { - handleRetrieverEnd(documents) { - resolveWithDocuments(documents); - }, - }, - ], - }); - } - - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, - }); - - if (!modelinfo) { - return reply.status(404).send({ - message: "Model not found", - }); - } - - const botConfig: any = (modelinfo.config as {}) || {}; - if (bot.provider.toLowerCase() === "openai") { - if (bot.bot_model_api_key && bot.bot_model_api_key.trim() !== "") { - botConfig.configuration = { - apiKey: bot.bot_model_api_key, - }; - } - } - - const model = chatModelProvider(bot.provider, bot.model, temperature, { - ...botConfig, - }); - - const chain = createChain({ - llm: model, - question_llm: model, - question_template: bot.questionGeneratorPrompt, - response_template: bot.qaPrompt, - retriever, - }); - - const botResponse = await chain.invoke({ - question: sanitizedQuestion, - chat_history: groupMessagesByConversation( - history.map((message) => ({ - type: message.role, - content: message.text, - })) - ), - }); - - const documents = await documentPromise; - - await prisma.botApiHistory.create({ - data: { - api_key: request.headers.authorization || "", - bot_id: bot.id, - human: message, - bot: botResponse, - } - }) - - return { - bot: { - text: botResponse, - sourceDocuments: documents, - }, - history: [ - ...history, - { - type: "human", - text: message, - }, - { - type: "ai", - text: botResponse, - }, - ], - }; - } catch (e) { - return reply.status(500).send({ - message: "Internal Server Error", - }); + ); + + const nonStreamingModel = chatModelProvider( + bot.provider, + bot.model, + temperature, + { + ...botConfig, } + ); + + reply.raw.on("close", () => { + // close the model + }); + + const chain = createChain({ + llm: streamedModel, + question_llm: nonStreamingModel, + question_template: bot.questionGeneratorPrompt, + response_template: bot.qaPrompt, + retriever, + }); + + let stream = await chain.stream({ + question: sanitizedQuestion, + chat_history: groupMessagesByConversation( + history.map((message) => ({ + type: message.role, + content: message.text, + })) + ), + }); + + for await (const token of stream) { + reply.sse({ + id: "", + event: "chunk", + data: JSON.stringify({ + bot: { + text: token || "", + sourceDocuments: [], + }, + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: token || "", + }, + ], + }), + }); + response += token; + } + + const documents = await documentPromise; + + await prisma.botApiHistory.create({ + data: { + api_key: request.headers.authorization || "", + bot_id: bot.id, + human: message, + bot: response, + }, + }); + + reply.sse({ + event: "result", + id: "", + data: JSON.stringify({ + bot: { + text: response, + sourceDocuments: documents, + }, + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: response, + }, + ], + }), + }); + await nextTick(); + return reply.raw.end(); + } catch (e) { + return reply.status(500).send({ + message: "Internal Server Error", + }); } + } else { + try { + const bot_id = request.params.id; + const user_id = request.user.user_id; + + const prisma = request.server.prisma; + + const bot = await prisma.bot.findFirst({ + where: { + id: bot_id, + user_id, + }, + }); + + if (!bot) { + return reply.status(404).send({ + message: "Bot not found", + }); + } + + const temperature = bot.temperature; + + const sanitizedQuestion = message.trim().replaceAll("\n", " "); + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", + }); + + if (!embeddingInfo) { + return reply.status(404).send({ + message: "Embedding not found", + }); + } + + const embeddingModel = embeddings( + embeddingInfo.model_provider!.toLowerCase(), + embeddingInfo.model_id, + embeddingInfo?.config + ); + + let retriever: BaseRetriever; + let resolveWithDocuments: (value: Document[]) => void; + const documentPromise = new Promise((resolve) => { + resolveWithDocuments = resolve; + }); + if (bot.use_hybrid_search) { + retriever = new DialoqbaseHybridRetrival(embeddingModel, { + botId: bot.id, + sourceId: null, + callbacks: [ + { + handleRetrieverEnd(documents) { + resolveWithDocuments(documents); + }, + }, + ], + }); + } else { + const vectorstore = await DialoqbaseVectorStore.fromExistingIndex( + embeddingModel, + { + botId: bot.id, + sourceId: null, + } + ); + + retriever = vectorstore.asRetriever({ + callbacks: [ + { + handleRetrieverEnd(documents) { + resolveWithDocuments(documents); + }, + }, + ], + }); + } + + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", + }); + + if (!modelinfo) { + return reply.status(404).send({ + message: "Model not found", + }); + } + + const botConfig: any = (modelinfo.config as {}) || {}; + if (bot.provider.toLowerCase() === "openai") { + if (bot.bot_model_api_key && bot.bot_model_api_key.trim() !== "") { + botConfig.configuration = { + apiKey: bot.bot_model_api_key, + }; + } + } + + const model = chatModelProvider(bot.provider, bot.model, temperature, { + ...botConfig, + }); + + const chain = createChain({ + llm: model, + question_llm: model, + question_template: bot.questionGeneratorPrompt, + response_template: bot.qaPrompt, + retriever, + }); + + const botResponse = await chain.invoke({ + question: sanitizedQuestion, + chat_history: groupMessagesByConversation( + history.map((message) => ({ + type: message.role, + content: message.text, + })) + ), + }); + + const documents = await documentPromise; + + await prisma.botApiHistory.create({ + data: { + api_key: request.headers.authorization || "", + bot_id: bot.id, + human: message, + bot: botResponse, + }, + }); + + return { + bot: { + text: botResponse, + sourceDocuments: documents, + }, + history: [ + ...history, + { + type: "human", + text: message, + }, + { + type: "ai", + text: botResponse, + }, + ], + }; + } catch (e) { + return reply.status(500).send({ + message: "Internal Server Error", + }); + } + } }; diff --git a/server/src/handlers/api/v1/bot/bot/get.handler.ts b/server/src/handlers/api/v1/bot/bot/get.handler.ts index c6f2cc9d..ac75ce92 100644 --- a/server/src/handlers/api/v1/bot/bot/get.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/get.handler.ts @@ -2,6 +2,7 @@ import { FastifyReply, FastifyRequest } from "fastify"; import { GetBotRequestById } from "./types"; import { getSettings } from "../../../../../utils/common"; +import { getAllOllamaModels } from "../../../../../utils/ollama"; export const getBotByIdEmbeddingsHandler = async ( request: FastifyRequest, @@ -125,10 +126,18 @@ export const getCreateBotConfigHandler = async ( reply: FastifyReply ) => { const prisma = request.server.prisma; + const settings = await getSettings(prisma); + + const not_to_hide_providers = settings?.hideDefaultModels + ? ["Local", "local", "ollama", "transformer", "Transformer"] + : undefined; const models = await prisma.dialoqbaseModels.findMany({ where: { hide: false, deleted: false, + model_provider: { + in: not_to_hide_providers, + }, }, }); @@ -146,16 +155,29 @@ export const getCreateBotConfigHandler = async ( .filter((model) => model.model_type === "embedding") .map((model) => { return { - label: `${model.name || model.model_id} ${model.model_id === "dialoqbase_eb_dialoqbase-ollama" + label: `${model.name || model.model_id} ${ + model.model_id === "dialoqbase_eb_dialoqbase-ollama" ? "(Deprecated)" : "" - }`, + }`, value: model.model_id, disabled: model.model_id === "dialoqbase_eb_dialoqbase-ollama", }; }); - - const settings = await getSettings(prisma); + if (settings?.dynamicallyFetchOllamaModels) { + const ollamaModels = await getAllOllamaModels(settings.ollamaURL); + chatModel.push( + ...ollamaModels?.filter((model) => { + return ( + !model?.details?.families?.includes("bert") && + !model?.details?.families?.includes("nomic-bert") + ); + }) + ); + embeddingModel.push( + ...ollamaModels.map((model) => ({ ...model, disabled: false })) + ); + } return { chatModel, @@ -178,11 +200,23 @@ export const getBotByIdSettingsHandler = async ( user_id: request.user.user_id, }, }); + if (!bot) { + return reply.status(404).send({ + message: "Bot not found", + }); + } + const settings = await getSettings(prisma); + const not_to_hide_providers = settings?.hideDefaultModels + ? ["Local", "local", "ollama", "transformer", "Transformer"] + : undefined; const models = await prisma.dialoqbaseModels.findMany({ where: { hide: false, deleted: false, + model_provider: { + in: not_to_hide_providers, + }, }, }); @@ -200,19 +234,28 @@ export const getBotByIdSettingsHandler = async ( .filter((model) => model.model_type === "embedding") .map((model) => { return { - label: `${model.name || model.model_id} ${model.model_id === "dialoqbase_eb_dialoqbase-ollama" + label: `${model.name || model.model_id} ${ + model.model_id === "dialoqbase_eb_dialoqbase-ollama" ? "(Deprecated)" : "" - }`, + }`, value: model.model_id, disabled: model.model_id === "dialoqbase_eb_dialoqbase-ollama", }; }); - - if (!bot) { - return reply.status(404).send({ - message: "Bot not found", - }); + if (settings?.dynamicallyFetchOllamaModels) { + const ollamaModels = await getAllOllamaModels(settings.ollamaURL); + chatModel.push( + ...ollamaModels?.filter((model) => { + return ( + !model?.details?.families?.includes("bert") && + !model?.details?.families?.includes("nomic-bert") + ); + }) + ); + embeddingModel.push( + ...ollamaModels.map((model) => ({ ...model, disabled: false })) + ); } return { data: bot, @@ -221,7 +264,6 @@ export const getBotByIdSettingsHandler = async ( }; }; - export const isBotReadyHandler = async ( request: FastifyRequest, reply: FastifyReply @@ -252,4 +294,4 @@ export const isBotReadyHandler = async ( return { is_ready: source === 0, }; -}; \ No newline at end of file +}; diff --git a/server/src/handlers/api/v1/bot/bot/post.handler.ts b/server/src/handlers/api/v1/bot/bot/post.handler.ts index 238b7133..fd557a7b 100644 --- a/server/src/handlers/api/v1/bot/bot/post.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/post.handler.ts @@ -16,6 +16,7 @@ import { HELPFUL_ASSISTANT_WITH_CONTEXT_PROMPT, HELPFUL_ASSISTANT_WITHOUT_CONTEXT_PROMPT, } from "../../../../../utils/prompts"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const createBotHandler = async ( request: FastifyRequest, @@ -55,12 +56,10 @@ export const createBotHandler = async ( }); } // const providerName = modelProviderName(model); - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: model, - hide: false, - deleted: false, - }, + const modelInfo = await getModelInfo({ + model, + prisma, + type: "chat", }); if (!modelInfo) { @@ -69,12 +68,10 @@ export const createBotHandler = async ( }); } - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { diff --git a/server/src/handlers/api/v1/bot/bot/put.handler.ts b/server/src/handlers/api/v1/bot/bot/put.handler.ts index 380d6fe7..7dc0fc97 100644 --- a/server/src/handlers/api/v1/bot/bot/put.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/put.handler.ts @@ -4,6 +4,7 @@ import { apiKeyValidaton, apiKeyValidatonMessage, } from "../../../../../utils/validate"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const updateBotByIdHandler = async ( request: FastifyRequest, @@ -25,12 +26,9 @@ export const updateBotByIdHandler = async ( }); } - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: request.body.model, - hide: false, - deleted: false, - }, + const modelInfo = await getModelInfo({ + model: request.body.model, + prisma, }); if (!modelInfo) { @@ -96,26 +94,15 @@ export const updateBotAPIByIdHandler = async ( questionGeneratorPrompt: request.body?.question_generator_prompt, system_prompt: undefined, question_generator_prompt: undefined, - } - + }; if (updateBody.model) { - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - hide: false, - deleted: false, - OR: [ - { - model_id: updateBody.model - }, - { - model_id: `${updateBody.model}-dbase` - } - ], - }, + const modelInfo = await getModelInfo({ + model: updateBody.model, + prisma, + type: "chat", }); - if (!modelInfo) { return reply.status(400).send({ message: "Model not found", @@ -140,8 +127,7 @@ export const updateBotAPIByIdHandler = async ( updateBody = { ...updateBody, provider: modelInfo.model_provider || "", - } - + }; } await prisma.bot.update({ where: { diff --git a/server/src/handlers/api/v1/bot/bot/upload.handler.ts b/server/src/handlers/api/v1/bot/bot/upload.handler.ts index cf0f1b9a..6264c04c 100644 --- a/server/src/handlers/api/v1/bot/bot/upload.handler.ts +++ b/server/src/handlers/api/v1/bot/bot/upload.handler.ts @@ -19,6 +19,7 @@ const pump = util.promisify(pipeline); import { fileTypeFinder } from "../../../../../utils/fileType"; import { getSettings } from "../../../../../utils/common"; import { HELPFUL_ASSISTANT_WITH_CONTEXT_PROMPT } from "../../../../../utils/prompts"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const createBotFileHandler = async ( request: FastifyRequest, @@ -52,12 +53,10 @@ export const createBotFileHandler = async ( }); } - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -76,12 +75,10 @@ export const createBotFileHandler = async ( }); } - const modelInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: model, - hide: false, - deleted: false, - }, + const modelInfo = await getModelInfo({ + model, + prisma, + type: "chat", }); if (!modelInfo) { diff --git a/server/src/handlers/api/v1/bot/playground/chat.handler.ts b/server/src/handlers/api/v1/bot/playground/chat.handler.ts index ab950f97..2e7661ec 100644 --- a/server/src/handlers/api/v1/bot/playground/chat.handler.ts +++ b/server/src/handlers/api/v1/bot/playground/chat.handler.ts @@ -7,6 +7,7 @@ import { DialoqbaseHybridRetrival } from "../../../../../utils/hybrid"; import { BaseRetriever } from "@langchain/core/retrievers"; import { Document } from "langchain/document"; import { createChain, groupMessagesByConversation } from "../../../../../chain"; +import { getModelInfo } from "../../../../../utils/get-model-info"; export const chatRequestHandler = async ( request: FastifyRequest, @@ -48,14 +49,11 @@ export const chatRequestHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "all", }); - if (!embeddingInfo) { return { bot: { @@ -118,12 +116,10 @@ export const chatRequestHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); if (!modelinfo) { @@ -295,12 +291,10 @@ export const chatRequestStreamHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -375,12 +369,10 @@ export const chatRequestStreamHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); if (!modelinfo) { reply.raw.setHeader("Content-Type", "text/event-stream"); diff --git a/server/src/handlers/bot/api.handler.ts b/server/src/handlers/bot/api.handler.ts index aef50c1b..d7e995aa 100644 --- a/server/src/handlers/bot/api.handler.ts +++ b/server/src/handlers/bot/api.handler.ts @@ -8,6 +8,7 @@ import { Document } from "langchain/document"; import { BaseRetriever } from "@langchain/core/retrievers"; import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { createChain, groupMessagesByConversation } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; export const chatRequestAPIHandler = async ( request: FastifyRequest, @@ -40,12 +41,10 @@ export const chatRequestAPIHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + prisma, + model: bot.embedding, + type: "embedding", }); if (!embeddingInfo) { @@ -64,12 +63,10 @@ export const chatRequestAPIHandler = async ( console.log("closed"); }); - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + prisma, + model: bot.model, + type: "chat", }); if (!modelinfo) { @@ -172,15 +169,14 @@ export const chatRequestAPIHandler = async ( }); const documents = await documentPromise; - await prisma.botApiHistory.create({ data: { api_key: request.headers["x-api-key"], bot_id: bot.id, human: message, bot: response, - } - }) + }, + }); reply.sse({ event: "result", @@ -237,12 +233,10 @@ export const chatRequestAPIHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + prisma, + model: bot.embedding, + type: "embedding", }); if (!embeddingInfo) { @@ -308,12 +302,10 @@ export const chatRequestAPIHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + prisma, + model: bot.model, + type: "chat", }); if (!modelinfo) { @@ -375,8 +367,8 @@ export const chatRequestAPIHandler = async ( bot_id: bot.id, human: message, bot: botResponse, - } - }) + }, + }); return { bot: { diff --git a/server/src/handlers/bot/post.handler.ts b/server/src/handlers/bot/post.handler.ts index 9bbc17c9..04e2a2c9 100644 --- a/server/src/handlers/bot/post.handler.ts +++ b/server/src/handlers/bot/post.handler.ts @@ -7,6 +7,7 @@ import { BaseRetriever } from "@langchain/core/retrievers"; import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { Document } from "langchain/document"; import { createChain, groupMessagesByConversation } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; export const chatRequestHandler = async ( request: FastifyRequest, @@ -69,12 +70,10 @@ export const chatRequestHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -139,13 +138,11 @@ export const chatRequestHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, - }); + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", + }) if (!modelinfo) { return { @@ -341,13 +338,11 @@ export const chatRequestStreamHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { return { @@ -416,13 +411,11 @@ export const chatRequestStreamHandler = async ( console.log("closed"); }); - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, - }); + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", + }) if (!modelinfo) { reply.raw.setHeader("Content-Type", "text/event-stream"); diff --git a/server/src/integration/handlers/discord.handler.ts b/server/src/integration/handlers/discord.handler.ts index 3eb9c4da..4549b37e 100644 --- a/server/src/integration/handlers/discord.handler.ts +++ b/server/src/integration/handlers/discord.handler.ts @@ -6,6 +6,7 @@ import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { Document } from "langchain/document"; import { BaseRetriever } from "@langchain/core/retrievers"; import { createChain } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; const prisma = new PrismaClient(); export const discordBotHandler = async ( @@ -51,14 +52,11 @@ export const discordBotHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, - }); - + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { return { text: "Opps! Model not found", @@ -108,12 +106,10 @@ export const discordBotHandler = async ( }); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); if (!modelinfo) { diff --git a/server/src/integration/handlers/telegram.handler.ts b/server/src/integration/handlers/telegram.handler.ts index 9f675ff2..0be218d1 100644 --- a/server/src/integration/handlers/telegram.handler.ts +++ b/server/src/integration/handlers/telegram.handler.ts @@ -5,6 +5,7 @@ import { chatModelProvider } from "../../utils/models"; import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { BaseRetriever } from "@langchain/core/retrievers"; import { createChain } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; const prisma = new PrismaClient(); export const telegramBotHandler = async ( @@ -46,12 +47,10 @@ export const telegramBotHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -83,14 +82,12 @@ export const telegramBotHandler = async ( retriever = vectorstore.asRetriever({}); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); - + if (!modelinfo) { return "Unable to find model"; } diff --git a/server/src/integration/handlers/whatsapp.handler.ts b/server/src/integration/handlers/whatsapp.handler.ts index 7945312f..bfed6e21 100644 --- a/server/src/integration/handlers/whatsapp.handler.ts +++ b/server/src/integration/handlers/whatsapp.handler.ts @@ -5,6 +5,7 @@ import { chatModelProvider } from "../../utils/models"; import { BaseRetriever } from "@langchain/core/retrievers"; import { DialoqbaseHybridRetrival } from "../../utils/hybrid"; import { createChain } from "../../chain"; +import { getModelInfo } from "../../utils/get-model-info"; const prisma = new PrismaClient(); export const whatsappBotHandler = async ( @@ -55,12 +56,10 @@ export const whatsappBotHandler = async ( const temperature = bot.temperature; const sanitizedQuestion = message.trim().replaceAll("\n", " "); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: bot.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -92,12 +91,10 @@ export const whatsappBotHandler = async ( retriever = vectorstore.asRetriever({}); } - const modelinfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: bot.model, - hide: false, - deleted: false, - }, + const modelinfo = await getModelInfo({ + model: bot.model, + prisma, + type: "chat", }); if (!modelinfo) { diff --git a/server/src/queue/controllers/audio.controller.ts b/server/src/queue/controllers/audio.controller.ts index 86f6b07c..daeddd40 100644 --- a/server/src/queue/controllers/audio.controller.ts +++ b/server/src/queue/controllers/audio.controller.ts @@ -5,6 +5,7 @@ import { embeddings } from "../../utils/embeddings"; import { DialoqbaseAudioVideoLoader } from "../../loader/audio-video"; import { convertMp3ToWave } from "../../utils/ffmpeg"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const audioQueueController = async ( source: QSource, @@ -25,13 +26,11 @@ export const audioQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/csv.controller.ts b/server/src/queue/controllers/csv.controller.ts index 8e159e5b..c9e83179 100644 --- a/server/src/queue/controllers/csv.controller.ts +++ b/server/src/queue/controllers/csv.controller.ts @@ -4,6 +4,7 @@ import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const csvQueueController = async ( source: QSource, @@ -21,12 +22,10 @@ export const csvQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { diff --git a/server/src/queue/controllers/docx.controller.ts b/server/src/queue/controllers/docx.controller.ts index 185a8e26..aea1dc4e 100644 --- a/server/src/queue/controllers/docx.controller.ts +++ b/server/src/queue/controllers/docx.controller.ts @@ -5,6 +5,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { DialoqbaseDocxLoader } from "../../loader/docx"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const DocxQueueController = async ( source: QSource, @@ -22,13 +23,11 @@ export const DocxQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/github.controller.ts b/server/src/queue/controllers/github.controller.ts index 7300c90f..fc95d477 100644 --- a/server/src/queue/controllers/github.controller.ts +++ b/server/src/queue/controllers/github.controller.ts @@ -5,6 +5,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { DialoqbaseGithub } from "../../loader/github"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const githubQueueController = async ( source: QSource, @@ -25,13 +26,11 @@ export const githubQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/pdf.controller.ts b/server/src/queue/controllers/pdf.controller.ts index 2bfcaa65..2b211935 100644 --- a/server/src/queue/controllers/pdf.controller.ts +++ b/server/src/queue/controllers/pdf.controller.ts @@ -5,6 +5,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { DialoqbasePDFLoader } from "../../loader/pdf"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const pdfQueueController = async ( source: QSource, @@ -22,13 +23,11 @@ export const pdfQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/rest.controller.ts b/server/src/queue/controllers/rest.controller.ts index 722dc08a..b6375b0f 100644 --- a/server/src/queue/controllers/rest.controller.ts +++ b/server/src/queue/controllers/rest.controller.ts @@ -3,6 +3,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { DialoqbaseRestApi } from "../../loader/rest"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const restQueueController = async ( source: QSource, @@ -18,13 +19,11 @@ export const restQueueController = async ( }); const docs = await loader.load(); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/text.controller.ts b/server/src/queue/controllers/text.controller.ts index 11e8c2a3..b4d73c65 100644 --- a/server/src/queue/controllers/text.controller.ts +++ b/server/src/queue/controllers/text.controller.ts @@ -3,6 +3,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const textQueueController = async ( source: QSource, @@ -20,15 +21,12 @@ export const textQueueController = async ( }, }, ]); - - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); - + + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); } diff --git a/server/src/queue/controllers/txt.controller.ts b/server/src/queue/controllers/txt.controller.ts index baf894fd..445fa24b 100644 --- a/server/src/queue/controllers/txt.controller.ts +++ b/server/src/queue/controllers/txt.controller.ts @@ -4,6 +4,7 @@ import { DialoqbaseVectorStore } from "../../utils/store"; import { embeddings } from "../../utils/embeddings"; import { TextLoader } from "langchain/document_loaders/fs/text"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const txtQueueController = async ( source: QSource, @@ -20,13 +21,11 @@ export const txtQueueController = async ( chunkOverlap: source.chunkOverlap, }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/video.controller.ts b/server/src/queue/controllers/video.controller.ts index c0646a3a..c41db9df 100644 --- a/server/src/queue/controllers/video.controller.ts +++ b/server/src/queue/controllers/video.controller.ts @@ -6,6 +6,7 @@ import { embeddings } from "../../utils/embeddings"; import { DialoqbaseAudioVideoLoader } from "../../loader/audio-video"; import { convertMp4ToWave } from "../../utils/ffmpeg"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const videoQueueController = async ( source: QSource, @@ -26,13 +27,11 @@ export const videoQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/website.controller.ts b/server/src/queue/controllers/website.controller.ts index d0843f75..2c492326 100644 --- a/server/src/queue/controllers/website.controller.ts +++ b/server/src/queue/controllers/website.controller.ts @@ -8,6 +8,7 @@ import { DialoqbasePDFLoader } from "../../loader/pdf"; import { DialoqbaseWebLoader } from "../../loader/web"; import { CheerioWebBaseLoader } from "langchain/document_loaders/web/cheerio"; import { PrismaClient } from "@prisma/client"; +import { getModelInfo } from "../../utils/get-model-info"; export const websiteQueueController = async ( source: QSource, @@ -37,13 +38,11 @@ export const websiteQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); @@ -79,13 +78,11 @@ export const websiteQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, - }); + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", + }) if (!embeddingInfo) { throw new Error("Embedding not found. Please verify the embedding id"); diff --git a/server/src/queue/controllers/youtube.controller.ts b/server/src/queue/controllers/youtube.controller.ts index fdb0e12a..5f1f30ef 100644 --- a/server/src/queue/controllers/youtube.controller.ts +++ b/server/src/queue/controllers/youtube.controller.ts @@ -6,6 +6,7 @@ import { embeddings } from "../../utils/embeddings"; import { DialoqbaseYoutube } from "../../loader/youtube"; import { PrismaClient } from "@prisma/client"; import { DialoqbaseYoutubeTranscript } from "../../loader/youtube-transcript"; +import { getModelInfo } from "../../utils/get-model-info"; export const youtubeQueueController = async ( source: QSource, @@ -29,12 +30,10 @@ export const youtubeQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { @@ -66,12 +65,10 @@ export const youtubeQueueController = async ( }); const chunks = await textSplitter.splitDocuments(docs); - const embeddingInfo = await prisma.dialoqbaseModels.findFirst({ - where: { - model_id: source.embedding, - hide: false, - deleted: false, - }, + const embeddingInfo = await getModelInfo({ + model: source.embedding, + prisma, + type: "embedding", }); if (!embeddingInfo) { diff --git a/server/src/schema/api/v1/admin/index.ts b/server/src/schema/api/v1/admin/index.ts index 6f4c9599..424b6e6d 100644 --- a/server/src/schema/api/v1/admin/index.ts +++ b/server/src/schema/api/v1/admin/index.ts @@ -20,6 +20,8 @@ export const dialoqbaseSettingsSchema: FastifySchema = { defaultChatModel: { type: "string" }, defaultEmbeddingModel: { type: "string" }, dynamicallyFetchOllamaModels: { type: "boolean" }, + hideDefaultModels: { type: "boolean" }, + ollamaURL: { type: "string" }, }, }, }; @@ -43,12 +45,9 @@ export const updateDialoqbaseSettingsSchema: FastifySchema = { dynamicallyFetchOllamaModels: { type: "boolean" }, defaultChatModel: { type: "string" }, defaultEmbeddingModel: { type: "string" }, + hideDefaultModels: { type: "boolean" }, + ollamaURL: { type: "string" }, }, - required: [ - "noOfBotsPerUser", - "allowUserToCreateBots", - "allowUserToRegister", - ], }, response: { 200: { diff --git a/server/src/schema/api/v1/bot/bot/index.ts b/server/src/schema/api/v1/bot/bot/index.ts index ca2a74be..a81e6c40 100644 --- a/server/src/schema/api/v1/bot/bot/index.ts +++ b/server/src/schema/api/v1/bot/bot/index.ts @@ -159,6 +159,9 @@ export const updateBotByIdSchema: FastifySchema = { bot_model_api_key: { type: "string", }, + noOfDocumentsToRetrieve: { + type: "number", + } }, }, }; diff --git a/server/src/utils/get-model-info.ts b/server/src/utils/get-model-info.ts new file mode 100644 index 00000000..e5aa21f6 --- /dev/null +++ b/server/src/utils/get-model-info.ts @@ -0,0 +1,92 @@ +import { PrismaClient, DialoqbaseModels } from "@prisma/client"; +import { getSettings } from "./common"; +import { cleanUrl, getAllOllamaModels } from "./ollama"; + +export const getModelInfo = async ({ + model, + prisma, + type = "all", +}: { + prisma: PrismaClient; + model: string; + type?: "all" | "chat" | "embedding"; +}): Promise => { + let modelInfo: DialoqbaseModels | null = null; + const settings = await getSettings(prisma); + const not_to_hide_providers = settings?.hideDefaultModels + ? [ "Local", "local", "ollama", "transformer", "Transformer"] + : undefined; + if (type === "all") { + modelInfo = await prisma.dialoqbaseModels.findFirst({ + where: { + model_id: model, + hide: false, + deleted: false, + model_provider: { + in: not_to_hide_providers, + }, + }, + }); + } else if (type === "chat") { + modelInfo = await prisma.dialoqbaseModels.findFirst({ + where: { + hide: false, + deleted: false, + model_provider: { + in: not_to_hide_providers, + }, + OR: [ + { + model_id: model, + }, + { + model_id: `${model}-dbase`, + }, + ], + }, + }); + } else if (type === "embedding") { + modelInfo = await prisma.dialoqbaseModels.findFirst({ + where: { + OR: [ + { + model_id: model, + }, + { + model_id: `dialoqbase_eb_${model}`, + }, + ], + hide: false, + deleted: false, + model_provider: { + in: not_to_hide_providers, + }, + }, + }); + } + if (!modelInfo) { + if (settings?.dynamicallyFetchOllamaModels) { + const ollamaModles = await getAllOllamaModels(settings.ollamaURL); + const ollamaInfo = ollamaModles.find((m) => m.value === model); + if (ollamaInfo) { + return { + name: ollamaInfo.name, + model_id: ollamaInfo.name, + stream_available: true, + local_model: true, + model_provider: "ollama", + config: { + baseURL: cleanUrl(settings.ollamaURL), + }, + createdAt: new Date(), + model_type: "chat", + deleted: false, + hide: false, + id: 1, + }; + } + } + } + + return modelInfo; +}; diff --git a/server/src/utils/hybrid.ts b/server/src/utils/hybrid.ts index a821bf88..28956b0c 100644 --- a/server/src/utils/hybrid.ts +++ b/server/src/utils/hybrid.ts @@ -40,47 +40,53 @@ export class DialoqbaseHybridRetrival extends BaseRetriever { protected async similaritySearch( query: string, k: number, - _callbacks?: Callbacks, + _callbacks?: Callbacks ): Promise { try { + const embeddedQuery = await this.embeddings.embedQuery(query); - const embeddedQuery = await this.embeddings.embedQuery(query); + const vector = `[${embeddedQuery.join(",")}]`; + const bot_id = this.botId; - const vector = `[${embeddedQuery.join(",")}]`; - const bot_id = this.botId; - - const data = await prisma.$queryRaw` + const data = await prisma.$queryRaw` SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${k}::int) `; - const result: [Document, number, number][] = ( - data as SearchEmbeddingsResponse[] - ).map((resp) => [ - new Document({ - metadata: resp.metadata, - pageContent: resp.content, - }), - resp.similarity * 10, - resp.id, - ]); - - - return result; -} catch (e) { - console.log(e) - return [] -} + const result: [Document, number, number][] = ( + data as SearchEmbeddingsResponse[] + ).map((resp) => [ + new Document({ + metadata: resp.metadata, + pageContent: resp.content, + }), + resp.similarity * 10, + resp.id, + ]); + + return result; + } catch (e) { + console.log(e); + return []; + } } protected async keywordSearch( query: string, - k: number, + k: number ): Promise { const query_text = query; const bot_id = this.botId; + const botInfo = await prisma.bot.findFirst({ + where: { + id: bot_id, + }, + }); + + const match_count = botInfo?.noOfDocumentsToRetrieve || k; + const data = await prisma.$queryRaw` - SELECT * FROM "kw_match_documents"(query_text := ${query_text}::text, bot_id := ${bot_id}::text,match_count := ${k}::int) + SELECT * FROM "kw_match_documents"(query_text := ${query_text}::text, bot_id := ${bot_id}::text,match_count := ${match_count}::int) `; const result: [Document, number, number][] = ( @@ -104,12 +110,12 @@ export class DialoqbaseHybridRetrival extends BaseRetriever { query: string, similarityK: number, keywordK: number, - callbacks?: Callbacks, + callbacks?: Callbacks ): Promise { const similarity_search = this.similaritySearch( query, similarityK, - callbacks, + callbacks ); const keyword_search = this.keywordSearch(query, keywordK); @@ -136,13 +142,13 @@ export class DialoqbaseHybridRetrival extends BaseRetriever { async _getRelevantDocuments( query: string, - runManager?: CallbackManagerForRetrieverRun, + runManager?: CallbackManagerForRetrieverRun ): Promise { const searchResults = await this.hybridSearch( query, this.similarityK, this.keywordK, - runManager?.getChild("hybrid_search"), + runManager?.getChild("hybrid_search") ); return searchResults.map(([doc]) => doc); diff --git a/server/src/utils/ollama.ts b/server/src/utils/ollama.ts new file mode 100644 index 00000000..bdfaac77 --- /dev/null +++ b/server/src/utils/ollama.ts @@ -0,0 +1,38 @@ +import axios from "axios"; + +export const cleanUrl = (url: string) => { + if (url.endsWith("/")) { + return url.slice(0, -1); + } + return url; +}; + +export const getAllOllamaModels = async (url: string) => { + try { + const response = await axios.get(`${cleanUrl(url)}/api/tags`); + const { models } = response.data as { + models: { + name: string; + details?: { + parent_model?: string + format: string + family: string + families: string[] + parameter_size: string + quantization_level: string + } + }[]; + }; + return models.map((data) => { + return { + ...data, + label: data.name, + value: data.name, + stream: true + }; + }); + } catch (error) { + console.log(`Error fetching Ollama models`, error); + return []; + } +}; diff --git a/server/src/utils/store.ts b/server/src/utils/store.ts index 7bcff0d1..c104fd43 100644 --- a/server/src/utils/store.ts +++ b/server/src/utils/store.ts @@ -40,8 +40,7 @@ export class DialoqbaseVectorStore extends VectorStore { if (row?.embedding) { const vector = `[${row.embedding.join(",")}]`; const content = row?.content.replace(/\x00/g, "").trim(); - await prisma - .$executeRaw`INSERT INTO "BotDocument" ("content", "embedding", "metadata", "botId", "sourceId") VALUES (${content}, ${vector}::vector, ${row.metadata}, ${row.botId}, ${row.sourceId})`; + await prisma.$executeRaw`INSERT INTO "BotDocument" ("content", "embedding", "metadata", "botId", "sourceId") VALUES (${content}, ${vector}::vector, ${row.metadata}, ${row.botId}, ${row.sourceId})`; } }); } catch (e) { @@ -57,7 +56,7 @@ export class DialoqbaseVectorStore extends VectorStore { static async fromDocuments( docs: Document[], embeddings: Embeddings, - dbConfig: DialoqbaseLibArgs, + dbConfig: DialoqbaseLibArgs ) { const instance = new this(embeddings, dbConfig); await instance.addDocuments(docs); @@ -68,7 +67,7 @@ export class DialoqbaseVectorStore extends VectorStore { texts: string[], metadatas: object[] | object, embeddings: Embeddings, - dbConfig: DialoqbaseLibArgs, + dbConfig: DialoqbaseLibArgs ) { const docs = []; for (let i = 0; i < texts.length; i += 1) { @@ -84,7 +83,7 @@ export class DialoqbaseVectorStore extends VectorStore { static async fromExistingIndex( embeddings: Embeddings, - dbConfig: DialoqbaseLibArgs, + dbConfig: DialoqbaseLibArgs ) { const instance = new this(embeddings, dbConfig); return instance; @@ -93,14 +92,22 @@ export class DialoqbaseVectorStore extends VectorStore { async similaritySearchVectorWithScore( query: number[], k: number, - filter?: this["FilterType"] | undefined, + filter?: this["FilterType"] | undefined ): Promise<[Document>, number][]> { console.log(this.botId); const vector = `[${query.join(",")}]`; const bot_id = this.botId; + const botInfo = await prisma.bot.findFirst({ + where: { + id: bot_id, + }, + }); + + const match_count = botInfo?.noOfDocumentsToRetrieve || k; + const data = await prisma.$queryRaw` - SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${k}::int) + SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${match_count}::int) `; const result: [Document, number][] = (