Skip to content

Commit

Permalink
run auto deploy remote model in partially deployed status
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Jan 23, 2025
1 parent 1659a60 commit 0671963
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,21 @@ public String[] getWorkerNodes(String modelId) {
return modelCache.getWorkerNodes();
}

/**
* Get target worker nodes of model.
*
* @param modelId model id
* @return array of node id; return null if model not exists in cache
*/
public String[] getTargetWorkerNodes(String modelId) {
MLModelCache modelCache = modelCaches.get(modelId);
if (modelCache == null) {
return null;
}
return modelCache.getTargetWorkerNodes();
}


/**
* Add worker node of model.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2460,6 +2460,10 @@ public int getWorkerNodesSize(String modelId, FunctionName functionName) {
return getWorkerNodes(modelId, functionName, false).length;
}

public String[] getTargetWorkerNodes(String modelId) {
return modelCacheHelper.getTargetWorkerNodes(modelId);
}

/**
* Get predictable instance with model id.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ public void dispatchTask(
}
}, listener::onFailure);
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true);
if (workerNodes == null || workerNodes.length == 0) {
String[] targetWorkerNodes = mlModelManager.getTargetWorkerNodes(modelId);

if (requiresAutoDeployment(workerNodes, targetWorkerNodes)) {
if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlModelManager.getModel(modelId, ActionListener.runBefore(ActionListener.wrap(model -> {
Expand Down Expand Up @@ -568,4 +570,9 @@ public void validateOutputSchema(String modelId, ModelTensorOutput output) {
}
}
}

private boolean requiresAutoDeployment(String[] workerNodes, String[] targetWorkerNodes) {
return workerNodes == null || workerNodes.length == 0 ||
(targetWorkerNodes != null && workerNodes.length < targetWorkerNodes.length);
}
}

0 comments on commit 0671963

Please sign in to comment.