diff --git a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java index 07d75a32ce..d64050a8a3 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java @@ -28,7 +28,6 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import org.opensearch.ml.common.utils.StringUtils; import java.io.IOException; import java.security.AccessController; @@ -51,7 +50,7 @@ public class ModelGuardrail extends Guardrail { public static final String MODEL_ID_FIELD = "model_id"; public static final String RESPONSE_FILTER_FIELD = "response_filter"; - public static final String RESPONSE_ACCEPT_FIELD = "response_accept"; + public static final String RESPONSE_VALIDATION_REGEX_FIELD = "response_validation_regex"; private String modelId; private String responseFilter; @@ -67,7 +66,7 @@ public ModelGuardrail(String modelId, String responseFilter, String responseAcce this.responseAccept = responseAccept; } public ModelGuardrail(@NonNull Map params) { - this((String) params.get(MODEL_ID_FIELD), (String) params.get(RESPONSE_FILTER_FIELD), (String) params.get(RESPONSE_ACCEPT_FIELD)); + this((String) params.get(MODEL_ID_FIELD), (String) params.get(RESPONSE_FILTER_FIELD), (String) params.get(RESPONSE_VALIDATION_REGEX_FIELD)); } public ModelGuardrail(StreamInput input) throws IOException { @@ -157,7 +156,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(RESPONSE_FILTER_FIELD, responseFilter); } if (responseAccept != null) { - builder.field(RESPONSE_ACCEPT_FIELD, responseAccept); + builder.field(RESPONSE_VALIDATION_REGEX_FIELD, responseAccept); } builder.endObject(); return builder; @@ -180,7 +179,7 @@ public static ModelGuardrail parse(XContentParser parser) throws IOException { case RESPONSE_FILTER_FIELD: responseFilter = parser.text(); break; - case RESPONSE_ACCEPT_FIELD: + case RESPONSE_VALIDATION_REGEX_FIELD: responseAccept = parser.text(); break; default: diff --git a/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java index 9c30dc06e1..ebbbd0b9a4 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java @@ -83,12 +83,12 @@ public void toXContent() throws IOException { modelGuardrail.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_accept\":\"^accept$\"}", content); + Assert.assertEquals("{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_validation_regex\":\"^accept$\"}", content); } @Test public void parse() throws IOException { - String jsonStr = "{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_accept\":\"^accept$\"}"; + String jsonStr = "{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_validation_regex\":\"^accept$\"}"; XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), null, jsonStr); parser.nextToken(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java index 0ca15a6220..44533d3ae5 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java @@ -467,13 +467,13 @@ protected Response registerRemoteModelWithModelGuardrails(String name, String co + " \"model_id\": \"" + guardrailModelId + "\",\n" - + " \"response_accept\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\"" + + " \"response_validation_regex\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\"" + " },\n" + " \"output_guardrail\": {\n" + " \"model_id\": \"" + guardrailModelId + "\",\n" - + " \"response_accept\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\"" + + " \"response_validation_regex\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\"" + " }\n" + " }\n" + "}";