Skip to content

Commit

Permalink
Merge pull request #40 from nadheesh/improved-types
Browse files Browse the repository at this point in the history
Improve agent module api to support API-Chat
  • Loading branch information
nadheesh committed Jan 26, 2024
2 parents d8c0adf + 29c61d4 commit cf19f68
Show file tree
Hide file tree
Showing 16 changed files with 262 additions and 266 deletions.
2 changes: 1 addition & 1 deletion ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
distribution = "2201.8.4"
org = "ballerinax"
name = "ai.agent"
version = "0.7.0"
version = "0.7.1"
license = ["Apache-2.0"]
authors = ["Ballerina"]
keywords = ["AI/Agent", "Cost/Freemium"]
Expand Down
167 changes: 93 additions & 74 deletions ballerina/agent.bal
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,30 @@ public type ExecutionProgress record {|
# Execution history up to the current action
ExecutionStep[] history = [];
# Contextual instruction that can be used by the agent during the execution
map<json>|string context?;
map<json>|string? context = ();
|};

# Execution step information
public type ExecutionStep record {|
# Response generated by the LLM
json llmResponse;
# Observations produced by the tool during the execution
anydata|error observation;
|};

# Execution step information
public type ExecutionResult record {|
# Tool decided by the LLM during the reasoning
LlmToolResponse toolResponse;
LlmToolResponse tool;
# Observations produced by the tool during the execution
anydata|error observation;
|};

