Skip to content

Commit

Permalink
Adding support for GPT4V image uploads (#45)
Browse files Browse the repository at this point in the history
* added support for image uploads on the chat ui

* fixed error message for invalid image uploads

* added upload image to s3 functionality

* added uuid file naming for image uploads

* removed unused import in ChatInput

* removed unused imports

* fixed vercel error

* added export default to fix npm build import error

* Initial work towards image functionality locally

* refined local image previews

* fixed padding in text area and fixed typing bug

* slightly rounded edges of text input

* image padding fix

* added functionality for handling multiple images

* message structure fix

* fixed message structure for openai api calls

* fixxed vercel error

* Changes to add support for GPT4 Vision API (without image based retrieval)

* Fixed image rendering on Chat scree, fixed previews based on website theme

* Retrieval using image description

* Refactored handleSend method for better readability

* Minor cleanup, nothing major

* Added logic to validate and regenerate preSigned urls, propogate the same to update the messages.

* fixed img previews resizing of long vertical imgs

* Improved dropzone for images in chat

* Bugfixes on local storage updates and presigned link validation

* Bug fixes and feedback: Handling regenerate gracefully, hiding edit button, removing image preview title

* Minor prompt improvement

* Removing duplicate import added while resolving conflicts

* Removed the wrong import earlier, correcting it

* Added accordion for image description

* Build fix

* Adding a deep equality check to handle infinite loop in memo caused by strict equality check on objects instead of values

* Minor bugfixes with dependencies and conditions

* Build fix

* Dependency removal broke switching to an older conversation, reverting the change. Some more styling and error handling changes.

* Adding conditional checks for file drag events based on GPT-4 Vision model

* Improve: new conversation defaults to last convo's model, full error handling. Very nice

* Improve: fix padding on chat input box with/without image input icon

* Image filetype support, 100% of what openai allows, ignore caps

* Image filetype support; one more push

* Delete .vscode/settings.json

* Update Image Description header for readability

* Fix GPT4-V from using too many tokens. it’s a 40k TMP limit for some people, set limit to 15k

* Rename models to even better human-readable names

* Properly await image description to be fully generated

---------

Co-authored-by: Rohan Marwaha <[email protected]>
Co-authored-by: Kastan Day <[email protected]>
  • Loading branch information
3 people authored Dec 7, 2023
1 parent d608891 commit 5ab5a06
Show file tree
Hide file tree
Showing 25 changed files with 1,447 additions and 360 deletions.
212 changes: 181 additions & 31 deletions src/components/Chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ import {
useRef,
useState,
} from 'react'
import toast from 'react-hot-toast'
import { Button, Container, Text, Title } from '@mantine/core'
import { Button, Text } from '@mantine/core'
import { useTranslation } from 'next-i18next'

import { getEndpoint } from '@/utils/app/api'
import {
saveConversation,
saveConversations,
updateConversation,
} from '@/utils/app/conversation'
import { throttle } from '@/utils/data/throttle'

Expand All @@ -46,6 +44,7 @@ import {
type ChatBody,
type Conversation,
type Message,
Content,
} from '@/types/chat'
import { type Plugin } from '@/types/plugin'

Expand All @@ -55,7 +54,7 @@ import { ChatInput } from './ChatInput'
import { ChatLoader } from './ChatLoader'
import { ErrorMessageDiv } from './ErrorMessageDiv'
import { MemoizedChatMessage } from './MemoizedChatMessage'
import { fetchPresignedUrl } from '~/components/UIUC-Components/ContextCards'
import { fetchPresignedUrl } from '~/utils/apiUtils'

import { type CourseMetadata } from '~/types/courseMetadata'

Expand All @@ -75,7 +74,6 @@ import ChatNavbar from '../UIUC-Components/navbars/ChatNavbar'
import { notifications } from '@mantine/notifications'
import { Montserrat } from 'next/font/google'
import { montserrat_heading, montserrat_paragraph } from 'fonts'
import { NextResponse } from 'next/server'

const montserrat_med = Montserrat({
weight: '500',
Expand Down Expand Up @@ -114,6 +112,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
loading,
prompts,
showModelSettings,
isImg2TextLoading
},
handleUpdateConversation,
dispatch: homeDispatch,
Expand Down Expand Up @@ -173,14 +172,90 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
}
}

