Skip to content

Commit 18998f8

Browse files
committed
fix: change logic for starting job to support single node cluster as well
Signed-off-by: Pavan Yekbote <[email protected]>
1 parent a584903 commit 18998f8

File tree

3 files changed

+115
-77
lines changed

3 files changed

+115
-77
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.mockito.Mockito.when;
1212
import static org.opensearch.ml.common.CommonValue.META;
1313
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
14+
import static org.opensearch.ml.common.CommonValue.ML_JOBS_INDEX;
1415
import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX;
1516
import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX;
1617
import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD;
@@ -216,4 +217,36 @@ public void initMLConnectorIndex_ResourceAlreadyExistsException_RaceCondition()
216217
verify(listener).onResponse(argumentCaptor.capture());
217218
assertEquals(true, argumentCaptor.getValue());
218219
}
220+
221+
@Test
222+
public void initMLJobsIndex() {
223+
ActionListener<Boolean> listener = mock(ActionListener.class);
224+
doAnswer(invocation -> {
225+
ActionListener<AcknowledgedResponse> actionListener = invocation.getArgument(1);
226+
actionListener.onResponse(new AcknowledgedResponse(true));
227+
return null;
228+
}).when(indicesAdminClient).putMapping(any(), any());
229+
ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class);
230+
indicesHandler.initMLJobsIndex(listener);
231+
232+
verify(listener).onResponse(argumentCaptor.capture());
233+
assertEquals(true, argumentCaptor.getValue());
234+
}
235+
236+
@Test
237+
public void initMLJobsIndexNoIndex() {
238+
ActionListener<Boolean> listener = mock(ActionListener.class);
239+
when(metadata.hasIndex(anyString())).thenReturn(false);
240+
doAnswer(invocation -> {
241+
ActionListener<CreateIndexResponse> actionListener = invocation.getArgument(1);
242+
actionListener.onResponse(new CreateIndexResponse(true, true, ML_JOBS_INDEX));
243+
return null;
244+
}).when(indicesAdminClient).create(any(), any());
245+
ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class);
246+
indicesHandler.initMLJobsIndex(listener);
247+
248+
verify(indicesAdminClient).create(isA(CreateIndexRequest.class), any());
249+
verify(listener).onResponse(argumentCaptor.capture());
250+
assertEquals(true, argumentCaptor.getValue());
251+
}
219252
}

plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterEventListener.java

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,31 @@ public void clusterChanged(ClusterChangedEvent event) {
7575
Set<String> removedNodeIds = delta.removedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toSet());
7676
mlModelManager.removeWorkerNodes(removedNodeIds, false);
7777
} else if (delta.added()) {
78-
for (DiscoveryNode node : delta.addedNodes()) {
79-
// 3.1 introduces a new index for the job scheduler to track jobs
80-
// the statsCollectorJob needs to be run when a cluster is started with the stats settings enabled
81-
// As a result, we need to wait for a data node to come up before creating the new jobs index
82-
if (node.isDataNode() && Version.V_3_1_0.onOrAfter(node.getVersion())) {
83-
if (mlFeatureEnabledSetting.isMetricCollectionEnabled() && mlFeatureEnabledSetting.isStaticMetricCollectionEnabled()) {
84-
mlTaskManager.startStatsCollectorJob();
85-
}
78+
List<String> addedNodesIds = delta.addedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toList());
79+
mlModelAutoReDeployer.buildAutoReloadArrangement(addedNodesIds, state.getNodes().getClusterManagerNodeId());
80+
}
8681

87-
if (clusterService.state().getMetadata().hasIndex(TASK_POLLING_JOB_INDEX)) {
88-
mlTaskManager.startTaskPollingJob();
89-
}
82+
/*
83+
* In version 3.1, a new index `.plugins-ml-jobs` replaces the old `.ml_commons_task_polling_job` index for the job scheduler.
84+
* Version 3.1 also introduces a stats collector job that should run at startup if the relevant settings are enabled.
85+
* When upgrading from 3.0 to 3.1, we need to ensure the new `.plugins-ml-jobs` index is created if either:
86+
* - The stats collector job is enabled, or
87+
* - The batch polling task job was already running.
88+
* To avoid issues during blue/green or rolling upgrades, we wait for a data node running 3.1 or later before creating the new jobs index and starting the jobs.
89+
* The following logic implements this behavior.
90+
*/
91+
for (DiscoveryNode node : state.nodes()) {
92+
if (node.isDataNode() && Version.V_3_1_0.onOrAfter(node.getVersion())) {
93+
if (mlFeatureEnabledSetting.isMetricCollectionEnabled() && mlFeatureEnabledSetting.isStaticMetricCollectionEnabled()) {
94+
mlTaskManager.startStatsCollectorJob();
9095
}
91-
}
9296

93-
List<String> addedNodesIds = delta.addedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toList());
94-
mlModelAutoReDeployer.buildAutoReloadArrangement(addedNodesIds, state.getNodes().getClusterManagerNodeId());
97+
if (clusterService.state().getMetadata().hasIndex(TASK_POLLING_JOB_INDEX)) {
98+
mlTaskManager.startTaskPollingJob();
99+
}
100+
101+
break;
102+
}
95103
}
96104
}
97105
}

plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java

