Skip to content

Commit

Permalink
Fix custom prompt substitute with List issue in ml inference search r…
Browse files Browse the repository at this point in the history
…esponse processor (opensearch-project#2871) (opensearch-project#2874)

(cherry picked from commit 49d4a01)

Co-authored-by: Mingshi Liu <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and mingshl authored Sep 3, 2024
1 parent a1a7dbb commit 54d7daf
Show file tree
Hide file tree
Showing 7 changed files with 473 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -322,40 +323,16 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) {
String payload = connectorAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
parseParameters(parameters);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
String payloadAfterEscape = connectorAction.get().getRequestBody();
Map<String, String> escapedParameters = escapeMapValues(parameters);
StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}");
payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape);
if (!isJson(payloadAfterEscape)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
} else {
payload = payloadAfterEscape;
}
throw new IllegalArgumentException("Invalid payload: " + payload);
}
return (T) payload;
}
return (T) parameters.get("http_body");

}

public static Map<String, String> escapeMapValues(Map<String, String> parameters) {
Map<String, String> escapedMap = new HashMap<>();
if (parameters != null) {
for (Map.Entry<String, String> entry : parameters.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
String escapedValue = escapeValue(value);
escapedMap.put(key, escapedValue);
}
}
return escapedMap;
}

private static String escapeValue(String value) {
return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t");
}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class StringUtils {
static {
gson = new Gson();
}
public static final String TO_STRING_FUNCTION_NAME = ".toString()";

public static boolean isValidJsonString(String Json) {
try {
Expand Down Expand Up @@ -233,4 +234,49 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
return errorMessage + " Model ID: " + modelId;
}
}

/**
* Collects the prefixes of the toString() method calls present in the values of the given map.
*
* @param map A map containing key-value pairs where the values may contain toString() method calls.
* @return A list of prefixes for the toString() method calls found in the map values.
*/
public static List<String> collectToStringPrefixes(Map<String, String> map) {
List<String> prefixes = new ArrayList<>();
for (String key : map.keySet()) {
String value = map.get(key);
if (value != null) {
Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}");
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String prefix = matcher.group(1);
prefixes.add(prefix);
}
}
}
return prefixes;
}

/**
* Parses the given parameters map and processes the values containing toString() method calls.
*
* @param parameters A map containing key-value pairs where the values may contain toString() method calls.
* @return A new map with the processed values for the toString() method calls.
*/
public static Map<String, String> parseParameters(Map<String, String> parameters) {
if (parameters != null) {
List<String> toStringParametersPrefixes = collectToStringPrefixes(parameters);

if (!toStringParametersPrefixes.isEmpty()) {
for (String prefix : toStringParametersPrefixes) {
String value = parameters.get(prefix);
if (value != null) {
parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value));
}
}
}
}
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.ml.common.connector;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -184,114 +183,6 @@ public void createPayload_InvalidJson() {
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithString() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();

parameters.put("prompt", "answer question based on context: ${parameters.context}");
parameters.put("context", "document1");
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload);
}

@Test
public void createPayloadWithList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
listOfDocuments.add("document2");
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
NestedListOfDocuments.add("document2");
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithMap() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedMapOfString() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments));
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedMapOfObject() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, Object> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", nestedMapOfDocuments);
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedListOfMapOfObject() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<Object> NestedListOfDocuments = new ArrayList<>();
Map<String, Object> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", nestedMapOfDocuments);
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayload() {
HttpConnector connector = createHttpConnector();
Expand Down
Loading

0 comments on commit 54d7daf

Please sign in to comment.