From a9aac7e5662d966aa289db77cab376526f7f02e7 Mon Sep 17 00:00:00 2001 From: hanlily666 Date: Mon, 1 Jul 2024 09:53:03 -0700 Subject: [PATCH] load model before runChatCompletion --- src/components/Chat/Chat.tsx | 50 ++++++++++++++++++------------ src/utils/modelProviders/WebLLM.ts | 28 ++++++++++++----- 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/src/components/Chat/Chat.tsx b/src/components/Chat/Chat.tsx index f6157d7c1..5886d4708 100644 --- a/src/components/Chat/Chat.tsx +++ b/src/components/Chat/Chat.tsx @@ -97,15 +97,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { } const [chat_ui] = useState(new ChatUI(new MLCEngine())) - useEffect(() => { - // TODO: load the actual model the user selects... (we can hard-code for now to a single model) - // selectedConversation.model - const loadModel = async () => { - await chat_ui.loadModel() - } - loadModel() - }, [chat_ui]) const [inputContent, setInputContent] = useState('') @@ -164,6 +156,20 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { dispatch: homeDispatch, } = useContext(HomeContext) + useEffect(() => { + // TODO: load the actual model the user selects... (we can hard-code for now to a single model) + // selectedConversation.model + const loadModel = async () => { + if (selectedConversation && !chat_ui.isModelLoading()) { + await chat_ui.loadModel(selectedConversation) + if (!chat_ui.isModelLoading()) { + console.log('Model has finished loading') + } + } + } + loadModel() + }, [selectedConversation?.model.name]) + const [currentMessage, setCurrentMessage] = useState() const [autoScrollEnabled, setAutoScrollEnabled] = useState(true) // const [showSettings, setShowSettings] = useState(false) @@ -345,12 +351,12 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { ) if (imgDescIndex !== -1) { - ;(message.content as Content[])[imgDescIndex] = { + ; (message.content as Content[])[imgDescIndex] = { type: 'text', text: `Image description: ${imgDesc}`, } } else { - ;(message.content as Content[]).push({ + ; (message.content as Content[]).push({ type: 'text', text: `Image description: ${imgDesc}`, }) @@ -541,15 +547,21 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { let response let reader + console.log("Selected model name:", selectedConversation.model.name); + if ( ['TinyLlama-1.1B', 'Llama-3-8B-Instruct-q4f32_1-MLC'].some((prefix) => selectedConversation.model.name.startsWith(prefix), ) ) { // TODO: Call the WebLLM API - response = await chat_ui.runChatCompletion( - chatBody.conversation.messages, - ) + console.log("is model loading", chat_ui.isModelLoading()) + if (!chat_ui.isModelLoading()) { + console.log("loaded model and initiate chat completions") + response = await chat_ui.runChatCompletion( + chatBody.conversation.messages, + ) + } } else { // Call the OpenAI API response = await fetch(endpoint, { @@ -839,7 +851,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { if (imgDescIndex !== -1) { // Remove the existing image description - ;(currentMessage.content as Content[]).splice(imgDescIndex, 1) + ; (currentMessage.content as Content[]).splice(imgDescIndex, 1) } handleSend(currentMessage, 2, null, tools, enabledDocumentGroups) @@ -932,13 +944,13 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { const statements = courseMetadata?.example_questions && - courseMetadata.example_questions.length > 0 + courseMetadata.example_questions.length > 0 ? courseMetadata.example_questions : [ - 'Make a bullet point list of key takeaways from this project.', - 'What are the best practices for [Activity or Process] in [Context or Field]?', - 'Can you explain the concept of [Specific Concept] in simple terms?', - ] + 'Make a bullet point list of key takeaways from this project.', + 'What are the best practices for [Activity or Process] in [Context or Field]?', + 'Can you explain the concept of [Specific Concept] in simple terms?', + ] // Add this function to create dividers with statements const renderIntroductoryStatements = () => { diff --git a/src/utils/modelProviders/WebLLM.ts b/src/utils/modelProviders/WebLLM.ts index 4b284292d..3ed4e3425 100644 --- a/src/utils/modelProviders/WebLLM.ts +++ b/src/utils/modelProviders/WebLLM.ts @@ -50,7 +50,7 @@ export default class ChatUI { // all requests send to chat are sequentialized private chatRequestChain: Promise = Promise.resolve() private chatHistory: ChatCompletionMessageParam[] = [] - + private modelLoading = false constructor(engine: MLCEngineInterface) { this.engine = engine } @@ -127,14 +127,26 @@ export default class ChatUI { this.chatLoaded = false } - async loadModel() { - console.log('staritng to load model') - // TODO: don't hard-code this model name - // const selectedModel = 'Llama-3-8B-Instruct-q4f32_1-MLC' - const selectedModel = 'TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC-1k' - await this.engine.reload(selectedModel) - console.log('done loading model') + async loadModel(selectedConversation: { model: { name: string } }) { + console.log('starting to load model') + this.modelLoading = true // Set loading state to true + const selectedModel = selectedConversation.model.name + try { + await this.engine.reload(selectedModel) + console.log('done loading model') + } catch (error) { + console.error('Error loading model:', error) + } finally { + this.modelLoading = false // Set loading state to false + console.log('model has been loaded modelLoading set to false') + } } + isModelLoading() { + console.log('ismodelloading,', this.modelLoading) + return this.modelLoading + } + + async runChatCompletion(messages: Message[]) { let curMessage = '' let usage: CompletionUsage | undefined = undefined