Lines changed: 60 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ public class MLTaskManager {
8282
private final MLIndicesHandler mlIndicesHandler;
8383
private final Map<MLTaskType, AtomicInteger> runningTasksCount;
8484
private boolean taskPollingJobStarted;
85+
private boolean statsCollectorJobStarted;
8586
public static final ImmutableSet<MLTaskState> TASK_DONE_STATES = ImmutableSet
8687
.of(MLTaskState.COMPLETED, MLTaskState.COMPLETED_WITH_ERROR, MLTaskState.FAILED, MLTaskState.CANCELLED);
8788

@@ -546,74 +547,70 @@ public void startTaskPollingJob() {
546547
return;
547548
}
548549

549-
mlIndicesHandler.initMLJobsIndex(ActionListener.wrap(success -> {
550-
if (success) {
551-
String id = "ml_batch_task_polling_job";
552-
String jobName = "poll_batch_jobs";
553-
String interval = "1";
554-
Long lockDurationSeconds = 20L;
555-
556-
MLJobParameter jobParameter = new MLJobParameter(
557-
jobName,
558-
new IntervalSchedule(Instant.now(), Integer.parseInt(interval), ChronoUnit.MINUTES),
559-
lockDurationSeconds,
560-
null,
561-
MLJobType.BATCH_TASK_UPDATE
562-
);
563-
IndexRequest indexRequest = new IndexRequest()
564-
.index(CommonValue.ML_JOBS_INDEX)
565-
.id(id)
566-
.source(jobParameter.toXContent(JsonXContent.contentBuilder(), null))
567-
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
568-
569-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
570-
client.index(indexRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
571-
log.info("Indexed ml task polling job successfully");
572-
this.taskPollingJobStarted = true;
573-
}, e -> log.error("Failed to index task polling job", e)), context::restore));
574-
}
575-
}
576-
}, e -> log.error("Failed to initialize ML jobs index", e)));
550+
try {
551+
MLJobParameter jobParameter = new MLJobParameter(
552+
MLJobType.BATCH_TASK_UPDATE.name(),
553+
new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES),
554+
20L,
555+
null,
556+
MLJobType.BATCH_TASK_UPDATE
557+
);
558+
559+
IndexRequest indexRequest = new IndexRequest()
560+
.index(CommonValue.ML_JOBS_INDEX)
561+
.id(MLJobType.BATCH_TASK_UPDATE.name())
562+
.source(jobParameter.toXContent(JsonXContent.contentBuilder(), null))
563+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
564+
565+
startJob(indexRequest, MLJobType.BATCH_TASK_UPDATE, () -> this.taskPollingJobStarted = true);
566+
} catch (IOException e) {
567+
log.error("Failed to index task polling job", e);
568+
}
577569
}
578570

579571
public void startStatsCollectorJob() {
572+
if (statsCollectorJobStarted) {
573+
return;
574+
}
575+
576+
try {
577+
MLJobParameter jobParameter = new MLJobParameter(
578+
MLJobType.STATS_COLLECTOR.name(),
579+
new IntervalSchedule(Instant.now(), 5, ChronoUnit.MINUTES),
580+
60L,
581+
null,
582+
MLJobType.STATS_COLLECTOR
583+
);
584+
585+
IndexRequest indexRequest = new IndexRequest()
586+
.index(CommonValue.ML_JOBS_INDEX)
587+
.id(MLJobType.STATS_COLLECTOR.name())
588+
.source(jobParameter.toXContent(JsonXContent.contentBuilder(), null))
589+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
590+
591+
startJob(indexRequest, MLJobType.STATS_COLLECTOR, () -> this.statsCollectorJobStarted = true);
592+
} catch (IOException e) {
593+
log.error("Failed to index stats collection job", e);
594+
}
595+
}
596+
597+
/**
598+
* Start a job by indexing the job parameter to ML jobs index.
599+
*
600+
* @param indexRequest the index request containing the job parameter
601+
* @param jobType the type of job being started
602+
* @param successCallback callback to execute on successful job indexing
603+
*/
604+
private void startJob(IndexRequest indexRequest, MLJobType jobType, Runnable successCallback) {
580605
mlIndicesHandler.initMLJobsIndex(ActionListener.wrap(success -> {
581606
if (success) {
582-
try {
583-
int intervalInMinutes = 5;
584-
Long lockDurationSeconds = 60L;
585-
586-
MLJobParameter jobParameter = new MLJobParameter(
587-
MLJobType.STATS_COLLECTOR.name(),
588-
new IntervalSchedule(Instant.now(), intervalInMinutes, ChronoUnit.MINUTES),
589-
lockDurationSeconds,
590-
null,
591-
MLJobType.STATS_COLLECTOR
592-
);
593-
594-
IndexRequest indexRequest = new IndexRequest()
595-
.index(CommonValue.ML_JOBS_INDEX)
596-
.id(MLJobType.STATS_COLLECTOR.name())
597-
.source(jobParameter.toXContent(JsonXContent.contentBuilder(), null))
598-
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
599-
600-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
601-
client
602-
.index(
603-
indexRequest,
604-
ActionListener
605-
.runBefore(
606-
ActionListener
607-
.wrap(
608-
r -> log.info("Indexed ml stats collection job successfully"),
609-
e -> log.error("Failed to index stats collection job", e)
610-
),
611-
context::restore
612-
)
613-
);
614-
}
615-
} catch (IOException e) {
616-
log.error("Failed to index stats collection job", e);
607+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
608+
client.index(indexRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
609+
log.info("Indexed {} successfully", jobType.name());
610+
if (successCallback != null) {
611+
successCallback.run();
612+
}
613+
}, e -> log.error("Failed to index {} job", jobType.name(), e)), context::restore));
617614
}
618615
}
619616
}, e -> log.error("Failed to initialize ML jobs index", e)));

0 commit comments

Comments
 (0)