Skip to content

Commit 494711d

Browse files
committed
feat: support task cancellation for other agents
Signed-off-by: Pavan Yekbote <[email protected]>
1 parent 9ddf0e5 commit 494711d

File tree

5 files changed

+46
-10
lines changed

5 files changed

+46
-10
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,17 @@
6464
import org.opensearch.core.xcontent.NamedXContentRegistry;
6565
import org.opensearch.core.xcontent.XContentParser;
6666
import org.opensearch.index.IndexNotFoundException;
67+
import org.opensearch.ml.common.MLTaskState;
6768
import org.opensearch.ml.common.agent.MLAgent;
6869
import org.opensearch.ml.common.agent.MLToolSpec;
6970
import org.opensearch.ml.common.connector.Connector;
7071
import org.opensearch.ml.common.connector.McpConnector;
7172
import org.opensearch.ml.common.output.model.ModelTensor;
7273
import org.opensearch.ml.common.output.model.ModelTensorOutput;
7374
import org.opensearch.ml.common.spi.tools.Tool;
75+
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
76+
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
77+
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
7478
import org.opensearch.ml.common.utils.StringUtils;
7579
import org.opensearch.ml.engine.MLEngineClassLoader;
7680
import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor;
@@ -931,4 +935,14 @@ public static void cleanUpResource(Map<String, Tool> tools) {
931935
}
932936
}
933937
}
938+
939+
public static boolean isTaskMarkedForCancel(String taskId, Client client) {
940+
if (taskId != null && !taskId.isEmpty()) {
941+
MLTaskGetRequest taskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build();
942+
MLTaskGetResponse taskResponse = client.execute(MLTaskGetAction.INSTANCE, taskGetRequest).actionGet();
943+
return taskResponse.getMlTask().getState().equals(MLTaskState.CANCELLING);
944+
}
945+
946+
return false;
947+
}
934948
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.engine.algorithms.agent;
77

8+
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
89
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
910
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
1011
import static org.opensearch.ml.common.utils.StringUtils.gson;
@@ -35,6 +36,7 @@
3536
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
3637
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName;
3738
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolNames;
39+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.isTaskMarkedForCancel;
3840
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.outputToOutputString;
3941
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput;
4042
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute;
@@ -47,6 +49,7 @@
4749
import java.util.List;
4850
import java.util.Locale;
4951
import java.util.Map;
52+
import java.util.concurrent.CancellationException;
5053
import java.util.concurrent.ConcurrentHashMap;
5154
import java.util.concurrent.CopyOnWriteArrayList;
5255
import java.util.concurrent.atomic.AtomicInteger;
@@ -426,6 +429,13 @@ private void runReAct(
426429
int finalI = i;
427430
StepListener<?> nextStepListener = new StepListener<>();
428431

432+
// check if task has been marked to cancel
433+
String taskId = parameters.get(TASK_ID_FIELD);
434+
if (isTaskMarkedForCancel(taskId, client)) {
435+
listener.onFailure(new CancellationException(String.format("Agent execution cancelled for task: %s", taskId)));
436+
return;
437+
}
438+
429439
lastStepListener.whenComplete(output -> {
430440
StringBuilder sessionMsgAnswerBuilder = new StringBuilder();
431441
if (finalI % 2 == 0) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
99
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
10+
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
1011
import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD;
1112
import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD;
1213
import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID;
@@ -16,6 +17,7 @@
1617
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit;
1718
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
1819
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName;
20+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.isTaskMarkedForCancel;
1921
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;
2022

2123
import java.io.IOException;
@@ -26,6 +28,7 @@
2628
import java.util.HashMap;
2729
import java.util.List;
2830
import java.util.Map;
31+
import java.util.concurrent.CancellationException;
2932
import java.util.concurrent.ConcurrentHashMap;
3033
import java.util.concurrent.atomic.AtomicInteger;
3134

@@ -181,6 +184,12 @@ private void runAgent(
181184

182185
MLMemorySpec memorySpec = mlAgent.getMemory();
183186
for (int i = 0; i <= toolSpecs.size(); i++) {
187+
String taskId = params.get(TASK_ID_FIELD);
188+
if (isTaskMarkedForCancel(taskId, client)) {
189+
listener.onFailure(new CancellationException(String.format("Agent execution cancelled for task: %s", taskId)));
190+
return;
191+
}
192+
184193
if (i == 0) {
185194
MLToolSpec toolSpec = toolSpecs.get(i);
186195
Tool tool = createTool(toolFactories, params, toolSpec, mlAgent.getTenantId());

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
import static org.apache.commons.text.StringEscapeUtils.escapeJson;
99
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
10+
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
1011
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
1112
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getToolName;
13+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.isTaskMarkedForCancel;
1214

1315
import java.io.IOException;
1416
import java.security.AccessController;
@@ -17,6 +19,7 @@
1719
import java.util.HashMap;
1820
import java.util.List;
1921
import java.util.Map;
22+
import java.util.concurrent.CancellationException;
2023
import java.util.concurrent.ConcurrentHashMap;
2124

2225
import org.apache.commons.text.StringSubstitutor;
@@ -102,6 +105,12 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
102105
String parentInteractionId = params.get(MLAgentExecutor.PARENT_INTERACTION_ID);
103106

104107
for (int i = 0; i <= toolSpecs.size(); i++) {
108+
String taskId = params.get(TASK_ID_FIELD);
109+
if (isTaskMarkedForCancel(taskId, client)) {
110+
listener.onFailure(new CancellationException(String.format("Agent execution cancelled for task: %s", taskId)));
111+
return;
112+
}
113+
105114
if (i == 0) {
106115
MLToolSpec toolSpec = toolSpecs.get(i);
107116
Tool tool = createTool(toolSpec, mlAgent.getTenantId());

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools;
2020
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs;
2121
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs;
22+
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.isTaskMarkedForCancel;
2223
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE;
2324
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.MAX_ITERATION;
2425
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.saveTraceData;
@@ -65,9 +66,6 @@
6566
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
6667
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
6768
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
68-
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
69-
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
70-
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
7169
import org.opensearch.ml.common.utils.StringUtils;
7270
import org.opensearch.ml.engine.encryptor.Encryptor;
7371
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
@@ -407,13 +405,9 @@ private void executePlanningLoop(
407405

408406
// check if task has been marked to cancel
409407
String taskId = allParams.get(TASK_ID_FIELD);
410-
if (taskId != null && !taskId.isEmpty()) {
411-
MLTaskGetRequest taskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build();
412-
MLTaskGetResponse taskResponse = client.execute(MLTaskGetAction.INSTANCE, taskGetRequest).actionGet();
413-
if (taskResponse.getMlTask().getState().equals(MLTaskState.CANCELLING)) {
414-
finalListener.onFailure(new CancellationException(String.format("Agent execution cancelled for task: %s", taskId)));
415-
return;
416-
}
408+
if (isTaskMarkedForCancel(taskId, client)) {
409+
finalListener.onFailure(new CancellationException(String.format("Agent execution cancelled for task: %s", taskId)));
410+
return;
417411
}
418412

419413
client.execute(MLExecuteTaskAction.INSTANCE, executeRequest, ActionListener.wrap(executeResponse -> {

0 commit comments

Comments
 (0)