Skip to content

Commit c04f537

Browse files
[Agentic Search]Use same model for Agent and QPT (#4262)
1 parent c663291 commit c04f537

File tree

13 files changed

+560
-248
lines changed

13 files changed

+560
-248
lines changed

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,6 @@ private MLCommonsSettings() {}
311311
public static final Setting<Boolean> ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting
312312
.boolSetting(ML_PLUGIN_SETTING_PREFIX + "memory_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
313313

314-
public static final Setting<Boolean> ML_COMMONS_AGENTIC_SEARCH_ENABLED = Setting
315-
.boolSetting(ML_PLUGIN_SETTING_PREFIX + "agentic_search_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
316-
public static final String ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE =
317-
"The QueryPlanningTool tool for Agentic Search is not enabled. To enable, please update the setting "
318-
+ ML_COMMONS_AGENTIC_SEARCH_ENABLED.getKey();
319-
320314
public static final Setting<Boolean> ML_COMMONS_MCP_CONNECTOR_ENABLED = Setting
321315
.boolSetting(ML_PLUGIN_SETTING_PREFIX + "mcp_connector_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
322316
public static final String ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE =

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.ml.common.settings;
77

88
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED;
9-
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
109
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
1110
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
1211
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
@@ -57,8 +56,6 @@ public class MLFeatureEnabledSetting {
5756

5857
private volatile Boolean isExecuteToolEnabled;
5958

60-
private volatile Boolean isAgenticSearchEnabled;
61-
6259
private volatile Boolean isMcpConnectorEnabled;
6360

6461
private volatile Boolean isAgenticMemoryEnabled;
@@ -83,7 +80,6 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
8380
isMetricCollectionEnabled = ML_COMMONS_METRIC_COLLECTION_ENABLED.get(settings);
8481
isStaticMetricCollectionEnabled = ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED.get(settings);
8582
isExecuteToolEnabled = ML_COMMONS_EXECUTE_TOOL_ENABLED.get(settings);
86-
isAgenticSearchEnabled = ML_COMMONS_AGENTIC_SEARCH_ENABLED.get(settings);
8783
isMcpConnectorEnabled = ML_COMMONS_MCP_CONNECTOR_ENABLED.get(settings);
8884
isAgenticMemoryEnabled = ML_COMMONS_AGENTIC_MEMORY_ENABLED.get(settings);
8985
isIndexInsightEnabled = ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED.get(settings);
@@ -111,7 +107,6 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
111107
.getClusterSettings()
112108
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> isRagSearchPipelineEnabled = it);
113109
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_EXECUTE_TOOL_ENABLED, it -> isExecuteToolEnabled = it);
114-
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_SEARCH_ENABLED, it -> isAgenticSearchEnabled = it);
115110
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> isMcpConnectorEnabled = it);
116111
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_MEMORY_ENABLED, it -> isAgenticMemoryEnabled = it);
117112
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STREAM_ENABLED, it -> isStreamEnabled = it);
@@ -237,10 +232,6 @@ public void notifyMultiTenancyListeners(boolean isEnabled) {
237232
}
238233
}
239234

240-
public boolean isAgenticSearchEnabled() {
241-
return isAgenticSearchEnabled;
242-
}
243-
244235
public boolean isMcpConnectorEnabled() {
245236
return isMcpConnectorEnabled;
246237
}