const handleImageContent = async (message: Message, endpoint: string, updatedConversation: Conversation, searchQuery: string, controller: AbortController) => {
const imageContent = (message.content as Content[]).filter(content => content.type === 'image_url');
if (imageContent.length > 0) {
homeDispatch({ field: 'isImg2TextLoading', value: true })
const chatBody: ChatBody = {
model: updatedConversation.model,
messages: [
{
...message,
content: [
...imageContent,
{ type: 'text', text: 'Provide detailed description of the image(s) focusing on any text (OCR information), distinct objects, colors, and actions depicted. Include contextual information, subtle details, and specific terminologies relevant for semantic document retrieval.' }
]
}
],
key: courseMetadata?.openai_api_key && courseMetadata?.openai_api_key != '' ? courseMetadata.openai_api_key : apiKey,
prompt: updatedConversation.prompt,
temperature: updatedConversation.temperature,
course_name: getCurrentPageName(),
stream: false,
};

try {
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(chatBody),
signal: controller.signal,
});

if (!response.ok) {
const final_response = await response.json();
homeDispatch({ field: 'loading', value: false });
homeDispatch({ field: 'messageIsStreaming', value: false });
throw new Error(final_response.message);
}

const data = await response.json();
const imgDesc = data.choices[0].message.content || '';

searchQuery += ` Image description: ${imgDesc}`;

const imgDescIndex = (message.content as Content[]).findIndex(content => content.type === 'text' && (content.text as string).startsWith('Image description: '));

if (imgDescIndex !== -1) {
(message.content as Content[])[imgDescIndex] = { type: 'text', text: `Image description: ${imgDesc}` };
} else {
(message.content as Content[]).push({ type: 'text', text: `Image description: ${imgDesc}` });
}
} catch (error) {
console.error('Error in chat.tsx running onResponseCompletion():', error);
controller.abort();
} finally {
homeDispatch({ field: 'isImg2TextLoading', value: false })
};
}
return searchQuery;
}

const handleContextSearch = async (message: Message, selectedConversation: Conversation, searchQuery: string) => {
if (getCurrentPageName() != 'gpt4') {
const token_limit = OpenAIModels[selectedConversation?.model.id as OpenAIModelID].tokenLimit
await fetchContexts(getCurrentPageName(), searchQuery, token_limit).then((curr_contexts) => {
message.contexts = curr_contexts as ContextWithMetadata[]
})
}
}

