Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport feature/multi_tenancy] [Backport 2.15] change guardrail field name #3042

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.security.AccessController;
Expand All @@ -51,7 +50,7 @@
public class ModelGuardrail extends Guardrail {
public static final String MODEL_ID_FIELD = "model_id";
public static final String RESPONSE_FILTER_FIELD = "response_filter";
public static final String RESPONSE_ACCEPT_FIELD = "response_accept";
public static final String RESPONSE_VALIDATION_REGEX_FIELD = "response_validation_regex";

private String modelId;
private String responseFilter;
Expand All @@ -67,7 +66,7 @@ public ModelGuardrail(String modelId, String responseFilter, String responseAcce
this.responseAccept = responseAccept;
}
public ModelGuardrail(@NonNull Map<String, Object> params) {
this((String) params.get(MODEL_ID_FIELD), (String) params.get(RESPONSE_FILTER_FIELD), (String) params.get(RESPONSE_ACCEPT_FIELD));
this((String) params.get(MODEL_ID_FIELD), (String) params.get(RESPONSE_FILTER_FIELD), (String) params.get(RESPONSE_VALIDATION_REGEX_FIELD));
}

public ModelGuardrail(StreamInput input) throws IOException {
Expand Down Expand Up @@ -157,7 +156,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(RESPONSE_FILTER_FIELD, responseFilter);
}
if (responseAccept != null) {
builder.field(RESPONSE_ACCEPT_FIELD, responseAccept);
builder.field(RESPONSE_VALIDATION_REGEX_FIELD, responseAccept);
}
builder.endObject();
return builder;
Expand All @@ -180,7 +179,7 @@ public static ModelGuardrail parse(XContentParser parser) throws IOException {
case RESPONSE_FILTER_FIELD:
responseFilter = parser.text();
break;
case RESPONSE_ACCEPT_FIELD:
case RESPONSE_VALIDATION_REGEX_FIELD:
responseAccept = parser.text();
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ public void toXContent() throws IOException {
modelGuardrail.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);

Assert.assertEquals("{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_accept\":\"^accept$\"}", content);
Assert.assertEquals("{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_validation_regex\":\"^accept$\"}", content);
}

@Test
public void parse() throws IOException {
String jsonStr = "{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_accept\":\"^accept$\"}";
String jsonStr = "{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_validation_regex\":\"^accept$\"}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,13 @@ protected Response registerRemoteModelWithModelGuardrails(String name, String co
+ " \"model_id\": \""
+ guardrailModelId
+ "\",\n"
+ " \"response_accept\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\""
+ " \"response_validation_regex\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\""
+ " },\n"
+ " \"output_guardrail\": {\n"
+ " \"model_id\": \""
+ guardrailModelId
+ "\",\n"
+ " \"response_accept\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\""
+ " \"response_validation_regex\": \"^\\\"\\\\s*[Aa]ccept\\\\s*\\\"$\""
+ " }\n"
+ " }\n"
+ "}";
Expand Down
Loading