Skip to content

Commit

Permalink
Fix argument pass (opensearch-project#1941)
Browse files Browse the repository at this point in the history
* add logs

Signed-off-by: xinyual <[email protected]>

* option2FixBug

Signed-off-by: xinyual <[email protected]>

* remove useless log

Signed-off-by: xinyual <[email protected]>

* remove useless log

Signed-off-by: xinyual <[email protected]>

* fix spot less

Signed-off-by: xinyual <[email protected]>

* change argument parsing

Signed-off-by: xinyual <[email protected]>

* move common function to utils

Signed-off-by: xinyual <[email protected]>

* checkout for typo

Signed-off-by: xinyual <[email protected]>

* remove useless code

Signed-off-by: xinyual <[email protected]>

* add UTs

Signed-off-by: xinyual <[email protected]>

* apply spotless

Signed-off-by: xinyual <[email protected]>

* add more uts

Signed-off-by: xinyual <[email protected]>

* modify import

Signed-off-by: xinyual <[email protected]>

* protect original parameters

Signed-off-by: xinyual <[email protected]>

---------

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual authored Feb 2, 2024
1 parent 343ae16 commit 4a6ceba
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;

import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -23,6 +26,8 @@
import java.util.regex.Pattern;

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.spi.tools.Tool;

public class AgentUtils {
Expand Down Expand Up @@ -152,4 +157,32 @@ public static String extractModelResponseJson(String text) {
throw new IllegalArgumentException("Model output is invalid");
}
}

public static String outputToOutputString(Object output) throws PrivilegedActionException {
String outputString;
if (output instanceof ModelTensorOutput) {
ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0);
if (outputModel.getDataAsMap() != null) {
outputString = AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(outputModel.getDataAsMap()));
} else {
outputString = outputModel.getResult();
}
} else if (output instanceof String) {
outputString = (String) output;
} else {
outputString = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
}
return outputString;
}

public static String parseInputFromLLMReturn(Map<String, ?> retMap) {
Object actionInput = retMap.get("action_input");
if (actionInput instanceof Map) {
return gson.toJson(actionInput);
} else {
return String.valueOf(actionInput);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn;

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -325,7 +325,7 @@ private void runReAct(
}
String thought = String.valueOf(dataAsMap.get("thought"));
String action = String.valueOf(dataAsMap.get("action"));
String actionInput = String.valueOf(dataAsMap.get("action_input"));
String actionInput = parseInputFromLLMReturn(dataAsMap);
String finalAnswer = (String) dataAsMap.get("final_answer");
if (!dataAsMap.containsKey("thought")) {
String response = (String) dataAsMap.get("response");
Expand All @@ -336,7 +336,7 @@ private void runReAct(
Map map = gson.fromJson(jsonBlock, Map.class);
thought = String.valueOf(map.get("thought"));
action = String.valueOf(map.get("action"));
actionInput = String.valueOf(map.get("action_input"));
actionInput = parseInputFromLLMReturn(map);
finalAnswer = (String) map.get("final_answer");
} else {
finalAnswer = response;
Expand Down Expand Up @@ -513,9 +513,7 @@ private void runReAct(
} else {
MLToolSpec toolSpec = toolSpecMap.get(lastAction.get());
if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) {
String outputString = output instanceof String
? (String) output
: AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(output));
String outputString = outputToOutputString(output);

String toolOutputKey = String.format("%s.output", toolSpec.getType());
if (additionalInfo.get(toolOutputKey) != null) {
Expand All @@ -535,7 +533,13 @@ private void runReAct(
.singletonList(
ModelTensor
.builder()
.dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + output))
.dataAsMap(
ImmutableMap
.of(
"response",
lastThought.get() + "\nObservation: " + outputToOutputString(output)
)
)
.build()
)
)
Expand All @@ -544,7 +548,7 @@ private void runReAct(

String toolResponse = tmpParameters.get("prompt.tool_response");
StringSubstitutor toolResponseSubstitutor = new StringSubstitutor(
ImmutableMap.of("observation", output),
ImmutableMap.of("observation", outputToOutputString(output)),
"${parameters.",
"}"
);
Expand All @@ -556,7 +560,7 @@ private void runReAct(
.conversationIndexMessageBuilder()
.type("ReAct")
.question(lastActionInput.get())
.response((String) output)
.response(outputToOutputString(output))
.finalAnswer(false)
.sessionId(sessionId)
.build();
Expand All @@ -571,7 +575,7 @@ private void runReAct(
newPrompt.set(substitutor.replace(finalPrompt));
tmpParameters.put(PROMPT, newPrompt.get());

sessionMsgAnswerBuilder.append("\nObservation: ").append(output);
sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output));
cotModelTensors
.add(
ModelTensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.HashMap;
import java.util.Map;

import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -51,11 +54,12 @@ public AgentTool(Client client, String agentId) {

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
Map<String, String> extractedParameters = extractInputParameters(parameters);
AgentMLInput agentMLInput = AgentMLInput
.AgentMLInputBuilder()
.agentId(agentId)
.functionName(FunctionName.AGENT)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build())
.build();
ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false);
client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> {
Expand Down Expand Up @@ -135,4 +139,18 @@ public String getDefaultVersion() {
return null;
}
}

private Map<String, String> extractInputParameters(Map<String, String> parameters) {
Map<String, String> extractedParameters = new HashMap<>();
extractedParameters.putAll(parameters);
if (parameters.containsKey("input")) {
try {
Map<String, String> chatParameters = gson.fromJson(parameters.get("input"), Map.class);
extractedParameters.putAll(chatParameters);
} catch (Exception exception) {
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
}
}
return extractedParameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,86 @@ public void testParsingJsonBlockFromResponse() {
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testParsingJsonBlockFromResponse2() {
// Prepare the response with JSON block
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
+ "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}";
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";

// Mock LLM response to not contain "thought" but contain "response" with JSON block
Map<String, String> llmResponse = new HashMap<>();
llmResponse.put("response", responseWithJsonBlock);
doAnswer(getLLMAnswer(llmResponse))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

// Create an MLAgent and run the MLChatAgentRunner
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "true");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Capture the response passed to the listener
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

// Extract the captured response
Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);

// Verify that the parsed values from JSON block are correctly set
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
assertEquals("Thought: parsed thought", modelTensor1.getResult());
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testParsingJsonBlockFromResponse3() {
// Prepare the response with JSON block
String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", "
+ "\"action_input\":{\"a\":\"n\"}, \"final_answer\":\"parsed final answer\"}";
String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text";

// Mock LLM response to not contain "thought" but contain "response" with JSON block
Map<String, String> llmResponse = new HashMap<>();
llmResponse.put("response", responseWithJsonBlock);
doAnswer(getLLMAnswer(llmResponse))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

// Create an MLAgent and run the MLChatAgentRunner
MLAgent mlAgent = createMLAgentWithTools();
Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put("verbose", "true");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Capture the response passed to the listener
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(agentActionListener).onResponse(responseCaptor.capture());

// Extract the captured response
Object capturedResponse = responseCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;

ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1);
ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0);
ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0);

// Verify that the parsed values from JSON block are correctly set
assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult());
assertEquals("Thought: parsed thought", modelTensor1.getResult());
assertEquals("parsed final answer", modelTensor2.getResult());
}

@Test
public void testRunWithIncludeOutputNotSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down Expand Up @@ -219,6 +299,35 @@ public void testRunWithIncludeOutputNotSet() {
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
}

@Test
public void testRunWithIncludeOutputMLModel() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Mockito
.doAnswer(generateToolResponseAsMLModelResult("First tool response", 1))
.when(firstTool)
.run(Mockito.anyMap(), toolListenerCaptor.capture());
Mockito
.doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2))
.when(secondTool)
.run(Mockito.anyMap(), toolListenerCaptor.capture());
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec, secondToolSpec))
.build();
mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener);
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue();
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors();
assertEquals(1, agentOutput.size());
// Respond with last tool output
assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response"));
}

