Skip to content

Commit

Permalink
implementing hidden agent (opensearch-project#2204) (opensearch-proje…
Browse files Browse the repository at this point in the history
…ct#2220)

* implementing hidden agent

Signed-off-by: Dhrubo Saha <[email protected]>

* added more validation in agent input

Signed-off-by: Dhrubo Saha <[email protected]>

* updated branch coverage

Signed-off-by: Dhrubo Saha <[email protected]>

* adding more test

Signed-off-by: Dhrubo Saha <[email protected]>

* fixing test

Signed-off-by: Dhrubo Saha <[email protected]>

* adding filter in search agent action

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments

Signed-off-by: Dhrubo Saha <[email protected]>

* add locale root

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments + put restriction on deleting hidden agents

Signed-off-by: Dhrubo Saha <[email protected]>

* updating isHiddenAgentfield

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
(cherry picked from commit affb047)

Co-authored-by: Dhrubo Saha <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and dhrubo-os authored Mar 19, 2024
1 parent af539dc commit a7d1ef0
Show file tree
Hide file tree
Showing 26 changed files with 912 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
Expand Down Expand Up @@ -484,7 +485,7 @@ public void deleteConnector() {

@Test
public void testRegisterAgent() {
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();
MLAgent mlAgent = MLAgent.builder().name("Agent name").type(MLAgentType.FLOW.name()).build();
assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
Expand Down Expand Up @@ -869,7 +870,7 @@ public void testRegisterAgent() {
}).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any());

ArgumentCaptor<MLRegisterAgentResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class);
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();
MLAgent mlAgent = MLAgent.builder().name("Agent name").type(MLAgentType.FLOW.name()).build();

