From 7a0050688426e741a36e6cfee95e9c3996e5c292 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 19 Feb 2024 16:44:15 -0800 Subject: [PATCH] Allow specifying key-value pairs in body Signed-off-by: Daniel Widdis --- .../rest/RestProvisionWorkflowAction.java | 26 +++++-- .../RestProvisionWorkflowActionTests.java | 67 ++++++++++++++++--- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index f7d61a114..84c518615 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -16,6 +16,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; @@ -31,6 +32,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -71,23 +73,37 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); + // Get any other params from path Map params = request.params() .keySet() .stream() .filter(k -> !WORKFLOW_ID.equals(k)) .collect(Collectors.toMap(Function.identity(), request::param)); try { + // If body is included get any params from body + if (request.hasContent()) { + try (XContentParser parser = request.contentParser()) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String key = parser.currentName(); + if (params.containsKey(key)) { + throw new FlowFrameworkException("Duplicate key " + key, RestStatus.BAD_REQUEST); + } + if (parser.nextToken() != XContentParser.Token.VALUE_STRING) { + throw new FlowFrameworkException("Request body fields must have string values", RestStatus.BAD_REQUEST); + } + params.put(key, parser.text()); + } + } catch (IOException e) { + throw new FlowFrameworkException("Request body parsing failed", RestStatus.BAD_REQUEST); + } + } if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { throw new FlowFrameworkException( "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", RestStatus.FORBIDDEN ); } - // Validate content - if (request.hasContent()) { - // BaseRestHandler will give appropriate error message - return channel -> channel.sendResponse(null); - } // Validate params if (workflowId == null) { throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java index 6ddd83d11..fd5cd478d 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -9,10 +9,13 @@ package org.opensearch.flowframework.rest; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.flowframework.transport.WorkflowResponse; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; @@ -21,8 +24,11 @@ import java.util.List; import java.util.Locale; +import java.util.Map; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -49,7 +55,7 @@ public void testRestProvisionWorkflowActionName() { assertEquals("provision_workflow_action", name); } - public void testRestProvisiionWorkflowActionRoutes() { + public void testRestProvisionWorkflowActionRoutes() { List routes = provisionWorkflowRestAction.routes(); assertEquals(1, routes.size()); assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); @@ -71,20 +77,61 @@ public void testNullWorkflowId() throws Exception { assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); } - public void testInvalidRequestWithContent() { + public void testContentParsing() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.provisionWorkflowPath) - .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .withParams(Map.of("workflow_id", "abc")) + .withContent(new BytesArray("{\"foo\": \"bar\"}"), MediaTypeRegistry.JSON) .build(); FakeRestChannel channel = new FakeRestChannel(request, false, 1); - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { - provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); - }); - assertEquals( - "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_provision] does not support having a body", - ex.getMessage() - ); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("id-123")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.OK, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); + } + + public void testContentParsingDuplicate() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withParams(Map.ofEntries(Map.entry("workflow_id", "abc"), Map.entry("foo", "bar"))) + .withContent(new BytesArray("{\"bar\": \"none\", \"foo\": \"baz\"}"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + // assertEquals("", channel.capturedResponse().content().utf8ToString()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("Duplicate key foo")); + } + + public void testContentParsingBadType() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withParams(Map.of("workflow_id", "abc")) + .withContent(new BytesArray("{\"foo\": 123}"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("Request body fields must have string values")); + } + + public void testContentParsingError() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withContent(new BytesArray("not json"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("Request body parsing failed")); } public void testFeatureFlagNotEnabled() throws Exception {