From 4a6cebab817e4c8aaef556fdbe889767a8a3b6a6 Mon Sep 17 00:00:00 2001 From: xinyual <74362153+xinyual@users.noreply.github.com> Date: Sat, 3 Feb 2024 04:22:47 +0800 Subject: [PATCH] Fix argument pass (#1941) * add logs Signed-off-by: xinyual * option2FixBug Signed-off-by: xinyual * remove useless log Signed-off-by: xinyual * remove useless log Signed-off-by: xinyual * fix spot less Signed-off-by: xinyual * change argument parsing Signed-off-by: xinyual * move common function to utils Signed-off-by: xinyual * checkout for typo Signed-off-by: xinyual * remove useless code Signed-off-by: xinyual * add UTs Signed-off-by: xinyual * apply spotless Signed-off-by: xinyual * add more uts Signed-off-by: xinyual * modify import Signed-off-by: xinyual * protect original parameters Signed-off-by: xinyual --------- Signed-off-by: xinyual --- .../engine/algorithms/agent/AgentUtils.java | 33 +++++ .../algorithms/agent/MLChatAgentRunner.java | 26 ++-- .../opensearch/ml/engine/tools/AgentTool.java | 20 ++- .../agent/MLChatAgentRunnerTest.java | 126 ++++++++++++++++++ .../ml/engine/tools/AgentToolTests.java | 30 +++-- 5 files changed, 215 insertions(+), 20 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 11b9a91f5b..7e9bdd1ab0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -15,6 +15,9 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -23,6 +26,8 @@ import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.spi.tools.Tool; public class AgentUtils { @@ -152,4 +157,32 @@ public static String extractModelResponseJson(String text) { throw new IllegalArgumentException("Model output is invalid"); } } + + public static String outputToOutputString(Object output) throws PrivilegedActionException { + String outputString; + if (output instanceof ModelTensorOutput) { + ModelTensor outputModel = ((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors().get(0); + if (outputModel.getDataAsMap() != null) { + outputString = AccessController + .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(outputModel.getDataAsMap())); + } else { + outputString = outputModel.getResult(); + } + } else if (output instanceof String) { + outputString = (String) output; + } else { + outputString = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); + } + return outputString; + } + + public static String parseInputFromLLMReturn(Map retMap) { + Object actionInput = retMap.get("action_input"); + if (actionInput instanceof Map) { + return gson.toJson(actionInput); + } else { + return String.valueOf(actionInput); + } + + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 2a147557c8..e974328708 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -9,9 +9,9 @@ import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.extractModelResponseJson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseInputFromLLMReturn; -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -325,7 +325,7 @@ private void runReAct( } String thought = String.valueOf(dataAsMap.get("thought")); String action = String.valueOf(dataAsMap.get("action")); - String actionInput = String.valueOf(dataAsMap.get("action_input")); + String actionInput = parseInputFromLLMReturn(dataAsMap); String finalAnswer = (String) dataAsMap.get("final_answer"); if (!dataAsMap.containsKey("thought")) { String response = (String) dataAsMap.get("response"); @@ -336,7 +336,7 @@ private void runReAct( Map map = gson.fromJson(jsonBlock, Map.class); thought = String.valueOf(map.get("thought")); action = String.valueOf(map.get("action")); - actionInput = String.valueOf(map.get("action_input")); + actionInput = parseInputFromLLMReturn(map); finalAnswer = (String) map.get("final_answer"); } else { finalAnswer = response; @@ -513,9 +513,7 @@ private void runReAct( } else { MLToolSpec toolSpec = toolSpecMap.get(lastAction.get()); if (toolSpec != null && toolSpec.isIncludeOutputInAgentResponse()) { - String outputString = output instanceof String - ? (String) output - : AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(output)); + String outputString = outputToOutputString(output); String toolOutputKey = String.format("%s.output", toolSpec.getType()); if (additionalInfo.get(toolOutputKey) != null) { @@ -535,7 +533,13 @@ private void runReAct( .singletonList( ModelTensor .builder() - .dataAsMap(ImmutableMap.of("response", lastThought.get() + "\nObservation: " + output)) + .dataAsMap( + ImmutableMap + .of( + "response", + lastThought.get() + "\nObservation: " + outputToOutputString(output) + ) + ) .build() ) ) @@ -544,7 +548,7 @@ private void runReAct( String toolResponse = tmpParameters.get("prompt.tool_response"); StringSubstitutor toolResponseSubstitutor = new StringSubstitutor( - ImmutableMap.of("observation", output), + ImmutableMap.of("observation", outputToOutputString(output)), "${parameters.", "}" ); @@ -556,7 +560,7 @@ private void runReAct( .conversationIndexMessageBuilder() .type("ReAct") .question(lastActionInput.get()) - .response((String) output) + .response(outputToOutputString(output)) .finalAnswer(false) .sessionId(sessionId) .build(); @@ -571,7 +575,7 @@ private void runReAct( newPrompt.set(substitutor.replace(finalPrompt)); tmpParameters.put(PROMPT, newPrompt.get()); - sessionMsgAnswerBuilder.append("\nObservation: ").append(output); + sessionMsgAnswerBuilder.append("\nObservation: ").append(outputToOutputString(output)); cotModelTensors .add( ModelTensors diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index f048c62dc8..197f562bb6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -5,6 +5,9 @@ package org.opensearch.ml.engine.tools; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.HashMap; import java.util.Map; import org.opensearch.action.ActionRequest; @@ -51,11 +54,12 @@ public AgentTool(Client client, String agentId) { @Override public void run(Map parameters, ActionListener listener) { + Map extractedParameters = extractInputParameters(parameters); AgentMLInput agentMLInput = AgentMLInput .AgentMLInputBuilder() .agentId(agentId) .functionName(FunctionName.AGENT) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(extractedParameters).build()) .build(); ActionRequest request = new MLExecuteTaskRequest(FunctionName.AGENT, agentMLInput, false); client.execute(MLExecuteTaskAction.INSTANCE, request, ActionListener.wrap(r -> { @@ -135,4 +139,18 @@ public String getDefaultVersion() { return null; } } + + private Map extractInputParameters(Map parameters) { + Map extractedParameters = new HashMap<>(); + extractedParameters.putAll(parameters); + if (parameters.containsKey("input")) { + try { + Map chatParameters = gson.fromJson(parameters.get("input"), Map.class); + extractedParameters.putAll(chatParameters); + } catch (Exception exception) { + log.info("fail extract parameters from key 'input' due to" + exception.getMessage()); + } + } + return extractedParameters; + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6bc84593e..47ef11fffa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -188,6 +188,86 @@ public void testParsingJsonBlockFromResponse() { assertEquals("parsed final answer", modelTensor2.getResult()); } + @Test + public void testParsingJsonBlockFromResponse2() { + // Prepare the response with JSON block + String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " + + "\"action_input\":\"parsed action input\", \"final_answer\":\"parsed final answer\"}"; + String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text"; + + // Mock LLM response to not contain "thought" but contain "response" with JSON block + Map llmResponse = new HashMap<>(); + llmResponse.put("response", responseWithJsonBlock); + doAnswer(getLLMAnswer(llmResponse)) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + // Create an MLAgent and run the MLChatAgentRunner + MLAgent mlAgent = createMLAgentWithTools(); + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put("verbose", "true"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Capture the response passed to the listener + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); + verify(agentActionListener).onResponse(responseCaptor.capture()); + + // Extract the captured response + Object capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + + ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1); + ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); + ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0); + + // Verify that the parsed values from JSON block are correctly set + assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult()); + assertEquals("Thought: parsed thought", modelTensor1.getResult()); + assertEquals("parsed final answer", modelTensor2.getResult()); + } + + @Test + public void testParsingJsonBlockFromResponse3() { + // Prepare the response with JSON block + String jsonBlock = "{\"thought\":\"parsed thought\", \"action\":\"parsed action\", " + + "\"action_input\":{\"a\":\"n\"}, \"final_answer\":\"parsed final answer\"}"; + String responseWithJsonBlock = "Some text```json" + jsonBlock + "```More text"; + + // Mock LLM response to not contain "thought" but contain "response" with JSON block + Map llmResponse = new HashMap<>(); + llmResponse.put("response", responseWithJsonBlock); + doAnswer(getLLMAnswer(llmResponse)) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + // Create an MLAgent and run the MLChatAgentRunner + MLAgent mlAgent = createMLAgentWithTools(); + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put("verbose", "true"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Capture the response passed to the listener + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); + verify(agentActionListener).onResponse(responseCaptor.capture()); + + // Extract the captured response + Object capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + + ModelTensor parentInteractionModelTensor = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1); + ModelTensor modelTensor1 = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0); + ModelTensor modelTensor2 = modelTensorOutput.getMlModelOutputs().get(2).getMlModelTensors().get(0); + + // Verify that the parsed values from JSON block are correctly set + assertEquals("parent_interaction_id", parentInteractionModelTensor.getResult()); + assertEquals("Thought: parsed thought", modelTensor1.getResult()); + assertEquals("parsed final answer", modelTensor2.getResult()); + } + @Test public void testRunWithIncludeOutputNotSet() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); @@ -219,6 +299,35 @@ public void testRunWithIncludeOutputNotSet() { assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response")); } + @Test + public void testRunWithIncludeOutputMLModel() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + Mockito + .doAnswer(generateToolResponseAsMLModelResult("First tool response", 1)) + .when(firstTool) + .run(Mockito.anyMap(), toolListenerCaptor.capture()); + Mockito + .doAnswer(generateToolResponseAsMLModelResult("Second tool response", 2)) + .when(secondTool) + .run(Mockito.anyMap(), toolListenerCaptor.capture()); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + MLToolSpec secondToolSpec = MLToolSpec.builder().name(SECOND_TOOL).type(SECOND_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec, secondToolSpec)) + .build(); + mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener); + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + List agentOutput = modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + // Respond with last tool output + assertEquals("This is the final answer", agentOutput.get(0).getDataAsMap().get("response")); + } + @Test public void testRunWithIncludeOutputSet() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); @@ -565,6 +674,23 @@ private Answer generateToolResponse(String response) { }; } + private Answer generateToolResponseAsMLModelResult(String response, int type) { + ModelTensor modelTensor; + if (type == 1) { + modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("return", response)).build(); + } else { + modelTensor = ModelTensor.builder().result(response).build(); + } + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + return invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mlModelTensorOutput); + return null; + }; + } + private Answer generateToolFailure(Exception e) { return invocation -> { ActionListener listener = invocation.getArgument(1); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java index 431e609bba..02b6627a31 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/AgentToolTests.java @@ -10,6 +10,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.tools.AgentTool.DEFAULT_DESCRIPTION; import java.util.Arrays; @@ -24,8 +25,6 @@ import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -67,13 +66,28 @@ public void setup() { public void testAgenttestRunMethod() { Map parameters = new HashMap<>(); parameters.put("testKey", "testValue"); - AgentMLInput agentMLInput = AgentMLInput - .AgentMLInputBuilder() - .agentId("agentId") - .functionName(FunctionName.AGENT) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) - .build(); + doTestRunMethod(parameters); + } + + @Test + public void testAgentWithChatAgentInput() { + Map parameters = new HashMap<>(); + parameters.put("testKey", "testValue"); + Map chatAgentInput = new HashMap<>(); + chatAgentInput.put("input", gson.toJson(parameters)); + doTestRunMethod(chatAgentInput); + assertEquals(chatAgentInput.size(), 1); + assertEquals(chatAgentInput.get("input"), gson.toJson(parameters)); // assert no influence on original parameters + } + + @Test + public void testAgentWithChatAgentInputWrongFormat() { + Map chatAgentInput = new HashMap<>(); + chatAgentInput.put("input", "wrong format"); + doTestRunMethod(chatAgentInput); + } + private void doTestRunMethod(Map parameters) { ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();