Skip to content

Commit

Permalink
model access mode correction
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Jun 30, 2023
1 parent 0086fe3 commit 5a562eb
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {
public static final String CONNECTOR_ID_FIELD = "connector_id";
public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value";
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
public static final String ACCESS_MODE = "access_mode"; //optional
public static final String MODEL_ACCESS_MODE = "model_access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional

private FunctionName functionName;
Expand All @@ -73,7 +73,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable {

private List<String> backendRoles;
private Boolean addAllBackendRoles;
private AccessMode accessMode;
private AccessMode modelAccessMode;

@Builder(toBuilder = true)
public MLRegisterModelInput(FunctionName functionName,
Expand All @@ -91,7 +91,7 @@ public MLRegisterModelInput(FunctionName functionName,
String connectorId,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode accessMode
AccessMode modelAccessMode
) {
if (functionName == null) {
this.functionName = FunctionName.TEXT_EMBEDDING;
Expand Down Expand Up @@ -123,7 +123,7 @@ public MLRegisterModelInput(FunctionName functionName,
this.connectorId = connectorId;
this.backendRoles = backendRoles;
this.addAllBackendRoles = addAllBackendRoles;
this.accessMode = accessMode;
this.modelAccessMode = modelAccessMode;
}


Expand Down Expand Up @@ -153,7 +153,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
}
this.addAllBackendRoles = in.readOptionalBoolean();
if (in.readBoolean()) {
this.accessMode = in.readEnum(AccessMode.class);
this.modelAccessMode = in.readEnum(AccessMode.class);
}
}

Expand Down Expand Up @@ -195,9 +195,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(addAllBackendRoles);
if (accessMode != null) {
if (modelAccessMode != null) {
out.writeBoolean(true);
out.writeEnum(accessMode);
out.writeEnum(modelAccessMode);
} else {
out.writeBoolean(false);
}
Expand Down Expand Up @@ -245,8 +245,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (addAllBackendRoles != null) {
builder.field(ADD_ALL_BACKEND_ROLES_FIELD, addAllBackendRoles);
}
if (accessMode != null) {
builder.field(ACCESS_MODE_FIELD, accessMode);
if (modelAccessMode != null) {
builder.field(ACCESS_MODE_FIELD, modelAccessMode);
}
builder.endObject();
return builder;
Expand All @@ -265,7 +265,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
String connectorId = null;
List<String> backendRoles = new ArrayList<>();
Boolean addAllBackendRoles = null;
AccessMode accessMode = null;
AccessMode modelAccessMode = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -319,14 +319,14 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName
addAllBackendRoles = parser.booleanValue();
break;
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
modelAccessMode = modelAccessMode.from(parser.text());
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, modelAccessMode);
}

public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException {
Expand All @@ -343,7 +343,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
Connector connector = null;
String connectorId = null;
List<String> backendRoles = new ArrayList<>();
AccessMode accessMode = null;
AccessMode modelAccessMode = null;
Boolean addAllBackendRoles = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
Expand Down Expand Up @@ -405,13 +405,13 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo
addAllBackendRoles = parser.booleanValue();
break;
case ACCESS_MODE_FIELD:
accessMode = AccessMode.from(parser.text());
modelAccessMode = modelAccessMode.from(parser.text());
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode);
return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, modelAccessMode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
private MLModelConfig modelConfig;
private Integer totalChunks;
private List<String> backendRoles;
private AccessMode AccessMode;
private AccessMode modelAccessMode;
private Boolean isAddAllBackendRoles;

@Builder(toBuilder = true)
public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List<String> backendRoles,
AccessMode AccessMode,
AccessMode modelAccessMode,
Boolean isAddAllBackendRoles) {
if (name == null) {
throw new IllegalArgumentException("model name is null");
Expand Down Expand Up @@ -101,7 +101,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
this.modelConfig = modelConfig;
this.totalChunks = totalChunks;
this.backendRoles = backendRoles;
this.AccessMode = AccessMode;
this.modelAccessMode = modelAccessMode;
this.isAddAllBackendRoles = isAddAllBackendRoles;
}

Expand All @@ -125,7 +125,7 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{
this.totalChunks = in.readInt();
this.backendRoles = in.readOptionalStringList();
if (in.readBoolean()) {
AccessMode = in.readEnum(AccessMode.class);
modelAccessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
}
Expand Down Expand Up @@ -164,9 +164,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (AccessMode != null) {
if (modelAccessMode != null) {
out.writeBoolean(true);
out.writeEnum(AccessMode);
out.writeEnum(modelAccessMode);
} else {
out.writeBoolean(false);
}
Expand Down Expand Up @@ -200,8 +200,8 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (backendRoles != null && backendRoles.size() > 0) {
builder.field(BACKEND_ROLES_FIELD, backendRoles);
}
if (AccessMode != null) {
builder.field(MODEL_ACCESS_MODE, AccessMode);
if (modelAccessMode != null) {
builder.field(MODEL_ACCESS_MODE, modelAccessMode);
}
if (isAddAllBackendRoles != null) {
builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles);
Expand All @@ -223,7 +223,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
MLModelConfig modelConfig = null;
Integer totalChunks = null;
List<String> backendRoles = null;
AccessMode AccessMode = null;
AccessMode modelAccessMode = null;
Boolean isAddAllBackendRoles = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
Expand Down Expand Up @@ -272,7 +272,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
}
break;
case MODEL_ACCESS_MODE:
AccessMode = AccessMode.from(parser.text().toLowerCase(Locale.ROOT));
modelAccessMode = AccessMode.from(parser.text().toLowerCase(Locale.ROOT));
break;
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = parser.booleanValue();
Expand All @@ -282,7 +282,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
break;
}
}
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, AccessMode, isAddAllBackendRoles);
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, modelAccessMode, isAddAllBackendRoles);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterMode
.name(registerModelInput.getModelName())
.description(registerModelInput.getDescription())
.backendRoles(registerModelInput.getBackendRoles())
.modelAccessMode(registerModelInput.getAccessMode())
.modelAccessMode(registerModelInput.getModelAccessMode())
.isAddAllBackendRoles(registerModelInput.getAddAllBackendRoles())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterMode
.name(mlUploadInput.getName())
.description(mlUploadInput.getDescription())
.backendRoles(mlUploadInput.getBackendRoles())
.modelAccessMode(mlUploadInput.getAccessMode())
.modelAccessMode(mlUploadInput.getModelAccessMode())
.isAddAllBackendRoles(mlUploadInput.getIsAddAllBackendRoles())
.build();
}
Expand Down

0 comments on commit 5a562eb

Please sign in to comment.