Skip to content

Commit

Permalink
MistralAI function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
nbonamy committed May 3, 2024
1 parent e88e26f commit f6a2a14
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 31 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ To use Internet search you need a [Tavily API key](https://app.tavily.com/home).

## DONE

- [x] MistralAI function calling
- [x] Auto-update
- [x] History date sections
- [x] Multiple selection delete
Expand Down
1 change: 1 addition & 0 deletions src/plugins/tavily.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export default class extends Plugin {
include_answer: true,
//include_raw_content: true,
})
//console.log('Tavily response:', response)
return response
} catch (error) {
return { error: error.message }
Expand Down
30 changes: 25 additions & 5 deletions src/services/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ export default class LlmEngine {
throw new Error('Not implemented')
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
getPluginAsTool(plugin: Plugin): anyDict {
throw new Error('Not implemented')
}

getChatModel(): string {
return this.config.engines[this.getName()].model.chat
}
Expand Down Expand Up @@ -174,6 +169,31 @@ export default class LlmEngine {
return Object.values(this.plugins).map((plugin: Plugin) => this.getPluginAsTool(plugin))
}

// this is the default implementation as per OpenAI API
// it is now almost a de facto standard and other providers
// are following it such as MistralAI
getPluginAsTool(plugin: Plugin): anyDict {
return {
type: 'function',
function: {
name: plugin.getName(),
description: plugin.getDescription(),
parameters: {
type: 'object',
properties: plugin.getParameters().reduce((obj: anyDict, param: PluginParameter) => {
obj[param.name] = {
type: param.type,
enum: param.enum,
description: param.description,
}
return obj
}, {}),
required: plugin.getParameters().filter(param => param.required).map(param => param.name),
},
},
}
}

getToolPreparationDescription(tool: string): string {
const plugin = this.plugins[tool]
return plugin?.getPreparationDescription()
Expand Down
129 changes: 124 additions & 5 deletions src/services/mistralai.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/no-unused-vars */
import { Message } from '../types/index.d'
import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmEventCallback } from '../types/llm.d'
import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmEventCallback, LlmToolCall } from '../types/llm.d'
import { EngineConfig, Configuration } from '../types/config.d'
import LlmEngine from './engine'

Expand All @@ -15,6 +15,9 @@ export const isMistrailAIReady = (engineConfig: EngineConfig): boolean => {
export default class extends LlmEngine {

client: MistralClient
currentModel: string
currentThread: Array<any>
toolCalls: LlmToolCall[]

constructor(config: Configuration) {
super(config)
Expand Down Expand Up @@ -73,13 +76,29 @@ export default class extends LlmEngine {
async stream(thread: Message[], opts: LlmCompletionOpts): Promise<LlmStream> {

// model: switch to vision if needed
const model = this.selectModel(thread, opts?.model || this.getChatModel())
this.currentModel = this.selectModel(thread, opts?.model || this.getChatModel())

// save the message thread
this.currentThread = this.buildPayload(thread, this.currentModel)
return await this.doStream()

}

async doStream(): Promise<LlmStream> {

// reset
this.toolCalls = []

// tools
const tools = this.getAvailableToolsForModel(this.currentModel)

// call
console.log(`[mistralai] prompting model ${model}`)
console.log(`[mistralai] prompting model ${this.currentModel}`)
const stream = this.client.chatStream({
model: model,
messages: this.buildPayload(thread, model),
model: this.currentModel,
messages: this.currentThread,
tools: tools.length ? tools : null,

Check failure on line 100 in src/services/mistralai.ts

View workflow job for this annotation

GitHub Actions / build

tests/unit/engine_mistralai.test.ts > MistralAI stream

TypeError: Cannot read properties of null (reading 'length') ❯ __vite_ssr_exports__.default.doStream src/services/mistralai.ts:100:20 ❯ __vite_ssr_exports__.default.stream src/services/mistralai.ts:83:23 ❯ tests/unit/engine_mistralai.test.ts:58:36
tool_choice: tools.length ? 'any' : null,
})

// done
Expand All @@ -93,6 +112,98 @@ export default class extends LlmEngine {

// eslint-disable-next-line @typescript-eslint/no-unused-vars
async streamChunkToLlmChunk(chunk: any, eventCallback: LlmEventCallback): Promise<LlmChunk|null> {

// tool calls
if (chunk.choices[0]?.delta?.tool_calls) {

// arguments or new tool?
if (chunk.choices[0].delta.tool_calls[0].id) {

// debug
//console.log('[mistralai] tool call start:', chunk)

// record the tool call
const toolCall: LlmToolCall = {
id: chunk.choices[0].delta.tool_calls[0].id,
message: chunk.choices[0].delta.tool_calls.map((tc: any) => {
delete tc.index
return tc
}),
function: chunk.choices[0].delta.tool_calls[0].function.name,
args: chunk.choices[0].delta.tool_calls[0].function.arguments,
}
console.log('[mistralai] tool call:', toolCall)
this.toolCalls.push(toolCall)

// first notify
eventCallback?.call(this, {
type: 'tool',
content: this.getToolPreparationDescription(toolCall.function)
})

} else {

const toolCall = this.toolCalls[this.toolCalls.length-1]
toolCall.args += chunk.choices[0].delta.tool_calls[0].function.arguments
return null

}

}

// now tool calling
if (chunk.choices[0]?.finish_reason === 'tool_calls') {

// debug
//console.log('[mistralai] tool calls:', this.toolCalls)

// add tools
for (const toolCall of this.toolCalls) {

// first notify
eventCallback?.call(this, {
type: 'tool',
content: this.getToolRunningDescription(toolCall.function)
})

// now execute
const args = JSON.parse(toolCall.args)
const content = await this.callTool(toolCall.function, args)
console.log(`[mistralai] tool call ${toolCall.function} with ${JSON.stringify(args)} => ${JSON.stringify(content).substring(0, 128)}`)

// add tool call message
this.currentThread.push({
role: 'assistant',
tool_calls: toolCall.message
})

// add tool response message
this.currentThread.push({
role: 'tool',
tool_call_id: toolCall.id,
name: toolCall.function,
content: JSON.stringify(content)
})
}

// clear
eventCallback?.call(this, {
type: 'tool',
content: null,
})

// switch to new stream
eventCallback?.call(this, {
type: 'stream',
content: await this.doStream(),
})

// done
return null

}

// default
return {
text: chunk.choices[0].delta.content,
done: chunk.choices[0].finish_reason != null
Expand All @@ -107,4 +218,12 @@ export default class extends LlmEngine {
async image(prompt: string, opts: LlmCompletionOpts): Promise<LlmResponse|null> {
return null
}

getAvailableToolsForModel(model: string): any[] {
if (model.includes('mistral-large') || model.includes('mixtral-8x22b')) {
return this.getAvailableTools()
} else {
return null
}
}
}
21 changes: 0 additions & 21 deletions src/services/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,25 +248,4 @@ export default class extends LlmEngine {

}

getPluginAsTool(plugin: Plugin): anyDict {
return {
type: 'function',
function: {
name: plugin.getName(),
description: plugin.getDescription(),
parameters: {
type: 'object',
properties: plugin.getParameters().reduce((obj: anyDict, param: PluginParameter) => {
obj[param.name] = {
type: param.type,
enum: param.enum,
description: param.description,
}
return obj
}, {}),
required: plugin.getParameters().filter(param => param.required).map(param => param.name),
},
},
}
}
}

0 comments on commit f6a2a14

Please sign in to comment.