Skip to content

Commit

Permalink
Add support for pasting images
Browse files Browse the repository at this point in the history
  • Loading branch information
codewithcheese committed Aug 6, 2024
1 parent 362b3a9 commit 1d2ac99
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 77 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"@ai-sdk/google-vertex": "^0.0.19",
"@ai-sdk/mistral": "^0.0.26",
"@ai-sdk/openai": "^0.0.10",
"@ai-sdk/ui-utils": "^0.0.9",
"@ai-sdk/ui-utils": "^0.0.24",
"@fontsource-variable/inter": "^5.0.18",
"@lexical/history": "^0.16.0",
"@lexical/plain-text": "^0.16.0",
Expand Down
46 changes: 2 additions & 44 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 38 additions & 11 deletions src/lib/chat-service.svelte.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type {
IdGenerator,
JSONValue,
Message as AIMessage,
Attachment as AIAttachment,
ToolCallHandler,
} from "@ai-sdk/ui-utils";
import { callChatApi, processChatStream } from "@ai-sdk/ui-utils";
Expand All @@ -21,7 +22,7 @@ export type MessageAttachment = {
type: string;
name: string;
content: string;
attributes: {};
attributes: Record<string, any>;
};
export type ChatMessage = AIMessage & {
attachments: MessageAttachment[];
Expand Down Expand Up @@ -131,7 +132,7 @@ export type ChatOptions = {
*/
sendExtraMessageFields?: boolean;
/** Stream mode (default to "stream-data") */
streamMode?: "stream-data" | "text";
streamProtocol?: "data" | "text";
/**
Custom fetch implementation. You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.
Expand All @@ -153,7 +154,7 @@ const getStreamedResponse = async (
previousMessages: ChatMessage[],
abortControllerRef: AbortController | null,
generateId: IdGenerator,
streamMode: "stream-data" | "text" | undefined,
streamProtocol: "data" | "text" | undefined,
onFinish: ((message: ChatMessage) => void) | undefined,
onResponse: ((response: Response) => void | Promise<void>) | undefined,
sendExtraMessageFields: boolean | undefined,
Expand All @@ -162,10 +163,21 @@ const getStreamedResponse = async (
const constructedMessagesPayload = sendExtraMessageFields
? chatRequest.messages
: chatRequest.messages.map(
({ role, content, name, data, annotations, function_call, tool_calls, tool_call_id }) => ({
({
role,
content,
experimental_attachments,
name,
data,
annotations,
function_call,
tool_calls,
tool_call_id,
}) => ({
role,
content,
...(name !== undefined && { name }),
...(experimental_attachments !== undefined && { experimental_attachments }),
...(data !== undefined && { data }),
...(annotations !== undefined && { annotations }),
// outdated function/tool call handling (TODO deprecate):
Expand All @@ -175,6 +187,9 @@ const getStreamedResponse = async (
}),
);

console.log("chatRequest", chatRequest);
console.log("constructedMessagesPayload", constructedMessagesPayload);

return await callChatApi({
api,
body: {
Expand All @@ -195,7 +210,7 @@ const getStreamedResponse = async (
tool_choice: chatRequest.tool_choice,
}),
},
streamMode,
streamProtocol,
credentials: extraMetadata.credentials,
headers: {
...extraMetadata.headers,
Expand Down Expand Up @@ -233,7 +248,7 @@ export class ChatService {
private sendExtraMessageFields: boolean | undefined;
private experimental_onFunctionCall: FunctionCallHandler | undefined;
private experimental_onToolCall: ToolCallHandler | undefined;
private streamMode: "stream-data" | "text" | undefined;
private streamProtocol: "data" | "text" | undefined;
private onLoading: (() => void) | undefined;
private onAppend: (() => void) | undefined;
private onRevision: (() => void) | undefined;
Expand All @@ -255,7 +270,7 @@ export class ChatService {
sendExtraMessageFields,
experimental_onFunctionCall,
experimental_onToolCall,
streamMode,
streamProtocol,
onLoading,
onAppend,
onRevision,
Expand All @@ -273,7 +288,7 @@ export class ChatService {
this.sendExtraMessageFields = sendExtraMessageFields;
this.experimental_onFunctionCall = experimental_onFunctionCall;
this.experimental_onToolCall = experimental_onToolCall;
this.streamMode = streamMode;
this.streamProtocol = streamProtocol;
this.onLoading = onLoading;
this.onAppend = onAppend;
this.onRevision = onRevision;
Expand Down Expand Up @@ -388,9 +403,13 @@ export class ChatService {
// inline attachments into message content surrounded by tags
const messages = this.messages.map((message) => {
let content = message.content;
let experimental_attachments: AIAttachment[] | undefined = undefined;
if (message.attachments) {
const attachmentContent = message.attachments
.filter((attachment) => attachment.content.trim() !== "") // Exclude empty attachments
.filter(
(attachment) =>
!attachment.type.startsWith("image/") && attachment.content.trim() !== "",
) // Exclude empty attachments
.map((attachment) => {
const escapedContent = attachment.content.replace(/</g, "&lt;").replace(/>/g, "&gt;");
const attributes = Object.entries(attachment.attributes)
Expand All @@ -402,8 +421,16 @@ export class ChatService {
if (attachmentContent.length > 0) {
content = `${attachmentContent}\n\n${content}`;
}
experimental_attachments = message.attachments
.filter((attachment) => attachment.type.startsWith("image/"))
.map((attachment) => ({
name: attachment.name,
contentType: attachment.type,
url: attachment.content,
}));
console.log("experimental_attachments", experimental_attachments);
}
return { ...message, content };
return { ...message, content, experimental_attachments };
});
return {
messages,
Expand Down Expand Up @@ -467,7 +494,7 @@ export class ChatService {
this.messages,
this.abortController,
this.generateId,
this.streamMode,
this.streamProtocol,
this.handleOnFinish,
() => {},
this.sendExtraMessageFields,
Expand Down
29 changes: 29 additions & 0 deletions src/lib/image.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
export async function compressImage(file: Blob, quality: number = 0.7): Promise<Blob> {
return new Promise((resolve, reject) => {
const img = new Image();
img.onload = () => {
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
if (!ctx) {
reject(new Error("Failed to get canvas context"));
return;
}
canvas.width = img.width;
canvas.height = img.height;
ctx.drawImage(img, 0, 0);
canvas.toBlob(
(blob) => {
if (blob) {
resolve(blob);
} else {
reject(new Error("Failed to compress image"));
}
},
"image/jpeg",
quality,
);
};
img.onerror = () => reject(new Error("Failed to load image"));
img.src = URL.createObjectURL(file);
});
}
6 changes: 4 additions & 2 deletions src/routes/(app)/api/chat/+server.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { StreamingTextResponse, streamText } from "ai";
import { convertToCoreMessages, StreamingTextResponse, streamText } from "ai";
import type { RequestHandler } from "./$types";
import { createOpenAI } from "@ai-sdk/openai";
import { createAnthropic } from "@ai-sdk/anthropic";
Expand All @@ -24,6 +24,8 @@ export const POST = (async ({ request }) => {
});
}

console.log("chat request", messages, sdkId, apiKey, baseURL, modelName);

let provider;
switch (sdkId) {
case "azure":
Expand Down Expand Up @@ -53,7 +55,7 @@ export const POST = (async ({ request }) => {
try {
const result = await streamText({
model: provider(modelName),
messages,
messages: convertToCoreMessages(messages),
});
return new StreamingTextResponse(result.toAIStream());
} catch (e: unknown) {
Expand Down
11 changes: 7 additions & 4 deletions src/routes/(app)/chat/[id]/Attachment.svelte
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
<script lang="ts">
import { XIcon } from "lucide-svelte";
import { Button } from "@/components/ui/button";
import { toTitleCase } from "$lib/string";
type Props = {
type: string;
Expand All @@ -12,7 +11,7 @@
</script>

<div
class="h-32 min-h-28 w-28 min-w-28 cursor-pointer overflow-hidden rounded-lg border border-gray-300 bg-white text-gray-700"
class="h-32 min-h-28 w-28 min-w-28 overflow-hidden rounded-lg border border-gray-300 bg-white text-gray-700"
>
<div class="relative h-full border-0">
{#if onRemove}
Expand All @@ -26,13 +25,17 @@
{/if}
<div class="flex h-full flex-col gap-1 p-2">
<div class="relative flex-1 overflow-hidden">
<p class="overflow-y-hidden break-words text-sm">{content}</p>
{#if type.startsWith("image/")}
<img src={content} alt="Pasted" class="w-full" />
{:else}
<p class="overflow-y-hidden break-words text-sm">{content}</p>
{/if}
<div
class="absolute bottom-0 left-0 right-0 h-8 bg-gradient-to-t from-white to-transparent"
></div>
</div>
<div class="name sticky bottom-0 text-sm font-semibold text-gray-700">
{toTitleCase(type)}
{type}
</div>
</div>
</div>
Expand Down
Loading

0 comments on commit 1d2ac99

Please sign in to comment.