diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 595ed4932..bf4943403 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -27,10 +27,10 @@ import org.opensearch.rest.RestRequest; import java.io.IOException; -import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -78,16 +78,19 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String workflowId = request.param(WORKFLOW_ID); String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" }); boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); - Map params = Collections.emptyMap(); final List validCreateParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW); // If provisioning, consume all other params and pass to provision transport action - if (provision) { - params = request.params() + Map params = provision + ? request.params() + .keySet() + .stream() + .filter(k -> !validCreateParams.contains(k)) + .collect(Collectors.toMap(Function.identity(), request::param)) + : request.params() .entrySet() .stream() .filter(e -> !validCreateParams.contains(e.getKey())) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { FlowFrameworkException ffe = new FlowFrameworkException( "This API is disabled. To enable it, set [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", @@ -98,6 +101,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); } if (!provision && !params.isEmpty()) { + // Consume params and content so custom exception is processed + params.keySet().stream().forEach(request::param); + request.content(); FlowFrameworkException ffe = new FlowFrameworkException( "Only the parameters " + validCreateParams + " are permitted unless the provision parameter is set to true.", RestStatus.BAD_REQUEST diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 7c0fdb3d4..f7d61a114 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -28,10 +28,9 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; -import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; -import static org.opensearch.flowframework.common.CommonValue.VALIDATION; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -72,12 +71,11 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); - final List excludeParams = List.of(WORKFLOW_ID, VALIDATION, PROVISION_WORKFLOW); Map params = request.params() - .entrySet() + .keySet() .stream() - .filter(e -> !WORKFLOW_ID.equals(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + .filter(k -> !WORKFLOW_ID.equals(k)) + .collect(Collectors.toMap(Function.identity(), request::param)); try { if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index fcdaf5757..1d99cf517 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -10,6 +10,7 @@ import org.opensearch.Version; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -19,6 +20,8 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.transport.WorkflowResponse; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; @@ -30,12 +33,16 @@ import java.util.Locale; import java.util.Map; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class RestCreateWorkflowActionTests extends OpenSearchTestCase { + private String validTemplate; private String invalidTemplate; private RestCreateWorkflowAction createWorkflowRestAction; private String createWorkflowPath; @@ -70,7 +77,8 @@ public void setUp() throws Exception { ); // Invalid template configuration, wrong field name - this.invalidTemplate = template.toJson().replace("use_case", "invalid"); + this.validTemplate = template.toJson(); + this.invalidTemplate = this.validTemplate.replace("use_case", "invalid"); this.createWorkflowRestAction = new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting); this.createWorkflowPath = String.format(Locale.ROOT, "%s", WORKFLOW_URI); this.updateWorkflowPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); @@ -92,6 +100,42 @@ public void testRestCreateWorkflowActionRoutes() { } + public void testCreateWorkflowRequestWithParamsAndProvision() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.ofEntries(Map.entry(PROVISION_WORKFLOW, "true"), Map.entry("foo", "bar"))) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("id-123")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.CREATED, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); + } + + public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.createWorkflowPath) + .withParams(Map.of("foo", "bar")) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue( + channel.capturedResponse() + .content() + .utf8ToString() + .contains( + "Only the parameters [workflow_id, validation, provision] are permitted unless the provision parameter is set to true." + ) + ); + } + public void testInvalidCreateWorkflowRequest() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath)