Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/multi_tenancy] Add more missing tenant ids in requests #2959

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse

sdkClient
.putDataObjectAsync(
PutDataObjectRequest.builder().index(ML_AGENT_INDEX).dataObject(mlAgent).build(),
PutDataObjectRequest.builder().index(ML_AGENT_INDEX).tenantId(tenantId).dataObject(mlAgent).build(),
client.threadPool().executor(GENERAL_THREAD_POOL)
)
.whenComplete((r, throwable) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
.builder()
.index(ML_MODEL_INDEX)
.id(modelId)
.tenantId(tenantId)
.fetchSourceContext(fetchSourceContext)
.build();
User user = RestActionUtils.getUserContext(client);
Expand Down Expand Up @@ -174,7 +175,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
);
} else {
if (isModelNotDeployed(mlModelState)) {
deleteModel(modelId, algorithmName, isHidden, actionListener);
deleteModel(modelId, tenantId, algorithmName, isHidden, actionListener);
} else {
wrappedListener
.onFailure(
Expand All @@ -201,7 +202,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
)
);
} else if (isModelNotDeployed(mlModelState)) {
deleteModel(modelId, mlModel.getAlgorithm().name(), isHidden, actionListener);
deleteModel(
modelId,
tenantId,
mlModel.getAlgorithm().name(),
isHidden,
actionListener
);
} else {
wrappedListener
.onFailure(
Expand Down Expand Up @@ -281,8 +288,19 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action
actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR));
}

private void deleteModel(String modelId, String functionName, Boolean isHidden, ActionListener<DeleteResponse> actionListener) {
DeleteDataObjectRequest deleteDataObjectRequest = DeleteDataObjectRequest.builder().index(ML_MODEL_INDEX).id(modelId).build();
private void deleteModel(
String modelId,
String tenantId,
String functionName,
Boolean isHidden,
ActionListener<DeleteResponse> actionListener
) {
DeleteDataObjectRequest deleteDataObjectRequest = DeleteDataObjectRequest
.builder()
.index(ML_MODEL_INDEX)
.id(modelId)
.tenantId(tenantId)
.build();
sdkClient
.deleteDataObjectAsync(deleteDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL))
.whenComplete((r, throwable) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ private void updateRemoteOrTextEmbeddingModel(
updateModelWithRegisteringToAnotherModelGroup(
modelId,
newModelGroupId,
tenantId,
user,
updateModelInput,
wrappedListener,
Expand Down Expand Up @@ -291,6 +292,7 @@ private void updateRemoteOrTextEmbeddingModel(
updateModelWithRegisteringToAnotherModelGroup(
modelId,
newModelGroupId,
tenantId,
user,
updateModelInput,
wrappedListener,
Expand Down Expand Up @@ -336,6 +338,7 @@ private void updateModelWithNewStandAloneConnector(
updateModelWithRegisteringToAnotherModelGroup(
modelId,
newModelGroupId,
tenantId,
user,
updateModelInput,
wrappedListener,
Expand Down Expand Up @@ -369,6 +372,7 @@ private void updateModelWithNewStandAloneConnector(
private void updateModelWithRegisteringToAnotherModelGroup(
String modelId,
String newModelGroupId,
String tenantId,
User user,
MLUpdateModelInput updateModelInput,
ActionListener<UpdateResponse> wrappedListener,
Expand Down Expand Up @@ -414,12 +418,13 @@ private void updateModelWithRegisteringToAnotherModelGroup(
wrappedListener.onFailure(exception);
}));
} else {
buildUpdateRequest(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache);
buildUpdateRequest(modelId, tenantId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache);
}
}

private void buildUpdateRequest(
String modelId,
String tenantId,
UpdateRequest updateRequest,
MLUpdateModelInput updateModelInput,
ActionListener<UpdateResponse> wrappedListener,
Expand All @@ -430,6 +435,7 @@ private void buildUpdateRequest(
.builder()
.index(updateRequest.index())
.id(updateRequest.id())
.tenantId(tenantId)
.dataObject(updateModelInput)
.build();
// TODO: This should probably be default on update data object:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ public void validateModelGroupAccess(
listener.onResponse(true);
return;
}
GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest.builder().index(ML_MODEL_GROUP_INDEX).id(modelGroupId).build();
GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest
.builder()
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.tenantId(tenantId)
.build();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> wrappedListener = ActionListener.runBefore(listener, context::restore);
sdkClient
Expand Down
Loading