# An LLM response containing the selected tool to be executed
public type LlmToolResponse record {|
# Next tool to be executed
SelectedTool|LlmInvalidGenerationError tool;
# Raw LLM generated output
public type ExecutionError record {|
# Response generated by the LLM
json llmResponse;
# Error caused during the execution
LlmInvalidGenerationError|ToolExecutionError 'error;
|};

# An chat response by the LLM
Expand All @@ -49,7 +56,7 @@ public type LlmChatResponse record {|
|};

# Tool selected by LLM to be performed by the agent
public type SelectedTool record {|
public type LlmToolResponse record {|
# Name of the tool to selected
string name;
# Input to the tool
Expand All @@ -70,34 +77,40 @@ public type BaseAgent distinct isolated object {
#
# + progress - QueryProgress with the current query and execution history
# + return - NextAction decided by the LLM or an error if call to the LLM fails
isolated function selectNextTool(ExecutionProgress progress) returns LlmToolResponse|LlmChatResponse|LlmError;
isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError;

# Parse the llm response and extract the tool to be executed.
#
# + llmResponse - Raw LLM response
# + return - SelectedTool or an error if parsing fails
isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError;
};

# An iterator to iterate over agent's execution
public class AgentIterator {
public class Iterator {
*object:Iterable;
private final AgentExecutor executor;
private final Executor executor;

# Initialize the iterator with the agent and the query.
#
# + agent - Agent instance to be executed
# + query - Natural language query to be executed by the agent
# + context - Contextual information to be used by the agent during the execution
public isolated function init(BaseAgent agent, string query, map<json>|string? context = ()) {
self.executor = new (agent, query, context = context);
self.executor = new (agent, query = query, context = context);
}

# Iterate over the agent's execution steps.
# + return - a record with the execution step or an error if the agent failed
public function iterator() returns object {
public function next() returns record {|ExecutionStep|LlmChatResponse|error value;|}?;
public function next() returns record {|ExecutionResult|LlmChatResponse|ExecutionError|error value;|}?;
} {
return self.executor;
}
}

# An executor to perform step-by-step execution of the agent.
public class AgentExecutor {
public class Executor {
private boolean isCompleted = false;
private final BaseAgent agent;
# Contains the current execution progress for the agent and the query
Expand All @@ -109,13 +122,9 @@ public class AgentExecutor {
# + query - Natural language query to be executed by the agent
# + history - Execution history of the agent (This is used to continue an execution paused without completing)
# + context - Contextual information to be used by the agent during the execution
public isolated function init(BaseAgent agent, string query, ExecutionStep[] history = [], map<json>|string? context = ()) {
public isolated function init(BaseAgent agent, *ExecutionProgress progress) {
self.agent = agent;
self.progress = {
query,
history,
context
};
self.progress = progress;
}

# Checks whether agent has more steps to execute.
Expand All @@ -127,47 +136,61 @@ public class AgentExecutor {

# Reason the next step of the agent.
#
# + return - Llm tool response, chat response or an error if reasoning has failed
public isolated function reason() returns LlmToolResponse|LlmChatResponse|TaskCompletedError|LlmError {
# + return - generated LLM response during the reasoning or an error if the reasoning fails
public isolated function reason() returns json|TaskCompletedError|LlmError {
if self.isCompleted {
return error TaskCompletedError("Task is already completed. No more reasoning is needed.");
}
LlmToolResponse|LlmChatResponse response = check self.agent.selectNextTool(self.progress);
if response is LlmChatResponse {
self.isCompleted = true;
}
return response;
return check self.agent.selectNextTool(self.progress);
}

# Execute the next step of the agent.
#
# + toolResponse - LLM tool response containing the tool to be executed and the raw LLM output
# + llmResponse - LLM response containing the tool to be executed and the raw LLM output
# + return - Observations from the tool can be any|error|null
public isolated function act(LlmToolResponse toolResponse) returns ToolOutput {
ToolOutput observation;
SelectedTool|LlmInvalidGenerationError tool = toolResponse.tool;
public isolated function act(json llmResponse) returns ExecutionResult|LlmChatResponse|ExecutionError {
LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError parseLlmResponse = self.agent.parseLlmResponse(llmResponse);
if parseLlmResponse is LlmChatResponse {
self.isCompleted = true;
return parseLlmResponse;
}

// TODO Improve to use intructions from the error instead of generic error instructions
if tool is SelectedTool {
ToolOutput|error output = self.agent.toolStore.execute(tool);
if output is ToolNotFoundError {
observation = {value: "Tool is not found. Please check the tool name and retry."};
} else if output is ToolInvalidInputError {
observation = {value: "Tool execution failed due to invalid inputs. Retry with correct inputs."};
} else if output is error {
observation = {value: "Tool execution failed. Retry with correct inputs."};
anydata observation;
ExecutionResult|ExecutionError executionResult;
if parseLlmResponse is LlmToolResponse {
ToolOutput|ToolExecutionError|LlmInvalidGenerationError output = self.agent.toolStore.execute(parseLlmResponse);
if output is error {
if output is ToolNotFoundError {
observation = "Tool is not found. Please check the tool name and retry.";
} else if output is ToolInvalidInputError {
observation = "Tool execution failed due to invalid inputs. Retry with correct inputs.";
} else {
observation = "Tool execution failed. Retry with correct inputs.";
}
executionResult = {
llmResponse,
'error: output
};
} else {
observation = output;
anydata|error value = output.value;
observation = value is error ? value.toString() : value;
executionResult = {
tool: parseLlmResponse,
observation: value
};
}
} else {
observation = {value: "Tool extraction failed due to invalid JSON_BLOB. Retry with correct JSON_BLOB."};
observation = "Tool extraction failed due to invalid JSON_BLOB. Retry with correct JSON_BLOB.";
executionResult = {
llmResponse,
'error: parseLlmResponse
};
}
// update the execution history with the latest step
self.update({
toolResponse,
observation: observation.value
llmResponse,
observation
});
return observation;
return executionResult;
}

# Update the agent with an execution step.
Expand All @@ -179,25 +202,16 @@ public class AgentExecutor {

# Reason and execute the next step of the agent.
#
# + return - A record with ExecutionStep, chat response or an error
public isolated function next() returns record {|ExecutionStep|LlmChatResponse|error value;|}? {
LlmToolResponse|LlmChatResponse|error toolResponse = self.reason();
if toolResponse is LlmChatResponse|error {
return {value: toolResponse};
# + return - A record with ExecutionResult, chat response or an error
public isolated function next() returns record {|ExecutionResult|LlmChatResponse|ExecutionError|error value;|}? {
if self.isCompleted {
return ();
}
return {
value: {
toolResponse,
observation: self.act(toolResponse).value
}
};
}

# Allow retrieving the execution history during previous steps.
#
# + return - Execution history of the agent (A list of ExecutionStep)
public isolated function getExecutionHistory() returns ExecutionStep[] {
return self.progress.history;
json|TaskCompletedError|LlmError llmResponse = self.reason();
if llmResponse is error {
return {value: llmResponse};
}
return {value: self.act(llmResponse)};
}
}

Expand All @@ -209,12 +223,12 @@ public class AgentExecutor {
# + context - Context values to be used by the agent to execute the task
# + verbose - If true, then print the reasoning steps (default: true)
# + return - Returns the execution steps tracing the agent's reasoning and outputs from the tools
public isolated function run(BaseAgent agent, string query, int maxIter = 5, string|map<json> context = {}, boolean verbose = true) returns record {|ExecutionStep[] steps; string answer?;|} {
ExecutionStep[] steps = [];
public isolated function run(BaseAgent agent, string query, int maxIter = 5, string|map<json> context = {}, boolean verbose = true) returns record {|(ExecutionResult|ExecutionError)[] steps; string answer?;|} {
(ExecutionResult|ExecutionError)[] steps = [];
string? content = ();
AgentIterator iterator = new (agent, query, context = context);
Iterator iterator = new (agent, query, context = context);
int iter = 0;
foreach ExecutionStep|LlmChatResponse|error step in iterator {
foreach ExecutionResult|LlmChatResponse|ExecutionError|error step in iterator {
if iter == maxIter {
break;
}
Expand All @@ -233,8 +247,8 @@ public isolated function run(BaseAgent agent, string query, int maxIter = 5, str
iter += 1;
if verbose {
io:println(string `${"\n\n"}Agent Iteration ${iter.toString()}`);
SelectedTool|LlmInvalidGenerationError tool = step.toolResponse.tool;
if tool is SelectedTool {
if step is ExecutionResult {
LlmToolResponse tool = step.tool;
io:println(string `Action:
${BACKTICKS}
{
Expand All @@ -249,14 +263,13 @@ ${BACKTICKS}`);
io:println(string `${OBSERVATION_KEY}: ${observation.toString()}`);
}
} else {
error? cause = tool.cause();
string llmResponse = step.toolResponse.llmResponse.toString();
error? cause = step.'error.cause();
io:println(string `LLM Generation Error:
${BACKTICKS}
{
message: ${tool.message()},
message: ${step.'error.message()},
cause: ${(cause is error ? cause.message() : "Unspecified")},
llmResponse: ${llmResponse}
llmResponse: ${step.llmResponse.toString()}
}
${BACKTICKS}`);
}
Expand All @@ -282,3 +295,9 @@ isolated function getObservationString(anydata|error observation) returns string
return observation.toString().trim();
}
}

# Get the tools registered with the agent.
#
# + agent - Agent instance
# + return - Array of tools registered with the agent
public isolated function getTools(BaseAgent agent) returns AgentTool[] => agent.toolStore.tools.toArray();
9 changes: 6 additions & 3 deletions ballerina/error.bal
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@ public type InvalidParameterDefinition distinct OpenApiParsingError;
# Any error occurred during LLM generation is classified under this error type.
public type LlmError distinct error;

# Errors occurred due to unexpected responses from the LLM.
public type LlmInvalidResponseError distinct LlmError;

# Errors occurred due to invalid LLM generation.
public type LlmInvalidGenerationError distinct LlmError;

# Errors occurred during LLM generation due to connection.
type LlmConnectionError distinct LlmError;
public type LlmConnectionError distinct LlmError;

# Errors occurred due to termination of the Agent's execution.
public type TaskCompletedError distinct error;
Expand All @@ -41,10 +44,10 @@ public type HttpServiceToolKitError distinct error;
public type HttpResponseParsingError distinct HttpServiceToolKitError;

# Errors during tool execution.
type ToolExecutionError distinct error;
public type ToolExecutionError distinct error;

# Error during unexpected output by the tool
type ToolInvaludOutputError distinct ToolExecutionError;
public type ToolInvalidOutputError distinct ToolExecutionError;

# Errors occurred due to invalid tool name generated by the LLM.
public type ToolNotFoundError distinct LlmInvalidGenerationError;
Expand Down
45 changes: 24 additions & 21 deletions ballerina/function_call.bal
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,41 @@ public isolated class FunctionCallAgent {
self.model = model;
}

isolated function selectNextTool(ExecutionProgress progress) returns LlmToolResponse|LlmChatResponse|LlmError {
ChatMessage[] messages = createFunctionCallMessages(progress);
FunctionCall|string|error response = self.model.functionaCall(messages, self.toolStore.tools.toArray());
if response is error {
return error LlmConnectionError("Error while function call generation", response);
isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError {
if llmResponse is string {
return {content: llmResponse};
}
if response is string {
return {content: response};
if llmResponse !is FunctionCall {
return error LlmInvalidGenerationError("Invalid response", llmResponse = llmResponse);
}
string? name = response.name;
string? name = llmResponse.name;
if name is () {
return {tool: error LlmInvalidGenerationError("Missing name", name = response.name, arguments = response.arguments), llmResponse: response.toJson()};
return error LlmInvalidGenerationError("Missing name", name = llmResponse.name, arguments = llmResponse.arguments);
}
string? stringArgs = response.arguments;
string? stringArgs = llmResponse.arguments;
map<json>|error? arguments = ();
if stringArgs is string {
arguments = stringArgs.fromJsonStringWithType();
}
if arguments is error {
return {tool: error LlmInvalidGenerationError("Invalid arguments", arguments, name = response.name, arguments = stringArgs), llmResponse: response.toJson()};
return error LlmInvalidGenerationError("Invalid arguments", arguments, name = llmResponse.name, arguments = stringArgs);
}
return {
tool: {
name,
arguments
},
llmResponse: {
name: name,
arguments: stringArgs
}
name,
arguments
};
}

isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError {
ChatMessage[] messages = createFunctionCallMessages(progress);
return self.model.functionaCall(messages,
from AgentTool tool in self.toolStore.tools.toArray()
select {
name: tool.name,
description: tool.description,
parameters: tool.variables
});
}
}

isolated function createFunctionCallMessages(ExecutionProgress progress) returns ChatMessage[] {
Expand All @@ -81,9 +84,9 @@ isolated function createFunctionCallMessages(ExecutionProgress progress) returns
}
// include the history
foreach ExecutionStep step in progress.history {
FunctionCall|error functionCall = step.toolResponse.llmResponse.fromJsonWithType();
FunctionCall|error functionCall = step.llmResponse.fromJsonWithType();
if functionCall is error {
panic error("Badly formated history for function call agent", generated = step.toolResponse.llmResponse);
panic error("Badly formated history for function call agent", llmResponse = step.llmResponse);
}
messages.push({
role: ASSISTANT,
Expand Down
Loading

0 comments on commit cf19f68

Please sign in to comment.