From 7e03b170c899c4102303c54f5a2cdeed9ad97d6f Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Fri, 5 Jul 2024 14:19:18 -0500 Subject: [PATCH] address comments Signed-off-by: Bhavana Ramaram --- .../opensearch/ml/processor/MLInferenceIngestProcessor.java | 5 +++-- .../ml/processor/MLInferenceIngestProcessorFactoryTests.java | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index 381b57dee4..48e0464fbc 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -72,6 +72,7 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod // At default, ml inference processor allows maximum 10 prediction tasks running in parallel // it can be overwritten using max_prediction_tasks when creating processor public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; + public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }"; private final NamedXContentRegistry xContentRegistry; private Configuration suppressExceptionConfiguration = Configuration @@ -494,12 +495,12 @@ public MLInferenceIngestProcessor create( // if model input is not provided for remote models, use default value if (functionName.equalsIgnoreCase("remote")) { - modelInput = (modelInput != null) ? modelInput : "{ \"parameters\": ${ml_inference.parameters} }"; + modelInput = (modelInput != null) ? modelInput : DEFAULT_MODEl_INPUT; } else if (modelInput == null) { // if model input is not provided for local models, throw exception since it is mandatory here throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor"); } - boolean defaultFullResponsePath = !functionName.equalsIgnoreCase("remote"); + boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name()); boolean fullResponsePath = ConfigurationUtils .readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultFullResponsePath); diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java index 5068ae7043..c58c587ef7 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java @@ -207,7 +207,7 @@ public void testModelInputIsNullForLocalModels() throws Exception { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(MODEL_ID, "model2"); - config.put(FUNCTION_NAME, "text_embedding"); + config.put(FUNCTION_NAME, "remote"); Map model_config = new HashMap<>(); model_config.put("return_number", true); config.put(MODEL_CONFIG, model_config);