Skip to content

Commit

Permalink
restful connector actions and UT
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Jul 10, 2023
1 parent cb4c209 commit 50ca665
Show file tree
Hide file tree
Showing 13 changed files with 984 additions and 3 deletions.
3 changes: 2 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.action.connector.DeleteConnectorTransportAction',
'org.opensearch.ml.action.connector.DeleteConnectorTransportAction.1',
'org.opensearch.ml.action.connector.TransportCreateConnectorAction',
'org.opensearch.ml.action.connector.SearchConnectorTransportAction'
'org.opensearch.ml.action.connector.SearchConnectorTransportAction',
'org.opensearch.ml.rest.RestMLCreateConnectorAction'
]

jacocoTestCoverageVerification {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.ml.action.connector.DeleteConnectorTransportAction;
import org.opensearch.ml.action.connector.GetConnectorTransportAction;
import org.opensearch.ml.action.connector.SearchConnectorTransportAction;
import org.opensearch.ml.action.connector.TransportCreateConnectorAction;
import org.opensearch.ml.action.deploy.TransportDeployModelAction;
import org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction;
import org.opensearch.ml.action.execute.TransportExecuteTaskAction;
Expand Down Expand Up @@ -79,6 +83,10 @@
import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams;
import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorGetAction;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
Expand Down Expand Up @@ -116,18 +124,22 @@
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
import org.opensearch.ml.rest.RestMLDeleteConnectorAction;
import org.opensearch.ml.rest.RestMLDeleteModelAction;
import org.opensearch.ml.rest.RestMLDeleteModelGroupAction;
import org.opensearch.ml.rest.RestMLDeleteTaskAction;
import org.opensearch.ml.rest.RestMLDeployModelAction;
import org.opensearch.ml.rest.RestMLExecuteAction;
import org.opensearch.ml.rest.RestMLGetConnectorAction;
import org.opensearch.ml.rest.RestMLGetModelAction;
import org.opensearch.ml.rest.RestMLGetTaskAction;
import org.opensearch.ml.rest.RestMLPredictionAction;
import org.opensearch.ml.rest.RestMLProfileAction;
import org.opensearch.ml.rest.RestMLRegisterModelAction;
import org.opensearch.ml.rest.RestMLRegisterModelGroupAction;
import org.opensearch.ml.rest.RestMLRegisterModelMetaAction;
import org.opensearch.ml.rest.RestMLSearchConnectorAction;
import org.opensearch.ml.rest.RestMLSearchModelAction;
import org.opensearch.ml.rest.RestMLSearchModelGroupAction;
import org.opensearch.ml.rest.RestMLSearchTaskAction;
Expand Down Expand Up @@ -235,7 +247,11 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {
new ActionHandler<>(MLRegisterModelGroupAction.INSTANCE, TransportRegisterModelGroupAction.class),
new ActionHandler<>(MLUpdateModelGroupAction.INSTANCE, TransportUpdateModelGroupAction.class),
new ActionHandler<>(MLModelGroupSearchAction.INSTANCE, SearchModelGroupTransportAction.class),
new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class)
new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class),
new ActionHandler<>(MLCreateConnectorAction.INSTANCE, TransportCreateConnectorAction.class),
new ActionHandler<>(MLConnectorGetAction.INSTANCE, GetConnectorTransportAction.class),
new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class),
new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class)
);
}

Expand Down Expand Up @@ -453,6 +469,10 @@ public List<RestHandler> getRestHandlers(
RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction();
RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction();
RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction();
RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction();
RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction();
RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction();
RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction();
return ImmutableList
.of(
restMLStatsAction,
Expand All @@ -475,7 +495,11 @@ public List<RestHandler> getRestHandlers(
restMLCreateModelGroupAction,
restMLUpdateModelGroupAction,
restMLSearchModelGroupAction,
restMLDeleteModelGroupAction
restMLDeleteModelGroupAction,
restMLCreateConnectorAction,
restMLGetConnectorAction,
restMLDeleteConnectorAction,
restMLSearchConnectorAction
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

public class RestMLCreateConnectorAction extends BaseRestHandler {
private static final String ML_CREATE_CONNECTOR_ACTION = "ml_create_connector_action";

/**
* Constructor *
*/
public RestMLCreateConnectorAction() {}

@Override
public String getName() {
return ML_CREATE_CONNECTOR_ACTION;
}

@Override
public List<Route> routes() {
return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/connectors/_create", ML_BASE_URI)));
}

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
MLCreateConnectorRequest mlCreateConnectorRequest = getRequest(request);
return channel -> client.execute(MLCreateConnectorAction.INSTANCE, mlCreateConnectorRequest, new RestToXContentListener<>(channel));
}

/**
* * Creates a MLCreateConnectorRequest from a RestRequest
* @param request
* @return MLCreateConnectorRequest
* @throws IOException
*/
@VisibleForTesting
MLCreateConnectorRequest getRequest(RestRequest request) throws IOException {
if (!request.hasContent()) {
throw new IOException("Create Connector request has empty body");
}
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.parse(parser);
return new MLCreateConnectorRequest(mlCreateConnectorInput);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import com.google.common.collect.ImmutableList;

/**
* This class consists of the REST handler to delete ML Connector.
*/
public class RestMLDeleteConnectorAction extends BaseRestHandler {
private static final String ML_DELETE_CONNECTOR_ACTION = "ml_delete_connector_action";

public void RestMLDeleteConnectorAction() {}

@Override
public String getName() {
return ML_DELETE_CONNECTOR_ACTION;
}

@Override
public List<Route> routes() {
return ImmutableList
.of(
new Route(RestRequest.Method.DELETE, String.format(Locale.ROOT, "%s/connectors/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID))
);
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String connectorId = request.param(PARAMETER_CONNECTOR_ID);

MLConnectorDeleteRequest mlConnectorDeleteRequest = new MLConnectorDeleteRequest(connectorId);
return channel -> client.execute(MLConnectorDeleteAction.INSTANCE, mlConnectorDeleteRequest, new RestToXContentListener<>(channel));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
import static org.opensearch.ml.utils.RestActionUtils.returnContent;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.transport.connector.MLConnectorGetAction;
import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

public class RestMLGetConnectorAction extends BaseRestHandler {
private static final String ML_GET_CONNECTOR_ACTION = "ml_get_connector_action";

/**
* Constructor
*/
public RestMLGetConnectorAction() {}

@Override
public String getName() {
return ML_GET_CONNECTOR_ACTION;
}

@Override
public List<Route> routes() {
return ImmutableList
.of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/connectors/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID)));
}

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
MLConnectorGetRequest mlConnectorGetRequest = getRequest(request);
return channel -> client.execute(MLConnectorGetAction.INSTANCE, mlConnectorGetRequest, new RestToXContentListener<>(channel));
}

/**
* Creates a MLConnectorGetRequest from a RestRequest
*
* @param request RestRequest
* @return MLConnectorGetRequest
*/
@VisibleForTesting
MLConnectorGetRequest getRequest(RestRequest request) throws IOException {
String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID);
boolean returnContent = returnContent(request);

return new MLConnectorGetRequest(connectorId, returnContent);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;

import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;

import com.google.common.collect.ImmutableList;

public class RestMLSearchConnectorAction extends AbstractMLSearchAction<Connector> {
private static final String ML_SEARCH_CONNECTOR_ACTION = "ml_search_connector_action";
private static final String SEARCH_CONNECTOR_PATH = ML_BASE_URI + "/connectors/_search";

public RestMLSearchConnectorAction() {
super(ImmutableList.of(SEARCH_CONNECTOR_PATH), ML_CONNECTOR_INDEX, Connector.class, MLConnectorSearchAction.INSTANCE);
}

@Override
public String getName() {
return ML_SEARCH_CONNECTOR_ACTION;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class RestActionUtils {
public static final String PARAMETER_MODEL_GROUP_NAME = "model_group_name";
public static final String PARAMETER_MODEL_ID = "model_id";
public static final String PARAMETER_TASK_ID = "task_id";
public static final String PARAMETER_CONNECTOR_ID = "connector_id";
public static final String PARAMETER_DEPLOY_MODEL = "deploy";
public static final String PARAMETER_VERSION = "version";
public static final String PARAMETER_MODEL_GROUP_ID = "model_group_id";
Expand Down
Loading

0 comments on commit 50ca665

Please sign in to comment.