Skip to content

Commit

Permalink
solved the companion key bug; refactor; got rid of this.companionKey
Browse files Browse the repository at this point in the history
  • Loading branch information
ykhli committed Jul 10, 2023
1 parent 8bee1f0 commit 088be56
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 35 deletions.
16 changes: 9 additions & 7 deletions src/app/api/chatgpt/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,20 @@ export async function POST(req: Request) {
const seedsplit = presplit[1].split("###ENDSEEDCHAT###");
const seedchat = seedsplit[0];

const memoryManager = new MemoryManager({
const companionKey = {
companionName: name!,
modelName: "chatgpt",
userId: clerkUserId,
});
};
const memoryManager = await MemoryManager.getInstance();

const records = await memoryManager.readLatestHistory();
const records = await memoryManager.readLatestHistory(companionKey);
if (records.length === 0) {
await memoryManager.seedChatHistory(seedchat, "\n\n");
await memoryManager.seedChatHistory(seedchat, "\n\n", companionKey);
}

await memoryManager.writeToHistory("Human: " + prompt + "\n");
let recentChatHistory = await memoryManager.readLatestHistory();
await memoryManager.writeToHistory("Human: " + prompt + "\n", companionKey);
let recentChatHistory = await memoryManager.readLatestHistory(companionKey);

// query Pinecone
const similarDocs = await memoryManager.vectorSearch(
Expand Down Expand Up @@ -141,7 +142,8 @@ export async function POST(req: Request) {

console.log("result", result);
const chatHistoryRecord = await memoryManager.writeToHistory(
result!.text + "\n"
result!.text + "\n",
companionKey
);
console.log("chatHistoryRecord", chatHistoryRecord);
if (isText) {
Expand Down
20 changes: 12 additions & 8 deletions src/app/api/vicuna13b/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,27 @@ export async function POST(request: Request) {
const seedsplit = presplit[1].split("###ENDSEEDCHAT###");
const seedchat = seedsplit[0];

const memoryManager = new MemoryManager({
const companionKey = {
companionName: name!,
userId: clerkUserId!,
modelName: "vicuna13b",
});
};
const memoryManager = await MemoryManager.getInstance();

const { stream, handlers } = LangChainStream();

const records = await memoryManager.readLatestHistory();
const records = await memoryManager.readLatestHistory(companionKey);
if (records.length === 0) {
await memoryManager.seedChatHistory(seedchat, "\n\n");
await memoryManager.seedChatHistory(seedchat, "\n\n", companionKey);
}
await memoryManager.writeToHistory("### Human: " + prompt + "\n");
await memoryManager.writeToHistory(
"### Human: " + prompt + "\n",
companionKey
);

// Query Pinecone

let recentChatHistory = await memoryManager.readLatestHistory();
let recentChatHistory = await memoryManager.readLatestHistory(companionKey);

// Right now the preamble is included in the similarity search, but that
// shouldn't be an issue
Expand Down Expand Up @@ -140,14 +144,14 @@ export async function POST(request: Request) {
const response = chunks[0];
// const response = chunks.length > 1 ? chunks[0] : chunks[0];

await memoryManager.writeToHistory("### " + response.trim());
await memoryManager.writeToHistory("### " + response.trim(), companionKey);
var Readable = require("stream").Readable;

let s = new Readable();
s.push(response);
s.push(null);
if (response !== undefined && response.length > 1) {
await memoryManager.writeToHistory("### " + response.trim());
await memoryManager.writeToHistory("### " + response.trim(), companionKey);
}

return new StreamingTextResponse(s);
Expand Down
48 changes: 28 additions & 20 deletions src/app/utils/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,12 @@ export type CompanionKey = {
class MemoryManager {
private static instance: MemoryManager;
private history: Redis;
private companionKey: CompanionKey;
private vectorDBClient: PineconeClient | SupabaseClient;

public constructor(companionKey: CompanionKey) {
public constructor() {
this.history = Redis.fromEnv();
this.companionKey = companionKey;
if (process.env.VECTOR_DB === "pinecone") {
this.vectorDBClient = new PineconeClient();
this.vectorDBClient.init({
apiKey: process.env.PINECONE_API_KEY!,
environment: process.env.PINECONE_ENVIRONMENT!,
});
} else {
const auth = {
detectSessionInUrl: false,
Expand All @@ -38,6 +32,15 @@ class MemoryManager {
}
}

public async init() {
if (this.vectorDBClient instanceof PineconeClient) {
await this.vectorDBClient.init({
apiKey: process.env.PINECONE_API_KEY!,
environment: process.env.PINECONE_ENVIRONMENT!,
});
}
}

public async vectorSearch(
recentChatHistory: string,
companionFileName: string
Expand All @@ -47,7 +50,7 @@ class MemoryManager {
const pineconeClient = <PineconeClient>this.vectorDBClient;

const pineconeIndex = pineconeClient.Index(
process.env.PINECONE_INDEX || ""
process.env.PINECONE_INDEX! || ""
);

const vectorStore = await PineconeStore.fromExistingIndex(
Expand Down Expand Up @@ -81,24 +84,25 @@ class MemoryManager {
}
}

public static getInstance(companionKey: CompanionKey): MemoryManager {
public static async getInstance(): Promise<MemoryManager> {
if (!MemoryManager.instance) {
MemoryManager.instance = new MemoryManager(companionKey);
MemoryManager.instance = new MemoryManager();
await MemoryManager.instance.init();
}
return MemoryManager.instance;
}

public getCompanionKey(): string {
return `${this.companionKey.companionName}-${this.companionKey.modelName}-${this.companionKey.userId}`;
private generateRedisCompanionKey(companionKey: CompanionKey): string {
return `${companionKey.companionName}-${companionKey.modelName}-${companionKey.userId}`;
}

public async writeToHistory(text: string) {
if (!this.companionKey || typeof this.companionKey.userId == "undefined") {
public async writeToHistory(text: string, companionKey: CompanionKey) {
if (!companionKey || typeof companionKey.userId == "undefined") {
console.log("Companion key set incorrectly");
return "";
}

const key = this.getCompanionKey();
const key = this.generateRedisCompanionKey(companionKey);
const result = await this.history.zadd(key, {
score: Date.now(),
member: text,
Expand All @@ -107,13 +111,13 @@ class MemoryManager {
return result;
}

public async readLatestHistory(): Promise<string> {
if (!this.companionKey || typeof this.companionKey.userId == "undefined") {
public async readLatestHistory(companionKey: CompanionKey): Promise<string> {
if (!companionKey || typeof companionKey.userId == "undefined") {
console.log("Companion key set incorrectly");
return "";
}

const key = this.getCompanionKey();
const key = this.generateRedisCompanionKey(companionKey);
let result = await this.history.zrange(key, 0, Date.now(), {
byScore: true,
});
Expand All @@ -123,8 +127,12 @@ class MemoryManager {
return recentChats;
}

public async seedChatHistory(seedContent: String, delimiter: string = "\n") {
const key = this.getCompanionKey();
public async seedChatHistory(
seedContent: String,
delimiter: string = "\n",
companionKey: CompanionKey
) {
const key = this.generateRedisCompanionKey(companionKey);
if (await this.history.exists(key)) {
console.log("User already has chat history");
return;
Expand Down

0 comments on commit 088be56

Please sign in to comment.