From 8bee1f0efb175b6f4a0b5d1f5adb39c4667b5272 Mon Sep 17 00:00:00 2001 From: Yoko Li Date: Tue, 4 Jul 2023 20:37:55 -0700 Subject: [PATCH] refactor vector search management --- .env.local.example | 4 +++ src/app/api/chatgpt/route.ts | 32 ++++------------- src/app/api/vicuna13b/route.ts | 29 +++------------ src/app/utils/memory.ts | 66 +++++++++++++++++++++++++++++++++- 4 files changed, 80 insertions(+), 51 deletions(-) diff --git a/.env.local.example b/.env.local.example index 66316ce..6f7e8b6 100644 --- a/.env.local.example +++ b/.env.local.example @@ -1,3 +1,7 @@ +# Pick Vector DB +VECTOR_DB=pinecone +# VECTOR_DB=supabase + # Clerk related environment variables NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY=pk_**** CLERK_SECRET_KEY=sk_**** diff --git a/src/app/api/chatgpt/route.ts b/src/app/api/chatgpt/route.ts index f533800..8bfd1d2 100644 --- a/src/app/api/chatgpt/route.ts +++ b/src/app/api/chatgpt/route.ts @@ -1,6 +1,3 @@ -import { OpenAIEmbeddings } from "langchain/embeddings/openai"; -import { PineconeClient } from "@pinecone-database/pinecone"; -import { PineconeStore } from "langchain/vectorstores/pinecone"; import { OpenAI } from "langchain/llms/openai"; import dotenv from "dotenv"; import { LLMChain } from "langchain/chains"; @@ -38,7 +35,7 @@ export async function POST(req: Request) { // XXX Companion name passed here. Can use as a key to get backstory, chat history etc. const name = req.headers.get("name"); - const companion_file_name = name + ".txt"; + const companionFileName = name + ".txt"; console.log("prompt: ", prompt); if (isText) { @@ -69,7 +66,7 @@ export async function POST(req: Request) { // discussion. The PREAMBLE should include a seed conversation whose format will // vary by the model using it. const fs = require("fs").promises; - const data = await fs.readFile("companions/" + companion_file_name, "utf8"); + const data = await fs.readFile("companions/" + companionFileName, "utf8"); // Clunky way to break out PREAMBLE and SEEDCHAT from the character file const presplit = data.split("###ENDPREAMBLE###"); @@ -77,9 +74,6 @@ export async function POST(req: Request) { const seedsplit = presplit[1].split("###ENDSEEDCHAT###"); const seedchat = seedsplit[0]; - // console.log("Preamble: "+preamble); - // console.log("Seedchat: "+seedchat); - const memoryManager = new MemoryManager({ companionName: name!, modelName: "chatgpt", @@ -92,28 +86,14 @@ export async function POST(req: Request) { } await memoryManager.writeToHistory("Human: " + prompt + "\n"); + let recentChatHistory = await memoryManager.readLatestHistory(); // query Pinecone - const client = new PineconeClient(); - await client.init({ - apiKey: process.env.PINECONE_API_KEY || "", - environment: process.env.PINECONE_ENVIRONMENT || "", - }); - const pineconeIndex = client.Index(process.env.PINECONE_INDEX || ""); - - const vectorStore = await PineconeStore.fromExistingIndex( - new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), - { pineconeIndex } + const similarDocs = await memoryManager.vectorSearch( + recentChatHistory, + companionFileName ); - let recentChatHistory = await memoryManager.readLatestHistory(); - - const similarDocs = await vectorStore - .similaritySearch(recentChatHistory, 3, { fileName: companion_file_name }) - .catch((err) => { - console.log("WARNING: failed to get vector search results.", err); - }); - let relevantHistory = ""; if (!!similarDocs && similarDocs.length !== 0) { relevantHistory = similarDocs.map((doc) => doc.pageContent).join("\n"); diff --git a/src/app/api/vicuna13b/route.ts b/src/app/api/vicuna13b/route.ts index 533da50..db32c3a 100644 --- a/src/app/api/vicuna13b/route.ts +++ b/src/app/api/vicuna13b/route.ts @@ -2,9 +2,6 @@ import dotenv from "dotenv"; import { StreamingTextResponse, LangChainStream } from "ai"; import { Replicate } from "langchain/llms/replicate"; import { CallbackManager } from "langchain/callbacks"; -import { OpenAIEmbeddings } from "langchain/embeddings/openai"; -import { PineconeClient } from "@pinecone-database/pinecone"; -import { PineconeStore } from "langchain/vectorstores/pinecone"; import clerk from "@clerk/clerk-sdk-node"; import MemoryManager from "@/app/utils/memory"; import { currentUser } from "@clerk/nextjs"; @@ -73,9 +70,6 @@ export async function POST(request: Request) { const seedsplit = presplit[1].split("###ENDSEEDCHAT###"); const seedchat = seedsplit[0]; - // console.log("Preamble: "+preamble); - // console.log("Seedchat: "+seedchat); - const memoryManager = new MemoryManager({ companionName: name!, userId: clerkUserId!, @@ -91,29 +85,16 @@ export async function POST(request: Request) { await memoryManager.writeToHistory("### Human: " + prompt + "\n"); // Query Pinecone - const client = new PineconeClient(); - await client.init({ - apiKey: process.env.PINECONE_API_KEY || "", - environment: process.env.PINECONE_ENVIRONMENT || "", - }); - const pineconeIndex = client.Index(process.env.PINECONE_INDEX || ""); - const vectorStore = await PineconeStore.fromExistingIndex( - new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), - { pineconeIndex } - ); - - let recentChatHistory = ""; - recentChatHistory = await memoryManager.readLatestHistory(); + let recentChatHistory = await memoryManager.readLatestHistory(); // Right now the preamble is included in the similarity search, but that // shouldn't be an issue - const similarDocs = await vectorStore - .similaritySearch(recentChatHistory, 3, { fileName: companion_file_name }) - .catch((err) => { - console.log("WARNING: failed to get vector search results.", err); - }); + const similarDocs = await memoryManager.vectorSearch( + recentChatHistory, + companion_file_name + ); let relevantHistory = ""; if (!!similarDocs && similarDocs.length !== 0) { diff --git a/src/app/utils/memory.ts b/src/app/utils/memory.ts index 0aa3151..d9b738b 100644 --- a/src/app/utils/memory.ts +++ b/src/app/utils/memory.ts @@ -1,5 +1,9 @@ import { Redis } from "@upstash/redis"; -import { get } from "http"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { PineconeClient } from "@pinecone-database/pinecone"; +import { PineconeStore } from "langchain/vectorstores/pinecone"; +import { SupabaseVectorStore } from "langchain/vectorstores/supabase"; +import { SupabaseClient, createClient } from "@supabase/supabase-js"; export type CompanionKey = { companionName: string; @@ -11,10 +15,70 @@ class MemoryManager { private static instance: MemoryManager; private history: Redis; private companionKey: CompanionKey; + private vectorDBClient: PineconeClient | SupabaseClient; public constructor(companionKey: CompanionKey) { this.history = Redis.fromEnv(); this.companionKey = companionKey; + if (process.env.VECTOR_DB === "pinecone") { + this.vectorDBClient = new PineconeClient(); + this.vectorDBClient.init({ + apiKey: process.env.PINECONE_API_KEY!, + environment: process.env.PINECONE_ENVIRONMENT!, + }); + } else { + const auth = { + detectSessionInUrl: false, + persistSession: false, + autoRefreshToken: false, + }; + const url = process.env.SUPABASE_URL!; + const privateKey = process.env.SUPABASE_PRIVATE_KEY!; + this.vectorDBClient = createClient(url, privateKey, { auth }); + } + } + + public async vectorSearch( + recentChatHistory: string, + companionFileName: string + ) { + if (process.env.VECTOR_DB === "pinecone") { + console.log("INFO: using Pinecone for vector search."); + const pineconeClient = this.vectorDBClient; + + const pineconeIndex = pineconeClient.Index( + process.env.PINECONE_INDEX || "" + ); + + const vectorStore = await PineconeStore.fromExistingIndex( + new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), + { pineconeIndex } + ); + + const similarDocs = await vectorStore + .similaritySearch(recentChatHistory, 3, { fileName: companionFileName }) + .catch((err) => { + console.log("WARNING: failed to get vector search results.", err); + }); + return similarDocs; + } else { + console.log("INFO: using Supabase for vector search."); + const supabaseClient = this.vectorDBClient; + const vectorStore = await SupabaseVectorStore.fromExistingIndex( + new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }), + { + client: supabaseClient, + tableName: "documents", + queryName: "match_documents", + } + ); + const similarDocs = await vectorStore + .similaritySearch(recentChatHistory, 3) + .catch((err) => { + console.log("WARNING: failed to get vector search results.", err); + }); + return similarDocs; + } } public static getInstance(companionKey: CompanionKey): MemoryManager {