Skip to content

Commit

Permalink
load model before runChatCompletion
Browse files Browse the repository at this point in the history
  • Loading branch information
hanlily666 committed Jul 1, 2024
1 parent 4a34457 commit a9aac7e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 27 deletions.
50 changes: 31 additions & 19 deletions src/components/Chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
}

const [chat_ui] = useState(new ChatUI(new MLCEngine()))
useEffect(() => {
// TODO: load the actual model the user selects... (we can hard-code for now to a single model)
// selectedConversation.model
const loadModel = async () => {
await chat_ui.loadModel()
}

loadModel()
}, [chat_ui])

const [inputContent, setInputContent] = useState<string>('')

Expand Down Expand Up @@ -164,6 +156,20 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
dispatch: homeDispatch,
} = useContext(HomeContext)

useEffect(() => {
// TODO: load the actual model the user selects... (we can hard-code for now to a single model)
// selectedConversation.model
const loadModel = async () => {
if (selectedConversation && !chat_ui.isModelLoading()) {
await chat_ui.loadModel(selectedConversation)
if (!chat_ui.isModelLoading()) {
console.log('Model has finished loading')
}
}
}
loadModel()
}, [selectedConversation?.model.name])

const [currentMessage, setCurrentMessage] = useState<Message>()
const [autoScrollEnabled, setAutoScrollEnabled] = useState<boolean>(true)
// const [showSettings, setShowSettings] = useState<boolean>(false)
Expand Down Expand Up @@ -345,12 +351,12 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
)

if (imgDescIndex !== -1) {
;(message.content as Content[])[imgDescIndex] = {
; (message.content as Content[])[imgDescIndex] = {
type: 'text',
text: `Image description: ${imgDesc}`,
}
} else {
;(message.content as Content[]).push({
; (message.content as Content[]).push({
type: 'text',
text: `Image description: ${imgDesc}`,
})
Expand Down Expand Up @@ -541,15 +547,21 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {

let response
let reader
console.log("Selected model name:", selectedConversation.model.name);

if (
['TinyLlama-1.1B', 'Llama-3-8B-Instruct-q4f32_1-MLC'].some((prefix) =>
selectedConversation.model.name.startsWith(prefix),
)
) {
// TODO: Call the WebLLM API
response = await chat_ui.runChatCompletion(
chatBody.conversation.messages,
)
console.log("is model loading", chat_ui.isModelLoading())
if (!chat_ui.isModelLoading()) {
console.log("loaded model and initiate chat completions")
response = await chat_ui.runChatCompletion(
chatBody.conversation.messages,
)
}
} else {
// Call the OpenAI API
response = await fetch(endpoint, {
Expand Down Expand Up @@ -839,7 +851,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {

if (imgDescIndex !== -1) {
// Remove the existing image description
;(currentMessage.content as Content[]).splice(imgDescIndex, 1)
; (currentMessage.content as Content[]).splice(imgDescIndex, 1)
}

handleSend(currentMessage, 2, null, tools, enabledDocumentGroups)
Expand Down Expand Up @@ -932,13 +944,13 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {

const statements =
courseMetadata?.example_questions &&
courseMetadata.example_questions.length > 0
courseMetadata.example_questions.length > 0
? courseMetadata.example_questions
: [
'Make a bullet point list of key takeaways from this project.',
'What are the best practices for [Activity or Process] in [Context or Field]?',
'Can you explain the concept of [Specific Concept] in simple terms?',
]
'Make a bullet point list of key takeaways from this project.',
'What are the best practices for [Activity or Process] in [Context or Field]?',
'Can you explain the concept of [Specific Concept] in simple terms?',
]

// Add this function to create dividers with statements
const renderIntroductoryStatements = () => {
Expand Down
28 changes: 20 additions & 8 deletions src/utils/modelProviders/WebLLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export default class ChatUI {
// all requests send to chat are sequentialized
private chatRequestChain: Promise<void> = Promise.resolve()
private chatHistory: ChatCompletionMessageParam[] = []

private modelLoading = false
constructor(engine: MLCEngineInterface) {
this.engine = engine
}
Expand Down Expand Up @@ -127,14 +127,26 @@ export default class ChatUI {
this.chatLoaded = false
}

async loadModel() {
console.log('staritng to load model')
// TODO: don't hard-code this model name
// const selectedModel = 'Llama-3-8B-Instruct-q4f32_1-MLC'
const selectedModel = 'TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC-1k'
await this.engine.reload(selectedModel)
console.log('done loading model')
async loadModel(selectedConversation: { model: { name: string } }) {
console.log('starting to load model')
this.modelLoading = true // Set loading state to true
const selectedModel = selectedConversation.model.name
try {
await this.engine.reload(selectedModel)
console.log('done loading model')
} catch (error) {
console.error('Error loading model:', error)
} finally {
this.modelLoading = false // Set loading state to false
console.log('model has been loaded modelLoading set to false')
}
}
isModelLoading() {
console.log('ismodelloading,', this.modelLoading)
return this.modelLoading
}


async runChatCompletion(messages: Message[]) {
let curMessage = ''
let usage: CompletionUsage | undefined = undefined
Expand Down

0 comments on commit a9aac7e

Please sign in to comment.