Skip to content

Commit

Permalink
Merge pull request #37 from nadheesh/main
Browse files Browse the repository at this point in the history
Improve GPT3.5 reasoning and adds function calling agent
  • Loading branch information
nadheesh authored Jan 19, 2024
2 parents 35641da + d6321d7 commit d8c0adf
Show file tree
Hide file tree
Showing 28 changed files with 1,047 additions and 521 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ ballerina/Config.toml
ballerina/target
ballerina/sample*
ballerina/openapi*
ballerina/schem*
ballerina/schema*


4 changes: 2 additions & 2 deletions ballerina/Ballerina.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[package]
distribution = "2201.7.2"
distribution = "2201.8.4"
org = "ballerinax"
name = "ai.agent"
version = "0.6.1"
version = "0.7.0"
license = ["Apache-2.0"]
authors = ["Ballerina"]
keywords = ["AI/Agent", "Cost/Freemium"]
Expand Down
446 changes: 188 additions & 258 deletions ballerina/agent.bal

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions ballerina/constants.bal
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ const OPENAPI_PATTERN_DATE = "yyyy-MM-dd";
const OPENAPI_PATTERN_DATE_TIME = "yyyy-MM-dd'T'HH:mm:ssZ";

//agent
const FINAL_ANSWER_KEY = "final answer";
const THOUGHT_KEY = "Thought:";
const BACKTICK = "`";
const ERROR_INSTRUCTION_KEY = "instruction";
const BACKTICKS = "```";

final string:RegExp FINAL_ANSWER_REGEX = re `^final.?answer`;

const ACTION_KEY = "action";
const ACTION_NAME_KEY = "name";
const ACTION_ARGUEMENTS_KEY = "arguments";
final string:RegExp ACTION_INPUT_REGEX = re `^action.?input`;
21 changes: 14 additions & 7 deletions ballerina/error.bal
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,33 @@ public type InvalidParameterDefinition distinct OpenApiParsingError;
# Any error occurred during LLM generation is classified under this error type.
public type LlmError distinct error;

# Errors due to invalid action generated by the LLM.
type LlmActionParseError distinct LlmError;
# Errors occurred due to invalid LLM generation.
public type LlmInvalidGenerationError distinct LlmError;

# Errors occurred during LLM generation.
public type LlmGenerationError distinct LlmError;
# Errors occurred during LLM generation due to connection.
type LlmConnectionError distinct LlmError;

# Errors occurred due to termination of the Agent's execution.
public type TaskTerminationError distinct error;
public type TaskCompletedError distinct error;

# Errors occurred due while running HTTP service toolkit.
public type HttpServiceToolKitError distinct error;

# Any error occurred during parsing HTTP response is classified under this error type.
public type HttpResponseParsingError distinct HttpServiceToolKitError;

# Errors during tool execution.
type ToolExecutionError distinct error;

# Error during unexpected output by the tool
type ToolInvaludOutputError distinct ToolExecutionError;

# Errors occurred due to invalid tool name generated by the LLM.
public type ToolNotFoundError distinct error;
public type ToolNotFoundError distinct LlmInvalidGenerationError;

# Errors occurred due to invalid input to the tool generated by the LLM.
public type ToolInvalidInputError distinct error;
public type ToolInvalidInputError distinct LlmInvalidGenerationError;

# Errors occurred due to missing mandotary path or query parameters.
public type MissingHttpParameterError distinct ToolInvalidInputError;

98 changes: 98 additions & 0 deletions ballerina/function_call.bal
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) 2024 WSO2 LLC (http://www.wso2.org) All Rights Reserved.
//
// WSO2 Inc. licenses this file to you under the Apache License,
// Version 2.0 (the "License"); you may not use this file except
// in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

# Function call agent.
# This agent uses OpenAI function call API to perform the tool selection.
public isolated class FunctionCallAgent {
*BaseAgent;
final ToolStore toolStore;
final FunctionCallLlm model;

# Initialize an Agent.
#
# + model - LLM model instance
# + tools - Tools to be used by the agent
public isolated function init(FunctionCallLlm model, (BaseToolKit|Tool)... tools) returns error? {
self.toolStore = check new (...tools);
self.model = model;
}

isolated function selectNextTool(ExecutionProgress progress) returns LlmToolResponse|LlmChatResponse|LlmError {
ChatMessage[] messages = createFunctionCallMessages(progress);
FunctionCall|string|error response = self.model.functionaCall(messages, self.toolStore.tools.toArray());
if response is error {
return error LlmConnectionError("Error while function call generation", response);
}
if response is string {
return {content: response};
}
string? name = response.name;
if name is () {
return {tool: error LlmInvalidGenerationError("Missing name", name = response.name, arguments = response.arguments), llmResponse: response.toJson()};
}
string? stringArgs = response.arguments;
map<json>|error? arguments = ();
if stringArgs is string {
arguments = stringArgs.fromJsonStringWithType();
}
if arguments is error {
return {tool: error LlmInvalidGenerationError("Invalid arguments", arguments, name = response.name, arguments = stringArgs), llmResponse: response.toJson()};
}
return {
tool: {
name,
arguments
},
llmResponse: {
name: name,
arguments: stringArgs
}
};
}
}

isolated function createFunctionCallMessages(ExecutionProgress progress) returns ChatMessage[] {
// add the question
ChatMessage[] messages = [
{
role: USER,
content: progress.query
}
];
// add the context as the first message
if progress.context !is () {
messages.unshift({
role: SYSTEM,
content: string `You can use these information if needed: ${progress.context.toString()}`
});
}
// include the history
foreach ExecutionStep step in progress.history {
FunctionCall|error functionCall = step.toolResponse.llmResponse.fromJsonWithType();
if functionCall is error {
panic error("Badly formated history for function call agent", generated = step.toolResponse.llmResponse);
}
messages.push({
role: ASSISTANT,
function_call: functionCall
}, {
role: FUNCTION,
name: functionCall.name,
content: getObservationString(step.observation)
});
}
return messages;
}
3 changes: 1 addition & 2 deletions ballerina/http_utils.bal
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import ballerina/http;
import ballerina/lang.regexp;
import ballerina/mime;
Expand Down Expand Up @@ -287,7 +286,7 @@ isolated function extractResponsePayload(string path, http:Response response) re
};
}

public isolated function getContentLength(http:Response response) returns int|error? {
isolated function getContentLength(http:Response response) returns int|error? {
string|error contentLengthHeader = response.getHeader(mime:CONTENT_LENGTH);
if contentLengthHeader is error || contentLengthHeader == "" {
return;
Expand Down
Loading

0 comments on commit d8c0adf

Please sign in to comment.