From a7d1ef0770157c6f1c6d52337770b760a120e599 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Mon, 18 Mar 2024 23:40:23 -0700 Subject: [PATCH] implementing hidden agent (#2204) (#2220) * implementing hidden agent Signed-off-by: Dhrubo Saha * added more validation in agent input Signed-off-by: Dhrubo Saha * updated branch coverage Signed-off-by: Dhrubo Saha * adding more test Signed-off-by: Dhrubo Saha * fixing test Signed-off-by: Dhrubo Saha * adding filter in search agent action Signed-off-by: Dhrubo Saha * addressing comments Signed-off-by: Dhrubo Saha * addressing comments Signed-off-by: Dhrubo Saha * add locale root Signed-off-by: Dhrubo Saha * addressing comments + put restriction on deleting hidden agents Signed-off-by: Dhrubo Saha * updating isHiddenAgentfield Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha (cherry picked from commit affb0473e631f8d63cadb0b84ddf690cc51cde28) Co-authored-by: Dhrubo Saha --- .../ml/client/MachineLearningClientTest.java | 3 +- .../client/MachineLearningNodeClientTest.java | 3 +- .../org/opensearch/ml/common/CommonValue.java | 5 +- .../org/opensearch/ml/common/MLAgentType.java | 25 +++ .../opensearch/ml/common/agent/MLAgent.java | 63 +++++- .../transport/agent/MLAgentGetRequest.java | 9 +- .../ml/common/MLAgentTypeTests.java | 64 ++++++ .../ml/common/agent/MLAgentTest.java | 114 +++++++++-- .../agent/MLAgentGetRequestTest.java | 16 +- .../agent/MLAgentGetResponseTest.java | 7 +- .../algorithms/agent/MLAgentExecutor.java | 10 +- .../algorithms/agent/MLAgentExecutorTest.java | 5 +- .../agent/MLChatAgentRunnerTest.java | 24 ++- .../agent/MLFlowAgentRunnerTest.java | 8 +- .../engine/indices/MLIndicesHandlerTest.java | 17 +- .../agents/DeleteAgentTransportAction.java | 79 ++++++-- .../agents/GetAgentTransportAction.java | 26 ++- .../agents/TransportRegisterAgentAction.java | 11 +- .../agents/TransportSearchAgentAction.java | 25 +++ .../ml/rest/RestMLGetAgentAction.java | 2 +- .../ml/rest/RestMLRegisterAgentAction.java | 2 +- .../DeleteAgentTransportActionTests.java | 189 +++++++++++++++++- .../agents/GetAgentTransportActionTests.java | 91 +++++++-- .../RegisterAgentTransportActionTests.java | 106 +++++++++- .../TransportSearchAgentActionTests.java | 103 ++++++++-- .../RegisterAgentTransportActionTests.java | 5 +- 26 files changed, 912 insertions(+), 100 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/MLAgentType.java create mode 100644 common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index ccc0e050e9..d7ae5ec334 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -29,6 +29,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.ToolMetadata; @@ -484,7 +485,7 @@ public void deleteConnector() { @Test public void testRegisterAgent() { - MLAgent mlAgent = MLAgent.builder().name("Agent name").build(); + MLAgent mlAgent = MLAgent.builder().name("Agent name").type(MLAgentType.FLOW.name()).build(); assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet()); } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index f81b20747f..ebdcbf9e87 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -56,6 +56,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; @@ -869,7 +870,7 @@ public void testRegisterAgent() { }).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); - MLAgent mlAgent = MLAgent.builder().name("Agent name").build(); + MLAgent mlAgent = MLAgent.builder().name("Agent name").type(MLAgentType.FLOW.name()).build(); machineLearningNodeClient.registerAgent(mlAgent, registerAgentResponseActionListener); diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 196b1eb478..c69a5e6c08 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -66,7 +66,7 @@ public class CommonValue { public static final Integer ML_CONTROLLER_INDEX_SCHEMA_VERSION = 1; public static final String ML_MAP_RESPONSE_KEY = "response"; public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; - public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1; + public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 2; public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1; public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; @@ -419,6 +419,9 @@ public class CommonValue { + MLAgent.MEMORY_FIELD + "\" : {\"type\": \"flat_object\"},\n" + " \"" + + MLAgent.IS_HIDDEN_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + MLAgent.CREATED_TIME_FIELD + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + " \"" diff --git a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java new file mode 100644 index 0000000000..a877024e97 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import java.util.Locale; + +public enum MLAgentType { + FLOW, + CONVERSATIONAL, + CONVERSATIONAL_FLOW; + + public static MLAgentType from(String value) { + if (value == null) { + throw new IllegalArgumentException("Agent type can't be null"); + } + try { + return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong Agent type"); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index 9033b92afc..f068863b1d 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -8,18 +8,22 @@ import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.MLAgentType; +import org.opensearch.ml.common.MLModel; import java.io.IOException; import java.time.Instant; import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -41,6 +45,9 @@ public class MLAgent implements ToXContentObject, Writeable { public static final String CREATED_TIME_FIELD = "created_time"; public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; public static final String APP_TYPE_FIELD = "app_type"; + public static final String IS_HIDDEN_FIELD = "is_hidden"; + + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = Version.V_2_13_0; private String name; private String type; @@ -53,6 +60,7 @@ public class MLAgent implements ToXContentObject, Writeable { private Instant createdTime; private Instant lastUpdateTime; private String appType; + private Boolean isHidden; @Builder(toBuilder = true) public MLAgent(String name, @@ -64,7 +72,8 @@ public MLAgent(String name, MLMemorySpec memory, Instant createdTime, Instant lastUpdateTime, - String appType) { + String appType, + Boolean isHidden) { this.name = name; this.type = type; this.description = description; @@ -75,12 +84,18 @@ public MLAgent(String name, this.createdTime = createdTime; this.lastUpdateTime = lastUpdateTime; this.appType = appType; + // is_hidden field isn't going to be set by user. It will be set by the code. + this.isHidden = isHidden; validate(); } private void validate() { if (name == null) { - throw new IllegalArgumentException("agent name is null"); + throw new IllegalArgumentException("Agent name can't be null"); + } + validateMLAgentType(type); + if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) { + throw new IllegalArgumentException("We need model information for the conversational agent type"); } Set toolNames = new HashSet<>(); if (tools != null) { @@ -95,7 +110,21 @@ private void validate() { } } + private void validateMLAgentType(String agentType) { + if (type == null) { + throw new IllegalArgumentException("Agent type can't be null"); + } else { + try { + MLAgentType.valueOf(agentType.toUpperCase(Locale.ROOT)); // Use toUpperCase() to allow case-insensitive matching + } catch (IllegalArgumentException e) { + // The typeStr does not match any MLAgentType, so throw a new exception with a clearer message. + throw new IllegalArgumentException(agentType + " is not a valid Agent Type"); + } + } + } + public MLAgent(StreamInput input) throws IOException{ + Version streamInputVersion = input.getVersion(); name = input.readString(); type = input.readString(); description = input.readOptionalString(); @@ -118,10 +147,15 @@ public MLAgent(StreamInput input) throws IOException{ createdTime = input.readOptionalInstant(); lastUpdateTime = input.readOptionalInstant(); appType = input.readOptionalString(); + // is_hidden field isn't going to be set by user. It will be set by the code. + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) { + isHidden = input.readOptionalBoolean(); + } validate(); } public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeString(name); out.writeString(type); out.writeOptionalString(description); @@ -155,6 +189,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInstant(createdTime); out.writeOptionalInstant(lastUpdateTime); out.writeOptionalString(appType); + // is_hidden field isn't going to be set by user. It will be set by the code. + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) { + out.writeOptionalBoolean(isHidden); + } } @Override @@ -190,14 +228,26 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (appType != null) { builder.field(APP_TYPE_FIELD, appType); } + // is_hidden field isn't going to be set by user. It will be set by the code. + if (isHidden != null) { + builder.field(MLModel.IS_HIDDEN_FIELD, isHidden); + } builder.endObject(); return builder; } public static MLAgent parse(XContentParser parser) throws IOException { + return parseCommonFields(parser, true); // true to parse isHidden field + } + + public static MLAgent parseFromUserInput(XContentParser parser) throws IOException { + return parseCommonFields(parser, false); // false to skip isHidden field + } + + private static MLAgent parseCommonFields(XContentParser parser, boolean parseHidden) throws IOException { String name = null; String type = null; - String description = null;; + String description = null; LLMSpec llm = null; List tools = null; Map parameters = null; @@ -205,6 +255,7 @@ public static MLAgent parse(XContentParser parser) throws IOException { Instant createdTime = null; Instant lastUpdateTime = null; String appType = null; + boolean isHidden = false; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -246,11 +297,15 @@ public static MLAgent parse(XContentParser parser) throws IOException { case APP_TYPE_FIELD: appType = parser.text(); break; + case IS_HIDDEN_FIELD: + if (parseHidden) isHidden = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } + return MLAgent.builder() .name(name) .type(type) @@ -262,9 +317,9 @@ public static MLAgent parse(XContentParser parser) throws IOException { .createdTime(createdTime) .lastUpdateTime(lastUpdateTime) .appType(appType) + .isHidden(isHidden) .build(); } - public static MLAgent fromStream(StreamInput in) throws IOException { MLAgent agent = new MLAgent(in); return agent; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java index 4880a07abf..ea65a768df 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java @@ -25,21 +25,28 @@ public class MLAgentGetRequest extends ActionRequest { String agentId; + // This is to identify if the get request is initiated by user or not. Sometimes during + // delete/update options, we also perform get operation. This field is to distinguish between + // these two situations. + boolean isUserInitiatedGetRequest; @Builder - public MLAgentGetRequest(String agentId) { + public MLAgentGetRequest(String agentId, boolean isUserInitiatedGetRequest) { this.agentId = agentId; + this.isUserInitiatedGetRequest = isUserInitiatedGetRequest; } public MLAgentGetRequest(StreamInput in) throws IOException { super(in); this.agentId = in.readString(); + this.isUserInitiatedGetRequest = in.readBoolean(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(this.agentId); + out.writeBoolean(isUserInitiatedGetRequest); } @Override diff --git a/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java new file mode 100644 index 0000000000..491324dc56 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static org.junit.Assert.*; +public class MLAgentTypeTests { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + @Test + public void testFromWithValidTypes() { + // Test all enum values to ensure they return correctly + assertEquals(MLAgentType.FLOW, MLAgentType.from("FLOW")); + assertEquals(MLAgentType.CONVERSATIONAL, MLAgentType.from("CONVERSATIONAL")); + assertEquals(MLAgentType.CONVERSATIONAL_FLOW, MLAgentType.from("CONVERSATIONAL_FLOW")); + } + + @Test + public void testFromWithLowerCase() { + // Test with lowercase input + assertEquals(MLAgentType.FLOW, MLAgentType.from("flow")); + assertEquals(MLAgentType.CONVERSATIONAL, MLAgentType.from("conversational")); + assertEquals(MLAgentType.CONVERSATIONAL_FLOW, MLAgentType.from("conversational_flow")); + } + + @Test + public void testFromWithMixedCase() { + // Test with mixed case input + assertEquals(MLAgentType.FLOW, MLAgentType.from("Flow")); + assertEquals(MLAgentType.CONVERSATIONAL, MLAgentType.from("Conversational")); + assertEquals(MLAgentType.CONVERSATIONAL_FLOW, MLAgentType.from("Conversational_Flow")); + } + + @Test + public void testFromWithInvalidType() { + // This should throw an IllegalArgumentException + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Wrong Agent type"); + MLAgentType.from("INVALID_TYPE"); + } + + @Test + public void testFromWithEmptyString() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Wrong Agent type"); + // This should also throw an IllegalArgumentException + MLAgentType.from(""); + } + + @Test + public void testFromWithNull() { + // This should also throw an IllegalArgumentException + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Agent type can't be null"); + MLAgentType.from(null); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index e00a49aeb6..34e03c8419 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -1,16 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.common.agent; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; @@ -30,9 +38,25 @@ public class MLAgentTest { @Test public void constructor_NullName() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("agent name is null"); + exceptionRule.expectMessage("Agent name can't be null"); + + MLAgent agent = new MLAgent(null, MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test", false); + } + + @Test + public void constructor_NullType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Agent type can't be null"); + + MLAgent agent = new MLAgent("test_agent", null, "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test", false); + } + + @Test + public void constructor_NullLLMSpec() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("We need model information for the conversational agent type"); - MLAgent agent = new MLAgent(null, "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test"); + MLAgent agent = new MLAgent("test_agent", MLAgentType.CONVERSATIONAL.name(), "test", null, List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test", false); } @Test @@ -40,12 +64,12 @@ public void constructor_DuplicateTool() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Duplicate tool defined: test_tool_name"); MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false); - MLAgent agent = new MLAgent("test_name", "test_type", "test_description", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(mlToolSpec, mlToolSpec), null, null, Instant.EPOCH, Instant.EPOCH, "test"); + MLAgent agent = new MLAgent("test_name", MLAgentType.CONVERSATIONAL.name(), "test_description", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(mlToolSpec, mlToolSpec), null, null, Instant.EPOCH, Instant.EPOCH, "test", false); } @Test public void writeTo() throws IOException { - MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + MLAgent agent = new MLAgent("test", "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -60,7 +84,7 @@ public void writeTo() throws IOException { @Test public void writeTo_NullLLM() throws IOException { - MLAgent agent = new MLAgent("test", "test", "test", null, List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + MLAgent agent = new MLAgent("test", "FLOW", "test", null, List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -70,7 +94,7 @@ public void writeTo_NullLLM() throws IOException { @Test public void writeTo_NullTools() throws IOException { - MLAgent agent = new MLAgent("test", "flow", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + MLAgent agent = new MLAgent("test", "FLOW", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -80,7 +104,7 @@ public void writeTo_NullTools() throws IOException { @Test public void writeTo_NullParameters() throws IOException { - MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + MLAgent agent = new MLAgent("test", MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -90,7 +114,7 @@ public void writeTo_NullParameters() throws IOException { @Test public void writeTo_NullMemory() throws IOException { - MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), null, Instant.EPOCH, Instant.EPOCH, "test"); + MLAgent agent = new MLAgent("test", "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), null, Instant.EPOCH, Instant.EPOCH, "test", false); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -100,25 +124,25 @@ public void writeTo_NullMemory() throws IOException { @Test public void toXContent() throws IOException { - MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + MLAgent agent = new MLAgent("test", "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); agent.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - String expectedStr = "{\"name\":\"test\",\"type\":\"test\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\"}"; + String expectedStr = "{\"name\":\"test\",\"type\":\"CONVERSATIONAL\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\",\"is_hidden\":false}"; Assert.assertEquals(content, expectedStr); } @Test public void parse() throws IOException { - String jsonStr = "{\"name\":\"test\",\"type\":\"test\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\"}"; + String jsonStr = "{\"name\":\"test\",\"type\":\"CONVERSATIONAL\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\",\"is_hidden\":false}"; XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), null, jsonStr); parser.nextToken(); MLAgent agent = MLAgent.parse(parser); Assert.assertEquals(agent.getName(), "test"); - Assert.assertEquals(agent.getType(), "test"); + Assert.assertEquals(agent.getType(), "CONVERSATIONAL"); Assert.assertEquals(agent.getDescription(), "test"); Assert.assertEquals(agent.getLlm().getModelId(), "test_model"); Assert.assertEquals(agent.getLlm().getParameters(), Map.of("test_key", "test_value")); @@ -132,11 +156,12 @@ public void parse() throws IOException { Assert.assertEquals(agent.getAppType(), "test"); Assert.assertEquals(agent.getMemory().getSessionId(), "123"); Assert.assertEquals(agent.getParameters(), Map.of("test", "test")); + Assert.assertEquals(agent.getIsHidden(), false); } @Test public void fromStream() throws IOException { - MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + MLAgent agent = new MLAgent("test", MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = MLAgent.fromStream(output.bytes().streamInput()); @@ -148,4 +173,65 @@ public void fromStream() throws IOException { Assert.assertEquals(agent.getParameters(), agent1.getParameters()); Assert.assertEquals(agent.getType(), agent1.getType()); } -} \ No newline at end of file + + @Test + public void constructor_InvalidAgentType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage(" is not a valid Agent Type"); + + new MLAgent("test_name", "INVALID_TYPE", "test_description", null, null, null, null, Instant.EPOCH, Instant.EPOCH, "test", false); + } + + @Test + public void constructor_NonConversationalNoLLM() { + try { + MLAgent agent = new MLAgent("test_name", MLAgentType.FLOW.name(), "test_description", null, null, null, null, Instant.EPOCH, Instant.EPOCH, "test", false); + assertNotNull(agent); // Ensuring object creation was successful without throwing an exception + } catch (IllegalArgumentException e) { + fail("Should not throw an exception for non-conversational types without LLM"); + } + } + + @Test + public void writeTo_ReadFrom_HiddenFlag_VersionCompatibility() throws IOException { + MLAgent agent = new MLAgent("test", "FLOW", "test", null, null, null, null, Instant.EPOCH, Instant.EPOCH, "test", true); + BytesStreamOutput output = new BytesStreamOutput(); + Version oldVersion = Version.fromString("2.12.0"); + output.setVersion(oldVersion); // Version before MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(oldVersion); + MLAgent agentOldVersion = new MLAgent(streamInput); + assertNull(agentOldVersion.getIsHidden()); // Hidden should be null for old versions + + output = new BytesStreamOutput(); + output.setVersion(Version.V_2_13_0); // Version at or after MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT + agent.writeTo(output); + StreamInput streamInput1 = output.bytes().streamInput(); + streamInput1.setVersion(Version.V_2_13_0); + MLAgent agentNewVersion = new MLAgent(output.bytes().streamInput()); + assertEquals(Boolean.TRUE, agentNewVersion.getIsHidden()); // Hidden should be true for new versions + } + + @Test + public void parse_MissingFields() throws IOException { + String jsonStr = "{\"name\":\"test\",\"type\":\"FLOW\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + MLAgent agent = MLAgent.parse(parser); + + assertEquals("test", agent.getName()); + assertEquals("FLOW", agent.getType()); + assertNull(agent.getDescription()); + assertNull(agent.getLlm()); + assertNull(agent.getTools()); + assertNull(agent.getParameters()); + assertNull(agent.getMemory()); + assertNull(agent.getCreatedTime()); + assertNull(agent.getLastUpdateTime()); + assertNull(agent.getAppType()); + assertFalse(agent.getIsHidden()); // Default value for boolean when not specified + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java index a207d9cc8b..c32fdebb5b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java @@ -22,15 +22,16 @@ public class MLAgentGetRequestTest { @Test public void constructor_AgentId() { agentId = "test-abc"; - MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); assertEquals(mLAgentGetRequest.getAgentId(),agentId); + assertEquals(mLAgentGetRequest.isUserInitiatedGetRequest(),true); } @Test public void writeTo() throws IOException { agentId = "test-hij"; - MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); BytesStreamOutput output = new BytesStreamOutput(); mLAgentGetRequest.writeTo(output); @@ -38,12 +39,13 @@ public void writeTo() throws IOException { assertEquals(mLAgentGetRequest1.getAgentId(), mLAgentGetRequest.getAgentId()); assertEquals(mLAgentGetRequest1.getAgentId(), agentId); + assertEquals(mLAgentGetRequest.isUserInitiatedGetRequest(), true); } @Test public void validate_Success() { agentId = "not-null"; - MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); assertEquals(null, mLAgentGetRequest.validate()); } @@ -51,7 +53,7 @@ public void validate_Success() { @Test public void validate_Failure() { agentId = null; - MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); assertEquals(null,mLAgentGetRequest.agentId); ActionRequestValidationException exception = addValidationError("ML agent id can't be null", null); @@ -60,14 +62,14 @@ public void validate_Failure() { @Test public void fromActionRequest_Success() throws IOException { agentId = "test-lmn"; - MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); assertEquals(mLAgentGetRequest.fromActionRequest(mLAgentGetRequest), mLAgentGetRequest); } @Test public void fromActionRequest_Success_fromActionRequest() throws IOException { agentId = "test-opq"; - MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); ActionRequest actionRequest = new ActionRequest() { @Override @@ -86,7 +88,7 @@ public void writeTo(StreamOutput out) throws IOException { @Test(expected = UncheckedIOException.class) public void fromActionRequest_IOException() { agentId = "test-rst"; - MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index 7ec0b9e4cb..34ef3f332b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -12,6 +12,7 @@ import org.opensearch.core.common.io.stream.*; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; @@ -36,7 +37,7 @@ public void setUp() { mlAgent = MLAgent.builder() .name("test_agent") .appType("test_app") - .type("flow") + .type(MLAgentType.FLOW.name()) .tools(Arrays.asList(MLToolSpec.builder().type("CatIndexTool").build())) .build(); } @@ -68,7 +69,7 @@ public void mLAgentGetResponse_Builder() throws IOException { @Test public void writeTo() throws IOException { //create ml agent using MLAgent and mlAgentGetResponse - mlAgent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); + mlAgent = new MLAgent("test", MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() .mlAgent(mlAgent) .build(); @@ -88,7 +89,7 @@ public void writeTo() throws IOException { @Test public void toXContent() throws IOException { - mlAgent = new MLAgent("mock", "flow", "test", null, null, null, null, null, null, "test"); + mlAgent = new MLAgent("mock", MLAgentType.FLOW.name(), "test", null, null, null, null, null, null, "test", false); MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() .mlAgent(mlAgent) .build(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 715ea94348..ec5031b595 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -29,6 +29,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -280,10 +281,11 @@ private ActionListener createAgentActionListener( @VisibleForTesting protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { - switch (mlAgent.getType()) { - case "flow": + final MLAgentType agentType = MLAgentType.from(mlAgent.getType().toUpperCase()); + switch (agentType) { + case FLOW: return new MLFlowAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap); - case "conversational_flow": + case CONVERSATIONAL_FLOW: return new MLConversationalFlowAgentRunner( client, settings, @@ -292,7 +294,7 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { toolFactories, memoryFactoryMap ); - case "conversational": + case CONVERSATIONAL: return new MLChatAgentRunner(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap); default: throw new IllegalArgumentException("Unsupported agent type: " + mlAgent.getType()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index ce89464e37..3475e69672 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -43,6 +43,8 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLAgentType; +import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.conversation.Interaction; @@ -480,7 +482,8 @@ public void test_CreateFlowAgent() { @Test public void test_CreateChatAgent() { - MLAgent mlAgent = MLAgent.builder().name("test_agent").type("conversational").build(); + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLAgent mlAgent = MLAgent.builder().name("test_agent").type(MLAgentType.CONVERSATIONAL.name()).llm(llmSpec).build(); MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent); Assert.assertTrue(mlAgentRunner instanceof MLChatAgentRunner); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 3d3538412e..760d9f9c06 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -44,6 +44,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; @@ -355,6 +356,7 @@ public void testRunWithIncludeOutputNotSet() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) .llm(llmSpec) .memory(mlMemorySpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) @@ -384,6 +386,7 @@ public void testRunWithIncludeOutputMLModel() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) .llm(llmSpec) .memory(mlMemorySpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) @@ -417,6 +420,7 @@ public void testRunWithIncludeOutputSet() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) .memory(mlMemorySpec) .llm(llmSpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) @@ -455,6 +459,7 @@ public void testChatHistoryExcludeOngoingQuestion() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) .memory(mlMemorySpec) .llm(llmSpec) .description("mlagent description") @@ -510,6 +515,7 @@ private void testInteractions(String maxInteraction) { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) .memory(mlMemorySpec) .llm(llmSpec) .description("mlagent description") @@ -542,6 +548,7 @@ public void testChatHistoryException() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) .memory(mlMemorySpec) .llm(llmSpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) @@ -609,7 +616,13 @@ public void testToolValidationFailure() { public void testToolNotFound() { // Create an MLAgent without tools LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); - MLAgent mlAgent = MLAgent.builder().memory(mlMemorySpec).llm(llmSpec).name("TestAgent").build(); + MLAgent mlAgent = MLAgent + .builder() + .type(MLAgentType.CONVERSATIONAL.name()) + .memory(mlMemorySpec) + .llm(llmSpec) + .name("TestAgent") + .build(); // Create parameters for the agent with a non-existent tool Map params = createAgentParamsWithAction("nonExistentTool", "someInput"); @@ -763,7 +776,14 @@ private MLAgent createMLAgentWithTools() { .type(FIRST_TOOL) .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) .build(); - return MLAgent.builder().name("TestAgent").tools(Arrays.asList(firstToolSpec)).memory(mlMemorySpec).llm(llmSpec).build(); + return MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .tools(Arrays.asList(firstToolSpec)) + .memory(mlMemorySpec) + .llm(llmSpec) + .build(); } private Map createAgentParamsWithAction(String action, String actionInput) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java index 68cae251b9..609609438a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java @@ -46,6 +46,7 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; @@ -183,6 +184,7 @@ public void testRunWithIncludeOutputNotSet() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.FLOW.name()) .memory(mlMemorySpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) .build(); @@ -205,7 +207,7 @@ public void testRunWithNoToolSpec() { final Map params = new HashMap<>(); params.put(MLAgentExecutor.MEMORY_ID, "memoryId"); MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build(); - final MLAgent mlAgent = MLAgent.builder().name("TestAgent").memory(mlMemorySpec).build(); + final MLAgent mlAgent = MLAgent.builder().name("TestAgent").type(MLAgentType.FLOW.name()).memory(mlMemorySpec).build(); mlFlowAgentRunner.run(mlAgent, params, agentActionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); verify(agentActionListener).onFailure(argCaptor.capture()); @@ -236,6 +238,7 @@ public void testRunWithIncludeOutputSet() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.FLOW.name()) .memory(mlMemorySpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) .build(); @@ -264,6 +267,7 @@ public void testRunWithModelTensorOutput() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.FLOW.name()) .memory(mlMemorySpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) .build(); @@ -394,6 +398,7 @@ public void testWithMemoryNotSet() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.FLOW.name()) .memory(null) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) .build(); @@ -456,6 +461,7 @@ public void testRunWithUpdateFailure() { final MLAgent mlAgent = MLAgent .builder() .name("TestAgent") + .type(MLAgentType.FLOW.name()) .memory(mlMemorySpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) .build(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java index be2a5669ca..021397fae0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java @@ -60,10 +60,15 @@ public class MLIndicesHandlerTest { Metadata metadata; @Mock - IndexMetadata indexMetadata; + IndexMetadata agentindexMetadata; + @Mock + IndexMetadata memorymetaindexMetadata; + + @Mock + MappingMetadata agentmappingMetadata; @Mock - MappingMetadata mappingMetadata; + MappingMetadata memorymappingMetadata; @Mock private ThreadPool threadPool; @@ -87,9 +92,11 @@ public void setUp() { when(clusterState.metadata()).thenReturn(metadata); when(clusterState.getMetadata()).thenReturn(metadata); when(metadata.hasIndex(anyString())).thenReturn(true); - when(metadata.indices()).thenReturn(Map.of(ML_AGENT_INDEX, indexMetadata, ML_MEMORY_META_INDEX, indexMetadata)); - when(indexMetadata.mapping()).thenReturn(mappingMetadata); - when(mappingMetadata.getSourceAsMap()).thenReturn(Map.of(META, Map.of(SCHEMA_VERSION_FIELD, Integer.valueOf(1)))); + when(metadata.indices()).thenReturn(Map.of(ML_AGENT_INDEX, agentindexMetadata, ML_MEMORY_META_INDEX, memorymetaindexMetadata)); + when(agentindexMetadata.mapping()).thenReturn(agentmappingMetadata); + when(memorymetaindexMetadata.mapping()).thenReturn(memorymappingMetadata); + when(agentmappingMetadata.getSourceAsMap()).thenReturn(Map.of(META, Map.of(SCHEMA_VERSION_FIELD, Integer.valueOf(2)))); + when(memorymappingMetadata.getSourceAsMap()).thenReturn(Map.of(META, Map.of(SCHEMA_VERSION_FIELD, Integer.valueOf(1)))); settings = Settings.builder().put("test_key", 10).build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java index 0376fdbba9..33b113c403 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/DeleteAgentTransportAction.java @@ -5,23 +5,34 @@ package org.opensearch.ml.action.agents; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; +import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import com.google.common.annotations.VisibleForTesting; + import lombok.extern.log4j.Log4j2; @Log4j2 @@ -30,41 +41,79 @@ public class DeleteAgentTransportAction extends HandledTransportAction actionListener) { MLAgentDeleteRequest mlAgentDeleteRequest = MLAgentDeleteRequest.fromActionRequest(request); String agentId = mlAgentDeleteRequest.getAgentId(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); - DeleteRequest deleteRequest = new DeleteRequest(ML_AGENT_INDEX, agentId); - client.delete(deleteRequest, new ActionListener() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - log.debug("Completed Delete Agent Request, agent id:{} deleted", agentId); - wrappedListener.onResponse(deleteResponse); - } + GetRequest getRequest = new GetRequest(ML_AGENT_INDEX).id(agentId); + boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); - @Override - public void onFailure(Exception e) { - log.error("Failed to delete ML Agent " + agentId, e); - wrappedListener.onFailure(e); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (getResponse.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, getResponse.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLAgent mlAgent = MLAgent.parse(parser); + if (mlAgent.getIsHidden() && !isSuperAdmin) { + actionListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this hidden agent", + RestStatus.FORBIDDEN + ) + ); + } else { + // If the agent is not hidden or if the user is a super admin, proceed with deletion + DeleteRequest deleteRequest = new DeleteRequest(ML_AGENT_INDEX, agentId); + client.delete(deleteRequest, ActionListener.wrap(deleteResponse -> { + log.debug("Completed Delete Agent Request, agent id:{} deleted", agentId); + actionListener.onResponse(deleteResponse); + }, deleteException -> { + log.error("Failed to delete ML Agent " + agentId, deleteException); + actionListener.onFailure(deleteException); + })); + } + } catch (Exception parseException) { + log.error("Failed to parse ml agent " + getResponse.getId(), parseException); + actionListener.onFailure(parseException); + } + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find agent with the provided agent id: " + agentId, + RestStatus.NOT_FOUND + ) + ); } - }); + }, getException -> { + log.error("Failed to get ml agent " + agentId, getException); + actionListener.onFailure(getException); + })); } catch (Exception e) { log.error("Failed to delete ml agent " + agentId, e); actionListener.onFailure(e); } } + + @VisibleForTesting + boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) { + return RestActionUtils.isSuperAdminUser(clusterService, client); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java index 59d17651a3..a50a6f70a1 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/GetAgentTransportAction.java @@ -15,6 +15,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -26,9 +27,12 @@ import org.opensearch.ml.common.transport.agent.MLAgentGetAction; import org.opensearch.ml.common.transport.agent.MLAgentGetRequest; import org.opensearch.ml.common.transport.agent.MLAgentGetResponse; +import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import com.google.common.annotations.VisibleForTesting; + import lombok.AccessLevel; import lombok.experimental.FieldDefaults; import lombok.extern.log4j.Log4j2; @@ -40,16 +44,20 @@ public class GetAgentTransportAction extends HandledTransportAction { @@ -66,7 +75,17 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { Instant now = Instant.now(); - MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).build(); + boolean isHiddenAgent = RestActionUtils.isSuperAdminUser(clusterService, client); + MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).isHidden(isHiddenAgent).build(); mlIndicesHandler.initMLAgentIndex(ActionListener.wrap(result -> { if (result) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -87,5 +93,4 @@ private void registerAgent(MLAgent agent, ActionListener action ActionListener listener = wrapRestActionListener(actionListener, "Fail to search agent"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + // Check if the original query is not null before adding it to the must clause + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + if (request.source().query() != null) { + queryBuilder.must(request.source().query()); + } + + // Create a BoolQueryBuilder for the should clauses + BoolQueryBuilder shouldQuery = QueryBuilders.boolQuery(); + + // Add a should clause to include documents where IS_HIDDEN_FIELD is false + shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false)); + + // Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null + shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD))); + + // Set minimum should match to 1 to ensure at least one of the should conditions is met + shouldQuery.minimumShouldMatch(1); + + // Add the shouldQuery to the main queryBuilder + queryBuilder.filter(shouldQuery); + + request.source().query(queryBuilder); client.search(request, wrappedListener); } catch (Exception e) { log.error("failed to search the agent index", e); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java index 1030f8bb57..10da7ccaae 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetAgentAction.java @@ -66,6 +66,6 @@ MLAgentGetRequest getRequest(RestRequest request) throws IOException { } String agentId = getParameterId(request, PARAMETER_AGENT_ID); - return new MLAgentGetRequest(agentId); + return new MLAgentGetRequest(agentId, true); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java index cbcb616f2b..9242158e01 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java @@ -66,7 +66,7 @@ MLRegisterAgentRequest getRequest(RestRequest request) throws IOException { } XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLAgent mlAgent = MLAgent.parse(parser); + MLAgent mlAgent = MLAgent.parseFromUserInput(parser); return new MLRegisterAgentRequest(mlAgent); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java index 212112841a..3aa8d72906 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java @@ -14,12 +14,18 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest; import org.opensearch.tasks.Task; @@ -38,6 +44,9 @@ public class DeleteAgentTransportActionTests { @Mock private TransportService transportService; + @Mock + ClusterService clusterService; + @Mock private ActionFilters actionFilters; @@ -49,10 +58,13 @@ public class DeleteAgentTransportActionTests { @Before public void setup() { MockitoAnnotations.openMocks(this); - deleteAgentTransportAction = new DeleteAgentTransportAction(transportService, actionFilters, client, xContentRegistry); + deleteAgentTransportAction = spy( + new DeleteAgentTransportAction(transportService, actionFilters, client, xContentRegistry, clusterService) + ); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); + when(clusterService.getSettings()).thenReturn(settings); when(threadPool.getThreadContext()).thenReturn(threadContext); } @@ -67,13 +79,26 @@ public void testConstructor() { public void testDoExecute_Success() { String agentId = "test-agent-id"; DeleteResponse deleteResponse = mock(DeleteResponse.class); + GetResponse getResponse = mock(GetResponse.class); ActionListener actionListener = mock(ActionListener.class); MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + doReturn(true).when(deleteAgentTransportAction).isSuperAdminUserWrapper(clusterService, client); + + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsBytesRef()).thenReturn(new BytesArray("{\"is_hidden\":true, \"name\":\"agent\", \"type\":\"flow\"}")); // Mock + // agent + // source Task task = mock(Task.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); @@ -88,6 +113,11 @@ public void testDoExecute_Success() { @Test public void testDoExecute_Failure() { String agentId = "test-non-existed-agent-id"; + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsBytesRef()).thenReturn(new BytesArray("{\"is_hidden\":false, \"name\":\"agent\", \"type\":\"flow\"}")); // Mock + // agent + // source ActionListener actionListener = mock(ActionListener.class); @@ -95,6 +125,12 @@ public void testDoExecute_Failure() { Task task = mock(Task.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); NullPointerException NullPointerException = new NullPointerException("Failed to delete ML Agent " + agentId); @@ -106,7 +142,158 @@ public void testDoExecute_Failure() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to delete ML Agent " + agentId, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testDoExecute_HiddenAgentSuperAdmin() { + String agentId = "test-agent-id"; + DeleteResponse deleteResponse = mock(DeleteResponse.class); + GetResponse getResponse = mock(GetResponse.class); + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsBytesRef()).thenReturn(new BytesArray("{\"is_hidden\":true, \"name\":\"agent\", \"type\":\"flow\"}")); // Mock + // agent + // source + + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteAgentTransportAction.doExecute(task, deleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + } + + @Test + public void testDoExecute_HiddenAgentDeletionByNonSuperAdmin() { + String agentId = "hidden-agent-id"; + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsBytesRef()) + .thenReturn(new BytesArray("{\"is_hidden\":true, \"name\":\"hidden-agent\", \"type\":\"flow\"}")); + + ActionListener actionListener = mock(ActionListener.class); + MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + doReturn(false).when(deleteAgentTransportAction).isSuperAdminUserWrapper(clusterService, client); + + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + deleteAgentTransportAction.doExecute(task, deleteRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(RestStatus.FORBIDDEN, argumentCaptor.getValue().status()); + } + + @Test + public void testDoExecute_NonHiddenAgentDeletionByNonSuperAdmin() { + String agentId = "non-hidden-agent-id"; + GetResponse getResponse = mock(GetResponse.class); + DeleteResponse deleteResponse = mock(DeleteResponse.class); + + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsBytesRef()) + .thenReturn(new BytesArray("{\"is_hidden\":false, \"name\":\"non-hidden-agent\", \"type\":\"flow\"}")); + ActionListener actionListener = mock(ActionListener.class); + MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + doReturn(false).when(deleteAgentTransportAction).isSuperAdminUserWrapper(clusterService, client); + + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteAgentTransportAction.doExecute(task, deleteRequest, actionListener); + + verify(actionListener).onResponse(any(DeleteResponse.class)); } + @Test + public void testDoExecute_GetFails() { + String agentId = "test-agent-id"; + ActionListener actionListener = mock(ActionListener.class); + MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + + Task task = mock(Task.class); + Exception expectedException = new RuntimeException("Failed to fetch agent"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(expectedException); + return null; + }).when(client).get(any(), any()); + + deleteAgentTransportAction.doExecute(task, deleteRequest, actionListener); + + verify(actionListener).onFailure(any(RuntimeException.class)); + } + + @Test + public void testDoExecute_DeleteFails() { + String agentId = "test-agent-id"; + GetResponse getResponse = mock(GetResponse.class); + Exception expectedException = new RuntimeException("Deletion failed"); + + ActionListener actionListener = mock(ActionListener.class); + + MLAgentDeleteRequest deleteRequest = new MLAgentDeleteRequest(agentId); + + Task task = mock(Task.class); + + // Mock the GetResponse to simulate finding the agent + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsBytesRef()).thenReturn(new BytesArray("{\"is_hidden\":false, \"name\":\"agent\", \"type\":\"flow\"}")); + + // Mock the client.get() call to return the mocked GetResponse + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + // Mock the client.delete() call to throw an exception + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(expectedException); + return null; + }).when(client).delete(any(), any()); + + // Execute the action + deleteAgentTransportAction.doExecute(task, deleteRequest, actionListener); + + // Verify that actionListener.onFailure() was called with the expected exception + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Deletion failed", argumentCaptor.getValue().getMessage()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java index 07f406ac07..8a0ab62168 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java @@ -24,6 +24,9 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; @@ -35,6 +38,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; @@ -55,12 +59,19 @@ public class GetAgentTransportActionTests extends OpenSearchTestCase { private Client client; @Mock ThreadPool threadPool; + + private ClusterSettings clusterSettings; + + @Mock + private ClusterService clusterService; @Mock private NamedXContentRegistry xContentRegistry; @Mock private TransportService transportService; + @Mock + ClusterState clusterState; @Mock private ActionFilters actionFilters; @@ -73,11 +84,15 @@ public class GetAgentTransportActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - getAgentTransportAction = new GetAgentTransportAction(transportService, actionFilters, client, xContentRegistry); + getAgentTransportAction = spy( + new GetAgentTransportAction(transportService, actionFilters, client, clusterService, xContentRegistry) + ); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); } @@ -87,7 +102,7 @@ public void testDoExecute_Failure_Get_Agent() { ActionListener actionListener = mock(ActionListener.class); - MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId, true); Task task = mock(Task.class); @@ -111,7 +126,7 @@ public void testDoExecute_Failure_IndexNotFound() { ActionListener actionListener = mock(ActionListener.class); - MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId, true); Task task = mock(Task.class); @@ -135,7 +150,7 @@ public void testDoExecute_Failure_OpenSearchStatus() throws IOException { ActionListener actionListener = mock(ActionListener.class); - MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId, true); Task task = mock(Task.class); @@ -162,7 +177,7 @@ public void testDoExecute_RuntimeException() { Task task = mock(Task.class); ActionListener actionListener = mock(ActionListener.class); - MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId, true); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Failed to get ML agent " + agentId)); @@ -179,7 +194,7 @@ public void testGetTask_NullResponse() { String agentId = "test-agent-id-NullResponse"; Task task = mock(Task.class); ActionListener actionListener = mock(ActionListener.class); - MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId, true); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(null); @@ -196,12 +211,13 @@ public void testDoExecute_Failure_Context_Exception() { String agentId = "test-agent-id"; ActionListener actionListener = mock(ActionListener.class); - MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId); + MLAgentGetRequest getRequest = new MLAgentGetRequest(agentId, true); Task task = mock(Task.class); GetAgentTransportAction getAgentTransportActionNullContext = new GetAgentTransportAction( transportService, actionFilters, client, + clusterService, xContentRegistry ); when(client.threadPool()).thenReturn(threadPool); @@ -220,11 +236,11 @@ public void testDoExecute_Failure_Context_Exception() { @Test public void testDoExecute_NoAgentId() throws IOException { - GetResponse getResponse = prepareMLAgent(null); + GetResponse getResponse = prepareMLAgent(null, false); String agentId = "test-agent-id"; ActionListener actionListener = mock(ActionListener.class); - MLAgentGetRequest request = new MLAgentGetRequest(agentId); + MLAgentGetRequest request = new MLAgentGetRequest(agentId, true); Task task = mock(Task.class); doAnswer(invocation -> { @@ -244,9 +260,9 @@ public void testDoExecute_NoAgentId() throws IOException { public void testDoExecute_Success() throws IOException { String agentId = "test-agent-id"; - GetResponse getResponse = prepareMLAgent(agentId); + GetResponse getResponse = prepareMLAgent(agentId, false); ActionListener actionListener = mock(ActionListener.class); - MLAgentGetRequest request = new MLAgentGetRequest(agentId); + MLAgentGetRequest request = new MLAgentGetRequest(agentId, true); Task task = mock(Task.class); doAnswer(invocation -> { @@ -259,11 +275,11 @@ public void testDoExecute_Success() throws IOException { verify(actionListener).onResponse(any(MLAgentGetResponse.class)); } - public GetResponse prepareMLAgent(String agentId) throws IOException { + public GetResponse prepareMLAgent(String agentId, boolean isHidden) throws IOException { mlAgent = new MLAgent( "test", - "test", + MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), @@ -271,7 +287,8 @@ public GetResponse prepareMLAgent(String agentId) throws IOException { new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, - "test" + "test", + isHidden ); XContentBuilder content = mlAgent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); @@ -280,4 +297,50 @@ public GetResponse prepareMLAgent(String agentId) throws IOException { GetResponse getResponse = new GetResponse(getResult); return getResponse; } + + @Test + public void testRemoveModelIDIfHiddenAndNotSuperUser() throws IOException { + + String agentId = "test-agent-id"; + GetResponse getResponse = prepareMLAgent(agentId, true); + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest request = new MLAgentGetRequest(agentId, true); + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + doReturn(false).when(getAgentTransportAction).isSuperAdminUserWrapper(clusterService, client); + getAgentTransportAction.doExecute(task, request, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("User doesn't have privilege to perform this operation on this agent", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testNotRemoveModelIDIfHiddenAndSuperUser() throws IOException { + + String agentId = "test-agent-id"; + GetResponse getResponse = prepareMLAgent(agentId, true); + ActionListener actionListener = mock(ActionListener.class); + MLAgentGetRequest request = new MLAgentGetRequest(agentId, true); + Task task = mock(Task.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + doReturn(true).when(getAgentTransportAction).isSuperAdminUserWrapper(clusterService, client); + getAgentTransportAction.doExecute(task, request, actionListener); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLAgentGetResponse.class); + verify(actionListener, times(1)).onResponse(captor.capture()); + MLAgentGetResponse mlAgentGetResponse = captor.getValue(); + assertNotNull(mlAgentGetResponse.getMlAgent().getLlm()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java index 592bca752d..6af2565aa4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java @@ -15,6 +15,7 @@ import java.util.HashMap; import org.junit.Before; +import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -22,11 +23,13 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; @@ -51,6 +54,9 @@ public class RegisterAgentTransportActionTests extends OpenSearchTestCase { @Mock private TransportService transportService; + @Mock + private ClusterService clusterService; + @Mock private Task task; @@ -74,15 +80,23 @@ public void setup() throws IOException { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - transportRegisterAgentAction = new TransportRegisterAgentAction(transportService, actionFilters, client, mlIndicesHandler); + when(clusterService.getSettings()).thenReturn(settings); + transportRegisterAgentAction = new TransportRegisterAgentAction( + transportService, + actionFilters, + client, + mlIndicesHandler, + clusterService + ); } + @Test public void test_execute_registerAgent_success() { MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); MLAgent mlAgent = MLAgent .builder() .name("agent") - .type("some type") + .type(MLAgentType.CONVERSATIONAL.name()) .description("description") .llm(new LLMSpec("model_id", new HashMap<>())) .build(); @@ -106,12 +120,13 @@ public void test_execute_registerAgent_success() { verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void test_execute_registerAgent_AgentIndexNotInitialized() { MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); MLAgent mlAgent = MLAgent .builder() .name("agent") - .type("some type") + .type(MLAgentType.CONVERSATIONAL.name()) .description("description") .llm(new LLMSpec("model_id", new HashMap<>())) .build(); @@ -129,12 +144,13 @@ public void test_execute_registerAgent_AgentIndexNotInitialized() { assertEquals("Failed to create ML agent index", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_registerAgent_IndexFailure() { MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); MLAgent mlAgent = MLAgent .builder() .name("agent") - .type("some type") + .type(MLAgentType.CONVERSATIONAL.name()) .description("description") .llm(new LLMSpec("model_id", new HashMap<>())) .build(); @@ -159,12 +175,13 @@ public void test_execute_registerAgent_IndexFailure() { assertEquals("index failure", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_registerAgent_InitAgentIndexFailure() { MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); MLAgent mlAgent = MLAgent .builder() .name("agent") - .type("some type") + .type(MLAgentType.CONVERSATIONAL.name()) .description("description") .llm(new LLMSpec("model_id", new HashMap<>())) .build(); @@ -181,4 +198,83 @@ public void test_execute_registerAgent_InitAgentIndexFailure() { verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("agent index initialization failed", argumentCaptor.getValue().getMessage()); } + + @Test + public void test_execute_registerAgent_ModelNotHidden() { + MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); + MLAgent mlAgent = MLAgent + .builder() + .name("agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .description("description") + .llm(new LLMSpec("model_id", new HashMap<>())) + .build(); + when(request.getMlAgent()).thenReturn(mlAgent); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); // Simulate successful index initialization + return null; + }).when(mlIndicesHandler).initMLAgentIndex(any()); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(mock(IndexResponse.class)); // Simulating successful indexing + return null; + }).when(client).index(any(), any()); + + transportRegisterAgentAction.doExecute(task, request, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + + assertNotNull(argumentCaptor.getValue()); + } + + @Test + public void test_execute_registerAgent_Othertype() { + MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); + MLAgent mlAgent = MLAgent.builder().name("agent").type(MLAgentType.FLOW.name()).description("description").build(); + when(request.getMlAgent()).thenReturn(mlAgent); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); // Simulate successful index initialization + return null; + }).when(mlIndicesHandler).initMLAgentIndex(any()); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(mock(IndexResponse.class)); // Simulating successful indexing + return null; + }).when(client).index(any(), any()); + + transportRegisterAgentAction.doExecute(task, request, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + + assertNotNull(argumentCaptor.getValue()); + } + + // @Test + // public void test_execute_ModelNotFound() { + // MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); + // MLAgent mlAgent = MLAgent + // .builder() + // .name("agent") + // .type(MLAgentType.CONVERSATIONAL.name()) + // .description("description") + // .llm(new LLMSpec("model_id", new HashMap<>())) + // .build(); + // when(request.getMlAgent()).thenReturn(mlAgent); + // + // transportRegisterAgentAction.doExecute(task, request, actionListener); + // + // ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + // verify(actionListener).onFailure(argumentCaptor.capture()); + // + // assertNotNull(argumentCaptor.getValue()); + // } + } diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java index 9712aa9db9..9853578bfa 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java @@ -7,11 +7,11 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.*; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchRequest; @@ -21,6 +21,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -35,9 +37,6 @@ public class TransportSearchAgentActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; - @Mock - SearchRequest searchRequest; - @Mock ActionListener actionListener; @@ -56,26 +55,106 @@ public void setup() { when(threadPool.getThreadContext()).thenReturn(threadContext); } - public void test_DoExecute_OnResponse() { + @Test + public void testDoExecuteWithEmptyQuery() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(mockedSearchResponse); return null; - }).when(client).search(any(), isA(ActionListener.class)); - transportSearchAgentAction.doExecute(null, searchRequest, actionListener); - verify(client, times(1)).search(eq(searchRequest), any()); + }).when(client).search(eq(request), any()); + + transportSearchAgentAction.doExecute(null, request, actionListener); + + verify(client, times(1)).search(eq(request), any()); + verify(actionListener, times(1)).onResponse(eq(mockedSearchResponse)); + } + + @Test + public void testDoExecuteWithNonEmptyQuery() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.matchAllQuery()); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(eq(request), any()); + + transportSearchAgentAction.doExecute(null, request, actionListener); + + verify(client, times(1)).search(eq(request), any()); verify(actionListener, times(1)).onResponse(eq(mockedSearchResponse)); } @Test - public void test_DoExecute_OnFailure() { + public void testDoExecuteOnFailure() { + SearchRequest request = new SearchRequest("my_index"); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("runtime exception")); + listener.onFailure(new Exception("test exception")); return null; - }).when(client).search(any(), isA(ActionListener.class)); + }).when(client).search(eq(request), any()); + + transportSearchAgentAction.doExecute(null, request, actionListener); + + verify(client, times(1)).search(eq(request), any()); + verify(actionListener, times(1)).onFailure(any(Exception.class)); + } + + @Test + public void testSearchWithHiddenField() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(eq(request), any()); + + transportSearchAgentAction.doExecute(null, request, actionListener); + + verify(client, times(1)).search(eq(request), any()); + verify(actionListener, times(1)).onResponse(eq(mockedSearchResponse)); + } + + @Test + public void testSearchException() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("failed to search the agent index")); + return null; + }).when(client).search(eq(request), any()); + + transportSearchAgentAction.doExecute(null, request, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to search agent", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testSearchThrowsException() { + // Mock the client to throw an exception when the search method is called + doThrow(new RuntimeException("Search failed")).when(client).search(any(SearchRequest.class), any()); + + // Create a search request + SearchRequest searchRequest = new SearchRequest(); + + // Execute the action transportSearchAgentAction.doExecute(null, searchRequest, actionListener); - verify(client, times(1)).search(eq(searchRequest), any()); + + // Verify that the actionListener's onFailure method was called verify(actionListener, times(1)).onFailure(any(RuntimeException.class)); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RegisterAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RegisterAgentTransportActionTests.java index ec7870ec0b..76d4c63581 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RegisterAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RegisterAgentTransportActionTests.java @@ -40,6 +40,7 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; import org.opensearch.ml.settings.MLFeatureEnabledSetting; @@ -89,7 +90,7 @@ public void testPrepareRequest() throws Exception { AGENT_NAME_FIELD, "agent-name", AGENT_TYPE_FIELD, - "agent-type", + MLAgentType.CONVERSATIONAL.name(), DESCRIPTION_FIELD, "description", LLM_FIELD, @@ -120,7 +121,7 @@ public void testPrepareRequest() throws Exception { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentRequest.class); verify(client, times(1)).execute(eq(MLRegisterAgentAction.INSTANCE), argumentCaptor.capture(), any()); assert (argumentCaptor.getValue().getMlAgent().getName().equals("agent-name")); - assert (argumentCaptor.getValue().getMlAgent().getType().equals("agent-type")); + assert (argumentCaptor.getValue().getMlAgent().getType().equals(MLAgentType.CONVERSATIONAL.name())); assert (argumentCaptor.getValue().getMlAgent().getDescription().equals("description")); assert (argumentCaptor.getValue().getMlAgent().getTools().equals(new ArrayList<>())); assert (argumentCaptor.getValue().getMlAgent().getLlm().getModelId().equals("id"));