Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor vector search management #38

Merged
merged 2 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.local.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Pick Vector DB
VECTOR_DB=pinecone
# VECTOR_DB=supabase

# Clerk related environment variables
NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY=pk_****
CLERK_SECRET_KEY=sk_****
Expand Down
46 changes: 14 additions & 32 deletions src/app/api/chatgpt/route.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { PineconeClient } from "@pinecone-database/pinecone";
import { PineconeStore } from "langchain/vectorstores/pinecone";
import { OpenAI } from "langchain/llms/openai";
import dotenv from "dotenv";
import { LLMChain } from "langchain/chains";
Expand Down Expand Up @@ -38,7 +35,7 @@ export async function POST(req: Request) {

// XXX Companion name passed here. Can use as a key to get backstory, chat history etc.
const name = req.headers.get("name");
const companion_file_name = name + ".txt";
const companionFileName = name + ".txt";

console.log("prompt: ", prompt);
if (isText) {
Expand Down Expand Up @@ -69,51 +66,35 @@ export async function POST(req: Request) {
// discussion. The PREAMBLE should include a seed conversation whose format will
// vary by the model using it.
const fs = require("fs").promises;
const data = await fs.readFile("companions/" + companion_file_name, "utf8");
const data = await fs.readFile("companions/" + companionFileName, "utf8");

// Clunky way to break out PREAMBLE and SEEDCHAT from the character file
const presplit = data.split("###ENDPREAMBLE###");
const preamble = presplit[0];
const seedsplit = presplit[1].split("###ENDSEEDCHAT###");
const seedchat = seedsplit[0];

// console.log("Preamble: "+preamble);
// console.log("Seedchat: "+seedchat);

const memoryManager = new MemoryManager({
const companionKey = {
companionName: name!,
modelName: "chatgpt",
userId: clerkUserId,
});
};
const memoryManager = await MemoryManager.getInstance();

const records = await memoryManager.readLatestHistory();
const records = await memoryManager.readLatestHistory(companionKey);
if (records.length === 0) {
await memoryManager.seedChatHistory(seedchat, "\n\n");
await memoryManager.seedChatHistory(seedchat, "\n\n", companionKey);
}

await memoryManager.writeToHistory("Human: " + prompt + "\n");
await memoryManager.writeToHistory("Human: " + prompt + "\n", companionKey);
let recentChatHistory = await memoryManager.readLatestHistory(companionKey);

// query Pinecone
const client = new PineconeClient();
await client.init({
apiKey: process.env.PINECONE_API_KEY || "",
environment: process.env.PINECONE_ENVIRONMENT || "",
});
const pineconeIndex = client.Index(process.env.PINECONE_INDEX || "");

const vectorStore = await PineconeStore.fromExistingIndex(
new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }),
{ pineconeIndex }
const similarDocs = await memoryManager.vectorSearch(
recentChatHistory,
companionFileName
);

let recentChatHistory = await memoryManager.readLatestHistory();

const similarDocs = await vectorStore
.similaritySearch(recentChatHistory, 3, { fileName: companion_file_name })
.catch((err) => {
console.log("WARNING: failed to get vector search results.", err);
});

let relevantHistory = "";
if (!!similarDocs && similarDocs.length !== 0) {
relevantHistory = similarDocs.map((doc) => doc.pageContent).join("\n");
Expand Down Expand Up @@ -161,7 +142,8 @@ export async function POST(req: Request) {

console.log("result", result);
const chatHistoryRecord = await memoryManager.writeToHistory(
result!.text + "\n"
result!.text + "\n",
companionKey
);
console.log("chatHistoryRecord", chatHistoryRecord);
if (isText) {
Expand Down
47 changes: 16 additions & 31 deletions src/app/api/vicuna13b/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ import dotenv from "dotenv";
import { StreamingTextResponse, LangChainStream } from "ai";
import { Replicate } from "langchain/llms/replicate";
import { CallbackManager } from "langchain/callbacks";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { PineconeClient } from "@pinecone-database/pinecone";
import { PineconeStore } from "langchain/vectorstores/pinecone";
import clerk from "@clerk/clerk-sdk-node";
import MemoryManager from "@/app/utils/memory";
import { currentUser } from "@clerk/nextjs";
Expand Down Expand Up @@ -73,47 +70,35 @@ export async function POST(request: Request) {
const seedsplit = presplit[1].split("###ENDSEEDCHAT###");
const seedchat = seedsplit[0];

// console.log("Preamble: "+preamble);
// console.log("Seedchat: "+seedchat);

const memoryManager = new MemoryManager({
const companionKey = {
companionName: name!,
userId: clerkUserId!,
modelName: "vicuna13b",
});
};
const memoryManager = await MemoryManager.getInstance();

const { stream, handlers } = LangChainStream();

const records = await memoryManager.readLatestHistory();
const records = await memoryManager.readLatestHistory(companionKey);
if (records.length === 0) {
await memoryManager.seedChatHistory(seedchat, "\n\n");
await memoryManager.seedChatHistory(seedchat, "\n\n", companionKey);
}
await memoryManager.writeToHistory("### Human: " + prompt + "\n");
await memoryManager.writeToHistory(
"### Human: " + prompt + "\n",
companionKey
);

// Query Pinecone
const client = new PineconeClient();
await client.init({
apiKey: process.env.PINECONE_API_KEY || "",
environment: process.env.PINECONE_ENVIRONMENT || "",
});
const pineconeIndex = client.Index(process.env.PINECONE_INDEX || "");

const vectorStore = await PineconeStore.fromExistingIndex(
new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }),
{ pineconeIndex }
);

let recentChatHistory = "";
recentChatHistory = await memoryManager.readLatestHistory();
let recentChatHistory = await memoryManager.readLatestHistory(companionKey);

// Right now the preamble is included in the similarity search, but that
// shouldn't be an issue

const similarDocs = await vectorStore
.similaritySearch(recentChatHistory, 3, { fileName: companion_file_name })
.catch((err) => {
console.log("WARNING: failed to get vector search results.", err);
});
const similarDocs = await memoryManager.vectorSearch(
recentChatHistory,
companion_file_name
);

let relevantHistory = "";
if (!!similarDocs && similarDocs.length !== 0) {
Expand Down Expand Up @@ -159,14 +144,14 @@ export async function POST(request: Request) {
const response = chunks[0];
// const response = chunks.length > 1 ? chunks[0] : chunks[0];

await memoryManager.writeToHistory("### " + response.trim());
await memoryManager.writeToHistory("### " + response.trim(), companionKey);
var Readable = require("stream").Readable;

let s = new Readable();
s.push(response);
s.push(null);
if (response !== undefined && response.length > 1) {
await memoryManager.writeToHistory("### " + response.trim());
await memoryManager.writeToHistory("### " + response.trim(), companionKey);
}

return new StreamingTextResponse(s);
Expand Down
104 changes: 88 additions & 16 deletions src/app/utils/memory.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { Redis } from "@upstash/redis";
import { get } from "http";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { PineconeClient } from "@pinecone-database/pinecone";
import { PineconeStore } from "langchain/vectorstores/pinecone";
import { SupabaseVectorStore } from "langchain/vectorstores/supabase";
import { SupabaseClient, createClient } from "@supabase/supabase-js";

export type CompanionKey = {
companionName: string;
Expand All @@ -10,31 +14,95 @@ export type CompanionKey = {
class MemoryManager {
private static instance: MemoryManager;
private history: Redis;
private companionKey: CompanionKey;
private vectorDBClient: PineconeClient | SupabaseClient;

public constructor(companionKey: CompanionKey) {
public constructor() {
this.history = Redis.fromEnv();
this.companionKey = companionKey;
if (process.env.VECTOR_DB === "pinecone") {
this.vectorDBClient = new PineconeClient();
} else {
const auth = {
detectSessionInUrl: false,
persistSession: false,
autoRefreshToken: false,
};
const url = process.env.SUPABASE_URL!;
const privateKey = process.env.SUPABASE_PRIVATE_KEY!;
this.vectorDBClient = createClient(url, privateKey, { auth });
}
}

public async init() {
if (this.vectorDBClient instanceof PineconeClient) {
await this.vectorDBClient.init({
apiKey: process.env.PINECONE_API_KEY!,
environment: process.env.PINECONE_ENVIRONMENT!,
});
}
}

public async vectorSearch(
recentChatHistory: string,
companionFileName: string
) {
if (process.env.VECTOR_DB === "pinecone") {
console.log("INFO: using Pinecone for vector search.");
const pineconeClient = <PineconeClient>this.vectorDBClient;

const pineconeIndex = pineconeClient.Index(
process.env.PINECONE_INDEX! || ""
);

const vectorStore = await PineconeStore.fromExistingIndex(
new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }),
{ pineconeIndex }
);

const similarDocs = await vectorStore
.similaritySearch(recentChatHistory, 3, { fileName: companionFileName })
.catch((err) => {
console.log("WARNING: failed to get vector search results.", err);
});
return similarDocs;
} else {
console.log("INFO: using Supabase for vector search.");
const supabaseClient = <SupabaseClient>this.vectorDBClient;
const vectorStore = await SupabaseVectorStore.fromExistingIndex(
new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }),
{
client: supabaseClient,
tableName: "documents",
queryName: "match_documents",
}
);
const similarDocs = await vectorStore
.similaritySearch(recentChatHistory, 3)
.catch((err) => {
console.log("WARNING: failed to get vector search results.", err);
});
return similarDocs;
}
}

public static getInstance(companionKey: CompanionKey): MemoryManager {
public static async getInstance(): Promise<MemoryManager> {
if (!MemoryManager.instance) {
MemoryManager.instance = new MemoryManager(companionKey);
MemoryManager.instance = new MemoryManager();
await MemoryManager.instance.init();
}
return MemoryManager.instance;
}

public getCompanionKey(): string {
return `${this.companionKey.companionName}-${this.companionKey.modelName}-${this.companionKey.userId}`;
private generateRedisCompanionKey(companionKey: CompanionKey): string {
return `${companionKey.companionName}-${companionKey.modelName}-${companionKey.userId}`;
}

public async writeToHistory(text: string) {
if (!this.companionKey || typeof this.companionKey.userId == "undefined") {
public async writeToHistory(text: string, companionKey: CompanionKey) {
if (!companionKey || typeof companionKey.userId == "undefined") {
console.log("Companion key set incorrectly");
return "";
}

const key = this.getCompanionKey();
const key = this.generateRedisCompanionKey(companionKey);
const result = await this.history.zadd(key, {
score: Date.now(),
member: text,
Expand All @@ -43,13 +111,13 @@ class MemoryManager {
return result;
}

public async readLatestHistory(): Promise<string> {
if (!this.companionKey || typeof this.companionKey.userId == "undefined") {
public async readLatestHistory(companionKey: CompanionKey): Promise<string> {
if (!companionKey || typeof companionKey.userId == "undefined") {
console.log("Companion key set incorrectly");
return "";
}

const key = this.getCompanionKey();
const key = this.generateRedisCompanionKey(companionKey);
let result = await this.history.zrange(key, 0, Date.now(), {
byScore: true,
});
Expand All @@ -59,8 +127,12 @@ class MemoryManager {
return recentChats;
}

public async seedChatHistory(seedContent: String, delimiter: string = "\n") {
const key = this.getCompanionKey();
public async seedChatHistory(
seedContent: String,
delimiter: string = "\n",
companionKey: CompanionKey
) {
const key = this.generateRedisCompanionKey(companionKey);
if (await this.history.exists(key)) {
console.log("User already has chat history");
return;
Expand Down