Skip to content

Commit

Permalink
Merge pull request #41 from nadheesh/improved-types
Browse files Browse the repository at this point in the history
Update docs
  • Loading branch information
nadheesh committed Jan 29, 2024
2 parents cf19f68 + ab899d0 commit 36d8029
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 232 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,4 @@ ballerina/target
ballerina/sample*
ballerina/openapi*
ballerina/schema*


$(pwd)
2 changes: 1 addition & 1 deletion ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
distribution = "2201.8.4"
org = "ballerinax"
name = "ai.agent"
version = "0.7.1"
version = "0.7.2"
license = ["Apache-2.0"]
authors = ["Ballerina"]
keywords = ["AI/Agent", "Cost/Freemium"]
Expand Down
256 changes: 161 additions & 95 deletions ballerina/Module.md

Large diffs are not rendered by default.

256 changes: 161 additions & 95 deletions ballerina/Package.md

Large diffs are not rendered by default.

28 changes: 14 additions & 14 deletions ballerina/agent.bal
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,20 @@ public type ToolOutput record {|
|};

public type BaseAgent distinct isolated object {
LlmModel model;
ToolStore toolStore;

# Use LLMs to decide the next tool/step.
#
# + progress - QueryProgress with the current query and execution history
# + return - NextAction decided by the LLM or an error if call to the LLM fails
isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError;
public LlmModel model;
public ToolStore toolStore;

# Parse the llm response and extract the tool to be executed.
#
# + llmResponse - Raw LLM response
# + return - SelectedTool or an error if parsing fails
isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError;
# + return - A record containing the tool decided by the LLM, chat response or an error if the response is invalid
public isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError;

# Use LLM to decide the next tool/step.
#
# + progress - Execution progress with the current query and execution history
# + return - LLM response containing the tool or chat response (or an error if the call fails)
public isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError;
};

