diff --git a/.env.local.example b/.env.local.example index 66316ce..6f7e8b6 100644 --- a/.env.local.example +++ b/.env.local.example @@ -1,3 +1,7 @@ +# Pick Vector DB +VECTOR_DB=pinecone +# VECTOR_DB=supabase + # Clerk related environment variables NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY=pk_**** CLERK_SECRET_KEY=sk_**** diff --git a/src/app/api/chatgpt/route.ts b/src/app/api/chatgpt/route.ts index f533800..75208f7 100644 --- a/src/app/api/chatgpt/route.ts +++ b/src/app/api/chatgpt/route.ts @@ -1,6 +1,3 @@ -import { OpenAIEmbeddings } from "langchain/embeddings/openai"; -import { PineconeClient } from "@pinecone-database/pinecone"; -import { PineconeStore } from "langchain/vectorstores/pinecone"; import { OpenAI } from "langchain/llms/openai"; import dotenv from "dotenv"; import { LLMChain } from "langchain/chains"; @@ -38,7 +35,7 @@ export async function POST(req: Request) { // XXX Companion name passed here. Can use as a key to get backstory, chat history etc. const name = req.headers.get("name"); - const companion_file_name = name + ".txt"; + const companionFileName = name + ".txt"; console.log("prompt: ", prompt); if (isText) { @@ -69,7 +66,7 @@ export async function POST(req: Request) { // discussion. The PREAMBLE should include a seed conversation whose format will // vary by the model using it. const fs = require("fs").promises; - const data = await fs.readFile("companions/" + companion_file_name, "utf8"); + const data = await fs.readFile("companions/" + companionFileName, "utf8"); // Clunky way to break out PREAMBLE and SEEDCHAT from the character file const presplit = data.split("###ENDPREAMBLE###"); @@ -77,43 +74,27 @@ export async function POST(req: Request) { const seedsplit = presplit[1].split("###ENDSEEDCHAT###"); const seedchat = seedsplit[0]; - // console.log("Preamble: "+preamble); - // console.log("Seedchat: "+seedchat); - - const memoryManager = new MemoryManager({ + const companionKey = { companionName: name!, modelName: "chatgpt", userId: clerkUserId, - }); + }; + const memoryManager = await MemoryManager.getInstance(); - const records = await memoryManager.readLatestHistory(); + const records = await memoryManager.readLatestHistory(companionKey); if (records.length === 0) { - await memoryManager.seedChatHistory(seedchat, "\n\n"); + await memoryManager.seedChatHistory(seedchat, "\n\n", companionKey); } - await memoryManager.writeToHistory("Human: " + prompt + "\n"); + await memoryManager.writeToHistory("Human: " + prompt + "\n", companionKey); + let recentChatHistory = await memoryManager.readLatestHistory(companionKey); // query Pinecone - const client = new PineconeClient(); - await client.init({ - apiKey: process.env.PINECONE_API_KEY || "", - environment: process.env.PINECONE_ENVIRONMENT || "", - }); - const pineconeIndex = client.Index(process.env.PINECONE_INDEX || ""); - - const vectorStore = await PineconeStore.fromExistingIndex( - new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), - { pineconeIndex } + const similarDocs = await memoryManager.vectorSearch( + recentChatHistory, + companionFileName ); - let recentChatHistory = await memoryManager.readLatestHistory(); - - const similarDocs = await vectorStore - .similaritySearch(recentChatHistory, 3, { fileName: companion_file_name }) - .catch((err) => { - console.log("WARNING: failed to get vector search results.", err); - }); - let relevantHistory = ""; if (!!similarDocs && similarDocs.length !== 0) { relevantHistory = similarDocs.map((doc) => doc.pageContent).join("\n"); @@ -161,7 +142,8 @@ export async function POST(req: Request) { console.log("result", result); const chatHistoryRecord = await memoryManager.writeToHistory( - result!.text + "\n" + result!.text + "\n", + companionKey ); console.log("chatHistoryRecord", chatHistoryRecord); if (isText) { diff --git a/src/app/api/vicuna13b/route.ts b/src/app/api/vicuna13b/route.ts index 533da50..2d99081 100644 --- a/src/app/api/vicuna13b/route.ts +++ b/src/app/api/vicuna13b/route.ts @@ -2,9 +2,6 @@ import dotenv from "dotenv"; import { StreamingTextResponse, LangChainStream } from "ai"; import { Replicate } from "langchain/llms/replicate"; import { CallbackManager } from "langchain/callbacks"; -import { OpenAIEmbeddings } from "langchain/embeddings/openai"; -import { PineconeClient } from "@pinecone-database/pinecone"; -import { PineconeStore } from "langchain/vectorstores/pinecone"; import clerk from "@clerk/clerk-sdk-node"; import MemoryManager from "@/app/utils/memory"; import { currentUser } from "@clerk/nextjs"; @@ -73,47 +70,35 @@ export async function POST(request: Request) { const seedsplit = presplit[1].split("###ENDSEEDCHAT###"); const seedchat = seedsplit[0]; - // console.log("Preamble: "+preamble); - // console.log("Seedchat: "+seedchat); - - const memoryManager = new MemoryManager({ + const companionKey = { companionName: name!, userId: clerkUserId!, modelName: "vicuna13b", - }); + }; + const memoryManager = await MemoryManager.getInstance(); const { stream, handlers } = LangChainStream(); - const records = await memoryManager.readLatestHistory(); + const records = await memoryManager.readLatestHistory(companionKey); if (records.length === 0) { - await memoryManager.seedChatHistory(seedchat, "\n\n"); + await memoryManager.seedChatHistory(seedchat, "\n\n", companionKey); } - await memoryManager.writeToHistory("### Human: " + prompt + "\n"); + await memoryManager.writeToHistory( + "### Human: " + prompt + "\n", + companionKey + ); // Query Pinecone - const client = new PineconeClient(); - await client.init({ - apiKey: process.env.PINECONE_API_KEY || "", - environment: process.env.PINECONE_ENVIRONMENT || "", - }); - const pineconeIndex = client.Index(process.env.PINECONE_INDEX || ""); - const vectorStore = await PineconeStore.fromExistingIndex( - new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), - { pineconeIndex } - ); - - let recentChatHistory = ""; - recentChatHistory = await memoryManager.readLatestHistory(); + let recentChatHistory = await memoryManager.readLatestHistory(companionKey); // Right now the preamble is included in the similarity search, but that // shouldn't be an issue - const similarDocs = await vectorStore - .similaritySearch(recentChatHistory, 3, { fileName: companion_file_name }) - .catch((err) => { - console.log("WARNING: failed to get vector search results.", err); - }); + const similarDocs = await memoryManager.vectorSearch( + recentChatHistory, + companion_file_name + ); let relevantHistory = ""; if (!!similarDocs && similarDocs.length !== 0) { @@ -159,14 +144,14 @@ export async function POST(request: Request) { const response = chunks[0]; // const response = chunks.length > 1 ? chunks[0] : chunks[0]; - await memoryManager.writeToHistory("### " + response.trim()); + await memoryManager.writeToHistory("### " + response.trim(), companionKey); var Readable = require("stream").Readable; let s = new Readable(); s.push(response); s.push(null); if (response !== undefined && response.length > 1) { - await memoryManager.writeToHistory("### " + response.trim()); + await memoryManager.writeToHistory("### " + response.trim(), companionKey); } return new StreamingTextResponse(s); diff --git a/src/app/utils/memory.ts b/src/app/utils/memory.ts index 0aa3151..199c560 100644 --- a/src/app/utils/memory.ts +++ b/src/app/utils/memory.ts @@ -1,5 +1,9 @@ import { Redis } from "@upstash/redis"; -import { get } from "http"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { PineconeClient } from "@pinecone-database/pinecone"; +import { PineconeStore } from "langchain/vectorstores/pinecone"; +import { SupabaseVectorStore } from "langchain/vectorstores/supabase"; +import { SupabaseClient, createClient } from "@supabase/supabase-js"; export type CompanionKey = { companionName: string; @@ -10,31 +14,95 @@ export type CompanionKey = { class MemoryManager { private static instance: MemoryManager; private history: Redis; - private companionKey: CompanionKey; + private vectorDBClient: PineconeClient | SupabaseClient; - public constructor(companionKey: CompanionKey) { + public constructor() { this.history = Redis.fromEnv(); - this.companionKey = companionKey; + if (process.env.VECTOR_DB === "pinecone") { + this.vectorDBClient = new PineconeClient(); + } else { + const auth = { + detectSessionInUrl: false, + persistSession: false, + autoRefreshToken: false, + }; + const url = process.env.SUPABASE_URL!; + const privateKey = process.env.SUPABASE_PRIVATE_KEY!; + this.vectorDBClient = createClient(url, privateKey, { auth }); + } + } + + public async init() { + if (this.vectorDBClient instanceof PineconeClient) { + await this.vectorDBClient.init({ + apiKey: process.env.PINECONE_API_KEY!, + environment: process.env.PINECONE_ENVIRONMENT!, + }); + } + } + + public async vectorSearch( + recentChatHistory: string, + companionFileName: string + ) { + if (process.env.VECTOR_DB === "pinecone") { + console.log("INFO: using Pinecone for vector search."); + const pineconeClient = this.vectorDBClient; + + const pineconeIndex = pineconeClient.Index( + process.env.PINECONE_INDEX! || "" + ); + + const vectorStore = await PineconeStore.fromExistingIndex( + new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), + { pineconeIndex } + ); + + const similarDocs = await vectorStore + .similaritySearch(recentChatHistory, 3, { fileName: companionFileName }) + .catch((err) => { + console.log("WARNING: failed to get vector search results.", err); + }); + return similarDocs; + } else { + console.log("INFO: using Supabase for vector search."); + const supabaseClient = this.vectorDBClient; + const vectorStore = await SupabaseVectorStore.fromExistingIndex( + new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), + { + client: supabaseClient, + tableName: "documents", + queryName: "match_documents", + } + ); + const similarDocs = await vectorStore + .similaritySearch(recentChatHistory, 3) + .catch((err) => { + console.log("WARNING: failed to get vector search results.", err); + }); + return similarDocs; + } } - public static getInstance(companionKey: CompanionKey): MemoryManager { + public static async getInstance(): Promise { if (!MemoryManager.instance) { - MemoryManager.instance = new MemoryManager(companionKey); + MemoryManager.instance = new MemoryManager(); + await MemoryManager.instance.init(); } return MemoryManager.instance; } - public getCompanionKey(): string { - return `${this.companionKey.companionName}-${this.companionKey.modelName}-${this.companionKey.userId}`; + private generateRedisCompanionKey(companionKey: CompanionKey): string { + return `${companionKey.companionName}-${companionKey.modelName}-${companionKey.userId}`; } - public async writeToHistory(text: string) { - if (!this.companionKey || typeof this.companionKey.userId == "undefined") { + public async writeToHistory(text: string, companionKey: CompanionKey) { + if (!companionKey || typeof companionKey.userId == "undefined") { console.log("Companion key set incorrectly"); return ""; } - const key = this.getCompanionKey(); + const key = this.generateRedisCompanionKey(companionKey); const result = await this.history.zadd(key, { score: Date.now(), member: text, @@ -43,13 +111,13 @@ class MemoryManager { return result; } - public async readLatestHistory(): Promise { - if (!this.companionKey || typeof this.companionKey.userId == "undefined") { + public async readLatestHistory(companionKey: CompanionKey): Promise { + if (!companionKey || typeof companionKey.userId == "undefined") { console.log("Companion key set incorrectly"); return ""; } - const key = this.getCompanionKey(); + const key = this.generateRedisCompanionKey(companionKey); let result = await this.history.zrange(key, 0, Date.now(), { byScore: true, }); @@ -59,8 +127,12 @@ class MemoryManager { return recentChats; } - public async seedChatHistory(seedContent: String, delimiter: string = "\n") { - const key = this.getCompanionKey(); + public async seedChatHistory( + seedContent: String, + delimiter: string = "\n", + companionKey: CompanionKey + ) { + const key = this.generateRedisCompanionKey(companionKey); if (await this.history.exists(key)) { console.log("User already has chat history"); return;