machineLearningNodeClient.registerAgent(mlAgent, registerAgentResponseActionListener);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public class CommonValue {
public static final Integer ML_CONTROLLER_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MAP_RESPONSE_KEY = "response";
public static final String ML_AGENT_INDEX = ".plugins-ml-agent";
public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1;
public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 2;
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
Expand Down Expand Up @@ -419,6 +419,9 @@ public class CommonValue {
+ MLAgent.MEMORY_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
+ MLAgent.IS_HIDDEN_FIELD
+ "\": {\"type\": \"boolean\"},\n"
+ " \""
+ MLAgent.CREATED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
Expand Down
25 changes: 25 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLAgentType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common;

import java.util.Locale;

public enum MLAgentType {
FLOW,
CONVERSATIONAL,
CONVERSATIONAL_FLOW;

public static MLAgentType from(String value) {
if (value == null) {
throw new IllegalArgumentException("Agent type can't be null");
}
try {
return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong Agent type");
}
}
}
63 changes: 59 additions & 4 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLModel;

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand All @@ -41,6 +45,9 @@ public class MLAgent implements ToXContentObject, Writeable {
public static final String CREATED_TIME_FIELD = "created_time";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
public static final String APP_TYPE_FIELD = "app_type";
public static final String IS_HIDDEN_FIELD = "is_hidden";

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = Version.V_2_13_0;

private String name;
private String type;
Expand All @@ -53,6 +60,7 @@ public class MLAgent implements ToXContentObject, Writeable {
private Instant createdTime;
private Instant lastUpdateTime;
private String appType;
private Boolean isHidden;

@Builder(toBuilder = true)
public MLAgent(String name,
Expand All @@ -64,7 +72,8 @@ public MLAgent(String name,
MLMemorySpec memory,
Instant createdTime,
Instant lastUpdateTime,
String appType) {
String appType,
Boolean isHidden) {
this.name = name;
this.type = type;
this.description = description;
Expand All @@ -75,12 +84,18 @@ public MLAgent(String name,
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
this.appType = appType;
// is_hidden field isn't going to be set by user. It will be set by the code.
this.isHidden = isHidden;
validate();
}

private void validate() {
if (name == null) {
throw new IllegalArgumentException("agent name is null");
throw new IllegalArgumentException("Agent name can't be null");
}
validateMLAgentType(type);
if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) {
throw new IllegalArgumentException("We need model information for the conversational agent type");
}
Set<String> toolNames = new HashSet<>();
if (tools != null) {
Expand All @@ -95,7 +110,21 @@ private void validate() {
}
}

private void validateMLAgentType(String agentType) {
if (type == null) {
throw new IllegalArgumentException("Agent type can't be null");
} else {
try {
MLAgentType.valueOf(agentType.toUpperCase(Locale.ROOT)); // Use toUpperCase() to allow case-insensitive matching
} catch (IllegalArgumentException e) {
// The typeStr does not match any MLAgentType, so throw a new exception with a clearer message.
throw new IllegalArgumentException(agentType + " is not a valid Agent Type");
}
}
}

public MLAgent(StreamInput input) throws IOException{
Version streamInputVersion = input.getVersion();
name = input.readString();
type = input.readString();
description = input.readOptionalString();
Expand All @@ -118,10 +147,15 @@ public MLAgent(StreamInput input) throws IOException{
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
appType = input.readOptionalString();
// is_hidden field isn't going to be set by user. It will be set by the code.
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
isHidden = input.readOptionalBoolean();
}
validate();
}

public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeString(name);
out.writeString(type);
out.writeOptionalString(description);
Expand Down Expand Up @@ -155,6 +189,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalString(appType);
// is_hidden field isn't going to be set by user. It will be set by the code.
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) {
out.writeOptionalBoolean(isHidden);
}
}

@Override
Expand Down Expand Up @@ -190,21 +228,34 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (appType != null) {
builder.field(APP_TYPE_FIELD, appType);
}
// is_hidden field isn't going to be set by user. It will be set by the code.
if (isHidden != null) {
builder.field(MLModel.IS_HIDDEN_FIELD, isHidden);
}
builder.endObject();
return builder;
}

public static MLAgent parse(XContentParser parser) throws IOException {
return parseCommonFields(parser, true); // true to parse isHidden field
}

public static MLAgent parseFromUserInput(XContentParser parser) throws IOException {
return parseCommonFields(parser, false); // false to skip isHidden field
}

private static MLAgent parseCommonFields(XContentParser parser, boolean parseHidden) throws IOException {
String name = null;
String type = null;
String description = null;;
String description = null;
LLMSpec llm = null;
List<MLToolSpec> tools = null;
Map<String, String> parameters = null;
MLMemorySpec memory = null;
Instant createdTime = null;
Instant lastUpdateTime = null;
String appType = null;
boolean isHidden = false;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -246,11 +297,15 @@ public static MLAgent parse(XContentParser parser) throws IOException {
case APP_TYPE_FIELD:
appType = parser.text();
break;
case IS_HIDDEN_FIELD:
if (parseHidden) isHidden = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}

return MLAgent.builder()
.name(name)
.type(type)
Expand All @@ -262,9 +317,9 @@ public static MLAgent parse(XContentParser parser) throws IOException {
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.appType(appType)
.isHidden(isHidden)
.build();
}

public static MLAgent fromStream(StreamInput in) throws IOException {
MLAgent agent = new MLAgent(in);
return agent;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,28 @@
public class MLAgentGetRequest extends ActionRequest {

String agentId;
// This is to identify if the get request is initiated by user or not. Sometimes during
// delete/update options, we also perform get operation. This field is to distinguish between
// these two situations.
boolean isUserInitiatedGetRequest;

@Builder
public MLAgentGetRequest(String agentId) {
public MLAgentGetRequest(String agentId, boolean isUserInitiatedGetRequest) {
this.agentId = agentId;
this.isUserInitiatedGetRequest = isUserInitiatedGetRequest;
}

public MLAgentGetRequest(StreamInput in) throws IOException {
super(in);
this.agentId = in.readString();
this.isUserInitiatedGetRequest = in.readBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.agentId);
out.writeBoolean(isUserInitiatedGetRequest);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import static org.junit.Assert.*;
public class MLAgentTypeTests {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
@Test
public void testFromWithValidTypes() {
// Test all enum values to ensure they return correctly
assertEquals(MLAgentType.FLOW, MLAgentType.from("FLOW"));
assertEquals(MLAgentType.CONVERSATIONAL, MLAgentType.from("CONVERSATIONAL"));
assertEquals(MLAgentType.CONVERSATIONAL_FLOW, MLAgentType.from("CONVERSATIONAL_FLOW"));
}

@Test
public void testFromWithLowerCase() {
// Test with lowercase input
assertEquals(MLAgentType.FLOW, MLAgentType.from("flow"));
assertEquals(MLAgentType.CONVERSATIONAL, MLAgentType.from("conversational"));
assertEquals(MLAgentType.CONVERSATIONAL_FLOW, MLAgentType.from("conversational_flow"));
}

@Test
public void testFromWithMixedCase() {
// Test with mixed case input
assertEquals(MLAgentType.FLOW, MLAgentType.from("Flow"));
assertEquals(MLAgentType.CONVERSATIONAL, MLAgentType.from("Conversational"));
assertEquals(MLAgentType.CONVERSATIONAL_FLOW, MLAgentType.from("Conversational_Flow"));
}

@Test
public void testFromWithInvalidType() {
// This should throw an IllegalArgumentException
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Agent type");
MLAgentType.from("INVALID_TYPE");
}

@Test
public void testFromWithEmptyString() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Agent type");
// This should also throw an IllegalArgumentException
MLAgentType.from("");
}

@Test
public void testFromWithNull() {
// This should also throw an IllegalArgumentException
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Agent type can't be null");
MLAgentType.from(null);
}
}
Loading

0 comments on commit a7d1ef0

Please sign in to comment.