# An iterator to iterate over agent's execution
Expand All @@ -96,8 +96,8 @@ public class Iterator {
# + agent - Agent instance to be executed
# + query - Natural language query to be executed by the agent
# + context - Contextual information to be used by the agent during the execution
public isolated function init(BaseAgent agent, string query, map<json>|string? context = ()) {
self.executor = new (agent, query = query, context = context);
public isolated function init(BaseAgent agent, *ExecutionProgress progress) {
self.executor = new (agent, progress);
}

# Iterate over the agent's execution steps.
Expand Down Expand Up @@ -226,7 +226,7 @@ public class Executor {
public isolated function run(BaseAgent agent, string query, int maxIter = 5, string|map<json> context = {}, boolean verbose = true) returns record {|(ExecutionResult|ExecutionError)[] steps; string answer?;|} {
(ExecutionResult|ExecutionError)[] steps = [];
string? content = ();
Iterator iterator = new (agent, query, context = context);
Iterator iterator = new (agent, query = query, context = context);
int iter = 0;
foreach ExecutionResult|LlmChatResponse|ExecutionError|error step in iterator {
if iter == maxIter {
Expand All @@ -253,7 +253,7 @@ public isolated function run(BaseAgent agent, string query, int maxIter = 5, str
${BACKTICKS}
{
${ACTION_NAME_KEY}: ${tool.name},
${ACTION_ARGUEMENTS_KEY}: ${(tool.arguments ?: "None").toString()}}
${ACTION_ARGUEMENTS_KEY}: ${(tool.arguments ?: "None").toString()}
}
${BACKTICKS}`);
anydata|error observation = step?.observation;
Expand Down
22 changes: 16 additions & 6 deletions ballerina/function_call.bal
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@
# This agent uses OpenAI function call API to perform the tool selection.
public isolated class FunctionCallAgent {
*BaseAgent;
final ToolStore toolStore;
final FunctionCallLlm model;
# Tool store to be used by the agent
public final ToolStore toolStore;
# LLM model instance (should be a function call model)
public final FunctionCallLlmModel model;

# Initialize an Agent.
#
# + model - LLM model instance
# + tools - Tools to be used by the agent
public isolated function init(FunctionCallLlm model, (BaseToolKit|Tool)... tools) returns error? {
public isolated function init(FunctionCallLlmModel model, (BaseToolKit|Tool)... tools) returns error? {
self.toolStore = check new (...tools);
self.model = model;
}

isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError {
# Parse the function calling API response and extract the tool to be executed.
#
# + llmResponse - Raw LLM response
# + return - A record containing the tool decided by the LLM, chat response or an error if the response is invalid
public isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError {
if llmResponse is string {
return {content: llmResponse};
}
Expand All @@ -55,9 +61,13 @@ public isolated class FunctionCallAgent {
};
}

isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError {
# Use LLM to decide the next tool/step based on the function calling APIs.
#
# + progress - Execution progress with the current query and execution history
# + return - LLM response containing the tool or chat response (or an error if the call fails)
public isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError {
ChatMessage[] messages = createFunctionCallMessages(progress);
return self.model.functionaCall(messages,
return self.model.functionCall(messages,
from AgentTool tool in self.toolStore.tools.toArray()
select {
name: tool.name,
Expand Down
23 changes: 10 additions & 13 deletions ballerina/llm.bal
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,25 @@ public type LlmModel distinct isolated object {
# Extendable LLM model object for completion models.
public type CompletionLlmModel distinct isolated object {
*LlmModel;
CompletionModelConfig modelConfig;
public isolated function complete(string prompt, string? stop = ()) returns string|LlmError;
};

# Extendable LLM model object for chat LLM models
public type ChatLlmModel distinct isolated object {
*LlmModel;
ChatModelConfig modelConfig;
public isolated function chatComplete(ChatMessage[] messages, string? stop = ()) returns string|LlmError;
};

# Extendable LLM model object for LLM models with function call API
public type FunctionCallLlm distinct isolated object {
public type FunctionCallLlmModel distinct isolated object {
*LlmModel;
ChatModelConfig modelConfig;
public isolated function functionaCall(ChatMessage[] messages, ChatCompletionFunctions[] functions, string? stop = ()) returns FunctionCall|string|LlmError;
public isolated function functionCall(ChatMessage[] messages, ChatCompletionFunctions[] functions, string? stop = ()) returns string|FunctionCall|LlmError;
};

public isolated class Gpt3Model {
*CompletionLlmModel;
final text:Client llmClient;
final CompletionModelConfig modelConfig;
public final CompletionModelConfig modelConfig;

# Initializes the GPT-3 model with the given connection configuration and model configuration.
#
Expand Down Expand Up @@ -140,7 +137,7 @@ public isolated class Gpt3Model {
public isolated class AzureGpt3Model {
*CompletionLlmModel;
final azure_text:Client llmClient;
final CompletionModelConfig modelConfig;
public final CompletionModelConfig modelConfig;
private final string deploymentId;
private final string apiVersion;

Expand Down Expand Up @@ -179,10 +176,10 @@ public isolated class AzureGpt3Model {
}

public isolated class ChatGptModel {
*FunctionCallLlm;
*FunctionCallLlmModel;
*ChatLlmModel;
final chat:Client llmClient;
final ChatModelConfig modelConfig;
public final ChatModelConfig modelConfig;

# Initializes the ChatGPT model with the given connection configuration and model configuration.
#
Expand Down Expand Up @@ -219,7 +216,7 @@ public isolated class ChatGptModel {
# + functions - Function definitions to be used for the function call
# + stop - Stop sequence to stop the completion
# + return - Function to be called, chat response or an error in-case of failures
public isolated function functionaCall(ChatMessage[] messages, ChatCompletionFunctions[] functions, string? stop = ()) returns FunctionCall|string|LlmError {
public isolated function functionCall(ChatMessage[] messages, ChatCompletionFunctions[] functions, string? stop = ()) returns string|FunctionCall|LlmError {

chat:CreateChatCompletionResponse|error response = self.llmClient->/chat/completions.post({
...self.modelConfig,
Expand All @@ -246,10 +243,10 @@ public isolated class ChatGptModel {
}

public isolated class AzureChatGptModel {
*FunctionCallLlm;
*FunctionCallLlmModel;
*ChatLlmModel;
final azure_chat:Client llmClient;
final ChatModelConfig modelConfig;
public final ChatModelConfig modelConfig;
private final string deploymentId;
private final string apiVersion;

Expand Down Expand Up @@ -294,7 +291,7 @@ public isolated class AzureChatGptModel {
# + functions - Function definitions to be used for the function call
# + stop - Stop sequence to stop the completion
# + return - Function to be called, chat response or an error in-case of failures
public isolated function functionaCall(ChatMessage[] messages, ChatCompletionFunctions[] functions, string? stop = ()) returns FunctionCall|string|LlmError {
public isolated function functionCall(ChatMessage[] messages, ChatCompletionFunctions[] functions, string? stop = ()) returns string|FunctionCall|LlmError {
azure_chat:CreateChatCompletionResponse|error response =
self.llmClient->/deployments/[self.deploymentId]/chat/completions.post(self.apiVersion, {
...self.modelConfig,
Expand Down
18 changes: 14 additions & 4 deletions ballerina/react.bal
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ type ToolInfo readonly & record {|
public isolated class ReActAgent {
*BaseAgent;
final string instructionPrompt;
final ToolStore toolStore;
final CompletionLlmModel|ChatLlmModel model;
# ToolStore instance to store the tools used by the agent
public final ToolStore toolStore;
# LLM model instance to be used by the agent (Can be either CompletionLlmModel or ChatLlmModel)
public final CompletionLlmModel|ChatLlmModel model;

# Initialize an Agent.
#
Expand All @@ -39,9 +41,17 @@ public isolated class ReActAgent {
log:printDebug("Instruction Prompt Generated Successfully", instructionPrompt = self.instructionPrompt);
}

isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError => parseReActLlmResponse(normalizeLlmResponse(llmResponse.toString()));
# Parse the ReAct llm response and extract the tool to be executed.
#
# + llmResponse - Raw LLM response
# + return - A record containing the tool decided by the LLM, chat response or an error if the response is invalid
public isolated function parseLlmResponse(json llmResponse) returns LlmToolResponse|LlmChatResponse|LlmInvalidGenerationError => parseReActLlmResponse(normalizeLlmResponse(llmResponse.toString()));

isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError {
# Use LLM to decide the next tool/step based on the ReAct prompting
#
# + progress - Execution progress with the current query and execution history
# + return - LLM response containing the tool or chat response (or an error if the call fails)
public isolated function selectNextTool(ExecutionProgress progress) returns json|LlmError {
map<json>|string? context = progress.context;
string contextPrompt = context is () ? "" : string `${"\n\n"}You can use these information if needed: ${context.toString()}$`;

Expand Down
4 changes: 2 additions & 2 deletions ballerina/tool.bal
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ public type AgentTool record {|
isolated function caller;
|};

isolated class ToolStore {
public isolated class ToolStore {
final map<AgentTool> & readonly tools;

# Register tools to the agent.
# These tools will be by the LLM to perform tasks.
#
# + tools - A list of tools that are available to the LLM
# + return - An error if the tool is already registered
isolated function init((BaseToolKit|Tool)... tools) returns error? {
public isolated function init((BaseToolKit|Tool)... tools) returns error? {
if tools.length() == 0 {
return error("Initialization failed.", cause = "No tools provided to the agent.");
}
Expand Down

0 comments on commit 36d8029

Please sign in to comment.