diff --git a/.github/workflows/dev-build.yaml b/.github/workflows/dev-build.yaml index 583eeff7a5b..b2d21bf01ac 100644 --- a/.github/workflows/dev-build.yaml +++ b/.github/workflows/dev-build.yaml @@ -6,7 +6,7 @@ concurrency: on: push: - branches: ['3999-chromium-flags'] # put your current branch to create a build. Core team only. + branches: ['agentic-streaming'] # put your current branch to create a build. Core team only. paths-ignore: - '**.md' - 'cloud-deployments/*' diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx index 1d568441719..6b08a3b68ce 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/HistoricalMessage/index.jsx @@ -226,6 +226,7 @@ const RenderChatContent = memo( ); let thoughtChain = null; let msgToRender = message; + if (!message) return null; // If the message is a perfect thought chain, we can render it directly // Complete == open and close tags match perfectly. diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/StatusResponse/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/StatusResponse/index.jsx index fabaf038936..10befa09267 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/StatusResponse/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/StatusResponse/index.jsx @@ -4,11 +4,7 @@ import { CaretDown } from "@phosphor-icons/react"; import AgentAnimation from "@/media/animations/agent-animation.webm"; import AgentStatic from "@/media/animations/agent-static.png"; -export default function StatusResponse({ - messages = [], - isThinking = false, - showCheckmark = false, -}) { +export default function StatusResponse({ messages = [], isThinking = false }) { const [isExpanded, setIsExpanded] = useState(false); const currentThought = messages[messages.length - 1]; const previousThoughts = messages.slice(0, -1); diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx index 0284e211d7d..f9ee1b80a55 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/ChatHistory/index.jsx @@ -174,10 +174,6 @@ export default function ChatHistory({ key={`status-group-${index}`} messages={item} isThinking={!hasSubsequentMessages && lastMessageInfo.isAnimating} - showCheckmark={ - hasSubsequentMessages || - (!lastMessageInfo.isAnimating && !lastMessageInfo.isStatusResponse) - } /> ); }, diff --git a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx index 6e961a95717..ee57186f5d3 100644 --- a/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx +++ b/frontend/src/components/WorkspaceChat/ChatContainer/index.jsx @@ -234,6 +234,7 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { const socket = new WebSocket( `${websocketURI()}/api/agent-invocation/${socketId}` ); + socket.supportsAgentStreaming = false; window.addEventListener(ABORT_STREAM_EVENT, () => { window.dispatchEvent(new CustomEvent(AGENT_SESSION_END)); @@ -243,7 +244,7 @@ export default function ChatContainer({ workspace, knownHistory = [] }) { socket.addEventListener("message", (event) => { setLoadingResponse(true); try { - handleSocketResponse(event, setChatHistory); + handleSocketResponse(socket, event, setChatHistory); } catch (e) { console.error("Failed to parse data"); window.dispatchEvent(new CustomEvent(AGENT_SESSION_END)); diff --git a/frontend/src/utils/chat/agent.js b/frontend/src/utils/chat/agent.js index ad1193d304c..2194b3e1d47 100644 --- a/frontend/src/utils/chat/agent.js +++ b/frontend/src/utils/chat/agent.js @@ -12,6 +12,8 @@ const handledEvents = [ "awaitingFeedback", "wssFailure", "rechartVisualize", + // Streaming events + "reportStreamEvent", ]; export function websocketURI() { @@ -20,13 +22,13 @@ export function websocketURI() { return `${wsProtocol}//${new URL(import.meta.env.VITE_API_BASE).host}`; } -export default function handleSocketResponse(event, setChatHistory) { +export default function handleSocketResponse(socket, event, setChatHistory) { const data = safeJsonParse(event.data, null); if (data === null) return; // No message type is defined then this is a generic message // that we need to print to the user as a system response - if (!data.hasOwnProperty("type")) { + if (!data.hasOwnProperty("type") && !socket.supportsAgentStreaming) { return setChatHistory((prev) => { return [ ...prev.filter((msg) => !!msg.content), @@ -46,6 +48,90 @@ export default function handleSocketResponse(event, setChatHistory) { if (!handledEvents.includes(data.type) || !data.content) return; + if (data.type === "reportStreamEvent") { + // Enable agent streaming for the next message so we can handle streaming or non-streaming responses + // If we get this message we know the provider supports agentic streaming + socket.supportsAgentStreaming = true; + + return setChatHistory((prev) => { + if (data.content.type === "removeStatusResponse") + return [...prev.filter((msg) => msg.uuid !== data.content.uuid)]; + + const knownMessage = data.content.uuid + ? prev.find((msg) => msg.uuid === data.content.uuid) + : null; + if (!knownMessage) { + if (data.content.type === "fullTextResponse") { + return [ + ...prev.filter((msg) => !!msg.content), + { + uuid: data.content.uuid, + type: "textResponse", + content: data.content.content, + role: "assistant", + sources: [], + closed: true, + error: null, + animate: false, + pending: false, + }, + ]; + } + + return [ + ...prev.filter((msg) => !!msg.content), + { + uuid: data.content.uuid, + type: "statusResponse", + content: data.content.content, + role: "assistant", + sources: [], + closed: true, + error: null, + animate: false, + pending: false, + }, + ]; + } else { + const { type, content, uuid } = data.content; + // For tool call invocations, we need to update the existing message entirely since it is accumulated + // and we dont know if the function will have arguments or not while streaming - so replace the existing message entirely + if (type === "toolCallInvocation") { + const knownMessage = prev.find((msg) => msg.uuid === uuid); + if (!knownMessage) + return [...prev, { uuid, type: "toolCallInvocation", content }]; // If the message is not known, add it to the end of the list + return [ + ...prev.filter((msg) => msg.uuid !== uuid), + { ...knownMessage, content }, + ]; // If the message is known, replace it with the new content + } + + if (type === "textResponseChunk") { + return prev + .map((msg) => + msg.uuid === uuid + ? { + ...msg, + type: "textResponse", + content: msg.content + content, + } + : msg?.content + ? msg + : null + ) + .filter((msg) => !!msg); + } + + // Generic text response - will be put in the agent thought bubble + return prev.map((msg) => + msg.uuid === data.content.uuid + ? { ...msg, content: msg.content + data.content.content } + : msg + ); + } + }); + } + if (data.type === "fileDownload") { saveAs(data.content.b64Content, data.content.filename ?? "unknown.txt"); return; diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index 683850dfcb9..74e70540a5d 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -533,23 +533,17 @@ Only return the role. } /** - * Ask the for the AI provider to generate a reply to the chat. + * Get the chat history between two nodes or all chats to/from a node. * - * @param route.to The node that sent the chat. - * @param route.from The node that will reply to the chat. + * @param route + * @returns */ - async reply(route) { - // get the provider for the node that will reply - const fromConfig = this.getAgentConfig(route.from); - - const chatHistory = - // if it is sending message to a group, send the group chat history to the provider - // otherwise, send the chat history between the two nodes - this.channels.get(route.to) - ? [ - { - role: "user", - content: `You are in a whatsapp group. Read the following conversation and then reply. + getOrFormatNodeChatHistory(route) { + if (this.channels.get(route.to)) { + return [ + { + role: "user", + content: `You are in a whatsapp group. Read the following conversation and then reply. Do not add introduction or conclusion to your reply because this will be a continuous conversation. Don't introduce yourself. CHAT HISTORY @@ -558,20 +552,36 @@ ${this.getHistory({ to: route.to }) .join("\n")} @${route.from}:`, - }, - ] - : this.getHistory(route).map((c) => ({ - content: c.content, - role: c.from === route.to ? "user" : "assistant", - })); + }, + ]; + } - // build the messages to send to the provider + // This is normal chat between user<->agent + return this.getHistory(route).map((c) => ({ + content: c.content, + role: c.from === route.to ? "user" : "assistant", + })); + } + + /** + * Ask the for the AI provider to generate a reply to the chat. + * This will load the functions that the node can call and the chat history. + * Then before calling the provider, it will check if the provider supports agent streaming. + * If it does, it will call the provider asynchronously (streaming). + * Otherwise, it will call the provider synchronously (non-streaming). + * `.supportsAgentStreaming` is used to determine if the provider supports agent streaming on the respective provider. + * + * @param route.to The node that sent the chat. + * @param route.from The node that will reply to the chat. + */ + async reply(route) { + const fromConfig = this.getAgentConfig(route.from); + const chatHistory = this.getOrFormatNodeChatHistory(route); const messages = [ { content: fromConfig.role, role: "system", }, - // get the history of chats between the two nodes ...chatHistory, ]; @@ -585,18 +595,144 @@ ${this.getHistory({ to: route.to }) ...fromConfig, }); - // get the chat completion - const content = await this.handleExecution( - provider, + let content; + if (provider.supportsAgentStreaming) { + this.handlerProps.log?.( + "[DEBUG] Provider supports agent streaming - will use async execution!" + ); + content = await this.handleAsyncExecution( + provider, + messages, + functions, + route.from + ); + } else { + this.handlerProps.log?.( + "[DEBUG] Provider does not support agent streaming - will use synchronous execution!" + ); + content = await this.handleExecution( + provider, + messages, + functions, + route.from + ); + } + + this.newMessage({ ...route, content }); + return content; + } + + /** + * Handle the async (streaming) execution of the provider + * with tool calls. + * + * @param provider + * @param messages + * @param functions + * @param byAgent + * + * @returns {Promise} + */ + async handleAsyncExecution( + provider, + messages = [], + functions = [], + byAgent = null + ) { + const eventHandler = (type, data) => { + this?.socket?.send(type, data); + }; + + /** @type {{ functionCall: { name: string, arguments: string }, textResponse: string }} */ + const completionStream = await provider.stream( messages, functions, - route.from + eventHandler ); - this.newMessage({ ...route, content }); - return content; + if (completionStream.functionCall) { + const { name, arguments: args } = completionStream.functionCall; + const fn = this.functions.get(name); + + // if provider hallucinated on the function name + // ask the provider to complete again + if (!fn) { + return await this.handleAsyncExecution( + provider, + [ + ...messages, + { + name, + role: "function", + content: `Function "${name}" not found. Try again.`, + }, + ], + functions, + byAgent + ); + } + + // Execute the function and return the result to the provider + fn.caller = byAgent || "agent"; + + // If provider is verbose, log the tool call to the frontend + if (provider?.verbose) { + this?.introspect?.( + `${fn.caller} is executing \`${name}\` tool ${JSON.stringify(args, null, 2)}` + ); + } + + // Always log the tool call to the console for debugging purposes + this.handlerProps?.log?.( + `[debug]: ${fn.caller} is attempting to call \`${name}\` tool ${JSON.stringify(args, null, 2)}` + ); + + const result = await fn.handler(args); + Telemetry.sendTelemetry("agent_tool_call", { tool: name }, null, true); + + // If the tool call has direct output enabled, return the result directly to the chat + // without any further processing and no further tool calls will be run. + if (this.skipHandleExecution) { + this.skipHandleExecution = false; // reset the flag to prevent next tool call from being skipped + this?.introspect?.( + `The tool call has direct output enabled! The result will be returned directly to the chat without any further processing and no further tool calls will be run.` + ); + this?.introspect?.(`Tool use completed.`); + this.handlerProps?.log?.( + `${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.` + ); + return result; + } + + return await this.handleAsyncExecution( + provider, + [ + ...messages, + { + name, + role: "function", + content: result, + }, + ], + functions, + byAgent + ); + } + + return completionStream?.textResponse; } + /** + * Handle the synchronous (non-streaming) execution of the provider + * with tool calls. + * + * @param provider + * @param messages + * @param functions + * @param byAgent + * + * @returns {Promise} + */ async handleExecution( provider, messages = [], @@ -675,7 +811,7 @@ ${this.getHistory({ to: route.to }) ); } - return completion?.result; + return completion?.textResponse; } /** diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index c2528acd948..87ec35bf486 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -10,11 +10,12 @@ * @property {(string|null)} model - Overrides model used for provider. */ +const { v4 } = require("uuid"); const { ChatOpenAI } = require("@langchain/openai"); const { ChatAnthropic } = require("@langchain/anthropic"); const { ChatBedrockConverse } = require("@langchain/aws"); const { ChatOllama } = require("@langchain/community/chat_models/ollama"); -const { toValidNumber } = require("../../../http"); +const { toValidNumber, safeJsonParse } = require("../../../http"); const { getLLMProviderClass } = require("../../../helpers"); const { parseLMStudioBasePath } = require("../../../AiProviders/lmStudio"); @@ -288,6 +289,82 @@ class Provider { return DEFAULT_WORKSPACE_PROMPT; } } + + /** + * Whether the provider supports agent streaming. + * Disabled by default and needs to be explicitly enabled in the provider + * This is temporary while we migrate all providers to support agent streaming + * @returns {boolean} + */ + get supportsAgentStreaming() { + return false; + } + + /** + * Stream a chat completion from the LLM with tool calling + * Note: This using the OpenAI API format and may need to be adapted for other providers. + * + * @param {any[]} messages - The messages to send to the LLM. + * @param {any[]} functions - The functions to use in the LLM. + * @param {function} eventHandler - The event handler to use to report stream events. + * @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion. + */ + async stream(messages, functions = [], eventHandler = null) { + this.providerLog("Provider.stream - will process this chat completion."); + const msgUUID = v4(); + const stream = await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + ...(Array.isArray(functions) && functions?.length > 0 + ? { functions } + : {}), + }); + + const result = { + functionCall: null, + textResponse: "", + }; + + for await (const chunk of stream) { + if (!chunk?.choices?.[0]) continue; // Skip if no choices + const choice = chunk.choices[0]; + + if (choice.delta?.content) { + result.textResponse += choice.delta.content; + eventHandler?.("reportStreamEvent", { + type: "textResponseChunk", + uuid: msgUUID, + content: choice.delta.content, + }); + } + + if (choice.delta?.function_call) { + // accumulate the function call + if (result.functionCall) + result.functionCall.arguments += choice.delta.function_call.arguments; + else result.functionCall = choice.delta.function_call; + + eventHandler?.("reportStreamEvent", { + uuid: `${msgUUID}:tool_call_invocation`, + type: "toolCallInvocation", + content: `Assembling Tool Call: ${result.functionCall.name}(${result.functionCall.arguments})`, + }); + } + } + + // If there are arguments, parse them as json so that the tools can use them + if (!!result.functionCall?.arguments) + result.functionCall.arguments = safeJsonParse( + result.functionCall.arguments, + {} + ); + + return { + textResponse: result.textResponse, + functionCall: result.functionCall, + }; + } } module.exports = Provider; diff --git a/server/utils/agents/aibitat/providers/anthropic.js b/server/utils/agents/aibitat/providers/anthropic.js index fd70b23f70f..7bcc9976f5a 100644 --- a/server/utils/agents/aibitat/providers/anthropic.js +++ b/server/utils/agents/aibitat/providers/anthropic.js @@ -1,6 +1,8 @@ const Anthropic = require("@anthropic-ai/sdk"); const { RetryError } = require("../error.js"); const Provider = require("./ai-provider.js"); +const { v4 } = require("uuid"); +const { safeJsonParse } = require("../../../http"); /** * The agent provider for the Anthropic API. @@ -15,7 +17,7 @@ class AnthropicProvider extends Provider { apiKey: process.env.ANTHROPIC_API_KEY, maxRetries: 3, }, - model = "claude-2", + model = "claude-3-5-sonnet-20240620", } = config; const client = new Anthropic(options); @@ -25,70 +27,91 @@ class AnthropicProvider extends Provider { this.model = model; } - // For Anthropic we will always need to ensure the message sequence is role,content - // as we can attach any data to message nodes and this keeps the message property - // sent to the API always in spec. - #sanitize(chats) { - const sanitized = [...chats]; - - // If the first message is not a USER, Anthropic will abort so keep shifting the - // message array until that is the case. - while (sanitized.length > 0 && sanitized[0].role !== "user") - sanitized.shift(); - - return sanitized.map((msg) => { - const { role, content } = msg; - return { role, content }; - }); + get supportsAgentStreaming() { + return true; } - #normalizeChats(messages = []) { - if (!messages.length) return messages; - const normalized = []; - - [...messages].forEach((msg, i) => { - if (msg.role !== "function") return normalized.push(msg); - - // If the last message is a role "function" this is our special aibitat message node. - // and we need to remove it from the array of messages. - // Since Anthropic needs to have the tool call resolved, we look at the previous chat to "function" - // and go through its content "thought" from ~ln:143 and get the tool_call id so we can resolve - // this tool call properly. - const functionCompletion = msg; - const toolCallId = messages[i - 1]?.content?.find( - (msg) => msg.type === "tool_use" - )?.id; - - // Append the Anthropic acceptable node to the message chain so function can resolve. - normalized.push({ - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: toolCallId, - content: functionCompletion.content, - }, - ], - }); - }); - return normalized; - } - - // Anthropic handles system message as a property, so here we split the system message prompt - // from all the chats and then normalize them so they will be useable in case of tool_calls or general chat. - #parseSystemPrompt(messages = []) { - const chats = []; + #prepareMessages(messages = []) { + // Extract system prompt and filter out any system messages from the main chat. let systemPrompt = "You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions."; - for (const msg of messages) { + const chatMessages = messages.filter((msg) => { if (msg.role === "system") { systemPrompt = msg.content; - continue; + return false; } - chats.push(msg); + return true; + }); + + const processedMessages = chatMessages.reduce( + (processedMessages, message, index) => { + // Normalize `function` role to Anthropic's `tool_result` format. + if (message.role === "function") { + const prevMessage = chatMessages[index - 1]; + if (prevMessage?.role === "assistant") { + const toolUse = prevMessage.content.find( + (item) => item.type === "tool_use" + ); + if (toolUse) { + processedMessages.push({ + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: toolUse.id, + content: message.content + ? String(message.content) + : "Tool executed successfully.", + }, + ], + }); + } + } + return processedMessages; + } + + // Ensure message content is in array format and filter out empty text blocks. + let content = Array.isArray(message.content) + ? message.content + : [{ type: "text", text: message.content }]; + content = content.filter( + (item) => + item.type !== "text" || (item.text && item.text.trim().length > 0) + ); + + if (content.length === 0) return processedMessages; + + // Add a text block to assistant messages with tool use if one doesn't exist. + if ( + message.role === "assistant" && + content.some((item) => item.type === "tool_use") && + !content.some((item) => item.type === "text") + ) { + content.unshift({ + type: "text", + text: "I'll use a tool to help answer this question.", + }); + } + + const lastMessage = processedMessages[processedMessages.length - 1]; + if (lastMessage && lastMessage.role === message.role) { + // Merge consecutive messages from the same role. + lastMessage.content.push(...content); + } else { + processedMessages.push({ ...message, content }); + } + + return processedMessages; + }, + [] + ); + + // The first message must be from the user. + if (processedMessages.length > 0 && processedMessages[0].role !== "user") { + processedMessages.shift(); } - return [systemPrompt, this.#normalizeChats(chats)]; + return [systemPrompt, processedMessages]; } // Anthropic does not use the regular schema for functions so here we need to ensure it is in there specific format @@ -109,6 +132,136 @@ class AnthropicProvider extends Provider { }); } + /** + * Stream a chat completion from the LLM with tool calling + * Note: This using the OpenAI API format and may need to be adapted for other providers. + * + * @param {any[]} messages - The messages to send to the LLM. + * @param {any[]} functions - The functions to use in the LLM. + * @param {function} eventHandler - The event handler to use to report stream events. + * @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion. + */ + async stream(messages, functions = [], eventHandler = null) { + try { + const msgUUID = v4(); + const [systemPrompt, chats] = this.#prepareMessages(messages); + const response = await this.client.messages.create( + { + model: this.model, + max_tokens: 4096, + system: systemPrompt, + messages: chats, + stream: true, + ...(Array.isArray(functions) && functions?.length > 0 + ? { tools: this.#formatFunctions(functions) } + : {}), + }, + { headers: { "anthropic-beta": "tools-2024-04-04" } } // Required to we can use tools. + ); + + const result = { + functionCall: null, + textResponse: "", + }; + + for await (const chunk of response) { + if (chunk.type === "content_block_start") { + if (chunk.content_block.type === "text") { + result.textResponse += chunk.content_block.text; + eventHandler?.("reportStreamEvent", { + type: "textResponseChunk", + uuid: msgUUID, + content: chunk.content_block.text, + }); + } + + if (chunk.content_block.type === "tool_use") { + result.functionCall = { + id: chunk.content_block.id, + name: chunk.content_block.name, + // The initial arguments are empty {} (object) so we need to set it to an empty string. + // It is unclear if this is ALWAYS empty on the tool_use block or if it can possible be populated. + // This is a workaround to ensure the tool call is valid. + arguments: "", + }; + eventHandler?.("reportStreamEvent", { + type: "toolCallInvocation", + uuid: `${msgUUID}:tool_call_invocation`, + content: `Assembling Tool Call: ${result.functionCall.name}(${result.functionCall.arguments})`, + }); + } + } + + if (chunk.type === "content_block_delta") { + if (chunk.delta.type === "text_delta") { + result.textResponse += chunk.delta.text; + eventHandler?.("reportStreamEvent", { + type: "textResponseChunk", + uuid: msgUUID, + content: chunk.delta.text, + }); + } + + if (chunk.delta.type === "input_json_delta") { + result.functionCall.arguments += chunk.delta.partial_json; + eventHandler?.("reportStreamEvent", { + type: "toolCallInvocation", + uuid: `${msgUUID}:tool_call_invocation`, + content: `Assembling Tool Call: ${result.functionCall.name}(${result.functionCall.arguments})`, + }); + } + } + } + + if (result.functionCall) { + result.functionCall.arguments = safeJsonParse( + result.functionCall.arguments, + {} + ); + messages.push({ + role: "assistant", + content: [ + { type: "text", text: result.textResponse }, + { + type: "tool_use", + id: result.functionCall.id, + name: result.functionCall.name, + input: result.functionCall.arguments, + }, + ], + }); + return { + textResponse: result.textResponse, + functionCall: { + name: result.functionCall.name, + arguments: result.functionCall.arguments, + }, + cost: 0, + }; + } + + return { + textResponse: result.textResponse, + functionCall: null, + cost: 0, + }; + } catch (error) { + // If invalid Auth error we need to abort because no amount of waiting + // will make auth better. + if (error instanceof Anthropic.AuthenticationError) throw error; + + if ( + error instanceof Anthropic.RateLimitError || + error instanceof Anthropic.InternalServerError || + error instanceof Anthropic.APIError // Also will catch AuthenticationError!!! + ) { + throw new RetryError(error.message); + } + + throw error; + } + } + /** * Create a completion based on the received messages. * @@ -118,13 +271,13 @@ class AnthropicProvider extends Provider { */ async complete(messages, functions = []) { try { - const [systemPrompt, chats] = this.#parseSystemPrompt(messages); + const [systemPrompt, chats] = this.#prepareMessages(messages); const response = await this.client.messages.create( { model: this.model, max_tokens: 4096, system: systemPrompt, - messages: this.#sanitize(chats), + messages: chats, stream: false, ...(Array.isArray(functions) && functions?.length > 0 ? { tools: this.#formatFunctions(functions) } @@ -185,7 +338,7 @@ class AnthropicProvider extends Provider { const completion = response.content.find((msg) => msg.type === "text"); return { - result: + textResponse: completion?.text ?? "The model failed to complete the task and return back a valid response.", cost: 0, diff --git a/server/utils/agents/aibitat/providers/apipie.js b/server/utils/agents/aibitat/providers/apipie.js index 8e342ec963a..8b7dc14d01f 100644 --- a/server/utils/agents/aibitat/providers/apipie.js +++ b/server/utils/agents/aibitat/providers/apipie.js @@ -27,11 +27,14 @@ class ApiPieProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -46,60 +49,31 @@ class ApiPieProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/azure.js b/server/utils/agents/aibitat/providers/azure.js index 48078f2db82..037937f35f3 100644 --- a/server/utils/agents/aibitat/providers/azure.js +++ b/server/utils/agents/aibitat/providers/azure.js @@ -18,6 +18,11 @@ class AzureOpenAiProvider extends Provider { this.model = config.model ?? process.env.OPEN_MODEL_PREF; this.verbose = true; } + + get supportsAgentStreaming() { + return true; + } + /** * Create a completion based on the received messages. * @@ -29,7 +34,7 @@ class AzureOpenAiProvider extends Provider { try { const response = await this.client.chat.completions.create({ model: this.model, - // stream: true, + stream: false, messages, ...(Array.isArray(functions) && functions?.length > 0 ? { functions } @@ -63,7 +68,7 @@ class AzureOpenAiProvider extends Provider { // console.log(completion, { functionArgs }) return { - result: null, + textResponse: null, functionCall: { name: completion.function_call.name, arguments: functionArgs, @@ -73,7 +78,7 @@ class AzureOpenAiProvider extends Provider { } return { - result: completion.content, + textResponse: completion.content, cost, }; } catch (error) { diff --git a/server/utils/agents/aibitat/providers/cometapi.js b/server/utils/agents/aibitat/providers/cometapi.js index 87eca7a053b..c15564d0967 100644 --- a/server/utils/agents/aibitat/providers/cometapi.js +++ b/server/utils/agents/aibitat/providers/cometapi.js @@ -31,11 +31,14 @@ class CometApiProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return false; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -50,54 +53,31 @@ class CometApiProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog("Will assume chat completion without tool call inputs."); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/deepseek.js b/server/utils/agents/aibitat/providers/deepseek.js index ca74d58cf18..3f79f1ca5f4 100644 --- a/server/utils/agents/aibitat/providers/deepseek.js +++ b/server/utils/agents/aibitat/providers/deepseek.js @@ -28,11 +28,14 @@ class DeepSeekProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, max_tokens: this.maxTokens, }) @@ -48,60 +51,31 @@ class DeepSeekProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/dellProAiStudio.js b/server/utils/agents/aibitat/providers/dellProAiStudio.js index 07f86416272..a4618520510 100644 --- a/server/utils/agents/aibitat/providers/dellProAiStudio.js +++ b/server/utils/agents/aibitat/providers/dellProAiStudio.js @@ -33,6 +33,10 @@ class DellProAiStudioProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ @@ -51,60 +55,31 @@ class DellProAiStudioProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/fireworksai.js b/server/utils/agents/aibitat/providers/fireworksai.js index 7c40d6e73d4..caeb49f712c 100644 --- a/server/utils/agents/aibitat/providers/fireworksai.js +++ b/server/utils/agents/aibitat/providers/fireworksai.js @@ -29,11 +29,14 @@ class FireworksAIProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -48,60 +51,31 @@ class FireworksAIProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/gemini.js b/server/utils/agents/aibitat/providers/gemini.js index d624bd64f21..101d970c848 100644 --- a/server/utils/agents/aibitat/providers/gemini.js +++ b/server/utils/agents/aibitat/providers/gemini.js @@ -6,6 +6,8 @@ const { NO_SYSTEM_PROMPT_MODELS, } = require("../../../AiProviders/gemini/index.js"); const { APIError } = require("../error.js"); +const { v4 } = require("uuid"); +const { safeJsonParse } = require("../../../http"); /** * The agent provider for the Gemini provider. @@ -33,6 +35,19 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + // Tool call streaming results in a 400/503 error for all non-gemini models + // using the compatible v1beta/openai/ endpoint + if (!this.model.startsWith("gemini")) { + this.providerLog( + `Gemini: ${this.model} does not support tool call streaming.` + ); + return false; + } + + return true; + } + /** * Format the messages to the format required by the Gemini API since some models do not support system prompts. * @see {NO_SYSTEM_PROMPT_MODELS} @@ -61,11 +76,29 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) { return formattedMessages; } + /** + * Format the functions for the LLM. + * @param {any[]} functions - The functions to format. + * @returns {any[]} - The formatted functions. + */ + formatFunctions(functions = []) { + return functions.map((fn) => ({ + type: "function", + function: { + name: fn.name, + description: fn.description, + parameters: { + type: "object", + properties: fn.parameters.properties, + }, + }, + })); + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages: this.cleanMsgs(this.formatMessages(messages)), }) .then((result) => { @@ -80,9 +113,81 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) { }); } + /** + * Streaming for Gemini only supports `tools` and not `functions`, so + * we need to apply some transformations to the messages and functions. + * + * @see {formatFunctions} + * @param {*} messages + * @param {*} functions + * @param {*} eventHandler + * @returns + */ + async stream(messages, functions = [], eventHandler = null) { + const msgUUID = v4(); + const stream = await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages: this.cleanMsgs(this.formatMessages(messages)), + ...(Array.isArray(functions) && functions?.length > 0 + ? { + tools: this.formatFunctions(functions), + tool_choice: "auto", + } + : {}), + }); + + const result = { + functionCall: null, + textResponse: "", + }; + + for await (const chunk of stream) { + if (!chunk?.choices?.[0]) continue; // Skip if no choices + const choice = chunk.choices[0]; + + if (choice.delta?.content) { + result.textResponse += choice.delta.content; + eventHandler?.("reportStreamEvent", { + type: "textResponseChunk", + uuid: msgUUID, + content: choice.delta.content, + }); + } + + if (choice.delta?.tool_calls && choice.delta.tool_calls.length > 0) { + const toolCall = choice.delta.tool_calls[0]; + if (result.functionCall) + result.functionCall.arguments += toolCall.function.arguments; + else { + result.functionCall = { + name: toolCall.function.name, + arguments: toolCall.function.arguments, + }; + } + + eventHandler?.("reportStreamEvent", { + uuid: `${msgUUID}:tool_call_invocation`, + type: "toolCallInvocation", + content: `Assembling Tool Call: ${result.functionCall.name}(${result.functionCall.arguments})`, + }); + } + } + + // If there are arguments, parse them as json so that the tools can use them + if (!!result.functionCall?.arguments) + result.functionCall.arguments = safeJsonParse( + result.functionCall.arguments, + {} + ); + return result; + } + /** * Create a completion based on the received messages. * + * TODO: see stream() - tool_calls are now supported, so we can use that instead of Untooled + * * @param messages A list of messages to send to the API. * @param functions * @returns The completion. @@ -129,7 +234,7 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) { // _but_ we should enable it to call previously used tools in a new chat interaction. this.deduplicator.reset("runs"); return { - result: completion.content, + textResponse: completion.content, cost: 0, }; } catch (error) { diff --git a/server/utils/agents/aibitat/providers/genericOpenAi.js b/server/utils/agents/aibitat/providers/genericOpenAi.js index c067584994f..c1e0370f2bb 100644 --- a/server/utils/agents/aibitat/providers/genericOpenAi.js +++ b/server/utils/agents/aibitat/providers/genericOpenAi.js @@ -38,6 +38,12 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + // Honor streaming being disabled via ENV via user preference. + if (process.env.GENERIC_OPENAI_STREAMING_DISABLED === "true") return false; + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ @@ -58,60 +64,31 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/groq.js b/server/utils/agents/aibitat/providers/groq.js index f66dc27d4af..1ab5479e1fe 100644 --- a/server/utils/agents/aibitat/providers/groq.js +++ b/server/utils/agents/aibitat/providers/groq.js @@ -28,11 +28,14 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -47,60 +50,31 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/helpers/untooled.js b/server/utils/agents/aibitat/providers/helpers/untooled.js index aff7fa70353..2da3e4f34ff 100644 --- a/server/utils/agents/aibitat/providers/helpers/untooled.js +++ b/server/utils/agents/aibitat/providers/helpers/untooled.js @@ -1,5 +1,6 @@ const { safeJsonParse } = require("../../../../http"); const { Deduplicator } = require("../../utils/dedupe"); +const { v4 } = require("uuid"); // Useful inheritance class for a model which supports OpenAi schema for API requests // but does not have tool-calling or JSON output support. @@ -93,15 +94,10 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`; return { valid: true, reason: null }; } - async functionCall(messages, functions, chatCb = null) { - const history = [...messages].filter((msg) => - ["user", "assistant"].includes(msg.role) - ); - if (history[history.length - 1].role !== "user") return null; - const response = await chatCb({ - messages: [ - { - content: `You are a program which picks the most optimal function and parameters to call. + buildToolCallMessages(history = [], functions = []) { + return [ + { + content: `You are a program which picks the most optimal function and parameters to call. DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY. When a function is selection, respond in JSON with no additional text. When there is no relevant function to call - return with a regular chat text response. @@ -116,11 +112,20 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`; ${this.showcaseFunctions(functions)} Now pick a function if there is an appropriate one to use given the last user message and the given conversation so far.`, - role: "system", - }, - ...history, - ], - }); + role: "system", + }, + ...history, + ]; + } + + async functionCall(messages, functions, chatCb = null) { + const history = [...messages].filter((msg) => + ["user", "assistant"].includes(msg.role) + ); + if (history[history.length - 1].role !== "user") return null; + const historyMessages = this.buildToolCallMessages(history, functions); + const response = await chatCb({ messages: historyMessages }); + const call = safeJsonParse(response, null); if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text. @@ -139,6 +144,253 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`; return { toolCall: call, text: null }; } + + async streamingFunctionCall( + messages, + functions, + chatCb = null, + eventHandler = null + ) { + const history = [...messages].filter((msg) => + ["user", "assistant"].includes(msg.role) + ); + if (history[history.length - 1].role !== "user") return null; + + const msgUUID = v4(); + let textResponse = ""; + const historyMessages = this.buildToolCallMessages(history, functions); + const stream = await chatCb({ messages: historyMessages }); + + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: v4(), + content: "Agent is thinking...", + }); + + for await (const chunk of stream) { + if (!chunk?.choices?.[0]) continue; // Skip if no choices + const choice = chunk.choices[0]; + + if (choice.delta?.content) { + textResponse += choice.delta.content; + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: msgUUID, + content: choice.delta.content, + }); + } + } + + const call = safeJsonParse(textResponse, null); + if (call === null) + return { toolCall: null, text: textResponse, uuid: msgUUID }; // failed to parse, so must be regular text response. + + const { valid, reason } = this.validFuncCall(call, functions); + if (!valid) { + this.providerLog(`Invalid function tool call: ${reason}.`); + eventHandler?.("reportStreamEvent", { + type: "removeStatusResponse", + uuid: msgUUID, + content: + "The model attempted to make an invalid function call - it was ignored.", + }); + return { toolCall: null, text: null, uuid: msgUUID }; + } + + if (this.deduplicator.isDuplicate(call.name, call.arguments)) { + this.providerLog( + `Function tool with exact arguments has already been called this stack.` + ); + eventHandler?.("reportStreamEvent", { + type: "removeStatusResponse", + uuid: msgUUID, + content: + "The model tried to call a function with the same arguments as a previous call - it was ignored.", + }); + return { toolCall: null, text: null, uuid: msgUUID }; + } + + eventHandler?.("reportStreamEvent", { + uuid: `${msgUUID}:tool_call_invocation`, + type: "toolCallInvocation", + content: `Parsed Tool Call: ${call.name}(${JSON.stringify(call.arguments)})`, + }); + return { toolCall: call, text: null, uuid: msgUUID }; + } + + /** + * Stream a chat completion from the LLM with tool calling + * Note: This using the OpenAI API format and may need to be adapted for other providers. + * + * @param {any[]} messages - The messages to send to the LLM. + * @param {any[]} functions - The functions to use in the LLM. + * @param {function} chatCallback - A callback function to handle the chat completion. + * @param {function} eventHandler - The event handler to use to report stream events. + * @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion. + */ + async stream( + messages, + functions = [], + chatCallback = null, + eventHandler = null + ) { + this.providerLog("Untooled.stream - will process this chat completion."); + try { + let completion = { content: "" }; + if (functions.length > 0) { + const { + toolCall, + text, + uuid: msgUUID, + } = await this.streamingFunctionCall( + messages, + functions, + chatCallback, + eventHandler + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + + if (text) { + this.providerLog( + `No tool call found in the response - will send as a full text response.` + ); + completion.content = text; + eventHandler?.("reportStreamEvent", { + type: "removeStatusResponse", + uuid: msgUUID, + content: "No tool call found in the response", + }); + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: v4(), + content: "Done thinking.", + }); + eventHandler?.("reportStreamEvent", { + type: "fullTextResponse", + uuid: v4(), + content: text, + }); + } + } + + if (!completion?.content) { + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: v4(), + content: "Done thinking.", + }); + + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const msgUUID = v4(); + completion = { content: "" }; + const stream = await chatCallback({ + messages: this.cleanMsgs(messages), + }); + + for await (const chunk of stream) { + if (!chunk?.choices?.[0]) continue; // Skip if no choices + const choice = chunk.choices[0]; + if (choice.delta?.content) { + completion.content += choice.delta.content; + eventHandler?.("reportStreamEvent", { + type: "textResponseChunk", + uuid: msgUUID, + content: choice.delta.content, + }); + } + } + } + + // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent + // from calling the exact same function over and over in a loop within a single chat exchange + // _but_ we should enable it to call previously used tools in a new chat interaction. + this.deduplicator.reset("runs"); + return { + textResponse: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + /** + * Create a completion based on the received messages. + * + * @param messages A list of messages to send to the API. + * @param functions + * @param chatCallback - A callback function to handle the chat completion. + * @returns The completion. + */ + async complete(messages, functions = [], chatCallback = null) { + this.providerLog("Untooled.complete - will process this chat completion."); + try { + let completion = { content: "" }; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + chatCallback + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + completion.content = text; + } + + // If there are no functions, we want to run a normal chat completion. + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await chatCallback({ + messages: this.cleanMsgs(messages), + }); + // If the response from the callback is the raw OpenAI Spec response object, we can use that directly. + // Otherwise, we will assume the response is just the string output we wanted (see: `#handleFunctionCallChat` which returns the content only) + // This handles both streaming and non-streaming completions. + completion = + typeof response === "string" + ? { content: response } + : response.choices?.[0]?.message; + } + + // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent + // from calling the exact same function over and over in a loop within a single chat exchange + // _but_ we should enable it to call previously used tools in a new chat interaction. + this.deduplicator.reset("runs"); + return { + textResponse: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } } module.exports = UnTooled; diff --git a/server/utils/agents/aibitat/providers/koboldcpp.js b/server/utils/agents/aibitat/providers/koboldcpp.js index 34eafe0e909..e92b0861883 100644 --- a/server/utils/agents/aibitat/providers/koboldcpp.js +++ b/server/utils/agents/aibitat/providers/koboldcpp.js @@ -27,11 +27,14 @@ class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -46,60 +49,31 @@ class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/litellm.js b/server/utils/agents/aibitat/providers/litellm.js index 9cc8a4c3b11..9f93dbc0b86 100644 --- a/server/utils/agents/aibitat/providers/litellm.js +++ b/server/utils/agents/aibitat/providers/litellm.js @@ -27,11 +27,14 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -46,60 +49,31 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } getCost(_usage) { diff --git a/server/utils/agents/aibitat/providers/lmstudio.js b/server/utils/agents/aibitat/providers/lmstudio.js index f783ade9a33..fdaedf222ed 100644 --- a/server/utils/agents/aibitat/providers/lmstudio.js +++ b/server/utils/agents/aibitat/providers/lmstudio.js @@ -35,11 +35,14 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -54,60 +57,31 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/localai.js b/server/utils/agents/aibitat/providers/localai.js index 38c41778260..200631c7058 100644 --- a/server/utils/agents/aibitat/providers/localai.js +++ b/server/utils/agents/aibitat/providers/localai.js @@ -27,11 +27,14 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -48,59 +51,31 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { result: completion.content, cost: 0 }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/mistral.js b/server/utils/agents/aibitat/providers/mistral.js index a2595662fbc..4ed15a8f3a0 100644 --- a/server/utils/agents/aibitat/providers/mistral.js +++ b/server/utils/agents/aibitat/providers/mistral.js @@ -31,11 +31,14 @@ class MistralProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -50,60 +53,31 @@ class MistralProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/moonshotAi.js b/server/utils/agents/aibitat/providers/moonshotAi.js index b6bb3bebd03..8fac2310456 100644 --- a/server/utils/agents/aibitat/providers/moonshotAi.js +++ b/server/utils/agents/aibitat/providers/moonshotAi.js @@ -31,11 +31,14 @@ class MoonshotAiProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -50,53 +53,35 @@ class MoonshotAiProvider extends InheritMultiple([Provider, UnTooled]) { }); } - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + getCost(_usage) { + return 0; } } diff --git a/server/utils/agents/aibitat/providers/novita.js b/server/utils/agents/aibitat/providers/novita.js index a15a6eb9d11..7c9e0673372 100644 --- a/server/utils/agents/aibitat/providers/novita.js +++ b/server/utils/agents/aibitat/providers/novita.js @@ -31,11 +31,14 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -50,54 +53,31 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog("Will assume chat completion without tool call inputs."); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/nvidiaNim.js b/server/utils/agents/aibitat/providers/nvidiaNim.js index 529d21cd75f..94e3d2abd5d 100644 --- a/server/utils/agents/aibitat/providers/nvidiaNim.js +++ b/server/utils/agents/aibitat/providers/nvidiaNim.js @@ -2,6 +2,7 @@ const OpenAI = require("openai"); const Provider = require("./ai-provider.js"); const InheritMultiple = require("./helpers/classes.js"); const UnTooled = require("./helpers/untooled.js"); +const { parseNvidiaNimBasePath } = require("../../../AiProviders/nvidiaNim"); /** * The agent provider for the Nvidia NIM provider. @@ -14,7 +15,7 @@ class NvidiaNimProvider extends InheritMultiple([Provider, UnTooled]) { const { model } = config; super(); const client = new OpenAI({ - baseURL: process.env.NVIDIA_NIM_LLM_BASE_PATH, + baseURL: parseNvidiaNimBasePath(process.env.NVIDIA_NIM_LLM_BASE_PATH), apiKey: null, maxRetries: 0, }); @@ -28,11 +29,14 @@ class NvidiaNimProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -47,60 +51,31 @@ class NvidiaNimProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/ollama.js b/server/utils/agents/aibitat/providers/ollama.js index 9080cd77012..de2506197f9 100644 --- a/server/utils/agents/aibitat/providers/ollama.js +++ b/server/utils/agents/aibitat/providers/ollama.js @@ -2,6 +2,8 @@ const Provider = require("./ai-provider.js"); const InheritMultiple = require("./helpers/classes.js"); const UnTooled = require("./helpers/untooled.js"); const { Ollama } = require("ollama"); +const { v4 } = require("uuid"); +const { safeJsonParse } = require("../../../http"); /** * The agent provider for the Ollama provider. @@ -31,17 +33,215 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + + /** + * Handle a chat completion with tool calling + * + * @param messages + * @returns {Promise} The completion. + */ async #handleFunctionCallChat({ messages = [] }) { const response = await this.client.chat({ model: this.model, messages, - options: { - temperature: 0, - }, }); return response?.message?.content || null; } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat({ + model: this.model, + messages, + stream: true, + }); + } + + async streamingFunctionCall( + messages, + functions, + chatCb = null, + eventHandler = null + ) { + const history = [...messages].filter((msg) => + ["user", "assistant"].includes(msg.role) + ); + if (history[history.length - 1].role !== "user") return null; + + const msgUUID = v4(); + let textResponse = ""; + const historyMessages = this.buildToolCallMessages(history, functions); + const stream = await chatCb({ messages: historyMessages }); + + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: v4(), + content: "Agent is thinking...", + }); + + for await (const chunk of stream) { + if ( + !chunk.hasOwnProperty("message") || + !chunk.message.hasOwnProperty("content") + ) + continue; + + textResponse += chunk.message.content; + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: msgUUID, + content: chunk.message.content, + }); + } + + const call = safeJsonParse(textResponse, null); + if (call === null) + return { toolCall: null, text: textResponse, uuid: msgUUID }; // failed to parse, so must be regular text response. + + const { valid, reason } = this.validFuncCall(call, functions); + if (!valid) { + this.providerLog(`Invalid function tool call: ${reason}.`); + eventHandler?.("reportStreamEvent", { + type: "removeStatusResponse", + uuid: msgUUID, + content: + "The model attempted to make an invalid function call - it was ignored.", + }); + return { toolCall: null, text: null, uuid: msgUUID }; + } + + if (this.deduplicator.isDuplicate(call.name, call.arguments)) { + this.providerLog( + `Function tool with exact arguments has already been called this stack.` + ); + eventHandler?.("reportStreamEvent", { + type: "removeStatusResponse", + uuid: msgUUID, + content: + "The model tried to call a function with the same arguments as a previous call - it was ignored.", + }); + return { toolCall: null, text: null, uuid: msgUUID }; + } + + eventHandler?.("reportStreamEvent", { + uuid: `${msgUUID}:tool_call_invocation`, + type: "toolCallInvocation", + content: `Parsed Tool Call: ${call.name}(${JSON.stringify(call.arguments)})`, + }); + return { toolCall: call, text: null, uuid: msgUUID }; + } + + /** + * Stream a chat completion from the LLM with tool calling + * This is overriding the inherited `stream` method since Ollamas + * SDK has different response structures to other OpenAI. + * + * @param messages A list of messages to send to the API. + * @param functions + * @param eventHandler + * @returns The completion. + */ + async stream(messages, functions = [], eventHandler = null) { + this.providerLog( + "OllamaProvider.complete - will process this chat completion." + ); + try { + let completion = { content: "" }; + if (functions.length > 0) { + const { + toolCall, + text, + uuid: msgUUID, + } = await this.streamingFunctionCall( + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + + if (text) { + this.providerLog( + `No tool call found in the response - will send as a full text response.` + ); + completion.content = text; + eventHandler?.("reportStreamEvent", { + type: "removeStatusResponse", + uuid: msgUUID, + content: "No tool call found in the response", + }); + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: v4(), + content: "Done thinking.", + }); + eventHandler?.("reportStreamEvent", { + type: "fullTextResponse", + uuid: v4(), + content: text, + }); + } + } + + if (!completion?.content) { + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: v4(), + content: "Done thinking.", + }); + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const msgUUID = v4(); + completion = { content: "" }; + const stream = await this.#handleFunctionCallStream({ + messages: this.cleanMsgs(messages), + }); + + for await (const chunk of stream) { + if ( + !chunk.hasOwnProperty("message") || + !chunk.message.hasOwnProperty("content") + ) + continue; + + const delta = chunk.message.content; + completion.content += delta; + eventHandler?.("reportStreamEvent", { + type: "textResponseChunk", + uuid: msgUUID, + content: delta, + }); + } + } + + // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent + // from calling the exact same function over and over in a loop within a single chat exchange + // _but_ we should enable it to call previously used tools in a new chat interaction. + this.deduplicator.reset("runs"); + return { + textResponse: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + /** * Create a completion based on the received messages. * @@ -50,8 +250,11 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) { * @returns The completion. */ async complete(messages, functions = []) { + this.providerLog( + "OllamaProvider.complete - will process this chat completion." + ); try { - let completion; + let completion = { content: "" }; if (functions.length > 0) { const { toolCall, text } = await this.functionCall( messages, @@ -71,22 +274,17 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) { cost: 0, }; } - completion = { content: text }; + completion.content = text; } if (!completion?.content) { this.providerLog( "Will assume chat completion without tool call inputs." ); - const response = await this.client.chat({ - model: this.model, + const textResponse = await this.#handleFunctionCallChat({ messages: this.cleanMsgs(messages), - options: { - use_mlock: true, - temperature: 0.5, - }, }); - completion = response.message; + completion.content = textResponse; } // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent @@ -94,7 +292,7 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) { // _but_ we should enable it to call previously used tools in a new chat interaction. this.deduplicator.reset("runs"); return { - result: completion.content, + textResponse: completion.content, cost: 0, }; } catch (error) { diff --git a/server/utils/agents/aibitat/providers/openai.js b/server/utils/agents/aibitat/providers/openai.js index 73976bcb48e..ac31d2bf69e 100644 --- a/server/utils/agents/aibitat/providers/openai.js +++ b/server/utils/agents/aibitat/providers/openai.js @@ -8,37 +8,6 @@ const { RetryError } = require("../error.js"); */ class OpenAIProvider extends Provider { model; - static COST_PER_TOKEN = { - "gpt-3.5-turbo": { - input: 0.0015, - output: 0.002, - }, - "gpt-3.5-turbo-16k": { - input: 0.003, - output: 0.004, - }, - "gpt-4": { - input: 0.03, - output: 0.06, - }, - "gpt-4-turbo": { - input: 0.01, - output: 0.03, - }, - "gpt-4o": { - input: 0.005, - output: 0.015, - }, - "gpt-4-32k": { - input: 0.06, - output: 0.12, - }, - "gpt-4o-mini": { - input: 0.00015, - output: 0.0006, - }, - }; - constructor(config = {}) { const { options = { @@ -55,6 +24,10 @@ class OpenAIProvider extends Provider { this.model = model; } + get supportsAgentStreaming() { + return true; + } + /** * Create a completion based on the received messages. * @@ -66,7 +39,7 @@ class OpenAIProvider extends Provider { try { const response = await this.client.chat.completions.create({ model: this.model, - // stream: true, + stream: false, messages, ...(Array.isArray(functions) && functions?.length > 0 ? { functions } @@ -98,9 +71,8 @@ class OpenAIProvider extends Provider { ); } - // console.log(completion, { functionArgs }) return { - result: null, + textResponse: null, functionCall: { name: completion.function_call.name, arguments: functionArgs, @@ -110,7 +82,7 @@ class OpenAIProvider extends Provider { } return { - result: completion.content, + textResponse: completion.content, cost, }; } catch (error) { @@ -133,26 +105,11 @@ class OpenAIProvider extends Provider { /** * Get the cost of the completion. * - * @param usage The completion to get the cost for. + * @param _usage The completion to get the cost for. * @returns The cost of the completion. */ - getCost(usage) { - if (!usage) { - return Number.NaN; - } - - // regex to remove the version number from the model - const modelBase = this.model.replace(/-(\d{4})$/, ""); - - if (!(modelBase in OpenAIProvider.COST_PER_TOKEN)) { - return Number.NaN; - } - - const costPerToken = OpenAIProvider.COST_PER_TOKEN?.[modelBase]; - const inputCost = (usage.prompt_tokens / 1000) * costPerToken.input; - const outputCost = (usage.completion_tokens / 1000) * costPerToken.output; - - return inputCost + outputCost; + getCost() { + return 0; } } diff --git a/server/utils/agents/aibitat/providers/openrouter.js b/server/utils/agents/aibitat/providers/openrouter.js index c7d8dc7304d..80aa2a2fe4d 100644 --- a/server/utils/agents/aibitat/providers/openrouter.js +++ b/server/utils/agents/aibitat/providers/openrouter.js @@ -31,11 +31,14 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -50,60 +53,31 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/perplexity.js b/server/utils/agents/aibitat/providers/perplexity.js index 03fec91e460..eef4f8f5944 100644 --- a/server/utils/agents/aibitat/providers/perplexity.js +++ b/server/utils/agents/aibitat/providers/perplexity.js @@ -27,11 +27,14 @@ class PerplexityProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -46,60 +49,31 @@ class PerplexityProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/ppio.js b/server/utils/agents/aibitat/providers/ppio.js index 2baf895b2d2..404bf9fd71a 100644 --- a/server/utils/agents/aibitat/providers/ppio.js +++ b/server/utils/agents/aibitat/providers/ppio.js @@ -31,11 +31,14 @@ class PPIOProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return false; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -50,54 +53,31 @@ class PPIOProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = null) { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog("Will assume chat completion without tool call inputs."); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/textgenwebui.js b/server/utils/agents/aibitat/providers/textgenwebui.js index b55560b7597..c4d8b91cc38 100644 --- a/server/utils/agents/aibitat/providers/textgenwebui.js +++ b/server/utils/agents/aibitat/providers/textgenwebui.js @@ -18,7 +18,7 @@ class TextWebGenUiProvider extends InheritMultiple([Provider, UnTooled]) { }); this._client = client; - this.model = null; // text-web-gen-ui does not have a model pref. + this.model = "text-generation-webui"; // text-web-gen-ui does not have a model pref, but we need a placeholder this.verbose = true; } @@ -26,11 +26,14 @@ class TextWebGenUiProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -45,60 +48,31 @@ class TextWebGenUiProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/togetherai.js b/server/utils/agents/aibitat/providers/togetherai.js index efad3922d2d..519d2e21006 100644 --- a/server/utils/agents/aibitat/providers/togetherai.js +++ b/server/utils/agents/aibitat/providers/togetherai.js @@ -27,11 +27,14 @@ class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -46,60 +49,31 @@ class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/aibitat/providers/xai.js b/server/utils/agents/aibitat/providers/xai.js index 683e6aa4293..9b4632189f7 100644 --- a/server/utils/agents/aibitat/providers/xai.js +++ b/server/utils/agents/aibitat/providers/xai.js @@ -27,11 +27,14 @@ class XAIProvider extends InheritMultiple([Provider, UnTooled]) { return this._client; } + get supportsAgentStreaming() { + return true; + } + async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, - temperature: 0, messages, }) .then((result) => { @@ -46,60 +49,31 @@ class XAIProvider extends InheritMultiple([Provider, UnTooled]) { }); } - /** - * Create a completion based on the received messages. - * - * @param messages A list of messages to send to the API. - * @param functions - * @returns The completion. - */ - async complete(messages, functions = []) { - try { - let completion; - if (functions.length > 0) { - const { toolCall, text } = await this.functionCall( - messages, - functions, - this.#handleFunctionCallChat.bind(this) - ); - - if (toolCall !== null) { - this.providerLog(`Valid tool call found - running ${toolCall.name}.`); - this.deduplicator.trackRun(toolCall.name, toolCall.arguments); - return { - result: null, - functionCall: { - name: toolCall.name, - arguments: toolCall.arguments, - }, - cost: 0, - }; - } - completion = { content: text }; - } + async #handleFunctionCallStream({ messages = [] }) { + return await this.client.chat.completions.create({ + model: this.model, + stream: true, + messages, + }); + } - if (!completion?.content) { - this.providerLog( - "Will assume chat completion without tool call inputs." - ); - const response = await this.client.chat.completions.create({ - model: this.model, - messages: this.cleanMsgs(messages), - }); - completion = response.choices[0].message; - } + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } - // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent - // from calling the exact same function over and over in a loop within a single chat exchange - // _but_ we should enable it to call previously used tools in a new chat interaction. - this.deduplicator.reset("runs"); - return { - result: completion.content, - cost: 0, - }; - } catch (error) { - throw error; - } + async complete(messages, functions = []) { + return await UnTooled.prototype.complete.call( + this, + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); } /** diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index 46581d3c5ce..377b3c590fb 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -254,7 +254,7 @@ class AgentHandler { case "perplexity": return process.env.PERPLEXITY_MODEL_PREF ?? "sonar-small-online"; case "textgenwebui": - return null; + return "text-generation-webui"; case "bedrock": return process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE ?? null; case "fireworksai":