From 74a4ab49eded4a6c74c461ee10d7104b88603268 Mon Sep 17 00:00:00 2001 From: "Nadheesh Jihan, nadheesh@wso2.com" Date: Tue, 23 Jan 2024 21:49:47 +0530 Subject: [PATCH 1/3] Improve agent module api --- ballerina/Ballerina.toml | 2 +- ballerina/agent.bal | 169 ++++++++++++++++------------ ballerina/error.bal | 9 +- ballerina/function_call.bal | 46 ++++---- ballerina/llm.bal | 107 ++++++++++-------- ballerina/openapi_utils.bal | 2 +- ballerina/react.bal | 66 +++++++---- ballerina/tests/agent-test.bal | 62 +++------- ballerina/tests/tool-test.bal | 6 +- ballerina/tool.bal | 37 +----- examples/function-as-tool/main.bal | 2 +- examples/http-toolkit/main.bal | 2 +- examples/load-from-openapi/main.bal | 2 +- examples/mock-tools/main.bal | 2 +- examples/multi-type-tools/main.bal | 1 - examples/setup/Ballerina.toml | 2 +- 16 files changed, 260 insertions(+), 257 deletions(-) diff --git a/ballerina/Ballerina.toml b/ballerina/Ballerina.toml index 7d362d2..5aad396 100644 --- a/ballerina/Ballerina.toml +++ b/ballerina/Ballerina.toml @@ -2,7 +2,7 @@ distribution = "2201.8.4" org = "ballerinax" name = "ai.agent" -version = "0.7.0" +version = "0.7.1" license = ["Apache-2.0"] authors = ["Ballerina"] keywords = ["AI/Agent", "Cost/Freemium"] diff --git a/ballerina/agent.bal b/ballerina/agent.bal index 7572e20..392be7c 100644 --- a/ballerina/agent.bal +++ b/ballerina/agent.bal @@ -23,23 +23,30 @@ public type ExecutionProgress record {| # Execution history up to the current action ExecutionStep[] history = []; # Contextual instruction that can be used by the agent during the execution - map|string context?; + map|string? context = (); |}; # Execution step information public type ExecutionStep record {| + # Response generated by the LLM + json llmResponse; + # Observations produced by the tool during the execution + anydata|error observation; +|}; + +# Execution step information +public type ExecutionResult record {| # Tool decided by the LLM during the reasoning - LlmToolResponse toolResponse; + LlmToolResponse tool; # Observations produced by the tool during the execution anydata|error observation; |}; -# An LLM response containing the selected tool to be executed -public type LlmToolResponse record {| - # Next tool to be executed - SelectedTool|LlmInvalidGenerationError tool; - # Raw LLM generated output +public type ExecutionError record {| + # Response generated by the LLM json llmResponse; + # Error caused during the execution + LlmInvalidGenerationError|ToolExecutionError 'error; |}; # An chat response by the LLM @@ -49,7 +56,7 @@ public type LlmChatResponse record {| |}; # Tool selected by LLM to be performed by the agent -public type SelectedTool record {| +public type LlmToolResponse record {| # Name of the tool to selected string name; # Input to the tool @@ -70,13 +77,19 @@ public type BaseAgent distinct isolated object { # # + progress - QueryProgress with the current query and execution history # + return - NextAction decided by the LLM or an error if call to the LLM fails - isolated function selectNextTool(ExecutionProgress progress) returns LlmToolResponse|LlmChatResponse|LlmError; + isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError; + + # Parse the llm response and extract the tool to be executed. + # + # + llmResponse - Raw LLM response + # + return - SelectedTool or an error if parsing fails + isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError; }; # An iterator to iterate over agent's execution -public class AgentIterator { +public class Iterator { *object:Iterable; - private final AgentExecutor executor; + private final Executor executor; # Initialize the iterator with the agent and the query. # @@ -84,20 +97,20 @@ public class AgentIterator { # + query - Natural language query to be executed by the agent # + context - Contextual information to be used by the agent during the execution public isolated function init(BaseAgent agent, string query, map|string? context = ()) { - self.executor = new (agent, query, context = context); + self.executor = new (agent, query = query, context = context); } # Iterate over the agent's execution steps. # + return - a record with the execution step or an error if the agent failed public function iterator() returns object { - public function next() returns record {|ExecutionStep|LlmChatResponse|error value;|}?; + public function next() returns record {|ExecutionResult|LlmChatResponse|ExecutionError|error value;|}?; } { return self.executor; } } # An executor to perform step-by-step execution of the agent. -public class AgentExecutor { +public class Executor { private boolean isCompleted = false; private final BaseAgent agent; # Contains the current execution progress for the agent and the query @@ -109,13 +122,9 @@ public class AgentExecutor { # + query - Natural language query to be executed by the agent # + history - Execution history of the agent (This is used to continue an execution paused without completing) # + context - Contextual information to be used by the agent during the execution - public isolated function init(BaseAgent agent, string query, ExecutionStep[] history = [], map|string? context = ()) { + public isolated function init(BaseAgent agent, *ExecutionProgress progress) { self.agent = agent; - self.progress = { - query, - history, - context - }; + self.progress = progress; } # Checks whether agent has more steps to execute. @@ -127,47 +136,61 @@ public class AgentExecutor { # Reason the next step of the agent. # - # + return - Llm tool response, chat response or an error if reasoning has failed - public isolated function reason() returns LlmToolResponse|LlmChatResponse|TaskCompletedError|LlmError { + # + return - generated LLM response during the reasoning or an error if the reasoning fails + public isolated function reason() returns json|TaskCompletedError|LlmError { if self.isCompleted { return error TaskCompletedError("Task is already completed. No more reasoning is needed."); } - LlmToolResponse|LlmChatResponse response = check self.agent.selectNextTool(self.progress); - if response is LlmChatResponse { - self.isCompleted = true; - } - return response; + return check self.agent.selectNextTool(self.progress); } # Execute the next step of the agent. # - # + toolResponse - LLM tool response containing the tool to be executed and the raw LLM output + # + llmResponse - LLM response containing the tool to be executed and the raw LLM output # + return - Observations from the tool can be any|error|null - public isolated function act(LlmToolResponse toolResponse) returns ToolOutput { - ToolOutput observation; - SelectedTool|LlmInvalidGenerationError tool = toolResponse.tool; + public isolated function act(json llmResponse) returns ExecutionResult|LlmChatResponse|ExecutionError { + LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError parseLlmResponse = self.agent.parseLlmResponse(llmResponse); + if parseLlmResponse is LlmChatResponse { + self.isCompleted = true; + return parseLlmResponse; + } - // TODO Improve to use intructions from the error instead of generic error instructions - if tool is SelectedTool { - ToolOutput|error output = self.agent.toolStore.execute(tool); - if output is ToolNotFoundError { - observation = {value: "Tool is not found. Please check the tool name and retry."}; - } else if output is ToolInvalidInputError { - observation = {value: "Tool execution failed due to invalid inputs. Retry with correct inputs."}; - } else if output is error { - observation = {value: "Tool execution failed. Retry with correct inputs."}; + anydata observation; + ExecutionResult|ExecutionError executionResult; + if parseLlmResponse is LlmToolResponse { + ToolOutput|ToolExecutionError|LlmInvalidGenerationError output = self.agent.toolStore.execute(parseLlmResponse); + if output is error { + if output is ToolNotFoundError { + observation = "Tool is not found. Please check the tool name and retry."; + } else if output is ToolInvalidInputError { + observation = "Tool execution failed due to invalid inputs. Retry with correct inputs."; + } else { + observation = "Tool execution failed. Retry with correct inputs."; + } + executionResult = { + llmResponse, + 'error: output + }; } else { - observation = output; + anydata|error value = output.value; + observation = value is error ? value.toString() : value; + executionResult = { + tool: parseLlmResponse, + observation: value + }; } } else { - observation = {value: "Tool extraction failed due to invalid JSON_BLOB. Retry with correct JSON_BLOB."}; + observation = "Tool extraction failed due to invalid JSON_BLOB. Retry with correct JSON_BLOB."; + executionResult = { + llmResponse, + 'error: parseLlmResponse + }; } - // update the execution history with the latest step self.update({ - toolResponse, - observation: observation.value + llmResponse, + observation }); - return observation; + return executionResult; } # Update the agent with an execution step. @@ -179,25 +202,16 @@ public class AgentExecutor { # Reason and execute the next step of the agent. # - # + return - A record with ExecutionStep, chat response or an error - public isolated function next() returns record {|ExecutionStep|LlmChatResponse|error value;|}? { - LlmToolResponse|LlmChatResponse|error toolResponse = self.reason(); - if toolResponse is LlmChatResponse|error { - return {value: toolResponse}; + # + return - A record with ExecutionResult, chat response or an error + public isolated function next() returns record {|ExecutionResult|LlmChatResponse|ExecutionError|error value;|}? { + if self.isCompleted { + return (); } - return { - value: { - toolResponse, - observation: self.act(toolResponse).value - } - }; - } - - # Allow retrieving the execution history during previous steps. - # - # + return - Execution history of the agent (A list of ExecutionStep) - public isolated function getExecutionHistory() returns ExecutionStep[] { - return self.progress.history; + json|TaskCompletedError|LlmError llmResponse = self.reason(); + if llmResponse is error { + return {value: llmResponse}; + } + return {value: self.act(llmResponse)}; } } @@ -209,12 +223,12 @@ public class AgentExecutor { # + context - Context values to be used by the agent to execute the task # + verbose - If true, then print the reasoning steps (default: true) # + return - Returns the execution steps tracing the agent's reasoning and outputs from the tools -public isolated function run(BaseAgent agent, string query, int maxIter = 5, string|map context = {}, boolean verbose = true) returns record {|ExecutionStep[] steps; string answer?;|} { - ExecutionStep[] steps = []; +public isolated function run(BaseAgent agent, string query, int maxIter = 5, string|map context = {}, boolean verbose = true) returns record {|(ExecutionResult|ExecutionError)[] steps; string answer?;|} { + (ExecutionResult|ExecutionError)[] steps = []; string? content = (); - AgentIterator iterator = new (agent, query, context = context); + Iterator iterator = new (agent, query, context = context); int iter = 0; - foreach ExecutionStep|LlmChatResponse|error step in iterator { + foreach ExecutionResult|LlmChatResponse|ExecutionError|error step in iterator { if iter == maxIter { break; } @@ -233,8 +247,8 @@ public isolated function run(BaseAgent agent, string query, int maxIter = 5, str iter += 1; if verbose { io:println(string `${"\n\n"}Agent Iteration ${iter.toString()}`); - SelectedTool|LlmInvalidGenerationError tool = step.toolResponse.tool; - if tool is SelectedTool { + if step is ExecutionResult { + LlmToolResponse tool = step.tool; io:println(string `Action: ${BACKTICKS} { @@ -249,14 +263,13 @@ ${BACKTICKS}`); io:println(string `${OBSERVATION_KEY}: ${observation.toString()}`); } } else { - error? cause = tool.cause(); - string llmResponse = step.toolResponse.llmResponse.toString(); + error? cause = step.'error.cause(); io:println(string `LLM Generation Error: ${BACKTICKS} { - message: ${tool.message()}, + message: ${step.'error.message()}, cause: ${(cause is error ? cause.message() : "Unspecified")}, - llmResponse: ${llmResponse} + llmResponse: ${step.llmResponse.toString()} } ${BACKTICKS}`); } @@ -282,3 +295,11 @@ isolated function getObservationString(anydata|error observation) returns string return observation.toString().trim(); } } + +# Get the tools registered with the agent. +# +# + agent - Agent instance +# + return - Array of tools registered with the agent +public isolated function getTools(BaseAgent agent) returns AgentTool[] { + return agent.toolStore.tools.toArray(); +} diff --git a/ballerina/error.bal b/ballerina/error.bal index 91c6b52..14ec6a1 100644 --- a/ballerina/error.bal +++ b/ballerina/error.bal @@ -25,11 +25,14 @@ public type InvalidParameterDefinition distinct OpenApiParsingError; # Any error occurred during LLM generation is classified under this error type. public type LlmError distinct error; +# Errors occurred due to unexpected responses from the LLM. +public type LlmInvalidResponseError distinct LlmError; + # Errors occurred due to invalid LLM generation. public type LlmInvalidGenerationError distinct LlmError; # Errors occurred during LLM generation due to connection. -type LlmConnectionError distinct LlmError; +public type LlmConnectionError distinct LlmError; # Errors occurred due to termination of the Agent's execution. public type TaskCompletedError distinct error; @@ -41,10 +44,10 @@ public type HttpServiceToolKitError distinct error; public type HttpResponseParsingError distinct HttpServiceToolKitError; # Errors during tool execution. -type ToolExecutionError distinct error; +public type ToolExecutionError distinct error; # Error during unexpected output by the tool -type ToolInvaludOutputError distinct ToolExecutionError; +public type ToolInvaludOutputError distinct ToolExecutionError; # Errors occurred due to invalid tool name generated by the LLM. public type ToolNotFoundError distinct LlmInvalidGenerationError; diff --git a/ballerina/function_call.bal b/ballerina/function_call.bal index 96873b6..1b65a4b 100644 --- a/ballerina/function_call.bal +++ b/ballerina/function_call.bal @@ -30,38 +30,42 @@ public isolated class FunctionCallAgent { self.model = model; } - isolated function selectNextTool(ExecutionProgress progress) returns LlmToolResponse|LlmChatResponse|LlmError { - ChatMessage[] messages = createFunctionCallMessages(progress); - FunctionCall|string|error response = self.model.functionaCall(messages, self.toolStore.tools.toArray()); - if response is error { - return error LlmConnectionError("Error while function call generation", response); + isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError { + if llmResponse is string { + return {content: llmResponse}; } - if response is string { - return {content: response}; + if llmResponse !is FunctionCall { + return error LlmInvalidGenerationError("Invalid response", llmResponse = llmResponse); } - string? name = response.name; + string? name = llmResponse.name; if name is () { - return {tool: error LlmInvalidGenerationError("Missing name", name = response.name, arguments = response.arguments), llmResponse: response.toJson()}; + return error LlmInvalidGenerationError("Missing name", name = llmResponse.name, arguments = llmResponse.arguments); } - string? stringArgs = response.arguments; + string? stringArgs = llmResponse.arguments; map|error? arguments = (); if stringArgs is string { arguments = stringArgs.fromJsonStringWithType(); } if arguments is error { - return {tool: error LlmInvalidGenerationError("Invalid arguments", arguments, name = response.name, arguments = stringArgs), llmResponse: response.toJson()}; + return error LlmInvalidGenerationError("Invalid arguments", arguments, name = llmResponse.name, arguments = stringArgs); } return { - tool: { - name, - arguments - }, - llmResponse: { - name: name, - arguments: stringArgs - } + name, + arguments }; } + + isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError { + ChatMessage[] messages = createFunctionCallMessages(progress); + FunctionCall|string|LlmError functionaCall = self.model.functionaCall(messages, + from AgentTool tool in self.toolStore.tools.toArray() + select { + name: tool.name, + description: tool.description, + parameters: tool.variables + }); + return functionaCall; + } } isolated function createFunctionCallMessages(ExecutionProgress progress) returns ChatMessage[] { @@ -81,9 +85,9 @@ isolated function createFunctionCallMessages(ExecutionProgress progress) returns } // include the history foreach ExecutionStep step in progress.history { - FunctionCall|error functionCall = step.toolResponse.llmResponse.fromJsonWithType(); + FunctionCall|error functionCall = step.llmResponse.fromJsonWithType(); if functionCall is error { - panic error("Badly formated history for function call agent", generated = step.toolResponse.llmResponse); + panic error("Badly formated history for function call agent", llmResponse = step.llmResponse); } messages.push({ role: ASSISTANT, diff --git a/ballerina/llm.bal b/ballerina/llm.bal index 5ec7e6a..a41ed89 100644 --- a/ballerina/llm.bal +++ b/ballerina/llm.bal @@ -60,6 +60,16 @@ public type ChatMessage record {| FunctionCall function_call?; |}; +# Function definitions for function calling API. +public type ChatCompletionFunctions record {| + # Name of the function + string name; + # Description of the function + string description; + # Parameters of the function + JsonInputSchema parameters?; +|}; + # Function call record public type FunctionCall record {| # Name of the function @@ -77,21 +87,21 @@ public type LlmModel distinct isolated object { public type CompletionLlmModel distinct isolated object { *LlmModel; CompletionModelConfig modelConfig; - public isolated function complete(string prompt, string? stop = ()) returns string|error; + public isolated function complete(string prompt, string? stop = ()) returns string|LlmError; }; # Extendable LLM model object for chat LLM models public type ChatLlmModel distinct isolated object { *LlmModel; ChatModelConfig modelConfig; - public isolated function chatComplete(ChatMessage[] messages, string? stop = ()) returns string|error; + public isolated function chatComplete(ChatMessage[] messages, string? stop = ()) returns string|LlmError; }; # Extendable LLM model object for LLM models with function call API public type FunctionCallLlm distinct isolated object { *LlmModel; ChatModelConfig modelConfig; - public isolated function functionaCall(ChatMessage[] messages, AgentTool[] tools, string? stop = ()) returns FunctionCall|string|error; + public isolated function functionaCall(ChatMessage[] messages, ChatCompletionFunctions[] functions, string? stop = ()) returns FunctionCall|string|LlmError; }; public isolated class Gpt3Model { @@ -114,13 +124,16 @@ public isolated class Gpt3Model { # + prompt - Prompt to be completed # + stop - Stop sequence to stop the completion # + return - Completed prompt or error if the completion fails - public isolated function complete(string prompt, string? stop = ()) returns string|error { - text:CreateCompletionResponse response = check self.llmClient->/completions.post({ + public isolated function complete(string prompt, string? stop = ()) returns string|LlmError { + text:CreateCompletionResponse|error response = self.llmClient->/completions.post({ ...self.modelConfig, stop, prompt }); - return response.choices[0].text ?: error("Empty response from the model"); + if response is error { + return error LlmConnectionError("Error while connecting to the model", response); + } + return response.choices[0].text ?: error LlmInvalidResponseError("Empty response from the model"); } } @@ -152,13 +165,16 @@ public isolated class AzureGpt3Model { # + prompt - Prompt to be completed # + stop - Stop sequence to stop the completion # + return - Completed prompt or error if the completion fails - public isolated function complete(string prompt, string? stop = ()) returns string|error { - azure_text:Inline_response_200 response = check self.llmClient->/deployments/[self.deploymentId]/completions.post(self.apiVersion, { + public isolated function complete(string prompt, string? stop = ()) returns string|LlmError { + azure_text:Inline_response_200|error response = self.llmClient->/deployments/[self.deploymentId]/completions.post(self.apiVersion, { ...self.modelConfig, stop, prompt }); - return response.choices[0].text ?: error("Empty response from the model"); + if response is error { + return error LlmConnectionError("Error while connecting to the model", response); + } + return response.choices[0].text ?: error LlmInvalidResponseError("Empty response from the model"); } } @@ -183,38 +199,37 @@ public isolated class ChatGptModel { # + messages - Messages to be completed # + stop - Stop sequence to stop the completion # + return - Completed message or error if the completion fails - public isolated function chatComplete(ChatMessage[] messages, string? stop = ()) returns string|error { - chat:CreateChatCompletionResponse response = check self.llmClient->/chat/completions.post({ + public isolated function chatComplete(ChatMessage[] messages, string? stop = ()) returns string|LlmError { + chat:CreateChatCompletionResponse|error response = self.llmClient->/chat/completions.post({ ...self.modelConfig, stop, messages }); + if response is error { + return error LlmConnectionError("Error while connecting to the model", response); + } chat:ChatCompletionResponseMessage? message = response.choices[0].message; string? content = message?.content; - return content ?: error("Empty response from the model"); + return content ?: error LlmInvalidResponseError("Empty response from the model"); } # Uses function call API to determine next function to be called # # + messages - List of chat messages - # + tools - Tools to be used for the function call + # + functions - Function definitions to be used for the function call # + stop - Stop sequence to stop the completion - # + return - Next tool to be used or error if the function call fails - public isolated function functionaCall(ChatMessage[] messages, AgentTool[] tools, string? stop = ()) returns FunctionCall|string|error { + # + return - Function to be called, chat response or an error in-case of failures + public isolated function functionaCall(ChatMessage[] messages, ChatCompletionFunctions[] functions, string? stop = ()) returns FunctionCall|string|LlmError { - chat:CreateChatCompletionResponse response = check self.llmClient->/chat/completions.post( - { - ...self.modelConfig, - stop, - messages, - functions: from AgentTool tool in tools - select { - name: tool.name, - description: tool.description, - parameters: tool.variables - } - } - ); + chat:CreateChatCompletionResponse|error response = self.llmClient->/chat/completions.post({ + ...self.modelConfig, + stop, + messages, + functions + }); + if response is error { + return error LlmConnectionError("Error while connecting to the model", response); + } chat:ChatCompletionResponseMessage? message = response.choices[0].message; string? content = message?.content; if content is string { @@ -226,7 +241,7 @@ public isolated class ChatGptModel { ...'function }; } - return error LlmInvalidGenerationError("Empty response from the model when using function call API"); + return error LlmInvalidResponseError("Empty response from the model when using function call API"); } } @@ -259,37 +274,37 @@ public isolated class AzureChatGptModel { # + messages - Messages to be completed # + stop - Stop sequence to stop the completion # + return - Completed message or error if the completion fails - public isolated function chatComplete(ChatMessage[] messages, string? stop = ()) returns string|error { - azure_chat:CreateChatCompletionResponse response = check self.llmClient->/deployments/[self.deploymentId]/chat/completions.post(self.apiVersion, { + public isolated function chatComplete(ChatMessage[] messages, string? stop = ()) returns string|LlmError { + azure_chat:CreateChatCompletionResponse|error response = self.llmClient->/deployments/[self.deploymentId]/chat/completions.post(self.apiVersion, { ...self.modelConfig, stop, messages }); + if response is error { + return error LlmConnectionError("Error while connecting to the model", response); + } azure_chat:ChatCompletionResponseMessage? message = response.choices[0].message; string? content = message?.content; - return content ?: error("Empty response from the model"); + return content ?: error LlmInvalidResponseError("Empty response from the model"); } # Uses function call API to determine next function to be called # # + messages - List of chat messages - # + tools - Tools to be used for the function call + # + functions - Function definitions to be used for the function call # + stop - Stop sequence to stop the completion - # + return - Next tool to be used or error if the function call fails - public isolated function functionaCall(ChatMessage[] messages, AgentTool[] tools, string? stop = ()) returns FunctionCall|string|error { - azure_chat:CreateChatCompletionRequest request = { + # + return - Function to be called, chat response or an error in-case of failures + public isolated function functionaCall(ChatMessage[] messages, ChatCompletionFunctions[] functions, string? stop = ()) returns FunctionCall|string|LlmError { + azure_chat:CreateChatCompletionResponse|error response = + self.llmClient->/deployments/[self.deploymentId]/chat/completions.post(self.apiVersion, { ...self.modelConfig, stop, messages, - functions: from AgentTool tool in tools - select { - name: tool.name, - description: tool.description, - parameters: tool.variables - } - }; - azure_chat:CreateChatCompletionResponse response = - check self.llmClient->/deployments/[self.deploymentId]/chat/completions.post(self.apiVersion, request); + functions + }); + if response is error { + return error LlmConnectionError("Error while connecting to the model", response); + } azure_chat:ChatCompletionResponseMessage? message = response.choices[0].message; string? content = message?.content; if content is string { @@ -301,7 +316,7 @@ public isolated class AzureChatGptModel { ...'function }; } - return error LlmInvalidGenerationError("Empty response from the model when using function call API"); + return error LlmInvalidResponseError("Empty response from the model when using function call API"); } } diff --git a/ballerina/openapi_utils.bal b/ballerina/openapi_utils.bal index 4f07e5d..610ba08 100644 --- a/ballerina/openapi_utils.bal +++ b/ballerina/openapi_utils.bal @@ -174,7 +174,7 @@ class OpenApiSpecVisitor { } string? name = operation.operationId; if name is () { - return error(string `OperationId is mandotory for API paths. But, tt is missing for the resource "[${method}]:${path}"`); + return error(string `OperationId is mandotory for API paths. But, it is missing for the resource "[${method}]:${path}"`); } // resolve parameters diff --git a/ballerina/react.bal b/ballerina/react.bal index e02ba57..fc3b419 100644 --- a/ballerina/react.bal +++ b/ballerina/react.bal @@ -16,6 +16,11 @@ import ballerina/lang.regexp; import ballerina/log; +type ToolInfo readonly & record {| + string toolList; + string toolIntro; +|}; + # A ReAct Agent that uses ReAct prompt to answer questions by using tools. public isolated class ReActAgent { *BaseAgent; @@ -30,11 +35,15 @@ public isolated class ReActAgent { public isolated function init(CompletionLlmModel|ChatLlmModel model, (BaseToolKit|Tool)... tools) returns error? { self.toolStore = check new (...tools); self.model = model; - self.instructionPrompt = constructReActPrompt(self.toolStore.extractToolInfo()); + self.instructionPrompt = constructReActPrompt(extractToolInfo(self.toolStore)); log:printDebug("Instruction Prompt Generated Successfully", instructionPrompt = self.instructionPrompt); } - isolated function selectNextTool(ExecutionProgress progress) returns LlmToolResponse|LlmChatResponse|LlmError { + isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError { + return parseReActLlmResponse(llmResponse); + } + + isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError { map|string? context = progress.context; string contextPrompt = context is () ? "" : string `${"\n\n"}You can use these information if needed: ${context.toString()}$`; @@ -43,24 +52,15 @@ public isolated class ReActAgent { Question: ${progress.query} ${constructHistoryPrompt(progress.history)} ${THOUGHT_KEY}`; - - string llmResponse = check self.generate(reactPrompt); - SelectedTool|LlmChatResponse|LlmInvalidGenerationError parsedResponse = parseLlmReponse(normalizeLlmResponse(llmResponse)); - if parsedResponse is LlmChatResponse { - return parsedResponse; - } - return { - tool: parsedResponse, - llmResponse - }; + return check self.generate(reactPrompt); } # Generate ReAct response for the given prompt # # + prompt - ReAct prompt to decide the next tool # + return - ReAct response - isolated function generate(string prompt) returns string|LlmConnectionError { - string|error? llmResult = (); + isolated function generate(string prompt) returns string|LlmError { + string|LlmError llmResult; CompletionLlmModel|ChatLlmModel model = self.model; if model is CompletionLlmModel { llmResult = model.complete(prompt, stop = OBSERVATION_KEY); @@ -71,13 +71,11 @@ ${THOUGHT_KEY}`; content: prompt } ], stop = OBSERVATION_KEY); + } else { + return error LlmError("Invalid LLM model is given."); } - if llmResult is string { - return llmResult; - } - return error LlmConnectionError("Geneartion Failed.", llmResult); + return llmResult; } - } isolated function normalizeLlmResponse(string llmResponse) returns string { @@ -99,8 +97,9 @@ isolated function normalizeLlmResponse(string llmResponse) returns string { return normalizedResponse; } -isolated function parseLlmReponse(string llmResponse) returns SelectedTool|LlmChatResponse|LlmInvalidGenerationError { - string[] content = regexp:split(re `${BACKTICKS}`, llmResponse + ""); +isolated function parseReActLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError { + string llmResponseStr = normalizeLlmResponse(llmResponse.toString()); + string[] content = regexp:split(re `${BACKTICKS}`, llmResponseStr + ""); if content.length() < 3 { log:printWarn("Unexpected LLM response is given", llmResponse = llmResponse); return error LlmInvalidGenerationError("Unable to extract the tool due to invalid generation", thought = llmResponse, instruction = "Tool execution failed due to invalid generation."); @@ -126,7 +125,7 @@ isolated function parseLlmReponse(string llmResponse) returns SelectedTool|LlmCh content: input }; } - SelectedTool|error tool = jsonAction.fromJsonWithType(); + LlmToolResponse|error tool = jsonAction.fromJsonWithType(); if tool is error { log:printError("Error while extracting action name and inputs from LLM response.", tool, llmResponse = llmResponse); return error LlmInvalidGenerationError("Generated 'Action' JSON_BLOB contains invalid action name or inputs.", tool, thought = llmResponse, instruction = "Tool execution failed due to an invalid schema for 'Action' JSON_BLOB."); @@ -141,12 +140,33 @@ isolated function constructHistoryPrompt(ExecutionStep[] history) returns string string historyPrompt = ""; foreach ExecutionStep step in history { string observationStr = getObservationString(step.observation); - string thoughtStr = step.toolResponse.llmResponse.toString(); + string thoughtStr = step.llmResponse.toString(); historyPrompt += string `${thoughtStr}${"\n"}${OBSERVATION_KEY}: ${observationStr}${"\n"}`; } return historyPrompt; } +# Generate descriptions for the tools registered. +# +# + toolStore - ToolStore instance +# + return - Return a record with tool names and descriptions +isolated function extractToolInfo(ToolStore toolStore) returns ToolInfo { + string[] toolNameList = []; + string[] toolIntroList = []; + foreach AgentTool tool in toolStore.tools { + toolNameList.push(string `${tool.name}`); + record {|string description; JsonInputSchema inputSchema?;|} toolDescription = { + description: tool.description, + inputSchema: tool.variables + }; + toolIntroList.push(tool.name + ": " + toolDescription.toString()); + } + return { + toolList: string:'join(", ", ...toolNameList).trim(), + toolIntro: string:'join("\n", ...toolIntroList).trim() + }; +} + isolated function constructReActPrompt(ToolInfo toolInfo) returns string => string `System: Respond to the human as helpfully and accurately as possible. You have access to the following tools: ${toolInfo.toolIntro} diff --git a/ballerina/tests/agent-test.bal b/ballerina/tests/agent-test.bal index 8103b14..5f4464e 100644 --- a/ballerina/tests/agent-test.bal +++ b/ballerina/tests/agent-test.bal @@ -37,8 +37,7 @@ function testReActAgentInitialization() { Calculator: ${{"description": calculatorTool.description, "inputSchema": calculatorTool.parameters}.toString()}` }; - ToolStore store = agent.toolStore; - test:assertEquals(store.extractToolInfo(), toolInfo); + test:assertEquals(extractToolInfo(agent.toolStore), toolInfo); } @test:Config {} @@ -91,12 +90,12 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use function testAgentExecutorRun() returns error? { ReActAgent agent = check new (model, searchTool, calculatorTool); string query = "Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?"; - AgentExecutor agentExecutor = new (agent, query); - record {|ExecutionStep|LlmChatResponse|error value;|}? result = agentExecutor.next(); + Executor agentExecutor = new (agent, query = query); + record {|ExecutionResult|LlmChatResponse|ExecutionError|error value;|}? result = agentExecutor.next(); if result is () { test:assertFail("AgentExecutor.next returns an null during first iteration"); } - ExecutionStep|LlmChatResponse|error output = result.value; + ExecutionResult|LlmChatResponse|ExecutionError|error output = result.value; if output is error { test:assertFail("AgentExecutor.next returns an error during first iteration"); } @@ -127,19 +126,7 @@ function testAgentExecutorRun() returns error? { function testConstructHistoryPrompt() { ExecutionStep[] history = [ { - toolResponse: { - tool: { - name: "Create wifi", - arguments: { - "path": "/guest-wifi-accounts", - "requestBody": { - "email": "johnny@wso2.com", - "username": "newGuest", - "password": "jh123" - } - } - }, - llmResponse: string `Thought: I need to use the "Create wifi" tool to create a new guest wifi account with the given username and password. + llmResponse: string `Thought: I need to use the "Create wifi" tool to create a new guest wifi account with the given username and password. Action: { "tool": "Create wifi", @@ -151,41 +138,23 @@ Action: "password": "jh123" } } -}` - - }, +}`, observation: "Successfully added the wifi account" }, { - toolResponse: { - tool: { - name: "List wifi", - arguments: { - "path": "//guest-wifi-accounts/johnny@wso2.com" - } - }, - llmResponse: string `Thought: Next, I need to use the "List wifi" tool to get the available list of wifi accounts for the given email. + + llmResponse: string `Thought: Next, I need to use the "List wifi" tool to get the available list of wifi accounts for the given email. Action: { "tool": "List wifi", "tool_input": { "path": "/guest-wifi-accounts/johnny@wso2.com" } -}` - }, +}`, observation: ["freeWifi.guestOf.johnny", "newGuest.guestOf.johnny"] }, { - toolResponse: { - tool: { - name: "Send mail", - arguments: { - "recipient": "alica@wso2.com", - "subject": "Available Wifi Accounts", - "messageBody": "Here are the available wifi accounts: ['newGuest.guestOf.johnny','newGuest.guestOf.johnny']" - } - }, - llmResponse: string `Thought: Finally, I need to use the "Send mail" tool to send the list of available wifi accounts to the given email address. + llmResponse: string `Thought: Finally, I need to use the "Send mail" tool to send the list of available wifi accounts to the given email address. Action: { "tool": "Send mail", @@ -194,8 +163,7 @@ Action: "subject": "Available Wifi Accounts", "messageBody": "Here are the available wifi accounts: ['newGuest.guestOf.johnny','newGuest.guestOf.johnny']" } -}` - }, +}`, observation: error("Error while sending the email(ballerinax/googleapis.gmail)GmailError") } ]; @@ -250,8 +218,8 @@ ${"```"} } ${"```"}`; - SelectedTool|LlmChatResponse parsedResult = check parseLlmReponse(llmResponse); - if parsedResult is SelectedTool { + LlmToolResponse|LlmChatResponse parsedResult = check parseReActLlmResponse(llmResponse); + if parsedResult is LlmToolResponse { test:assertFail("Parsed result should be a ChatResponse"); } test:assertEquals(parsedResult.content, "The guest wifi account guestJohn with password abc123 has been successfully created. There are currently no other available wifi accounts."); @@ -281,8 +249,8 @@ ${"```"} } ${"```"}`; - SelectedTool|LlmChatResponse parsedResult = check parseLlmReponse(llmResponse); - if parsedResult is SelectedTool { + LlmToolResponse|LlmChatResponse parsedResult = check parseReActLlmResponse(llmResponse); + if parsedResult is LlmToolResponse { test:assertFail("Parsed result should be a ChatResponse"); } } diff --git a/ballerina/tests/tool-test.bal b/ballerina/tests/tool-test.bal index 8c3b9ec..507062d 100644 --- a/ballerina/tests/tool-test.bal +++ b/ballerina/tests/tool-test.bal @@ -96,7 +96,7 @@ function testExecuteSuccessfullOutput() { }, caller: sendMail }; - SelectedTool sendMailInput = { + LlmToolResponse sendMailInput = { name: "Send_mail", arguments: { "messageRequest": { @@ -143,7 +143,7 @@ function testExecuteErrorOutput() { }, caller: sendMail }; - SelectedTool sendMailInput = { + LlmToolResponse sendMailInput = { name: "Send_mail", arguments: { "messageRequest": { @@ -190,7 +190,7 @@ function testExecutionError() { }, caller: sendMail }; - SelectedTool sendMailInput = { + LlmToolResponse sendMailInput = { name: "Send_mail", arguments: { "messageRequest": { diff --git a/ballerina/tool.bal b/ballerina/tool.bal index ccf1624..908bb39 100644 --- a/ballerina/tool.bal +++ b/ballerina/tool.bal @@ -31,11 +31,6 @@ public type AgentTool record {| isolated function caller; |}; -type ToolInfo readonly & record {| - string toolList; - string toolIntro; -|}; - isolated class ToolStore { final map & readonly tools; @@ -66,12 +61,12 @@ isolated class ToolStore { # # + action - Action object that contains the tool name and inputs # + return - ActionResult containing the results of the tool execution or an error if tool execution fails - isolated function execute(SelectedTool action) returns ToolOutput|error { + isolated function execute(LlmToolResponse action) returns ToolOutput|LlmInvalidGenerationError|ToolExecutionError { string name = action.name; map? inputs = action.arguments; if !self.tools.hasKey(name) { - return error ToolNotFoundError("Cannot find the tool.", toolName = name, instruction = string `Tool "${name}" does not exists. Use a tool from the list: ${self.extractToolInfo().toolList}`); + return error ToolNotFoundError("Cannot find the tool.", toolName = name, instruction = string `Tool "${name}" does not exists. Use a tool from the list: ${self.tools.keys().toString()}}`); } map|error inputValues = mergeInputs(inputs, self.tools.get(name).constants); @@ -101,36 +96,11 @@ isolated class ToolStore { } return error ToolExecutionError("Tool execution failed.", observation, toolName = name, inputs = inputValues.length() == 0 ? {} : inputValues); } - - # Generate descriptions for the tools registered. - # - # + return - Return a record with tool names and descriptions - isolated function extractToolInfo() returns ToolInfo { - string[] toolNameList = []; - string[] toolIntroList = []; - - map tools = self.tools; - foreach AgentTool tool in tools { - toolNameList.push(string `${tool.name}`); - record {|string description; JsonInputSchema inputSchema?;|} toolDescription = { - description: tool.description, - inputSchema: tool.variables - }; - toolIntroList.push(tool.name + ": " + toolDescription.toString()); - } - return { - toolList: string:'join(", ", ...toolNameList).trim(), - toolIntro: string:'join("\n", ...toolIntroList).trim() - }; - } } isolated function registerTool(map toolMap, Tool[] tools) returns error? { foreach Tool tool in tools { string name = tool.name; - if toolMap.hasKey(name) { - return error("Duplicated tools. Tool name should be unique.", toolName = name); - } if name.toLowerAscii().matches(FINAL_ANSWER_REGEX) { return error(string ` Tool name '${name}' is reserved for the 'Final answer'.`); } @@ -141,6 +111,9 @@ isolated function registerTool(map toolMap, Tool[] tools) } name = regexp:replaceAll(re `[^a-zA-Z0-9_-]`, name, "_"); } + if toolMap.hasKey(name) { + return error("Duplicated tools. Tool name should be unique.", toolName = name); + } JsonInputSchema? variables = check tool.parameters.cloneWithType(); map constants = {}; diff --git a/examples/function-as-tool/main.bal b/examples/function-as-tool/main.bal index e8f2f26..5dcacbc 100644 --- a/examples/function-as-tool/main.bal +++ b/examples/function-as-tool/main.bal @@ -50,7 +50,7 @@ const string DEFAULT_QUERY = "create a new guest wifi with user guestJohn and pa public function main(string query = DEFAULT_QUERY) returns error? { // 1) Create the model (brain of the agent) - agent:Gpt3Model model = check new ({auth: {token: openAIToken}}); + agent:ChatGptModel model = check new ({auth: {token: openAIToken}}); // 2) Define functions as tools agent:Tool listwifi = { diff --git a/examples/http-toolkit/main.bal b/examples/http-toolkit/main.bal index 51739d6..b59fe14 100644 --- a/examples/http-toolkit/main.bal +++ b/examples/http-toolkit/main.bal @@ -71,7 +71,7 @@ public function main(string query = DEFAULT_QUERY) returns error? { // 2) Create the model (brain of the agent) agent:ChatGptModel model = check new ({auth: {token: openAIToken}}); // 3) Create the agent - agent:ReActAgent agent = check new (model, httpToolKit); + agent:FunctionCallAgent agent = check new (model, httpToolKit); // 4) Run the agent to execute user's query _ = agent:run(agent, query, maxIter = 5, context = "email is john@gmail.com"); } diff --git a/examples/load-from-openapi/main.bal b/examples/load-from-openapi/main.bal index 93cd99e..fb278cb 100644 --- a/examples/load-from-openapi/main.bal +++ b/examples/load-from-openapi/main.bal @@ -32,7 +32,7 @@ const string DEFAULT_QUERY = "create a new guest wifi with user openAPIwifi and public function main(string openAPIPath = OPENAPI_PATH, string query = DEFAULT_QUERY) returns error? { // 1) Create the model (brain of the agent) - agent:AzureGpt3Model model = check new ({auth: {apiKey}}, serviceUrl, deploymentId, apiVersion); + agent:AzureChatGptModel model = check new ({auth: {apiKey}}, serviceUrl, deploymentId, apiVersion); // 2) Extract tools from openAPI specification final agent:HttpApiSpecification apiSpecification = check agent:extractToolsFromOpenApiSpecFile(openAPIPath); diff --git a/examples/mock-tools/main.bal b/examples/mock-tools/main.bal index a9f9ff5..4203ee9 100644 --- a/examples/mock-tools/main.bal +++ b/examples/mock-tools/main.bal @@ -77,7 +77,7 @@ public function main(string query = DEFAULT_QUERY) returns error? { caller: calculatorToolMock }; - agent:Gpt3Model model = check new ({auth: {token: openAIToken}}); + agent:ChatGptModel model = check new ({auth: {token: openAIToken}}); agent:ReActAgent agent = check new (model, searchTool, calculatorTool); _ = agent:run(agent, query); } diff --git a/examples/multi-type-tools/main.bal b/examples/multi-type-tools/main.bal index f943438..dea091a 100644 --- a/examples/multi-type-tools/main.bal +++ b/examples/multi-type-tools/main.bal @@ -97,7 +97,6 @@ public function main(string query = DEFAULT_QUERY) returns error? { }); agent:ChatGptModel model = check new ({auth: {token: openAIToken}}); - agent:FunctionCallAgent agent = check new (model, wifiApiToolKit, sendEmailTool); // Execute the query using agent iterator diff --git a/examples/setup/Ballerina.toml b/examples/setup/Ballerina.toml index 51e49a5..8a0a27e 100644 --- a/examples/setup/Ballerina.toml +++ b/examples/setup/Ballerina.toml @@ -2,7 +2,7 @@ org = "wso2" name = "wifi_service_setup" version = "0.1.0" -distribution = "2201.7.1" +distribution = "2201.8.4" [build-options] observabilityIncluded = true From cf2a0a80d8634796014bee5a216fe37d5fb84ade Mon Sep 17 00:00:00 2001 From: "Nadheesh Jihan, nadheesh@wso2.com" Date: Wed, 24 Jan 2024 09:43:18 +0530 Subject: [PATCH 2/3] Fix the testcase --- ballerina/react.bal | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/ballerina/react.bal b/ballerina/react.bal index fc3b419..9ce7e55 100644 --- a/ballerina/react.bal +++ b/ballerina/react.bal @@ -40,7 +40,7 @@ public isolated class ReActAgent { } isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError { - return parseReActLlmResponse(llmResponse); + return parseReActLlmResponse(normalizeLlmResponse(llmResponse.toString())); } isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError { @@ -93,26 +93,24 @@ isolated function normalizeLlmResponse(string llmResponse) returns string { } normalizedResponse = regexp:replace(re `${BACKTICKS}json`, normalizedResponse, BACKTICKS); // replace ```json normalizedResponse = regexp:replaceAll(re `"\{\}"`, normalizedResponse, "{}"); // replace "{}" - normalizedResponse = regexp:replaceAll(re `\\"`, normalizedResponse, "\""); // replace \" return normalizedResponse; } -isolated function parseReActLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError { - string llmResponseStr = normalizeLlmResponse(llmResponse.toString()); - string[] content = regexp:split(re `${BACKTICKS}`, llmResponseStr + ""); +isolated function parseReActLlmResponse(string llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError { + string[] content = regexp:split(re `${BACKTICKS}`, llmResponse + ""); if content.length() < 3 { log:printWarn("Unexpected LLM response is given", llmResponse = llmResponse); - return error LlmInvalidGenerationError("Unable to extract the tool due to invalid generation", thought = llmResponse, instruction = "Tool execution failed due to invalid generation."); + return error LlmInvalidGenerationError("Unable to extract the tool due to invalid generation", llmResponse = llmResponse, instruction = "Tool execution failed due to invalid generation."); } - map|error jsonThought = content[1].fromJsonStringWithType(); - if jsonThought is error { - log:printWarn("Invalid JSON is given as the action.", jsonThought); - return error LlmInvalidGenerationError("Invalid JSON is given as the action.", jsonThought, thought = llmResponse, instruction = "Tool execution failed due to an invalid 'Action' JSON_BLOB."); + map|error jsonResponse = content[1].fromJsonStringWithType(); + if jsonResponse is error { + log:printWarn("Invalid JSON is given as the action.", jsonResponse); + return error LlmInvalidGenerationError("Invalid JSON is given as the action.", jsonResponse, llmResponse = llmResponse, instruction = "Tool execution failed due to an invalid 'Action' JSON_BLOB."); } map jsonAction = {}; - foreach [string, json] [key, value] in jsonThought.entries() { + foreach [string, json] [key, value] in jsonResponse.entries() { if key.toLowerAscii() == ACTION_KEY { jsonAction[ACTION_NAME_KEY] = value; } else if key.toLowerAscii().matches(ACTION_INPUT_REGEX) { @@ -128,7 +126,7 @@ isolated function parseReActLlmResponse(json llmResponse) returns LlmToolRespons LlmToolResponse|error tool = jsonAction.fromJsonWithType(); if tool is error { log:printError("Error while extracting action name and inputs from LLM response.", tool, llmResponse = llmResponse); - return error LlmInvalidGenerationError("Generated 'Action' JSON_BLOB contains invalid action name or inputs.", tool, thought = llmResponse, instruction = "Tool execution failed due to an invalid schema for 'Action' JSON_BLOB."); + return error LlmInvalidGenerationError("Generated 'Action' JSON_BLOB contains invalid action name or inputs.", tool, llmResponse = llmResponse, instruction = "Tool execution failed due to an invalid schema for 'Action' JSON_BLOB."); } return { name: tool.name, @@ -140,8 +138,8 @@ isolated function constructHistoryPrompt(ExecutionStep[] history) returns string string historyPrompt = ""; foreach ExecutionStep step in history { string observationStr = getObservationString(step.observation); - string thoughtStr = step.llmResponse.toString(); - historyPrompt += string `${thoughtStr}${"\n"}${OBSERVATION_KEY}: ${observationStr}${"\n"}`; + string llmResponseStr = step.llmResponse.toString(); + historyPrompt += string `${llmResponseStr}${"\n"}${OBSERVATION_KEY}: ${observationStr}${"\n"}`; } return historyPrompt; } From 29c61d43162226c6749eac257a0699828135f5b6 Mon Sep 17 00:00:00 2001 From: "Nadheesh Jihan, nadheesh@wso2.com" Date: Fri, 26 Jan 2024 08:32:35 +0530 Subject: [PATCH 3/3] Apply PR comments --- ballerina/agent.bal | 4 +--- ballerina/error.bal | 2 +- ballerina/function_call.bal | 3 +-- ballerina/react.bal | 6 ++---- ballerina/tool.bal | 2 +- 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/ballerina/agent.bal b/ballerina/agent.bal index 392be7c..91edbe2 100644 --- a/ballerina/agent.bal +++ b/ballerina/agent.bal @@ -300,6 +300,4 @@ isolated function getObservationString(anydata|error observation) returns string # # + agent - Agent instance # + return - Array of tools registered with the agent -public isolated function getTools(BaseAgent agent) returns AgentTool[] { - return agent.toolStore.tools.toArray(); -} +public isolated function getTools(BaseAgent agent) returns AgentTool[] => agent.toolStore.tools.toArray(); diff --git a/ballerina/error.bal b/ballerina/error.bal index 14ec6a1..a711dc3 100644 --- a/ballerina/error.bal +++ b/ballerina/error.bal @@ -47,7 +47,7 @@ public type HttpResponseParsingError distinct HttpServiceToolKitError; public type ToolExecutionError distinct error; # Error during unexpected output by the tool -public type ToolInvaludOutputError distinct ToolExecutionError; +public type ToolInvalidOutputError distinct ToolExecutionError; # Errors occurred due to invalid tool name generated by the LLM. public type ToolNotFoundError distinct LlmInvalidGenerationError; diff --git a/ballerina/function_call.bal b/ballerina/function_call.bal index 1b65a4b..884f1bc 100644 --- a/ballerina/function_call.bal +++ b/ballerina/function_call.bal @@ -57,14 +57,13 @@ public isolated class FunctionCallAgent { isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError { ChatMessage[] messages = createFunctionCallMessages(progress); - FunctionCall|string|LlmError functionaCall = self.model.functionaCall(messages, + return self.model.functionaCall(messages, from AgentTool tool in self.toolStore.tools.toArray() select { name: tool.name, description: tool.description, parameters: tool.variables }); - return functionaCall; } } diff --git a/ballerina/react.bal b/ballerina/react.bal index 9ce7e55..bb6b4b6 100644 --- a/ballerina/react.bal +++ b/ballerina/react.bal @@ -39,9 +39,7 @@ public isolated class ReActAgent { log:printDebug("Instruction Prompt Generated Successfully", instructionPrompt = self.instructionPrompt); } - isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError { - return parseReActLlmResponse(normalizeLlmResponse(llmResponse.toString())); - } + isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError => parseReActLlmResponse(normalizeLlmResponse(llmResponse.toString())); isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError { map|string? context = progress.context; @@ -157,7 +155,7 @@ isolated function extractToolInfo(ToolStore toolStore) returns ToolInfo { description: tool.description, inputSchema: tool.variables }; - toolIntroList.push(tool.name + ": " + toolDescription.toString()); + toolIntroList.push(string `${tool.name}: ${toolDescription.toString()}`); } return { toolList: string:'join(", ", ...toolNameList).trim(), diff --git a/ballerina/tool.bal b/ballerina/tool.bal index 908bb39..09135c4 100644 --- a/ballerina/tool.bal +++ b/ballerina/tool.bal @@ -89,7 +89,7 @@ isolated class ToolStore { return {value: observation}; } if observation !is error { - return error ToolInvaludOutputError("Tool returns an invalid output. Expected anydata or error.", outputType = typeof observation, toolName = name, inputs = inputValues.length() == 0 ? {} : inputValues); + return error ToolInvalidOutputError("Tool returns an invalid output. Expected anydata or error.", outputType = typeof observation, toolName = name, inputs = inputValues.length() == 0 ? {} : inputValues); } if observation.message() == "{ballerina/lang.function}IncompatibleArguments" { return error ToolInvalidInputError("Tool is provided with invalid inputs.", observation, toolName = name, inputs = inputValues.length() == 0 ? {} : inputValues, instruction = string `Tool "${name}" execution failed due to invalid inputs provided. Use the schema to provide inputs: ${self.tools.get(name).variables.toString()}`);