@Test
public void testRunWithIncludeOutputSet() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
Expand Down Expand Up @@ -565,6 +674,23 @@ private Answer generateToolResponse(String response) {
};
}

private Answer generateToolResponseAsMLModelResult(String response, int type) {
ModelTensor modelTensor;
if (type == 1) {
modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("return", response)).build();
} else {
modelTensor = ModelTensor.builder().result(response).build();
}
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

return invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
listener.onResponse(mlModelTensorOutput);
return null;
};
}

private Answer generateToolFailure(Exception e) {
return invocation -> {
ActionListener<Object> listener = invocation.getArgument(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION;

import java.util.Arrays;
Expand All @@ -24,8 +25,6 @@
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand Down Expand Up @@ -67,13 +66,28 @@ public void setup() {
public void testAgenttestRunMethod() {
Map<String, String> parameters = new HashMap<>();
parameters.put("testKey", "testValue");
AgentMLInput agentMLInput = AgentMLInput
.AgentMLInputBuilder()
.agentId("agentId")
.functionName(FunctionName.AGENT)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build())
.build();
doTestRunMethod(parameters);
}

@Test
public void testAgentWithChatAgentInput() {
Map<String, String> parameters = new HashMap<>();
parameters.put("testKey", "testValue");
Map<String, String> chatAgentInput = new HashMap<>();
chatAgentInput.put("input", gson.toJson(parameters));
doTestRunMethod(chatAgentInput);
assertEquals(chatAgentInput.size(), 1);
assertEquals(chatAgentInput.get("input"), gson.toJson(parameters)); // assert no influence on original parameters
}

@Test
public void testAgentWithChatAgentInputWrongFormat() {
Map<String, String> chatAgentInput = new HashMap<>();
chatAgentInput.put("input", "wrong format");
doTestRunMethod(chatAgentInput);
}

private void doTestRunMethod(Map<String, String> parameters) {
ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
Expand Down

0 comments on commit 4a6ceba

Please sign in to comment.