From e8dc3915d2b41fc0da57bc163476590a201a94eb Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Wed, 31 Jan 2024 12:13:24 -0600 Subject: [PATCH] add get config api to retrieve root agent id Signed-off-by: Bhavana Ramaram --- .../org/opensearch/ml/common/CommonValue.java | 10 ++ .../opensearch/ml/common/Configuration.java | 83 +++++++++++ .../org/opensearch/ml/common/MLConfig.java | 136 ++++++++++++++++++ .../transport/config/MLConfigGetAction.java | 16 +++ .../transport/config/MLConfigGetRequest.java | 71 +++++++++ .../transport/config/MLConfigGetResponse.java | 62 ++++++++ .../config/MLConfigGetActionTest.java | 21 +++ .../config/MLConfigGetRequestTest.java | 105 ++++++++++++++ .../config/MLConfigGetResponseTest.java | 136 ++++++++++++++++++ .../config/GetConfigTransportAction.java | 97 +++++++++++++ .../ml/plugin/MachineLearningPlugin.java | 10 +- .../ml/rest/RestMLGetConfigAction.java | 68 +++++++++ .../opensearch/ml/utils/RestActionUtils.java | 1 + .../config/GetConfigTransportActionTests.java | 110 ++++++++++++++ .../ml/rest/RestMLGetConfigActionTests.java | 105 ++++++++++++++ 15 files changed, 1029 insertions(+), 2 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/Configuration.java create mode 100644 common/src/main/java/org/opensearch/ml/common/MLConfig.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetResponse.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetActionTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/config/GetConfigTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConfigAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/config/GetConfigTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConfigActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 4ff3917080..e99736ea2a 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -45,6 +45,7 @@ public class CommonValue { public static final String MASTER_KEY = "master_key"; public static final String CREATE_TIME_FIELD = "create_time"; + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String BOX_TYPE_KEY = "box_type"; // hot node @@ -359,7 +360,16 @@ public class CommonValue { + MASTER_KEY + "\": {\"type\": \"keyword\"},\n" + " \"" + + MLConfig.TYPE_FIELD + + "\" : {\"type\":\"keyword\"},\n" + + " \"" + + MLConfig.CONFIGURATION_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + LAST_UPDATE_TIME_FIELD + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + " }\n" + "}"; diff --git a/common/src/main/java/org/opensearch/ml/common/Configuration.java b/common/src/main/java/org/opensearch/ml/common/Configuration.java new file mode 100644 index 0000000000..fa5a1bfe22 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/Configuration.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@Getter +@EqualsAndHashCode +public class Configuration implements ToXContentObject, Writeable { + + public static final String ROOT_AGENT_ID = "agent_id"; + + @Setter + private String agentId; + + @Builder(toBuilder = true) + public Configuration( + String agentId + ) { + this.agentId = agentId; + } + + public Configuration(StreamInput input) throws IOException { + this.agentId = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(agentId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + XContentBuilder builder = xContentBuilder.startObject(); + if (agentId != null) { + builder.field(ROOT_AGENT_ID, agentId); + } + return builder.endObject(); + } + + public static Configuration fromStream(StreamInput in) throws IOException { + Configuration configuration = new Configuration(in); + return configuration; + } + + public static Configuration parse(XContentParser parser) throws IOException { + String agentId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case ROOT_AGENT_ID: + agentId = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return Configuration.builder() + .agentId(agentId) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/MLConfig.java b/common/src/main/java/org/opensearch/ml/common/MLConfig.java new file mode 100644 index 0000000000..c81dddcc9b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/MLConfig.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.time.Instant; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@Getter +@EqualsAndHashCode +public class MLConfig implements ToXContentObject, Writeable { + + public static final String TYPE_FIELD = "type"; + + public static final String CONFIGURATION_FIELD = "configuration"; + + public static final String CREATE_TIME_FIELD = "create_time"; + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + + @Setter + private String type; + + private Configuration configuration; + private final Instant createTime; + private Instant lastUpdateTime; + + @Builder(toBuilder = true) + public MLConfig( + String type, + Configuration configuration, + Instant createTime, + Instant lastUpdateTime + ) { + this.type = type; + this.configuration = configuration; + this.createTime = createTime; + this.lastUpdateTime = lastUpdateTime; + } + + public MLConfig(StreamInput input) throws IOException { + this.type = input.readOptionalString(); + if (input.readBoolean()) { + configuration = new Configuration(input); + } + createTime = input.readOptionalInstant(); + lastUpdateTime = input.readOptionalInstant(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(type); + if (configuration != null) { + out.writeBoolean(true); + configuration.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalInstant(createTime); + out.writeOptionalInstant(lastUpdateTime); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + XContentBuilder builder = xContentBuilder.startObject(); + if (type != null) { + builder.field(TYPE_FIELD, type); + } + if (configuration != null) { + builder.field(CONFIGURATION_FIELD, configuration); + } + if (createTime != null) { + builder.field(CREATE_TIME_FIELD, createTime.toEpochMilli()); + } + if (lastUpdateTime != null) { + builder.field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } + return builder.endObject(); + } + + public static MLConfig fromStream(StreamInput in) throws IOException { + MLConfig mlConfig = new MLConfig(in); + return mlConfig; + } + + public static MLConfig parse(XContentParser parser) throws IOException { + String type = null; + Configuration configuration = null; + Instant createTime = null; + Instant lastUpdateTime = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TYPE_FIELD: + type = parser.text(); + break; + case CONFIGURATION_FIELD: + configuration = Configuration.parse(parser); + break; + case CREATE_TIME_FIELD: + createTime = Instant.ofEpochMilli(parser.longValue()); + break; + case LAST_UPDATE_TIME_FIELD: + lastUpdateTime = Instant.ofEpochMilli(parser.longValue()); + break; + default: + parser.skipChildren(); + break; + } + } + return MLConfig.builder() + .type(type) + .configuration(configuration) + .createTime(createTime) + .lastUpdateTime(lastUpdateTime) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetAction.java new file mode 100644 index 0000000000..6287559c03 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.config; + +import org.opensearch.action.ActionType; + +public class MLConfigGetAction extends ActionType { + public static final MLConfigGetAction INSTANCE = new MLConfigGetAction(); + public static final String NAME = "cluster:admin/opensearch/ml/config/get"; + + private MLConfigGetAction() { super(NAME, MLConfigGetResponse::new);} + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java new file mode 100644 index 0000000000..0542c9480b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.config; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +public class MLConfigGetRequest extends ActionRequest { + + String configId; + + @Builder + public MLConfigGetRequest(String configId) { + this.configId = configId; + } + + public MLConfigGetRequest(StreamInput in) throws IOException { + super(in); + this.configId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.configId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.configId == null) { + exception = addValidationError("ML config id can't be null", exception); + } + + return exception; + } + + public static MLConfigGetRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLConfigGetRequest) { + return (MLConfigGetRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLConfigGetRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLConfigGetRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetResponse.java new file mode 100644 index 0000000000..1fc353e54f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetResponse.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.config; + +import lombok.Builder; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLConfig; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +public class MLConfigGetResponse extends ActionResponse implements ToXContentObject { + MLConfig mlConfig; + + @Builder + public MLConfigGetResponse(MLConfig mlConfig) { + this.mlConfig = mlConfig; + } + + public MLConfigGetResponse(StreamInput in) throws IOException { + super(in); + mlConfig = MLConfig.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException{ + mlConfig.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return mlConfig.toXContent(xContentBuilder, params); + } + + public static MLConfigGetResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLConfigGetResponse) { + return (MLConfigGetResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLConfigGetResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLConfigGetResponse", e); + } + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetActionTest.java new file mode 100644 index 0000000000..935b4f0db8 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetActionTest.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.config; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class MLConfigGetActionTest { + + @Test + public void testMLAgentGetActionInstance() { + assertNotNull(MLConfigGetAction.INSTANCE); + assertEquals("cluster:admin/opensearch/ml/config/get", MLConfigGetAction.NAME); + } + + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java new file mode 100644 index 0000000000..7c86587816 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.config; + +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.action.ValidateActions.addValidationError; + +public class MLConfigGetRequestTest { + String configId; + + @Test + public void constructor_configId() { + configId = "test-abc"; + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + assertEquals(mlConfigGetRequest.getConfigId(),configId); + } + + @Test + public void writeTo() throws IOException { + configId = "test-hij"; + + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + BytesStreamOutput output = new BytesStreamOutput(); + mlConfigGetRequest.writeTo(output); + + MLConfigGetRequest mlConfigGetRequest1 = new MLConfigGetRequest(output.bytes().streamInput()); + + assertEquals(mlConfigGetRequest1.getConfigId(), mlConfigGetRequest.getConfigId()); + assertEquals(mlConfigGetRequest1.getConfigId(), configId); + } + + @Test + public void validate_Success() { + configId = "not-null"; + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + + assertEquals(null, mlConfigGetRequest.validate()); + } + + @Test + public void validate_Failure() { + configId = null; + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + assertEquals(null,mlConfigGetRequest.configId); + + ActionRequestValidationException exception = addValidationError("ML config id can't be null", null); + mlConfigGetRequest.validate().equals(exception) ; + } + @Test + public void fromActionRequest_Success() throws IOException { + configId = "test-lmn"; + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + assertEquals(mlConfigGetRequest.fromActionRequest(mlConfigGetRequest), mlConfigGetRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() throws IOException { + configId = "test-opq"; + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + @Override + public void writeTo(StreamOutput out) throws IOException { + mlConfigGetRequest.writeTo(out); + } + }; + MLConfigGetRequest request = mlConfigGetRequest.fromActionRequest(actionRequest); + assertEquals(request.configId, configId); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + configId = "test-rst"; + MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + mlConfigGetRequest.fromActionRequest(actionRequest); + } +} + + diff --git a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java new file mode 100644 index 0000000000..ea370f979a --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.config; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.Configuration; +import org.opensearch.ml.common.MLConfig; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class MLConfigGetResponseTest { + + MLConfig mlConfig; + + @Before + public void setUp() { + Configuration configuration = Configuration.builder().agentId("agent_id").build(); + mlConfig = MLConfig.builder() + .type("olly_agent") + .configuration(configuration) + .build(); + } + + @Test + public void Create_mlConfigResponse_With_StreamInput() throws IOException { + // Create a BytesStreamOutput to simulate the StreamOutput + MLConfigGetResponse agentGetResponse = new MLConfigGetResponse(mlConfig); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + agentGetResponse.writeTo(out); + } + }; + MLConfigGetResponse parsedResponse = MLConfigGetResponse.fromActionResponse(actionResponse); + assertNotSame(agentGetResponse, parsedResponse); + assertEquals(agentGetResponse.mlConfig, parsedResponse.mlConfig); + } + + @Test + public void MLConfigGetResponse_Builder() throws IOException { + + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() + .mlConfig(mlConfig) + .build(); + + assertEquals(mlConfigGetResponse.mlConfig, mlConfig); + } + @Test + public void writeTo() throws IOException { + //create ml agent using mlConfig and mlConfigGetResponse + mlConfig = new MLConfig("olly_agent",new Configuration("agent_id"), Instant.EPOCH, Instant.EPOCH); + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() + .mlConfig(mlConfig) + .build(); + //use write out for both agents + BytesStreamOutput output = new BytesStreamOutput(); + mlConfig.writeTo(output); + mlConfigGetResponse.writeTo(output); + MLConfig agent1 = mlConfigGetResponse.mlConfig; + + assertEquals(mlConfig.getType(), agent1.getType()); + assertEquals(mlConfig.getConfiguration(), agent1.getConfiguration()); + assertEquals(mlConfig.getCreateTime(), agent1.getCreateTime()); + assertEquals(mlConfig.getLastUpdateTime(), agent1.getLastUpdateTime()); + } + + @Test + public void toXContent() throws IOException { + mlConfig = new MLConfig(null, null, null, null); + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() + .mlConfig(mlConfig) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + ToXContent.Params params = EMPTY_PARAMS; + XContentBuilder getResponseXContentBuilder = mlConfigGetResponse.toXContent(builder, params); + assertEquals(getResponseXContentBuilder, mlConfig.toXContent(builder, params)); + } + + @Test + public void fromActionResponse_Success() throws IOException { + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() + .mlConfig(mlConfig) + .build(); + assertEquals(mlConfigGetResponse.fromActionResponse(mlConfigGetResponse), mlConfigGetResponse); + + } + @Test + public void fromActionResponse_Success_fromActionResponse() throws IOException { + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() + .mlConfig(mlConfig) + .build(); + + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + mlConfigGetResponse.writeTo(out); + } + }; + MLConfigGetResponse response = mlConfigGetResponse.fromActionResponse(actionResponse); + assertEquals(response.mlConfig, mlConfig); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponse_IOException() { + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() + .mlConfig(mlConfig) + .build(); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + mlConfigGetResponse.fromActionResponse(actionResponse); + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/action/config/GetConfigTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/config/GetConfigTransportAction.java new file mode 100644 index 0000000000..787198a826 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/config/GetConfigTransportAction.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.config; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.MLConfig; +import org.opensearch.ml.common.transport.config.MLConfigGetAction; +import org.opensearch.ml.common.transport.config.MLConfigGetRequest; +import org.opensearch.ml.common.transport.config.MLConfigGetResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class GetConfigTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + @Inject + public GetConfigTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(MLConfigGetAction.NAME, transportService, actionFilters, MLConfigGetRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.fromActionRequest(request); + String configId = mlConfigGetRequest.getConfigId(); + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(configId); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + log.debug("Completed Get Agent Request, id:{}", configId); + + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLConfig mlConfig = MLConfig.parse(parser); + actionListener.onResponse(MLConfigGetResponse.builder().mlConfig(mlConfig).build()); + } catch (Exception e) { + log.error("Failed to parse ml config" + r.getId(), e); + actionListener.onFailure(e); + } + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find config with the provided config id: " + configId, + RestStatus.NOT_FOUND + ) + ); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + log.error("Failed to get agent index", e); + actionListener.onFailure(new OpenSearchStatusException("Failed to get config index", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML config " + configId, e); + actionListener.onFailure(e); + } + }), context::restore)); + } catch (Exception e) { + log.error("Failed to get ML config " + configId, e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 7fce1c8a7f..fe0573d673 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -41,6 +41,7 @@ import org.opensearch.ml.action.agents.GetAgentTransportAction; import org.opensearch.ml.action.agents.TransportRegisterAgentAction; import org.opensearch.ml.action.agents.TransportSearchAgentAction; +import org.opensearch.ml.action.config.GetConfigTransportAction; import org.opensearch.ml.action.connector.DeleteConnectorTransportAction; import org.opensearch.ml.action.connector.GetConnectorTransportAction; import org.opensearch.ml.action.connector.SearchConnectorTransportAction; @@ -110,6 +111,7 @@ import org.opensearch.ml.common.transport.agent.MLAgentGetAction; import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; import org.opensearch.ml.common.transport.agent.MLSearchAgentAction; +import org.opensearch.ml.common.transport.config.MLConfigGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; @@ -209,6 +211,7 @@ import org.opensearch.ml.rest.RestMLDeployModelAction; import org.opensearch.ml.rest.RestMLExecuteAction; import org.opensearch.ml.rest.RestMLGetAgentAction; +import org.opensearch.ml.rest.RestMLGetConfigAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; import org.opensearch.ml.rest.RestMLGetControllerAction; import org.opensearch.ml.rest.RestMLGetModelAction; @@ -403,7 +406,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(UpdateInteractionAction.INSTANCE, UpdateInteractionTransportAction.class), new ActionHandler<>(GetTracesAction.INSTANCE, GetTracesTransportAction.class), new ActionHandler<>(MLListToolsAction.INSTANCE, ListToolsTransportAction.class), - new ActionHandler<>(MLGetToolAction.INSTANCE, GetToolTransportAction.class) + new ActionHandler<>(MLGetToolAction.INSTANCE, GetToolTransportAction.class), + new ActionHandler<>(MLConfigGetAction.INSTANCE, GetConfigTransportAction.class) ); } @@ -713,6 +717,7 @@ public List getRestHandlers( RestMLSearchAgentAction restMLSearchAgentAction = new RestMLSearchAgentAction(); RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(toolFactories); RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories); + RestMLGetConfigAction restMLGetConfigAction = new RestMLGetConfigAction(); return ImmutableList .of( restMLStatsAction, @@ -764,7 +769,8 @@ public List getRestHandlers( restMemoryGetTracesAction, restMLSearchAgentAction, restMLListToolsAction, - restMLGetToolAction + restMLGetToolAction, + restMLGetConfigAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConfigAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConfigAction.java new file mode 100644 index 0000000000..81cb02c597 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConfigAction.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONFIG_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.config.MLConfigGetAction; +import org.opensearch.ml.common.transport.config.MLConfigGetRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetConfigAction extends BaseRestHandler { + private static final String ML_GET_CONFIG_ACTION = "ml_get_config_action"; + + /** + * Constructor + */ + public RestMLGetConfigAction() {} + + @Override + public String getName() { + return ML_GET_CONFIG_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/config/{%s}", ML_BASE_URI, PARAMETER_CONFIG_ID))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLConfigGetRequest mlConfigGetRequest = getRequest(request); + return channel -> client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLTaskGetRequest from a RestRequest + * + * @param request RestRequest + * @return MLTaskGetRequest + */ + @VisibleForTesting + MLConfigGetRequest getRequest(RestRequest request) throws IOException { + String configID = getParameterId(request, PARAMETER_CONFIG_ID); + + if (configID.equals(MASTER_KEY)) { + throw new IllegalArgumentException("You are not allowed to access this config doc"); + } + + return new MLConfigGetRequest(configID); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 59f9e0f3a2..52c0bb1346 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -69,6 +69,7 @@ public class RestActionUtils { public static final String PARAMETER_DEPLOY_MODEL = "deploy"; public static final String PARAMETER_VERSION = "version"; public static final String PARAMETER_MODEL_GROUP_ID = "model_group_id"; + public static final String PARAMETER_CONFIG_ID = "config_id"; public static final String OPENSEARCH_DASHBOARDS_USER_AGENT = "OpenSearch Dashboards"; public static final String[] UI_METADATA_EXCLUDE = new String[] { "ui_metadata" }; diff --git a/plugin/src/test/java/org/opensearch/ml/action/config/GetConfigTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/config/GetConfigTransportActionTests.java new file mode 100644 index 0000000000..13112bba48 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/config/GetConfigTransportActionTests.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.config; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.transport.config.MLConfigGetRequest; +import org.opensearch.ml.common.transport.config.MLConfigGetResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetConfigTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + GetConfigTransportAction getConfigTransportAction; + MLConfigGetRequest mlConfigGetRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + mlConfigGetRequest = MLConfigGetRequest.builder().configId("test_id").build(); + + getConfigTransportAction = spy(new GetConfigTransportAction(transportService, actionFilters, client, xContentRegistry)); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testGetTask_NullResponse() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + getConfigTransportAction.doExecute(null, mlConfigGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find config with the provided config id: test_id", argumentCaptor.getValue().getMessage()); + } + + public void testGetTask_RuntimeException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("errorMessage")); + return null; + }).when(client).get(any(), any()); + getConfigTransportAction.doExecute(null, mlConfigGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + } + + public void testGetTask_IndexNotFoundException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("Index Not Found")); + return null; + }).when(client).get(any(), any()); + getConfigTransportAction.doExecute(null, mlConfigGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get config index", argumentCaptor.getValue().getMessage()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConfigActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConfigActionTests.java new file mode 100644 index 0000000000..a02a640654 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConfigActionTests.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONFIG_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.config.MLConfigGetAction; +import org.opensearch.ml.common.transport.config.MLConfigGetRequest; +import org.opensearch.ml.common.transport.config.MLConfigGetResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLGetConfigActionTests extends OpenSearchTestCase { + + private RestMLGetConfigAction restMLGetConfigAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLGetConfigAction = new RestMLGetConfigAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLConfigGetAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLGetConfigAction mlGetConfigAction = new RestMLGetConfigAction(); + assertNotNull(mlGetConfigAction); + } + + public void testGetName() { + String actionName = restMLGetConfigAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_get_config_action", actionName); + } + + public void testRoutes() { + List routes = restMLGetConfigAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/config/{config_id}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLGetConfigAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLConfigGetRequest.class); + verify(client, times(1)).execute(eq(MLConfigGetAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getConfigId(); + assertEquals(taskId, "test_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_CONFIG_ID, "test_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } +}