Skip to content

Commit

Permalink
Improve param consuming checks, add coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Feb 17, 2024
1 parent 2c3eecb commit c50b00f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
import org.opensearch.rest.RestRequest;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
Expand Down Expand Up @@ -78,16 +78,19 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
String workflowId = request.param(WORKFLOW_ID);
String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" });
boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false);
Map<String, String> params = Collections.emptyMap();
final List<String> validCreateParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW);
// If provisioning, consume all other params and pass to provision transport action
if (provision) {
params = request.params()
Map<String, String> params = provision
? request.params()
.keySet()
.stream()
.filter(k -> !validCreateParams.contains(k))
.collect(Collectors.toMap(Function.identity(), request::param))
: request.params()
.entrySet()
.stream()
.filter(e -> !validCreateParams.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
if (!flowFrameworkSettings.isFlowFrameworkEnabled()) {
FlowFrameworkException ffe = new FlowFrameworkException(
"This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.",
Expand All @@ -98,6 +101,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
);
}
if (!provision && !params.isEmpty()) {
// Consume params and content so custom exception is processed
params.keySet().stream().forEach(request::param);
request.content();
FlowFrameworkException ffe = new FlowFrameworkException(
"Only the parameters " + validCreateParams + " are permitted unless the provision parameter is set to true.",
RestStatus.BAD_REQUEST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.VALIDATION;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
Expand Down Expand Up @@ -72,12 +71,11 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String workflowId = request.param(WORKFLOW_ID);
final List<String> excludeParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW);
Map<String, String> params = request.params()
.entrySet()
.keySet()
.stream()
.filter(e -> !WORKFLOW_ID.equals(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
.filter(k -> !WORKFLOW_ID.equals(k))
.collect(Collectors.toMap(Function.identity(), request::param));
try {
if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) {
throw new FlowFrameworkException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.Version;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
Expand All @@ -19,6 +20,8 @@
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.transport.WorkflowResponse;
import org.opensearch.rest.RestHandler.Route;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -30,12 +33,16 @@
import java.util.Locale;
import java.util.Map;

import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class RestCreateWorkflowActionTests extends OpenSearchTestCase {

private String validTemplate;
private String invalidTemplate;
private RestCreateWorkflowAction createWorkflowRestAction;
private String createWorkflowPath;
Expand Down Expand Up @@ -70,7 +77,8 @@ public void setUp() throws Exception {
);

// Invalid template configuration, wrong field name
this.invalidTemplate = template.toJson().replace("use_case", "invalid");
this.validTemplate = template.toJson();
this.invalidTemplate = this.validTemplate.replace("use_case", "invalid");
this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting);
this.createWorkflowPath = String.format(Locale.ROOT, "%s", WORKFLOW_URI);
this.updateWorkflowPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id");
Expand All @@ -92,6 +100,42 @@ public void testRestCreateWorkflowActionRoutes() {

}

public void testCreateWorkflowRequestWithParamsAndProvision() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.createWorkflowPath)
.withParams(Map.ofEntries(Map.entry(PROVISION_WORKFLOW, "true"), Map.entry("foo", "bar")))
.withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON)
.build();
FakeRestChannel channel = new FakeRestChannel(request, false, 1);
doAnswer(invocation -> {
ActionListener<WorkflowResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(new WorkflowResponse("id-123"));
return null;
}).when(nodeClient).execute(any(), any(WorkflowRequest.class), any());
createWorkflowRestAction.handleRequest(request, channel, nodeClient);
assertEquals(RestStatus.CREATED, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123"));
}

public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.createWorkflowPath)
.withParams(Map.of("foo", "bar"))
.withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON)
.build();
FakeRestChannel channel = new FakeRestChannel(request, false, 1);
createWorkflowRestAction.handleRequest(request, channel, nodeClient);
assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status());
assertTrue(
channel.capturedResponse()
.content()
.utf8ToString()
.contains(
"Only the parameters [workflow_id, validation, provision] are permitted unless the provision parameter is set to true."
)
);
}

public void testInvalidCreateWorkflowRequest() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.createWorkflowPath)
Expand Down

0 comments on commit c50b00f

Please sign in to comment.