Skip to content

Commit

Permalink
Allow specifying key-value pairs in body
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Feb 20, 2024
1 parent 7701c27 commit 7a00506
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.ProvisionWorkflowAction;
Expand All @@ -31,6 +32,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
Expand Down Expand Up @@ -71,23 +73,37 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String workflowId = request.param(WORKFLOW_ID);
// Get any other params from path
Map<String, String> params = request.params()
.keySet()
.stream()
.filter(k -> !WORKFLOW_ID.equals(k))
.collect(Collectors.toMap(Function.identity(), request::param));
try {
// If body is included get any params from body
if (request.hasContent()) {
try (XContentParser parser = request.contentParser()) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String key = parser.currentName();
if (params.containsKey(key)) {
throw new FlowFrameworkException("Duplicate key " + key, RestStatus.BAD_REQUEST);
}
if (parser.nextToken() != XContentParser.Token.VALUE_STRING) {
throw new FlowFrameworkException("Request body fields must have string values", RestStatus.BAD_REQUEST);
}
params.put(key, parser.text());
}
} catch (IOException e) {
throw new FlowFrameworkException("Request body parsing failed", RestStatus.BAD_REQUEST);
}
}
if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) {
throw new FlowFrameworkException(
"This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.",
RestStatus.FORBIDDEN
);
}
// Validate content
if (request.hasContent()) {
// BaseRestHandler will give appropriate error message
return channel -> channel.sendResponse(null);
}
// Validate params
if (workflowId == null) {
throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
package org.opensearch.flowframework.rest;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.transport.WorkflowResponse;
import org.opensearch.rest.RestHandler.Route;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -21,8 +24,11 @@

import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -49,7 +55,7 @@ public void testRestProvisionWorkflowActionName() {
assertEquals("provision_workflow_action", name);
}

public void testRestProvisiionWorkflowActionRoutes() {
public void testRestProvisionWorkflowActionRoutes() {
List<Route> routes = provisionWorkflowRestAction.routes();
assertEquals(1, routes.size());
assertEquals(RestRequest.Method.POST, routes.get(0).getMethod());
Expand All @@ -71,20 +77,61 @@ public void testNullWorkflowId() throws Exception {
assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null"));
}

public void testInvalidRequestWithContent() {
public void testContentParsing() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.provisionWorkflowPath)
.withContent(new BytesArray("request body"), MediaTypeRegistry.JSON)
.withParams(Map.of("workflow_id", "abc"))
.withContent(new BytesArray("{\"foo\": \"bar\"}"), MediaTypeRegistry.JSON)
.build();

FakeRestChannel channel = new FakeRestChannel(request, false, 1);
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> {
provisionWorkflowRestAction.handleRequest(request, channel, nodeClient);
});
assertEquals(
"request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_provision] does not support having a body",
ex.getMessage()
);
doAnswer(invocation -> {
ActionListener<WorkflowResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(new WorkflowResponse("id-123"));
return null;
}).when(nodeClient).execute(any(), any(WorkflowRequest.class), any());
provisionWorkflowRestAction.handleRequest(request, channel, nodeClient);
assertEquals(RestStatus.OK, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123"));
}

public void testContentParsingDuplicate() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.provisionWorkflowPath)
.withParams(Map.ofEntries(Map.entry("workflow_id", "abc"), Map.entry("foo", "bar")))
.withContent(new BytesArray("{\"bar\": \"none\", \"foo\": \"baz\"}"), MediaTypeRegistry.JSON)
.build();

FakeRestChannel channel = new FakeRestChannel(request, false, 1);
provisionWorkflowRestAction.handleRequest(request, channel, nodeClient);
assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status());
// assertEquals("", channel.capturedResponse().content().utf8ToString());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("Duplicate key foo"));
}

public void testContentParsingBadType() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.provisionWorkflowPath)
.withParams(Map.of("workflow_id", "abc"))
.withContent(new BytesArray("{\"foo\": 123}"), MediaTypeRegistry.JSON)
.build();

FakeRestChannel channel = new FakeRestChannel(request, false, 1);
provisionWorkflowRestAction.handleRequest(request, channel, nodeClient);
assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("Request body fields must have string values"));
}

public void testContentParsingError() throws Exception {
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath(this.provisionWorkflowPath)
.withContent(new BytesArray("not json"), MediaTypeRegistry.JSON)
.build();

FakeRestChannel channel = new FakeRestChannel(request, false, 1);
provisionWorkflowRestAction.handleRequest(request, channel, nodeClient);
assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status());
assertTrue(channel.capturedResponse().content().utf8ToString().contains("Request body parsing failed"));
}

public void testFeatureFlagNotEnabled() throws Exception {
Expand Down

0 comments on commit 7a00506

Please sign in to comment.