common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ public void setUp() {
4545
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
4646
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
4747
MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_ENABLED,
48-
MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED,
4948
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
5049
MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED,
5150
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED,
@@ -92,7 +91,6 @@ public void testDefaults_allFeaturesEnabled() {
9291
assertTrue(setting.isMetricCollectionEnabled());
9392
assertTrue(setting.isStaticMetricCollectionEnabled());
9493
assertTrue(setting.isMcpConnectorEnabled());
95-
assertTrue(setting.isAgenticSearchEnabled());
9694
assertTrue(setting.isAgenticMemoryEnabled());
9795
assertTrue(setting.isStreamEnabled());
9896
}
@@ -134,7 +132,6 @@ public void testDefaults_someFeaturesDisabled() {
134132
assertFalse(setting.isMetricCollectionEnabled());
135133
assertFalse(setting.isStaticMetricCollectionEnabled());
136134
assertFalse(setting.isMcpConnectorEnabled());
137-
assertFalse(setting.isAgenticSearchEnabled());
138135
assertFalse(setting.isAgenticMemoryEnabled());
139136
assertFalse(setting.isStreamEnabled());
140137
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
1515
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
1616
import static org.opensearch.ml.common.output.model.ModelTensorOutput.INFERENCE_RESULT_FIELD;
17-
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
1817
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE;
1918
import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly;
2019

@@ -52,7 +51,6 @@
5251
import org.opensearch.ml.common.MLTaskType;
5352
import org.opensearch.ml.common.agent.MLAgent;
5453
import org.opensearch.ml.common.agent.MLMemorySpec;
55-
import org.opensearch.ml.common.agent.MLToolSpec;
5654
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5755
import org.opensearch.ml.common.input.Input;
5856
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
@@ -71,7 +69,6 @@
7169
import org.opensearch.ml.engine.indices.MLIndicesHandler;
7270
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
7371
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
74-
import org.opensearch.ml.engine.tools.QueryPlanningTool;
7572
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
7673
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
7774
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;
@@ -468,15 +465,6 @@ private void executeAgent(
468465
listener.onFailure(new OpenSearchException(ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE));
469466
return;
470467
}
471-
List<MLToolSpec> tools = mlAgent.getTools();
472-
if (tools != null) {
473-
for (MLToolSpec tool : tools) {
474-
if (tool.getType().equals(QueryPlanningTool.TYPE) && !mlFeatureEnabledSetting.isAgenticSearchEnabled()) {
475-
listener.onFailure(new OpenSearchException(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE));
476-
return;
477-
}
478-
}
479-
}
480468

481469
MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent);
482470
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID);

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningPromptTemplate.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ public class QueryPlanningPromptTemplate {
275275
+ "]\n"
276276
+ "Example output : 'product-search-template'";
277277

278-
public static final String TEMPLATE_SELECTION_SYSTEM_PROMPT = TEMPLATE_SELECTION_PURPOSE
278+
public static final String DEFAULT_TEMPLATE_SELECTION_SYSTEM_PROMPT = TEMPLATE_SELECTION_PURPOSE
279279
+ "==== GOAL ====\n"
280280
+ TEMPLATE_SELECTION_GOAL
281281
+ "\n"
@@ -291,7 +291,7 @@ public class QueryPlanningPromptTemplate {
291291
+ "==== EXAMPLES ====\n"
292292
+ TEMPLATE_SELECTION_EXAMPLES;
293293

294-
public static final String TEMPLATE_SELECTION_USER_PROMPT = "==== INPUTS ====\n" + TEMPLATE_SELECTION_INPUTS;
294+
public static final String DEFAULT_TEMPLATE_SELECTION_USER_PROMPT = "==== INPUTS ====\n" + TEMPLATE_SELECTION_INPUTS;
295295

296296
public static final String DEFAULT_SEARCH_TEMPLATE = "{"
297297
+ "\"from\": {{from}}{{^from}}0{{/from}},"

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,24 @@
77

88
import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT;
99
import static org.opensearch.ml.common.CommonValue.TOOL_INPUT_SCHEMA_FIELD;
10-
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
1110
import static org.opensearch.ml.common.utils.StringUtils.gson;
1211
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DEFAULT_DATETIME_FORMAT;
1312
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime;
1413
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY;
1514
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT;
1615
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY_PLANNING_USER_PROMPT;
1716
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_SEARCH_TEMPLATE;
18-
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.TEMPLATE_SELECTION_SYSTEM_PROMPT;
19-
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.TEMPLATE_SELECTION_USER_PROMPT;
17+
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_TEMPLATE_SELECTION_SYSTEM_PROMPT;
18+
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_TEMPLATE_SELECTION_USER_PROMPT;
2019

2120
import java.io.IOException;
2221
import java.util.HashMap;
2322
import java.util.List;
2423
import java.util.Map;
24+
import java.util.Set;
25+
import java.util.stream.Collectors;
2526

2627
import org.apache.commons.text.StringSubstitutor;
27-
import org.opensearch.OpenSearchException;
2828
import org.opensearch.action.admin.cluster.storedscripts.GetStoredScriptRequest;
2929
import org.opensearch.action.admin.indices.get.GetIndexRequest;
3030
import org.opensearch.action.admin.indices.get.GetIndexResponse;
@@ -35,7 +35,6 @@
3535
import org.opensearch.core.action.ActionListener;
3636
import org.opensearch.index.IndexNotFoundException;
3737
import org.opensearch.index.query.QueryBuilders;
38-
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
3938
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
4039
import org.opensearch.ml.common.spi.tools.WithModelTool;
4140
import org.opensearch.ml.common.utils.ToolUtils;
@@ -59,8 +58,12 @@ public class QueryPlanningTool implements WithModelTool {
5958
public static final String TYPE = "QueryPlanningTool";
6059
public static final String MODEL_ID_FIELD = "model_id";
6160
private final MLModelTool queryGenerationTool;
62-
public static final String SYSTEM_PROMPT_FIELD = "query_planner_system_prompt";
63-
public static final String USER_PROMPT_FIELD = "query_planner_user_prompt";
61+
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
62+
public static final String USER_PROMPT_FIELD = "user_prompt";
63+
public static final String QUERY_PLANNER_SYSTEM_PROMPT_FIELD = "query_planner_system_prompt";
64+
public static final String QUERY_PLANNER_USER_PROMPT_FIELD = "query_planner_user_prompt";
65+
public static final String TEMPLATE_SELECTION_SYSTEM_PROMPT_FIELD = "template_selection_system_prompt";
66+
public static final String TEMPLATE_SELECTION_USER_PROMPT_FIELD = "template_selection_user_prompt";
6467
public static final String INDEX_MAPPING_FIELD = "index_mapping";
6568
public static final String QUERY_FIELDS_FIELD = "query_fields";
6669
public static final String GENERATION_TYPE_FIELD = "generation_type";
@@ -77,6 +80,13 @@ public class QueryPlanningTool implements WithModelTool {
7780
public static final String INDEX_NAME_FIELD = "index_name";
7881
private static final int MAX_TRUNCATE_CHARS = 250;
7982
private static final String TRUNC_PREFIX = "[truncated]";
83+
// Agent context parameter keys to ignore
84+
private static final String CHAT_HISTORY_FIELD = "_chat_history";
85+
private static final String TOOLS_FIELD = "_tools";
86+
private static final String INTERACTIONS_FIELD = "_interactions";
87+
private static final String TOOL_CONFIGS_FIELD = "tool_configs";
88+
private static final Set<String> AGENT_CONTEXT_EXCLUDED_PARAMS = Set
89+
.of(CHAT_HISTORY_FIELD, TOOLS_FIELD, INTERACTIONS_FIELD, TOOL_CONFIGS_FIELD);
8090

8191
@Getter
8292
private final String generationType;
@@ -121,10 +131,22 @@ public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool,
121131
this.attributes = new HashMap<>(DEFAULT_ATTRIBUTES);
122132
}
123133

134+
private Map<String, String> stripAgentContextParameters(Map<String, String> originalParameters) {
135+
// Drop agent-specific metadata that can bias or slow query planning; keep all other non-null params.
136+
// This enables using the same LLM for both the agent and the Query Planning Tool.
137+
// Excluded keys: _chat_history, _tools, _interactions, tool_configs
138+
139+
return originalParameters
140+
.entrySet()
141+
.stream()
142+
.filter(entry -> entry.getValue() != null && !AGENT_CONTEXT_EXCLUDED_PARAMS.contains(entry.getKey()))
143+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
144+
}
145+
124146
@Override
125147
public <T> void run(Map<String, String> originalParameters, ActionListener<T> listener) {
126148
try {
127-
Map<String, String> parameters = ToolUtils.extractInputParameters(originalParameters, attributes);
149+
Map<String, String> parameters = stripAgentContextParameters(ToolUtils.extractInputParameters(originalParameters, attributes));
128150
if (!validate(parameters)) {
129151
listener
130152
.onFailure(
@@ -149,8 +171,19 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
149171

150172
// Template Selection, replace user and system prompts
151173
Map<String, String> templateSelectionParameters = new HashMap<>(parameters);
152-
templateSelectionParameters.put(SYSTEM_PROMPT_FIELD, TEMPLATE_SELECTION_SYSTEM_PROMPT);
153-
templateSelectionParameters.put(USER_PROMPT_FIELD, TEMPLATE_SELECTION_USER_PROMPT);
174+
templateSelectionParameters
175+
.put(
176+
SYSTEM_PROMPT_FIELD,
177+
templateSelectionParameters
178+
.getOrDefault(TEMPLATE_SELECTION_SYSTEM_PROMPT_FIELD, DEFAULT_TEMPLATE_SELECTION_SYSTEM_PROMPT)
179+
);
180+
181+
templateSelectionParameters
182+
.put(
183+
USER_PROMPT_FIELD,
184+
templateSelectionParameters.getOrDefault(TEMPLATE_SELECTION_USER_PROMPT_FIELD, DEFAULT_TEMPLATE_SELECTION_USER_PROMPT)
185+
);
186+
154187
templateSelectionParameters.put(SEARCH_TEMPLATES_FIELD, searchTemplates);
155188

156189
ActionListener<T> templateSelectionListener = ActionListener.wrap(r -> {
@@ -185,16 +218,14 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
185218
}
186219
}
187220

221+
@SuppressWarnings("unchecked")
188222
private <T> void executeQueryPlanning(Map<String, String> parameters, ActionListener<T> listener) {
189223
try {
190224
// Execute Query Planning, replace System and User prompt fields
191-
if (!parameters.containsKey(SYSTEM_PROMPT_FIELD)) {
192-
parameters.put(SYSTEM_PROMPT_FIELD, DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT);
193-
}
225+
parameters
226+
.put(SYSTEM_PROMPT_FIELD, parameters.getOrDefault(QUERY_PLANNER_SYSTEM_PROMPT_FIELD, DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT));
194227

195-
if (!parameters.containsKey(USER_PROMPT_FIELD)) {
196-
parameters.put(USER_PROMPT_FIELD, DEFAULT_QUERY_PLANNING_USER_PROMPT);
197-
}
228+
parameters.put(USER_PROMPT_FIELD, parameters.getOrDefault(QUERY_PLANNER_USER_PROMPT_FIELD, DEFAULT_QUERY_PLANNING_USER_PROMPT));
198229

199230
if (parameters.containsKey(QUERY_FIELDS_FIELD)) {
200231
parameters.put(QUERY_FIELDS_FIELD, gson.toJson(parameters.get(QUERY_FIELDS_FIELD)));
@@ -360,7 +391,6 @@ public boolean validate(Map<String, String> parameters) {
360391
public static class Factory implements WithModelTool.Factory<QueryPlanningTool> {
361392
private Client client;
362393
private static volatile Factory INSTANCE;
363-
private static MLFeatureEnabledSetting mlFeatureEnabledSetting;
364394

365395
public static Factory getInstance() {
366396
if (INSTANCE != null) {
@@ -375,18 +405,13 @@ public static Factory getInstance() {
375405
}
376406
}
377407

378-
public void init(Client client, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
408+
public void init(Client client) {
379409
this.client = client;
380-
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
381410
}
382411

383412
@Override
384413
public QueryPlanningTool create(Map<String, Object> map) {
385414

386-
if (!mlFeatureEnabledSetting.isAgenticSearchEnabled()) {
387-
throw new OpenSearchException(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE);
388-
}
389-
390415
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(map);
391416

392417
String type = (String) map.get(GENERATION_TYPE_FIELD);

0 commit comments

Comments
 (0)