Skip to content

Commit

Permalink
Refactor ChatService for simple append and revise handling
Browse files Browse the repository at this point in the history
  • Loading branch information
codewithcheese committed Jul 29, 2024
1 parent 21e3f49 commit 28fe86d
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 136 deletions.
128 changes: 69 additions & 59 deletions src/lib/chat-service.svelte.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,6 @@ const getStreamedResponse = async (
sendExtraMessageFields: boolean | undefined,
fetch: FetchFunction | undefined,
) => {
// Do an optimistic update to the chat state to show the updated messages
// immediately.
// mutate(chatRequest.messages);

const constructedMessagesPayload = sendExtraMessageFields
? chatRequest.messages
: chatRequest.messages.map(
Expand Down Expand Up @@ -222,19 +218,16 @@ const getStreamedResponse = async (
};

export class ChatService {
version: number;
messages: ChatMessage[] = $state([]);
error: undefined | Error = $state(undefined);
input: string = $state("");
attachments: MessageAttachment[] = [];
isLoading: boolean | undefined = $state(undefined);
data: JSONValue[] | undefined = $state(undefined);
metadata?: Object;
hasChanges: boolean = $state(false);
hasEdits: boolean = $state(false);

private id: string;
private version: number;
private api: string;
private mode: { type: "edit"; index: number } | { type: "append" };
private generateId: IdGenerator;
private abortController: AbortController | null;
private sendExtraMessageFields: boolean | undefined;
Expand All @@ -251,15 +244,14 @@ export class ChatService {
private fetch: FetchFunction | undefined;

// initial snapshot of messages before any changes
private initialSate: string | undefined;
private initialState: string | undefined;
private initialLength: number;

constructor({
id,
version,
api = "/api/chat",
initialMessages = [],
initialInput = "",
mode = { type: "append" },
sendExtraMessageFields,
experimental_onFunctionCall,
experimental_onToolCall,
Expand All @@ -278,7 +270,6 @@ export class ChatService {
this.id = id;
this.version = version;
this.api = api;
this.mode = mode;
this.sendExtraMessageFields = sendExtraMessageFields;
this.experimental_onFunctionCall = experimental_onFunctionCall;
this.experimental_onToolCall = experimental_onToolCall;
Expand All @@ -298,39 +289,50 @@ export class ChatService {
this.abortController = new AbortController();
this.generateId = generateId;
this.fetch = fetch;
console.log("mode", mode);
if (mode.type === "edit" && !initialMessages[mode.index]) {
throw new Error(`Message to edit at index ${mode.index} not found`);
}
this.input = mode.type === "edit" ? initialMessages[mode.index].content : initialInput;
this.initialLength = initialMessages.length;

$effect(() => {
const currentState = JSON.stringify(this.messages);
// on startup, check cache for changes
if (!this.initialSate) {
this.initialSate = currentState;
const cachedState = localStorage.getItem(this.cacheKey);
if (cachedState) {
// if the last user message is empty, remove it
const messages = JSON.parse(cachedState);
const lastMessage = messages[messages.length - 1];
if (lastMessage.role === "user" && lastMessage.content.trim() === "") {
messages.pop();
}
this.messages = messages;
this.hasChanges = true;
}
return;
if (!this.initialState) {
this.initialState = currentState;
console.log("initialState", this.initialState);
this.tryLoadFromCache();
}
untrack(() => {
this.hasChanges = this.initialSate !== currentState;
if (this.hasChanges) {
// has edits if initial messages have been modified
this.hasEdits =
JSON.stringify(this.messages.slice(0, this.initialLength)) !== this.initialState;
if (this.hasEdits) {
console.log(
"has edits",
JSON.stringify(this.messages.slice(0, this.initialLength)),
this.initialState,
);
}
if (currentState !== this.initialState) {
// cache changes to local storage
localStorage.setItem(this.cacheKey, JSON.stringify(this.messages));
if (!this.isLoading) {
console.log("caching changes", currentState, this.initialState);
localStorage.setItem(this.cacheKey, JSON.stringify(this.messages));
}
}
});
});
}

tryLoadFromCache() {
const cachedState = localStorage.getItem(this.cacheKey);
if (cachedState) {
// if the last user message is empty, remove it
const messages = JSON.parse(cachedState);
const lastMessage = messages[messages.length - 1];
if (lastMessage.role === "user" && lastMessage.content.trim() === "") {
messages.pop();
}
this.messages = messages;
}
}

get key() {
return `${this.api}|${this.id}`;
}
Expand Down Expand Up @@ -435,30 +437,38 @@ export class ChatService {
};
}

submit(requestBody: Record<string, any>) {
if (this.mode.type === "edit") {
return this.revise({ options: { body: requestBody } });
} else {
const inputContent = this.input;
const inputAttachments = this.attachments;
if (!inputContent) return;
this.input = "";
return this.append(
{
content: inputContent,
role: "user",
attachments: inputAttachments,
},
{ options: { body: requestBody } },
);
}
submit(requestOptions: ChatRequestOptions = {}) {
return this.triggerRequest(this.createChatRequest(requestOptions));
// if (this.mode.type === "edit") {
// return this.revise({ options: { body: requestBody } });
// } else {
// const inputContent = this.input;
// const inputAttachments = this.attachments;
// if (!inputContent) return;
// this.input = "";
// return this.append(
// {
// content: inputContent,
// role: "user",
// attachments: inputAttachments,
// },
// { options: { body: requestBody } },
// );
// }
}

resetChanges() {
localStorage.removeItem(this.cacheKey);
this.initialSate = JSON.stringify(this.messages);
this.hasChanges = false;
}
handleOnFinish = (message: ChatMessage) => {
if (this.onFinish) {
try {
this.onFinish(message as ChatMessage);
// clear cache
localStorage.removeItem(this.cacheKey);
} catch (e: unknown) {
console.error(e);
this.onError && this.onError(e instanceof Error ? e : new Error("Unknown error"));
}
}
};

get cacheKey() {
return `chat-${this.id}-v${this.version}`;
Expand Down Expand Up @@ -496,7 +506,7 @@ export class ChatService {
this.abortController,
this.generateId,
this.streamMode,
this.onFinish,
this.handleOnFinish,
this.onResponse,
this.sendExtraMessageFields,
this.fetch,
Expand Down
15 changes: 11 additions & 4 deletions src/routes/(app)/$data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@ import { invalidate } from "$app/navigation";

export async function newChat() {
const id = nanoid(10);
await useDb().insert(chatTable).values({
id: id,
name: "Untitled",
prompt: "",
await useDb().transaction(async (tx) => {
await tx.insert(chatTable).values({
id: id,
name: "Untitled",
prompt: "",
});
await tx.insert(revisionTable).values({
id: nanoid(10),
version: 1,
chatId: id,
});
});
await invalidate("view:chats");
return id;
Expand Down
4 changes: 2 additions & 2 deletions src/routes/(app)/chat/[id]/$data.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
getLatestRevision,
getModelService,
createRevision,
appendMessage,
appendMessages,
createRevision,
isTab,
tabRouteId,
Expand Down Expand Up @@ -216,7 +216,7 @@ describe("appendMessage", () => {
content: "New message content",
revisionId: "revision2",
};
await appendMessage(message, []);
await appendMessages(message, []);

const messages = await db.query.messageTable.findMany({
where: eq(schema.messageTable.revisionId, "revision2"),
Expand Down
53 changes: 26 additions & 27 deletions src/routes/(app)/chat/[id]/$data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,39 +148,38 @@ export async function interpolateDocuments(prompt: string) {
return interpolatedPrompt;
}

export async function appendMessage(
message: Omit<InsertMessage, "index" | "createdAt">,
attachments: MessageAttachment[] = [],
) {
export async function appendMessages(revisionId: string, messages: ChatMessage[]) {
try {
return await useDb().transaction(async (tx) => {
await tx
.insert(messageTable)
.values({
id: message.id,
revisionId: message.revisionId,
role: message.role,
content: message.content,
index: sql`(SELECT COUNT(id) FROM ${messageTable} WHERE ${eq(messageTable.revisionId, message.revisionId)})`,
})
.execute();
for (const attachment of attachments) {
const documentId = attachment.id;
for (const message of messages) {
await tx
.insert(documentTable)
.insert(messageTable)
.values({
id: documentId,
type: attachment.type,
name: `Pasted ${new Date().toLocaleString()}`,
description: "",
content: attachment.content,
id: message.id,
revisionId,
role: message.role,
content: message.content,
index: sql`(SELECT COUNT(id) FROM ${messageTable} WHERE ${eq(messageTable.revisionId, revisionId)})`,
})
.execute();
await tx.insert(attachmentTable).values({
id: nanoid(10),
documentId,
messageId: message.id,
});
for (const attachment of message.attachments) {
const documentId = attachment.id;
await tx
.insert(documentTable)
.values({
id: documentId,
type: attachment.type,
name: `Pasted ${new Date().toLocaleString()}`,
description: "",
content: attachment.content,
})
.execute();
await tx.insert(attachmentTable).values({
id: nanoid(10),
documentId,
messageId: message.id,
});
}
}
});
} catch (e) {
Expand Down
Loading

0 comments on commit 28fe86d

Please sign in to comment.