// THIS IS WHERE MESSAGES ARE SENT.
const handleSend = useCallback(
async (message: Message, deleteCount = 0, plugin: Plugin | null = null) => {

setCurrentMessage(message)
// New way with React Context API
// TODO: MOVE THIS INTO ChatMessage
// console.log('IN handleSend: ', message)
// setSearchQuery(message.content)
const searchQuery = message.content
let searchQuery = Array.isArray(message.content)
? message.content.map((content) => content.text).join(' ')
: message.content;

// console.log("QUERY: ", searchQuery)

if (selectedConversation) {
let updatedConversation: Conversation
Expand All @@ -206,21 +281,18 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
homeDispatch({ field: 'loading', value: true })
homeDispatch({ field: 'messageIsStreaming', value: true })

// Run context search, attach to Message object.
if (getCurrentPageName() != 'gpt4') {
// THE ONLY place we fetch contexts (except ExtremePromptStuffing is still in api/chat.ts)
const token_limit =
OpenAIModels[selectedConversation?.model.id as OpenAIModelID]
.tokenLimit
await fetchContexts(
getCurrentPageName(),
searchQuery,
token_limit,
).then((curr_contexts) => {
message.contexts = curr_contexts as ContextWithMetadata[]
})
const endpoint = getEndpoint(plugin);

const controller = new AbortController()

// Run image to text conversion, attach to Message object.
if (Array.isArray(message.content)) {
searchQuery = await handleImageContent(message, endpoint, updatedConversation, searchQuery, controller);
}

// Run context search, attach to Message object.
await handleContextSearch(message, selectedConversation, searchQuery);

const chatBody: ChatBody = {
model: updatedConversation.model,
messages: updatedConversation.messages,
Expand All @@ -232,8 +304,9 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
prompt: updatedConversation.prompt,
temperature: updatedConversation.temperature,
course_name: getCurrentPageName(),
stream: true
}
const endpoint = getEndpoint(plugin) // THIS is where we could support EXTREME prompt stuffing.

let body
if (!plugin) {
body = JSON.stringify(chatBody)
Expand All @@ -248,7 +321,8 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
?.requiredKeys.find((key) => key.key === 'GOOGLE_CSE_ID')?.value,
})
}
const controller = new AbortController()

// This is where we call the OpenAI API
const response = await fetch(endpoint, {
method: 'POST',
headers: {
Expand Down Expand Up @@ -301,13 +375,17 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
}
if (!plugin) {
if (updatedConversation.messages.length === 1) {
const { content } = message
const { content } = message;
// Use only texts instead of content itself
const contentText = Array.isArray(content)
? content.map((content) => content.text).join(' ')
: content;
const customName =
content.length > 30 ? content.substring(0, 30) + '...' : content
contentText.length > 30 ? contentText.substring(0, 30) + '...' : contentText;
updatedConversation = {
...updatedConversation,
name: customName,
}
};
}
homeDispatch({ field: 'loading', value: false })
const reader = data.getReader()
Expand Down Expand Up @@ -390,6 +468,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
updatedConversations.push(updatedConversation)
}
homeDispatch({ field: 'conversations', value: updatedConversations })
console.log('updatedConversations: ', updatedConversations)
saveConversations(updatedConversations)
homeDispatch({ field: 'messageIsStreaming', value: false })
} else {
Expand Down Expand Up @@ -434,6 +513,20 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
],
)

const handleRegenerate = useCallback(() => {
if (currentMessage && Array.isArray(currentMessage.content)) {
// Find the index of the existing image description
const imgDescIndex = (currentMessage.content as Content[]).findIndex(content => content.type === 'text' && (content.text as string).startsWith('Image description: '));

if (imgDescIndex !== -1) {
// Remove the existing image description
(currentMessage.content as Content[]).splice(imgDescIndex, 1);
}

handleSend(currentMessage, 2, null);
}
}, [currentMessage, handleSend]);

const scrollToBottom = useCallback(() => {
if (autoScrollEnabled) {
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' })
Expand Down Expand Up @@ -575,6 +668,64 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
</div>
)
}
// Inside Chat function before the return statement
const renderMessageContent = (message: Message) => {
if (Array.isArray(message.content)) {
return (
<>
{message.content.map((content, index) => {
if (content.type === 'image' && content.image_url) {
return <img key={index} src={content.image_url.url} alt="Uploaded content" />;
}
return <span key={index}>{content.text}</span>;
})}
</>
);
}
return <span>{message.content}</span>;
};

const updateMessages = (updatedMessage: Message, messageIndex: number) => {
return selectedConversation?.messages.map((message, index) => {
return index === messageIndex ? updatedMessage : message;
});
};

const updateConversations = (updatedConversation: Conversation) => {
return conversations.map((conversation) =>
conversation.id === selectedConversation?.id ? updatedConversation : conversation
);
};

const onImageUrlsUpdate = useCallback((updatedMessage: Message, messageIndex: number) => {
if (!selectedConversation) {
throw new Error("No selected conversation found");
}

const updatedMessages = updateMessages(updatedMessage, messageIndex);
if (!updatedMessages) {
throw new Error("Failed to update messages");
}

const updatedConversation = {
...selectedConversation,
messages: updatedMessages,
};

homeDispatch({
field: 'selectedConversation',
value: updatedConversation,
});

const updatedConversations = updateConversations(updatedConversation);
if (!updatedConversations) {
throw new Error("Failed to update conversations");
}

homeDispatch({ field: 'conversations', value: updatedConversations });
saveConversations(updatedConversations);
}, [selectedConversation, conversations]);


return (
<div className="overflow-wrap relative flex h-screen w-full flex-col overflow-hidden bg-white dark:bg-[#15162c]">
Expand Down Expand Up @@ -671,14 +822,16 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
<MemoizedChatMessage
key={index}
message={message}
contentRenderer={renderMessageContent}
messageIndex={index}
onEdit={(editedMessage) => {
setCurrentMessage(editedMessage)
// setCurrentMessage(editedMessage)
handleSend(
editedMessage,
selectedConversation?.messages.length - index,
)
}}
onImageUrlsUpdate={onImageUrlsUpdate}
/>
))}
{loading && <ChatLoader />}
Expand All @@ -694,18 +847,15 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
stopConversationRef={stopConversationRef}
textareaRef={textareaRef}
onSend={(message, plugin) => {
setCurrentMessage(message)
// setCurrentMessage(message)
handleSend(message, 0, plugin)
}}
onScrollDownClick={handleScrollDown}
onRegenerate={() => {
if (currentMessage) {
handleSend(currentMessage, 2, null)
}
}}
onRegenerate={handleRegenerate}
showScrollDownButton={showScrollDownButton}
inputContent={inputContent}
setInputContent={setInputContent}
courseName={getCurrentPageName()}
/>
{/* </div> */}
</>
Expand Down
Loading

0 comments on commit 5ab5a06

Please sign in to comment.