From f6a2a145963bf12ae9a7604a7e0d41264efeab64 Mon Sep 17 00:00:00 2001 From: Nicolas Bonamy Date: Thu, 2 May 2024 21:05:42 -0500 Subject: [PATCH] MistralAI function calling --- README.md | 1 + src/plugins/tavily.ts | 1 + src/services/engine.ts | 30 +++++++-- src/services/mistralai.ts | 129 ++++++++++++++++++++++++++++++++++++-- src/services/openai.ts | 21 ------- 5 files changed, 151 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 08177da..fd86046 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ To use Internet search you need a [Tavily API key](https://app.tavily.com/home). ## DONE +- [x] MistralAI function calling - [x] Auto-update - [x] History date sections - [x] Multiple selection delete diff --git a/src/plugins/tavily.ts b/src/plugins/tavily.ts index f1b26a9..a3218e9 100644 --- a/src/plugins/tavily.ts +++ b/src/plugins/tavily.ts @@ -45,6 +45,7 @@ export default class extends Plugin { include_answer: true, //include_raw_content: true, }) + //console.log('Tavily response:', response) return response } catch (error) { return { error: error.message } diff --git a/src/services/engine.ts b/src/services/engine.ts index babb7ed..b5685ca 100644 --- a/src/services/engine.ts +++ b/src/services/engine.ts @@ -75,11 +75,6 @@ export default class LlmEngine { throw new Error('Not implemented') } - // eslint-disable-next-line @typescript-eslint/no-unused-vars - getPluginAsTool(plugin: Plugin): anyDict { - throw new Error('Not implemented') - } - getChatModel(): string { return this.config.engines[this.getName()].model.chat } @@ -174,6 +169,31 @@ export default class LlmEngine { return Object.values(this.plugins).map((plugin: Plugin) => this.getPluginAsTool(plugin)) } + // this is the default implementation as per OpenAI API + // it is now almost a de facto standard and other providers + // are following it such as MistralAI + getPluginAsTool(plugin: Plugin): anyDict { + return { + type: 'function', + function: { + name: plugin.getName(), + description: plugin.getDescription(), + parameters: { + type: 'object', + properties: plugin.getParameters().reduce((obj: anyDict, param: PluginParameter) => { + obj[param.name] = { + type: param.type, + enum: param.enum, + description: param.description, + } + return obj + }, {}), + required: plugin.getParameters().filter(param => param.required).map(param => param.name), + }, + }, + } + } + getToolPreparationDescription(tool: string): string { const plugin = this.plugins[tool] return plugin?.getPreparationDescription() diff --git a/src/services/mistralai.ts b/src/services/mistralai.ts index 48689ba..516f45c 100644 --- a/src/services/mistralai.ts +++ b/src/services/mistralai.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ import { Message } from '../types/index.d' -import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmEventCallback } from '../types/llm.d' +import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmEventCallback, LlmToolCall } from '../types/llm.d' import { EngineConfig, Configuration } from '../types/config.d' import LlmEngine from './engine' @@ -15,6 +15,9 @@ export const isMistrailAIReady = (engineConfig: EngineConfig): boolean => { export default class extends LlmEngine { client: MistralClient + currentModel: string + currentThread: Array + toolCalls: LlmToolCall[] constructor(config: Configuration) { super(config) @@ -73,13 +76,29 @@ export default class extends LlmEngine { async stream(thread: Message[], opts: LlmCompletionOpts): Promise { // model: switch to vision if needed - const model = this.selectModel(thread, opts?.model || this.getChatModel()) + this.currentModel = this.selectModel(thread, opts?.model || this.getChatModel()) + // save the message thread + this.currentThread = this.buildPayload(thread, this.currentModel) + return await this.doStream() + + } + + async doStream(): Promise { + + // reset + this.toolCalls = [] + + // tools + const tools = this.getAvailableToolsForModel(this.currentModel) + // call - console.log(`[mistralai] prompting model ${model}`) + console.log(`[mistralai] prompting model ${this.currentModel}`) const stream = this.client.chatStream({ - model: model, - messages: this.buildPayload(thread, model), + model: this.currentModel, + messages: this.currentThread, + tools: tools.length ? tools : null, + tool_choice: tools.length ? 'any' : null, }) // done @@ -93,6 +112,98 @@ export default class extends LlmEngine { // eslint-disable-next-line @typescript-eslint/no-unused-vars async streamChunkToLlmChunk(chunk: any, eventCallback: LlmEventCallback): Promise { + + // tool calls + if (chunk.choices[0]?.delta?.tool_calls) { + + // arguments or new tool? + if (chunk.choices[0].delta.tool_calls[0].id) { + + // debug + //console.log('[mistralai] tool call start:', chunk) + + // record the tool call + const toolCall: LlmToolCall = { + id: chunk.choices[0].delta.tool_calls[0].id, + message: chunk.choices[0].delta.tool_calls.map((tc: any) => { + delete tc.index + return tc + }), + function: chunk.choices[0].delta.tool_calls[0].function.name, + args: chunk.choices[0].delta.tool_calls[0].function.arguments, + } + console.log('[mistralai] tool call:', toolCall) + this.toolCalls.push(toolCall) + + // first notify + eventCallback?.call(this, { + type: 'tool', + content: this.getToolPreparationDescription(toolCall.function) + }) + + } else { + + const toolCall = this.toolCalls[this.toolCalls.length-1] + toolCall.args += chunk.choices[0].delta.tool_calls[0].function.arguments + return null + + } + + } + + // now tool calling + if (chunk.choices[0]?.finish_reason === 'tool_calls') { + + // debug + //console.log('[mistralai] tool calls:', this.toolCalls) + + // add tools + for (const toolCall of this.toolCalls) { + + // first notify + eventCallback?.call(this, { + type: 'tool', + content: this.getToolRunningDescription(toolCall.function) + }) + + // now execute + const args = JSON.parse(toolCall.args) + const content = await this.callTool(toolCall.function, args) + console.log(`[mistralai] tool call ${toolCall.function} with ${JSON.stringify(args)} => ${JSON.stringify(content).substring(0, 128)}`) + + // add tool call message + this.currentThread.push({ + role: 'assistant', + tool_calls: toolCall.message + }) + + // add tool response message + this.currentThread.push({ + role: 'tool', + tool_call_id: toolCall.id, + name: toolCall.function, + content: JSON.stringify(content) + }) + } + + // clear + eventCallback?.call(this, { + type: 'tool', + content: null, + }) + + // switch to new stream + eventCallback?.call(this, { + type: 'stream', + content: await this.doStream(), + }) + + // done + return null + + } + + // default return { text: chunk.choices[0].delta.content, done: chunk.choices[0].finish_reason != null @@ -107,4 +218,12 @@ export default class extends LlmEngine { async image(prompt: string, opts: LlmCompletionOpts): Promise { return null } + + getAvailableToolsForModel(model: string): any[] { + if (model.includes('mistral-large') || model.includes('mixtral-8x22b')) { + return this.getAvailableTools() + } else { + return null + } + } } diff --git a/src/services/openai.ts b/src/services/openai.ts index 8a08082..f22ba01 100644 --- a/src/services/openai.ts +++ b/src/services/openai.ts @@ -248,25 +248,4 @@ export default class extends LlmEngine { } - getPluginAsTool(plugin: Plugin): anyDict { - return { - type: 'function', - function: { - name: plugin.getName(), - description: plugin.getDescription(), - parameters: { - type: 'object', - properties: plugin.getParameters().reduce((obj: anyDict, param: PluginParameter) => { - obj[param.name] = { - type: param.type, - enum: param.enum, - description: param.description, - } - return obj - }, {}), - required: plugin.getParameters().filter(param => param.required).map(param => param.name), - }, - }, - } - } }