Skip to content

Commit 1e52764

Browse files
committed
[Feature Enhancement] Set ordered replica index label to support multi-slice
Signed-off-by: Ryan O'Leary <[email protected]>
1 parent 696b08b commit 1e52764

File tree

6 files changed

+386
-149
lines changed

6 files changed

+386
-149
lines changed

ray-operator/controllers/ray/common/pod.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ func getEnableProbesInjection() bool {
251251
}
252252

253253
// DefaultWorkerPodTemplate sets the config values
254-
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, numHostIndex int) corev1.PodTemplateSpec {
254+
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, replicaIndex int, numHostIndex int) corev1.PodTemplateSpec {
255255
podTemplate := workerSpec.Template
256256
podTemplate.GenerateName = podName
257257
// Pods created by RayCluster should be restricted to the namespace of the RayCluster.
@@ -329,11 +329,15 @@ func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, wo
329329
mergedLabels := mergeLabels(workerSpec.Template.ObjectMeta.Labels, workerSpec.Labels)
330330
podTemplate.Labels = labelPod(rayv1.WorkerNode, instance.Name, workerSpec.GroupName, mergedLabels)
331331

332-
// Add additional labels for RayMultihostIndexing
333-
multihostIndexingEnabled := features.Enabled(features.RayMultiHostIndexing) && workerSpec.NumOfHosts > 1
334-
if multihostIndexingEnabled {
335-
podTemplate.Labels[utils.RayWorkerReplicaIndexKey] = replicaGrpName
336-
podTemplate.Labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex)
332+
// Add additional labels when RayMultihostIndexing is enabled.
333+
if features.Enabled(features.RayMultiHostIndexing) {
334+
// The ordered replica index can be used for the single-host, multi-slice case.
335+
podTemplate.Labels[utils.RayWorkerReplicaIndexKey] = strconv.Itoa(replicaIndex)
336+
if workerSpec.NumOfHosts > 1 {
337+
// These labels are specific to multi-host group setup and reconciliation.
338+
podTemplate.Labels[utils.RayWorkerReplicaIDKey] = replicaGrpName
339+
podTemplate.Labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex)
340+
}
337341
}
338342
workerSpec.RayStartParams = setMissingRayStartParams(ctx, workerSpec.RayStartParams, rayv1.WorkerNode, headPort, fqdnRayIP)
339343

ray-operator/controllers/ray/common/pod_test.go

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ func TestBuildPod(t *testing.T) {
687687
worker := cluster.Spec.WorkerGroupSpecs[0]
688688
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
689689
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
690-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
690+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
691691
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, defaultContainerEnvs)
692692

693693
// Check resources
@@ -761,7 +761,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) {
761761
worker := cluster.Spec.WorkerGroupSpecs[0]
762762
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
763763
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
764-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
764+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
765765
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, nil)
766766
expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080")
767767
actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0])
@@ -792,7 +792,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) {
792792
worker := cluster.Spec.WorkerGroupSpecs[0]
793793
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
794794
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
795-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
795+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
796796
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, nil)
797797
workerContainer := workerPod.Spec.Containers[utils.RayContainerIndex]
798798
assert.Equal(t, []string{"I am worker"}, workerContainer.Command)
@@ -847,7 +847,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) {
847847
worker := cluster.Spec.WorkerGroupSpecs[0]
848848
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
849849
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
850-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
850+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
851851
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP, nil)
852852

853853
val, ok = pod.Labels[utils.RayClusterServingServiceLabelKey]
@@ -903,7 +903,7 @@ func TestBuildPod_WithLoginBash(t *testing.T) {
903903
worker := cluster.Spec.WorkerGroupSpecs[0]
904904
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
905905
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
906-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
906+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
907907
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP, nil)
908908

909909
// Verify worker container command
@@ -1166,7 +1166,7 @@ func TestDefaultWorkerPodTemplateWithName(t *testing.T) {
11661166
expectedWorker := *worker.DeepCopy()
11671167

11681168
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1169-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
1169+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0, 0)
11701170
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
11711171
assert.Equal(t, expectedWorker, worker)
11721172
}
@@ -1187,9 +1187,10 @@ func TestDeafultWorkerPodTemplateWithReplicaGrpAndIndex(t *testing.T) {
11871187
groupReplicaName := utils.GenerateRayWorkerReplicaGroupName(worker.GroupName)
11881188

11891189
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1190-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", groupReplicaName, 2)
1190+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", groupReplicaName, 0, 2)
11911191
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
1192-
assert.Equal(t, podTemplateSpec.Labels[utils.RayWorkerReplicaIndexKey], groupReplicaName)
1192+
assert.Equal(t, podTemplateSpec.Labels[utils.RayWorkerReplicaIDKey], groupReplicaName)
1193+
assert.Equal(t, "0", podTemplateSpec.Labels[utils.RayWorkerReplicaIndexKey])
11931194
assert.Equal(t, "2", podTemplateSpec.Labels[utils.RayHostIndexKey])
11941195
}
11951196

@@ -1235,7 +1236,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
12351236
worker := cluster.Spec.WorkerGroupSpecs[0]
12361237
podName := cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
12371238
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
1238-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
1239+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
12391240
// DefaultWorkerPodTemplate will add the default metrics port if user doesn't specify it.
12401241
// Verify the default metrics port exists.
12411242
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, int32(utils.DefaultMetricsPort)))
@@ -1245,7 +1246,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
12451246
ContainerPort: customMetricsPort,
12461247
}
12471248
cluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Ports = []corev1.ContainerPort{metricsPort}
1248-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
1249+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
12491250
// Verify the custom metrics port exists.
12501251
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, customMetricsPort))
12511252
}
@@ -1284,7 +1285,7 @@ func TestDefaultWorkerPodTemplate_Autoscaling(t *testing.T) {
12841285

12851286
for name, tc := range tests {
12861287
t.Run(name, func(t *testing.T) {
1287-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0)
1288+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0, 0)
12881289
assert.Equal(t, tc.expectedRestartPolicy, podTemplateSpec.Spec.RestartPolicy)
12891290
})
12901291
}
@@ -1300,7 +1301,7 @@ func TestDefaultInitContainer(t *testing.T) {
13001301
expectedResult := len(cluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers) + 1
13011302

13021303
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1303-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
1304+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0, 0)
13041305
numInitContainers := len(podTemplateSpec.Spec.InitContainers)
13051306
assert.Equal(t, expectedResult, numInitContainers, "A default init container is expected to be added.")
13061307

@@ -1359,7 +1360,7 @@ func TestDefaultInitContainerImagePullPolicy(t *testing.T) {
13591360
// set ray container imagePullPolicy
13601361
worker.Template.Spec.Containers[utils.RayContainerIndex].ImagePullPolicy = tc.imagePullPolicy
13611362

1362-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
1363+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0, 0)
13631364

13641365
healthCheckContainer := podTemplateSpec.Spec.InitContainers[len(podTemplateSpec.Spec.InitContainers)-1]
13651366
assert.Equal(t, tc.expectedPullPolicy, healthCheckContainer.ImagePullPolicy, "The ImagePullPolicy of the init container should be the same as the Ray container.")

0 commit comments

Comments
 (0)