From 8bee1f0efb175b6f4a0b5d1f5adb39c4667b5272 Mon Sep 17 00:00:00 2001 From: Yoko Li Date: Tue, 4 Jul 2023 20:37:55 -0700 Subject: [PATCH 1/2] refactor vector search management --- .env.local.example | 4 +++ src/app/api/chatgpt/route.ts | 32 ++++------------- src/app/api/vicuna13b/route.ts | 29 +++------------ src/app/utils/memory.ts | 66 +++++++++++++++++++++++++++++++++- 4 files changed, 80 insertions(+), 51 deletions(-) diff --git a/.env.local.example b/.env.local.example index 66316ce..6f7e8b6 100644 --- a/.env.local.example +++ b/.env.local.example @@ -1,3 +1,7 @@ +# Pick Vector DB +VECTOR_DB=pinecone +# VECTOR_DB=supabase + # Clerk related environment variables NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY=pk_**** CLERK_SECRET_KEY=sk_**** diff --git a/src/app/api/chatgpt/route.ts b/src/app/api/chatgpt/route.ts index f533800..8bfd1d2 100644 --- a/src/app/api/chatgpt/route.ts +++ b/src/app/api/chatgpt/route.ts @@ -1,6 +1,3 @@ -import { OpenAIEmbeddings } from "langchain/embeddings/openai"; -import { PineconeClient } from "@pinecone-database/pinecone"; -import { PineconeStore } from "langchain/vectorstores/pinecone"; import { OpenAI } from "langchain/llms/openai"; import dotenv from "dotenv"; import { LLMChain } from "langchain/chains"; @@ -38,7 +35,7 @@ export async function POST(req: Request) { // XXX Companion name passed here. Can use as a key to get backstory, chat history etc. const name = req.headers.get("name"); - const companion_file_name = name + ".txt"; + const companionFileName = name + ".txt"; console.log("prompt: ", prompt); if (isText) { @@ -69,7 +66,7 @@ export async function POST(req: Request) { // discussion. The PREAMBLE should include a seed conversation whose format will // vary by the model using it. const fs = require("fs").promises; - const data = await fs.readFile("companions/" + companion_file_name, "utf8"); + const data = await fs.readFile("companions/" + companionFileName, "utf8"); // Clunky way to break out PREAMBLE and SEEDCHAT from the character file const presplit = data.split("###ENDPREAMBLE###"); @@ -77,9 +74,6 @@ export async function POST(req: Request) { const seedsplit = presplit[1].split("###ENDSEEDCHAT###"); const seedchat = seedsplit[0]; - // console.log("Preamble: "+preamble); - // console.log("Seedchat: "+seedchat); - const memoryManager = new MemoryManager({ companionName: name!, modelName: "chatgpt", @@ -92,28 +86,14 @@ export async function POST(req: Request) { } await memoryManager.writeToHistory("Human: " + prompt + "\n"); + let recentChatHistory = await memoryManager.readLatestHistory(); // query Pinecone - const client = new PineconeClient(); - await client.init({ - apiKey: process.env.PINECONE_API_KEY || "", - environment: process.env.PINECONE_ENVIRONMENT || "", - }); - const pineconeIndex = client.Index(process.env.PINECONE_INDEX || ""); - - const vectorStore = await PineconeStore.fromExistingIndex( - new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), - { pineconeIndex } + const similarDocs = await memoryManager.vectorSearch( + recentChatHistory, + companionFileName ); - let recentChatHistory = await memoryManager.readLatestHistory(); - - const similarDocs = await vectorStore - .similaritySearch(recentChatHistory, 3, { fileName: companion_file_name }) - .catch((err) => { - console.log("WARNING: failed to get vector search results.", err); - }); - let relevantHistory = ""; if (!!similarDocs && similarDocs.length !== 0) { relevantHistory = similarDocs.map((doc) => doc.pageContent).join("\n"); diff --git a/src/app/api/vicuna13b/route.ts b/src/app/api/vicuna13b/route.ts index 533da50..db32c3a 100644 --- a/src/app/api/vicuna13b/route.ts +++ b/src/app/api/vicuna13b/route.ts @@ -2,9 +2,6 @@ import dotenv from "dotenv"; import { StreamingTextResponse, LangChainStream } from "ai"; import { Replicate } from "langchain/llms/replicate"; import { CallbackManager } from "langchain/callbacks"; -import { OpenAIEmbeddings } from "langchain/embeddings/openai"; -import { PineconeClient } from "@pinecone-database/pinecone"; -import { PineconeStore } from "langchain/vectorstores/pinecone"; import clerk from "@clerk/clerk-sdk-node"; import MemoryManager from "@/app/utils/memory"; import { currentUser } from "@clerk/nextjs"; @@ -73,9 +70,6 @@ export async function POST(request: Request) { const seedsplit = presplit[1].split("###ENDSEEDCHAT###"); const seedchat = seedsplit[0]; - // console.log("Preamble: "+preamble); - // console.log("Seedchat: "+seedchat); - const memoryManager = new MemoryManager({ companionName: name!, userId: clerkUserId!, @@ -91,29 +85,16 @@ export async function POST(request: Request) { await memoryManager.writeToHistory("### Human: " + prompt + "\n"); // Query Pinecone - const client = new PineconeClient(); - await client.init({ - apiKey: process.env.PINECONE_API_KEY || "", - environment: process.env.PINECONE_ENVIRONMENT || "", - }); - const pineconeIndex = client.Index(process.env.PINECONE_INDEX || ""); - const vectorStore = await PineconeStore.fromExistingIndex( - new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), - { pineconeIndex } - ); - - let recentChatHistory = ""; - recentChatHistory = await memoryManager.readLatestHistory(); + let recentChatHistory = await memoryManager.readLatestHistory(); // Right now the preamble is included in the similarity search, but that // shouldn't be an issue - const similarDocs = await vectorStore - .similaritySearch(recentChatHistory, 3, { fileName: companion_file_name }) - .catch((err) => { - console.log("WARNING: failed to get vector search results.", err); - }); + const similarDocs = await memoryManager.vectorSearch( + recentChatHistory, + companion_file_name + ); let relevantHistory = ""; if (!!similarDocs && similarDocs.length !== 0) { diff --git a/src/app/utils/memory.ts b/src/app/utils/memory.ts index 0aa3151..d9b738b 100644 --- a/src/app/utils/memory.ts +++ b/src/app/utils/memory.ts @@ -1,5 +1,9 @@ import { Redis } from "@upstash/redis"; -import { get } from "http"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { PineconeClient } from "@pinecone-database/pinecone"; +import { PineconeStore } from "langchain/vectorstores/pinecone"; +import { SupabaseVectorStore } from "langchain/vectorstores/supabase"; +import { SupabaseClient, createClient } from "@supabase/supabase-js"; export type CompanionKey = { companionName: string; @@ -11,10 +15,70 @@ class MemoryManager { private static instance: MemoryManager; private history: Redis; private companionKey: CompanionKey; + private vectorDBClient: PineconeClient | SupabaseClient; public constructor(companionKey: CompanionKey) { this.history = Redis.fromEnv(); this.companionKey = companionKey; + if (process.env.VECTOR_DB === "pinecone") { + this.vectorDBClient = new PineconeClient(); + this.vectorDBClient.init({ + apiKey: process.env.PINECONE_API_KEY!, + environment: process.env.PINECONE_ENVIRONMENT!, + }); + } else { + const auth = { + detectSessionInUrl: false, + persistSession: false, + autoRefreshToken: false, + }; + const url = process.env.SUPABASE_URL!; + const privateKey = process.env.SUPABASE_PRIVATE_KEY!; + this.vectorDBClient = createClient(url, privateKey, { auth }); + } + } + + public async vectorSearch( + recentChatHistory: string, + companionFileName: string + ) { + if (process.env.VECTOR_DB === "pinecone") { + console.log("INFO: using Pinecone for vector search."); + const pineconeClient = this.vectorDBClient; + + const pineconeIndex = pineconeClient.Index( + process.env.PINECONE_INDEX || "" + ); + + const vectorStore = await PineconeStore.fromExistingIndex( + new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), + { pineconeIndex } + ); + + const similarDocs = await vectorStore + .similaritySearch(recentChatHistory, 3, { fileName: companionFileName }) + .catch((err) => { + console.log("WARNING: failed to get vector search results.", err); + }); + return similarDocs; + } else { + console.log("INFO: using Supabase for vector search."); + const supabaseClient = this.vectorDBClient; + const vectorStore = await SupabaseVectorStore.fromExistingIndex( + new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), + { + client: supabaseClient, + tableName: "documents", + queryName: "match_documents", + } + ); + const similarDocs = await vectorStore + .similaritySearch(recentChatHistory, 3) + .catch((err) => { + console.log("WARNING: failed to get vector search results.", err); + }); + return similarDocs; + } } public static getInstance(companionKey: CompanionKey): MemoryManager { From 088be5601da607fe55ba74cb0f42ef9fee529a33 Mon Sep 17 00:00:00 2001 From: Yoko Li Date: Sun, 9 Jul 2023 23:43:44 -0700 Subject: [PATCH 2/2] solved the companion key bug; refactor; got rid of this.companionKey --- src/app/api/chatgpt/route.ts | 16 +++++++----- src/app/api/vicuna13b/route.ts | 20 ++++++++------ src/app/utils/memory.ts | 48 ++++++++++++++++++++-------------- 3 files changed, 49 insertions(+), 35 deletions(-) diff --git a/src/app/api/chatgpt/route.ts b/src/app/api/chatgpt/route.ts index 8bfd1d2..75208f7 100644 --- a/src/app/api/chatgpt/route.ts +++ b/src/app/api/chatgpt/route.ts @@ -74,19 +74,20 @@ export async function POST(req: Request) { const seedsplit = presplit[1].split("###ENDSEEDCHAT###"); const seedchat = seedsplit[0]; - const memoryManager = new MemoryManager({ + const companionKey = { companionName: name!, modelName: "chatgpt", userId: clerkUserId, - }); + }; + const memoryManager = await MemoryManager.getInstance(); - const records = await memoryManager.readLatestHistory(); + const records = await memoryManager.readLatestHistory(companionKey); if (records.length === 0) { - await memoryManager.seedChatHistory(seedchat, "\n\n"); + await memoryManager.seedChatHistory(seedchat, "\n\n", companionKey); } - await memoryManager.writeToHistory("Human: " + prompt + "\n"); - let recentChatHistory = await memoryManager.readLatestHistory(); + await memoryManager.writeToHistory("Human: " + prompt + "\n", companionKey); + let recentChatHistory = await memoryManager.readLatestHistory(companionKey); // query Pinecone const similarDocs = await memoryManager.vectorSearch( @@ -141,7 +142,8 @@ export async function POST(req: Request) { console.log("result", result); const chatHistoryRecord = await memoryManager.writeToHistory( - result!.text + "\n" + result!.text + "\n", + companionKey ); console.log("chatHistoryRecord", chatHistoryRecord); if (isText) { diff --git a/src/app/api/vicuna13b/route.ts b/src/app/api/vicuna13b/route.ts index db32c3a..2d99081 100644 --- a/src/app/api/vicuna13b/route.ts +++ b/src/app/api/vicuna13b/route.ts @@ -70,23 +70,27 @@ export async function POST(request: Request) { const seedsplit = presplit[1].split("###ENDSEEDCHAT###"); const seedchat = seedsplit[0]; - const memoryManager = new MemoryManager({ + const companionKey = { companionName: name!, userId: clerkUserId!, modelName: "vicuna13b", - }); + }; + const memoryManager = await MemoryManager.getInstance(); const { stream, handlers } = LangChainStream(); - const records = await memoryManager.readLatestHistory(); + const records = await memoryManager.readLatestHistory(companionKey); if (records.length === 0) { - await memoryManager.seedChatHistory(seedchat, "\n\n"); + await memoryManager.seedChatHistory(seedchat, "\n\n", companionKey); } - await memoryManager.writeToHistory("### Human: " + prompt + "\n"); + await memoryManager.writeToHistory( + "### Human: " + prompt + "\n", + companionKey + ); // Query Pinecone - let recentChatHistory = await memoryManager.readLatestHistory(); + let recentChatHistory = await memoryManager.readLatestHistory(companionKey); // Right now the preamble is included in the similarity search, but that // shouldn't be an issue @@ -140,14 +144,14 @@ export async function POST(request: Request) { const response = chunks[0]; // const response = chunks.length > 1 ? chunks[0] : chunks[0]; - await memoryManager.writeToHistory("### " + response.trim()); + await memoryManager.writeToHistory("### " + response.trim(), companionKey); var Readable = require("stream").Readable; let s = new Readable(); s.push(response); s.push(null); if (response !== undefined && response.length > 1) { - await memoryManager.writeToHistory("### " + response.trim()); + await memoryManager.writeToHistory("### " + response.trim(), companionKey); } return new StreamingTextResponse(s); diff --git a/src/app/utils/memory.ts b/src/app/utils/memory.ts index d9b738b..199c560 100644 --- a/src/app/utils/memory.ts +++ b/src/app/utils/memory.ts @@ -14,18 +14,12 @@ export type CompanionKey = { class MemoryManager { private static instance: MemoryManager; private history: Redis; - private companionKey: CompanionKey; private vectorDBClient: PineconeClient | SupabaseClient; - public constructor(companionKey: CompanionKey) { + public constructor() { this.history = Redis.fromEnv(); - this.companionKey = companionKey; if (process.env.VECTOR_DB === "pinecone") { this.vectorDBClient = new PineconeClient(); - this.vectorDBClient.init({ - apiKey: process.env.PINECONE_API_KEY!, - environment: process.env.PINECONE_ENVIRONMENT!, - }); } else { const auth = { detectSessionInUrl: false, @@ -38,6 +32,15 @@ class MemoryManager { } } + public async init() { + if (this.vectorDBClient instanceof PineconeClient) { + await this.vectorDBClient.init({ + apiKey: process.env.PINECONE_API_KEY!, + environment: process.env.PINECONE_ENVIRONMENT!, + }); + } + } + public async vectorSearch( recentChatHistory: string, companionFileName: string @@ -47,7 +50,7 @@ class MemoryManager { const pineconeClient = this.vectorDBClient; const pineconeIndex = pineconeClient.Index( - process.env.PINECONE_INDEX || "" + process.env.PINECONE_INDEX! || "" ); const vectorStore = await PineconeStore.fromExistingIndex( @@ -81,24 +84,25 @@ class MemoryManager { } } - public static getInstance(companionKey: CompanionKey): MemoryManager { + public static async getInstance(): Promise { if (!MemoryManager.instance) { - MemoryManager.instance = new MemoryManager(companionKey); + MemoryManager.instance = new MemoryManager(); + await MemoryManager.instance.init(); } return MemoryManager.instance; } - public getCompanionKey(): string { - return `${this.companionKey.companionName}-${this.companionKey.modelName}-${this.companionKey.userId}`; + private generateRedisCompanionKey(companionKey: CompanionKey): string { + return `${companionKey.companionName}-${companionKey.modelName}-${companionKey.userId}`; } - public async writeToHistory(text: string) { - if (!this.companionKey || typeof this.companionKey.userId == "undefined") { + public async writeToHistory(text: string, companionKey: CompanionKey) { + if (!companionKey || typeof companionKey.userId == "undefined") { console.log("Companion key set incorrectly"); return ""; } - const key = this.getCompanionKey(); + const key = this.generateRedisCompanionKey(companionKey); const result = await this.history.zadd(key, { score: Date.now(), member: text, @@ -107,13 +111,13 @@ class MemoryManager { return result; } - public async readLatestHistory(): Promise { - if (!this.companionKey || typeof this.companionKey.userId == "undefined") { + public async readLatestHistory(companionKey: CompanionKey): Promise { + if (!companionKey || typeof companionKey.userId == "undefined") { console.log("Companion key set incorrectly"); return ""; } - const key = this.getCompanionKey(); + const key = this.generateRedisCompanionKey(companionKey); let result = await this.history.zrange(key, 0, Date.now(), { byScore: true, }); @@ -123,8 +127,12 @@ class MemoryManager { return recentChats; } - public async seedChatHistory(seedContent: String, delimiter: string = "\n") { - const key = this.getCompanionKey(); + public async seedChatHistory( + seedContent: String, + delimiter: string = "\n", + companionKey: CompanionKey + ) { + const key = this.generateRedisCompanionKey(companionKey); if (await this.history.exists(key)) { console.log("User already has chat history"); return;