77
88import static org .opensearch .action .support .clustermanager .ClusterManagerNodeRequest .DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT ;
99import static org .opensearch .ml .common .CommonValue .TOOL_INPUT_SCHEMA_FIELD ;
10- import static org .opensearch .ml .common .settings .MLCommonsSettings .ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE ;
1110import static org .opensearch .ml .common .utils .StringUtils .gson ;
1211import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .DEFAULT_DATETIME_FORMAT ;
1312import static org .opensearch .ml .engine .algorithms .agent .AgentUtils .getCurrentDateTime ;
1413import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .DEFAULT_QUERY ;
1514import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT ;
1615import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .DEFAULT_QUERY_PLANNING_USER_PROMPT ;
1716import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .DEFAULT_SEARCH_TEMPLATE ;
18- import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .TEMPLATE_SELECTION_SYSTEM_PROMPT ;
19- import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .TEMPLATE_SELECTION_USER_PROMPT ;
17+ import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .DEFAULT_TEMPLATE_SELECTION_SYSTEM_PROMPT ;
18+ import static org .opensearch .ml .engine .tools .QueryPlanningPromptTemplate .DEFAULT_TEMPLATE_SELECTION_USER_PROMPT ;
2019
2120import java .io .IOException ;
2221import java .util .HashMap ;
2322import java .util .List ;
2423import java .util .Map ;
24+ import java .util .Set ;
25+ import java .util .stream .Collectors ;
2526
2627import org .apache .commons .text .StringSubstitutor ;
27- import org .opensearch .OpenSearchException ;
2828import org .opensearch .action .admin .cluster .storedscripts .GetStoredScriptRequest ;
2929import org .opensearch .action .admin .indices .get .GetIndexRequest ;
3030import org .opensearch .action .admin .indices .get .GetIndexResponse ;
3535import org .opensearch .core .action .ActionListener ;
3636import org .opensearch .index .IndexNotFoundException ;
3737import org .opensearch .index .query .QueryBuilders ;
38- import org .opensearch .ml .common .settings .MLFeatureEnabledSetting ;
3938import org .opensearch .ml .common .spi .tools .ToolAnnotation ;
4039import org .opensearch .ml .common .spi .tools .WithModelTool ;
4140import org .opensearch .ml .common .utils .ToolUtils ;
@@ -59,8 +58,12 @@ public class QueryPlanningTool implements WithModelTool {
5958 public static final String TYPE = "QueryPlanningTool" ;
6059 public static final String MODEL_ID_FIELD = "model_id" ;
6160 private final MLModelTool queryGenerationTool ;
62- public static final String SYSTEM_PROMPT_FIELD = "query_planner_system_prompt" ;
63- public static final String USER_PROMPT_FIELD = "query_planner_user_prompt" ;
61+ public static final String SYSTEM_PROMPT_FIELD = "system_prompt" ;
62+ public static final String USER_PROMPT_FIELD = "user_prompt" ;
63+ public static final String QUERY_PLANNER_SYSTEM_PROMPT_FIELD = "query_planner_system_prompt" ;
64+ public static final String QUERY_PLANNER_USER_PROMPT_FIELD = "query_planner_user_prompt" ;
65+ public static final String TEMPLATE_SELECTION_SYSTEM_PROMPT_FIELD = "template_selection_system_prompt" ;
66+ public static final String TEMPLATE_SELECTION_USER_PROMPT_FIELD = "template_selection_user_prompt" ;
6467 public static final String INDEX_MAPPING_FIELD = "index_mapping" ;
6568 public static final String QUERY_FIELDS_FIELD = "query_fields" ;
6669 public static final String GENERATION_TYPE_FIELD = "generation_type" ;
@@ -77,6 +80,13 @@ public class QueryPlanningTool implements WithModelTool {
7780 public static final String INDEX_NAME_FIELD = "index_name" ;
7881 private static final int MAX_TRUNCATE_CHARS = 250 ;
7982 private static final String TRUNC_PREFIX = "[truncated]" ;
83+ // Agent context parameter keys to ignore
84+ private static final String CHAT_HISTORY_FIELD = "_chat_history" ;
85+ private static final String TOOLS_FIELD = "_tools" ;
86+ private static final String INTERACTIONS_FIELD = "_interactions" ;
87+ private static final String TOOL_CONFIGS_FIELD = "tool_configs" ;
88+ private static final Set <String > AGENT_CONTEXT_EXCLUDED_PARAMS = Set
89+ .of (CHAT_HISTORY_FIELD , TOOLS_FIELD , INTERACTIONS_FIELD , TOOL_CONFIGS_FIELD );
8090
8191 @ Getter
8292 private final String generationType ;
@@ -121,10 +131,22 @@ public QueryPlanningTool(String generationType, MLModelTool queryGenerationTool,
121131 this .attributes = new HashMap <>(DEFAULT_ATTRIBUTES );
122132 }
123133
134+ private Map <String , String > stripAgentContextParameters (Map <String , String > originalParameters ) {
135+ // Drop agent-specific metadata that can bias or slow query planning; keep all other non-null params.
136+ // This enables using the same LLM for both the agent and the Query Planning Tool.
137+ // Excluded keys: _chat_history, _tools, _interactions, tool_configs
138+
139+ return originalParameters
140+ .entrySet ()
141+ .stream ()
142+ .filter (entry -> entry .getValue () != null && !AGENT_CONTEXT_EXCLUDED_PARAMS .contains (entry .getKey ()))
143+ .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
144+ }
145+
124146 @ Override
125147 public <T > void run (Map <String , String > originalParameters , ActionListener <T > listener ) {
126148 try {
127- Map <String , String > parameters = ToolUtils .extractInputParameters (originalParameters , attributes );
149+ Map <String , String > parameters = stripAgentContextParameters ( ToolUtils .extractInputParameters (originalParameters , attributes ) );
128150 if (!validate (parameters )) {
129151 listener
130152 .onFailure (
@@ -149,8 +171,19 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
149171
150172 // Template Selection, replace user and system prompts
151173 Map <String , String > templateSelectionParameters = new HashMap <>(parameters );
152- templateSelectionParameters .put (SYSTEM_PROMPT_FIELD , TEMPLATE_SELECTION_SYSTEM_PROMPT );
153- templateSelectionParameters .put (USER_PROMPT_FIELD , TEMPLATE_SELECTION_USER_PROMPT );
174+ templateSelectionParameters
175+ .put (
176+ SYSTEM_PROMPT_FIELD ,
177+ templateSelectionParameters
178+ .getOrDefault (TEMPLATE_SELECTION_SYSTEM_PROMPT_FIELD , DEFAULT_TEMPLATE_SELECTION_SYSTEM_PROMPT )
179+ );
180+
181+ templateSelectionParameters
182+ .put (
183+ USER_PROMPT_FIELD ,
184+ templateSelectionParameters .getOrDefault (TEMPLATE_SELECTION_USER_PROMPT_FIELD , DEFAULT_TEMPLATE_SELECTION_USER_PROMPT )
185+ );
186+
154187 templateSelectionParameters .put (SEARCH_TEMPLATES_FIELD , searchTemplates );
155188
156189 ActionListener <T > templateSelectionListener = ActionListener .wrap (r -> {
@@ -185,16 +218,14 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
185218 }
186219 }
187220
221+ @ SuppressWarnings ("unchecked" )
188222 private <T > void executeQueryPlanning (Map <String , String > parameters , ActionListener <T > listener ) {
189223 try {
190224 // Execute Query Planning, replace System and User prompt fields
191- if (!parameters .containsKey (SYSTEM_PROMPT_FIELD )) {
192- parameters .put (SYSTEM_PROMPT_FIELD , DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT );
193- }
225+ parameters
226+ .put (SYSTEM_PROMPT_FIELD , parameters .getOrDefault (QUERY_PLANNER_SYSTEM_PROMPT_FIELD , DEFAULT_QUERY_PLANNING_SYSTEM_PROMPT ));
194227
195- if (!parameters .containsKey (USER_PROMPT_FIELD )) {
196- parameters .put (USER_PROMPT_FIELD , DEFAULT_QUERY_PLANNING_USER_PROMPT );
197- }
228+ parameters .put (USER_PROMPT_FIELD , parameters .getOrDefault (QUERY_PLANNER_USER_PROMPT_FIELD , DEFAULT_QUERY_PLANNING_USER_PROMPT ));
198229
199230 if (parameters .containsKey (QUERY_FIELDS_FIELD )) {
200231 parameters .put (QUERY_FIELDS_FIELD , gson .toJson (parameters .get (QUERY_FIELDS_FIELD )));
@@ -360,7 +391,6 @@ public boolean validate(Map<String, String> parameters) {
360391 public static class Factory implements WithModelTool .Factory <QueryPlanningTool > {
361392 private Client client ;
362393 private static volatile Factory INSTANCE ;
363- private static MLFeatureEnabledSetting mlFeatureEnabledSetting ;
364394
365395 public static Factory getInstance () {
366396 if (INSTANCE != null ) {
@@ -375,18 +405,13 @@ public static Factory getInstance() {
375405 }
376406 }
377407
378- public void init (Client client , MLFeatureEnabledSetting mlFeatureEnabledSetting ) {
408+ public void init (Client client ) {
379409 this .client = client ;
380- this .mlFeatureEnabledSetting = mlFeatureEnabledSetting ;
381410 }
382411
383412 @ Override
384413 public QueryPlanningTool create (Map <String , Object > map ) {
385414
386- if (!mlFeatureEnabledSetting .isAgenticSearchEnabled ()) {
387- throw new OpenSearchException (ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE );
388- }
389-
390415 MLModelTool queryGenerationTool = MLModelTool .Factory .getInstance ().create (map );
391416
392417 String type = (String ) map .get (GENERATION_TYPE_FIELD );
0 commit comments