From 02c65f967deb05a9197db4dacd0ce862bd4e4785 Mon Sep 17 00:00:00 2001 From: Dennis Huebner Date: Fri, 13 Dec 2024 10:40:12 +0100 Subject: [PATCH] Ollama LLM provider tools support #14610 --- .../src/node/ollama-language-model.ts | 170 ++++++++++++++---- 1 file changed, 134 insertions(+), 36 deletions(-) diff --git a/packages/ai-ollama/src/node/ollama-language-model.ts b/packages/ai-ollama/src/node/ollama-language-model.ts index 8c90ac29dcf1b..75e8044e34c5c 100644 --- a/packages/ai-ollama/src/node/ollama-language-model.ts +++ b/packages/ai-ollama/src/node/ollama-language-model.ts @@ -20,7 +20,9 @@ import { LanguageModelRequest, LanguageModelRequestMessage, LanguageModelResponse, + LanguageModelStreamResponse, LanguageModelStreamResponsePart, + ToolCall, ToolRequest } from '@theia/ai-core'; import { CancellationToken } from '@theia/core'; @@ -31,7 +33,9 @@ export const OllamaModelIdentifier = Symbol('OllamaModelIdentifier'); export class OllamaModel implements LanguageModel { protected readonly DEFAULT_REQUEST_SETTINGS: Partial> = { - keep_alive: '15m' + keep_alive: '15m', + // options see: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values + options: {} }; readonly providerId = 'ollama'; @@ -50,6 +54,26 @@ export class OllamaModel implements LanguageModel { public defaultRequestSettings?: { [key: string]: unknown } ) { } + async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise { + const settings = this.getSettings(request); + const ollama = this.initializeOllama(); + + const ollamaRequest: ExtendedChatRequest = { + model: this.model, + ...this.DEFAULT_REQUEST_SETTINGS, + ...settings, + messages: request.messages.map(this.toOllamaMessage), + tools: request.tools?.map(this.toOllamaTool) + }; + const structured = request.response_format?.type === 'json_schema'; + return this.dispatchRequest(ollama, ollamaRequest, structured, cancellationToken); + } + + /** + * Retrieves the settings for the chat request, merging the request-specific settings with the default settings. + * @param request The language model request containing specific settings. + * @returns A partial ChatRequest object containing the merged settings. + */ protected getSettings(request: LanguageModelRequest): Partial { const settings = request.settings ?? this.defaultRequestSettings ?? {}; return { @@ -57,55 +81,98 @@ export class OllamaModel implements LanguageModel { }; } - async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise { - const settings = this.getSettings(request); - const ollama = this.initializeOllama(); + protected async dispatchRequest(ollama: Ollama, ollamaRequest: ExtendedChatRequest, structured: boolean, cancellation?: CancellationToken): Promise { + + // Handle structured output request + if (structured) { + return this.handleStructuredOutputRequest(ollama, ollamaRequest); + } - if (request.response_format?.type === 'json_schema') { - return this.handleStructuredOutputRequest(ollama, request); + // Handle tool request - response may call tools + if (ollamaRequest.tools && ollamaRequest.tools?.length > 0) { + return this.handleToolsRequest(ollama, ollamaRequest); } + + // Handle standard chat request const response = await ollama.chat({ - model: this.model, - ...this.DEFAULT_REQUEST_SETTINGS, - ...settings, - messages: request.messages.map(this.toOllamaMessage), - stream: true, - tools: request.tools?.map(this.toOllamaTool), + ...ollamaRequest, + stream: true }); + return this.handleCancellationAndWrapIterator(response, cancellation); + } - cancellationToken?.onCancellationRequested(() => { - response.abort(); + protected async handleToolsRequest(ollama: Ollama, chatRequest: ExtendedChatRequest, prevResponse?: ChatResponse): Promise { + const response = prevResponse || await ollama.chat({ + ...chatRequest, + stream: false }); - - async function* wrapAsyncIterator(inputIterable: AsyncIterable): AsyncIterable { - for await (const item of inputIterable) { - // TODO handle tool calls - yield { content: item.message.content }; + if (response.message.tool_calls) { + const tools: ToolWithHandler[] = chatRequest.tools ?? []; + // Add response message to chat history + chatRequest.messages.push(response.message); + const tool_calls: ToolCall[] = []; + for (const [idx, toolCall] of response.message.tool_calls.entries()) { + const functionToCall = tools.find(tool => tool.function.name === toolCall.function.name); + if (functionToCall) { + const args = JSON.stringify(toolCall.function?.arguments); + const funcResult = await functionToCall.handler(args); + chatRequest.messages.push({ + role: 'tool', + content: `Tool call ${functionToCall.function.name} returned: ${String(funcResult)}`, + }); + let resultString = String(funcResult); + if (resultString.length > 1000) { + // truncate result string if it is too long + resultString = resultString.substring(0, 1000) + '...'; + } + tool_calls.push({ + id: `ollama_${response.created_at}_${idx}`, + function: { + name: functionToCall.function.name, + arguments: Object.values(toolCall.function?.arguments ?? {}).join(', ') + }, + result: resultString, + finished: true + }); + } + } + // Get final response from model with function outputs + const finalResponse = await ollama.chat({ ...chatRequest, stream: false }); + if (finalResponse.message.tool_calls) { + // If the final response also calls tools, recursively handle them + return this.handleToolsRequest(ollama, chatRequest, finalResponse); } + return { stream: this.createAsyncIterable([{ tool_calls }, { content: finalResponse.message.content }]) }; } - return { stream: wrapAsyncIterator(response) }; + return { text: response.message.content }; } - protected async handleStructuredOutputRequest(ollama: Ollama, request: LanguageModelRequest): Promise { - const settings = this.getSettings(request); - const result = await ollama.chat({ - ...settings, - ...this.DEFAULT_REQUEST_SETTINGS, - model: this.model, - messages: request.messages.map(this.toOllamaMessage), + protected createAsyncIterable(items: T[]): AsyncIterable { + return { + [Symbol.asyncIterator]: async function* (): AsyncIterableIterator { + for (const item of items) { + yield item; + } + } + }; + } + + protected async handleStructuredOutputRequest(ollama: Ollama, chatRequest: ChatRequest): Promise { + const response = await ollama.chat({ + ...chatRequest, format: 'json', stream: false, }); try { return { - content: result.message.content, - parsed: JSON.parse(result.message.content) + content: response.message.content, + parsed: JSON.parse(response.message.content) }; } catch (error) { // TODO use ILogger console.log('Failed to parse structured response from the language model.', error); return { - content: result.message.content, + content: response.message.content, parsed: {} }; } @@ -119,11 +186,21 @@ export class OllamaModel implements LanguageModel { return new Ollama({ host: host }); } - protected toOllamaTool(tool: ToolRequest): Tool { - const transform = (props: Record | undefined) => { + protected handleCancellationAndWrapIterator(response: AbortableAsyncIterable, token?: CancellationToken): LanguageModelStreamResponse { + token?.onCancellationRequested(() => { + // maybe it is better to use ollama.abort() as we are using one client per request + response.abort(); + }); + async function* wrapAsyncIterator(inputIterable: AsyncIterable): AsyncIterable { + for await (const item of inputIterable) { + yield { content: item.message.content }; + } + } + return { stream: wrapAsyncIterator(response) }; + } + + protected toOllamaTool(tool: ToolRequest): ToolWithHandler { + const transform = (props: Record | undefined) => { if (!props) { return undefined; } @@ -148,7 +225,8 @@ export class OllamaModel implements LanguageModel { required: Object.keys(tool.parameters?.properties ?? {}), properties: transform(tool.parameters?.properties) ?? {} }, - } + }, + handler: tool.handler }; } @@ -165,3 +243,23 @@ export class OllamaModel implements LanguageModel { return { role: 'system', content: '' }; } } + +/** + * Extended Tool containing a handler + * @see Tool + */ +type ToolWithHandler = Tool & { handler: (arg_string: string) => Promise }; + +/** + * Extended chat request with mandatory messages and ToolWithHandler tools + * + * @see ChatRequest + * @see ToolWithHandler + */ +type ExtendedChatRequest = ChatRequest & { + messages: Message[] + tools?: ToolWithHandler[] +}; + +// Ollama doesn't export this type, so we have to define it here +type AbortableAsyncIterable = AsyncIterable & { abort: () => void };