diff --git a/ray-operator/controllers/ray/common/pod.go b/ray-operator/controllers/ray/common/pod.go index 836272ecf37..21533cd8f49 100644 --- a/ray-operator/controllers/ray/common/pod.go +++ b/ray-operator/controllers/ray/common/pod.go @@ -251,7 +251,7 @@ func getEnableProbesInjection() bool { } // DefaultWorkerPodTemplate sets the config values -func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, numHostIndex int) corev1.PodTemplateSpec { +func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, replicaIndex int, numHostIndex int) corev1.PodTemplateSpec { podTemplate := workerSpec.Template podTemplate.GenerateName = podName // Pods created by RayCluster should be restricted to the namespace of the RayCluster. @@ -329,11 +329,15 @@ func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, wo mergedLabels := mergeLabels(workerSpec.Template.ObjectMeta.Labels, workerSpec.Labels) podTemplate.Labels = labelPod(rayv1.WorkerNode, instance.Name, workerSpec.GroupName, mergedLabels) - // Add additional labels for RayMultihostIndexing - multihostIndexingEnabled := features.Enabled(features.RayMultiHostIndexing) && workerSpec.NumOfHosts > 1 - if multihostIndexingEnabled { - podTemplate.Labels[utils.RayWorkerReplicaIndexKey] = replicaGrpName - podTemplate.Labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex) + // Add additional labels when RayMultihostIndexing is enabled. + if features.Enabled(features.RayMultiHostIndexing) { + // The ordered replica index can be used for the single-host, multi-slice case. + podTemplate.Labels[utils.RayWorkerReplicaIndexKey] = strconv.Itoa(replicaIndex) + if workerSpec.NumOfHosts > 1 { + // These labels are specific to multi-host group setup and reconciliation. + podTemplate.Labels[utils.RayWorkerReplicaNameKey] = replicaGrpName + podTemplate.Labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex) + } } workerSpec.RayStartParams = setMissingRayStartParams(ctx, workerSpec.RayStartParams, rayv1.WorkerNode, headPort, fqdnRayIP) diff --git a/ray-operator/controllers/ray/common/pod_test.go b/ray-operator/controllers/ray/common/pod_test.go index 27c0a3413c9..78ff88122ed 100644 --- a/ray-operator/controllers/ray/common/pod_test.go +++ b/ray-operator/controllers/ray/common/pod_test.go @@ -687,7 +687,7 @@ func TestBuildPod(t *testing.T) { worker := cluster.Spec.WorkerGroupSpecs[0] podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) - podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0) pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, defaultContainerEnvs) // Check resources @@ -761,7 +761,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) { worker := cluster.Spec.WorkerGroupSpecs[0] podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) - podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0) pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, nil) expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080") actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0]) @@ -792,7 +792,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) { worker := cluster.Spec.WorkerGroupSpecs[0] podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) - podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0) workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, nil) workerContainer := workerPod.Spec.Containers[utils.RayContainerIndex] assert.Equal(t, []string{"I am worker"}, workerContainer.Command) @@ -847,7 +847,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) { worker := cluster.Spec.WorkerGroupSpecs[0] podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) - podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0) pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP, nil) val, ok = pod.Labels[utils.RayClusterServingServiceLabelKey] @@ -903,7 +903,7 @@ func TestBuildPod_WithLoginBash(t *testing.T) { worker := cluster.Spec.WorkerGroupSpecs[0] podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) - podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0) workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP, nil) // Verify worker container command @@ -1166,7 +1166,7 @@ func TestDefaultWorkerPodTemplateWithName(t *testing.T) { expectedWorker := *worker.DeepCopy() // Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating. - podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0, 0) assert.Empty(t, podTemplateSpec.ObjectMeta.Name) assert.Equal(t, expectedWorker, worker) } @@ -1187,9 +1187,10 @@ func TestDeafultWorkerPodTemplateWithReplicaGrpAndIndex(t *testing.T) { groupReplicaName := utils.GenerateRayWorkerReplicaGroupName(worker.GroupName) // Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating. - podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", groupReplicaName, 2) + podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", groupReplicaName, 0, 2) assert.Empty(t, podTemplateSpec.ObjectMeta.Name) - assert.Equal(t, podTemplateSpec.Labels[utils.RayWorkerReplicaIndexKey], groupReplicaName) + assert.Equal(t, podTemplateSpec.Labels[utils.RayWorkerReplicaNameKey], groupReplicaName) + assert.Equal(t, "0", podTemplateSpec.Labels[utils.RayWorkerReplicaIndexKey]) assert.Equal(t, "2", podTemplateSpec.Labels[utils.RayHostIndexKey]) } @@ -1235,7 +1236,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) { worker := cluster.Spec.WorkerGroupSpecs[0] podName := cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) - podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0) // DefaultWorkerPodTemplate will add the default metrics port if user doesn't specify it. // Verify the default metrics port exists. require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, int32(utils.DefaultMetricsPort))) @@ -1245,7 +1246,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) { ContainerPort: customMetricsPort, } cluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Ports = []corev1.ContainerPort{metricsPort} - podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0) // Verify the custom metrics port exists. require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, customMetricsPort)) } @@ -1284,7 +1285,7 @@ func TestDefaultWorkerPodTemplate_Autoscaling(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0, 0) assert.Equal(t, tc.expectedRestartPolicy, podTemplateSpec.Spec.RestartPolicy) }) } @@ -1300,7 +1301,7 @@ func TestDefaultInitContainer(t *testing.T) { expectedResult := len(cluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers) + 1 // Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating. - podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0, 0) numInitContainers := len(podTemplateSpec.Spec.InitContainers) assert.Equal(t, expectedResult, numInitContainers, "A default init container is expected to be added.") @@ -1359,7 +1360,7 @@ func TestDefaultInitContainerImagePullPolicy(t *testing.T) { // set ray container imagePullPolicy worker.Template.Spec.Containers[utils.RayContainerIndex].ImagePullPolicy = tc.imagePullPolicy - podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0) + podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0, 0) healthCheckContainer := podTemplateSpec.Spec.InitContainers[len(podTemplateSpec.Spec.InitContainers)-1] assert.Equal(t, tc.expectedPullPolicy, healthCheckContainer.ImagePullPolicy, "The ImagePullPolicy of the init container should be the same as the Ray container.") diff --git a/ray-operator/controllers/ray/raycluster_controller.go b/ray-operator/controllers/ray/raycluster_controller.go index 32d730db368..f06f114a609 100644 --- a/ray-operator/controllers/ray/raycluster_controller.go +++ b/ray-operator/controllers/ray/raycluster_controller.go @@ -739,14 +739,42 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv logger.Info("reconcilePods", "workerReplicas", numExpectedWorkerPods, "NumOfHosts", worker.NumOfHosts, "runningPods", len(runningPods.Items), "diff", diff) + // Support replica indices for single-host, multi-slice environments. + validReplicaIndices := make(map[int]bool) + if features.Enabled(features.RayMultiHostIndexing) { + for _, pod := range runningPods.Items { + if indexStr, ok := pod.Labels[utils.RayWorkerReplicaIndexKey]; ok { + if index, err := strconv.Atoi(indexStr); err == nil { + validReplicaIndices[index] = true + } + } + } + logger.Info("reconcilePods", "found existing replica indices", "group", worker.GroupName, "indices", validReplicaIndices) + } if diff > 0 { // pods need to be added logger.Info("reconcilePods", "Number workers to add", diff, "Worker group", worker.GroupName) - // create all workers of this group - for i := 0; i < diff; i++ { - logger.Info("reconcilePods", "creating worker for group", worker.GroupName, "index", i, "total", diff) - if err := r.createWorkerPod(ctx, *instance, *worker.DeepCopy()); err != nil { - return errstd.Join(utils.ErrFailedCreateWorkerPod, err) + if features.Enabled(features.RayMultiHostIndexing) { + newReplicaIndex := 0 + // create all workers of this group + for i := 0; i < diff; i++ { + // Find the next available replica index. + for validReplicaIndices[newReplicaIndex] { + newReplicaIndex++ + } + validReplicaIndices[newReplicaIndex] = true + logger.Info("reconcilePods", "creating worker for group", worker.GroupName, "index", i, "total", diff, "replicaIndex", newReplicaIndex) + if err := r.createWorkerPodWithIndex(ctx, *instance, *worker.DeepCopy(), "", newReplicaIndex, 0); err != nil { + return errstd.Join(utils.ErrFailedCreateWorkerPod, err) + } + } + } else { + // create all workers of this group + for i := 0; i < diff; i++ { + logger.Info("reconcilePods", "creating worker for group", worker.GroupName, "index", i, "total", diff) + if err := r.createWorkerPod(ctx, *instance, *worker.DeepCopy()); err != nil { + return errstd.Join(utils.ErrFailedCreateWorkerPod, err) + } } } } else if diff == 0 { @@ -827,7 +855,7 @@ func (r *RayClusterReconciler) reconcileMultiHostWorkerGroup(ctx context.Context // 1. Group existing pods by ray.io/worker-group-replica-index. replicaMap := make(map[string][]corev1.Pod) for _, pod := range workerPods { - if replicaName, ok := pod.Labels[utils.RayWorkerReplicaIndexKey]; ok { + if replicaName, ok := pod.Labels[utils.RayWorkerReplicaNameKey]; ok { replicaMap[replicaName] = append(replicaMap[replicaName], pod) } } @@ -851,7 +879,7 @@ func (r *RayClusterReconciler) reconcileMultiHostWorkerGroup(ctx context.Context continue } if shouldDelete, reason := shouldDeletePod(pod, rayv1.WorkerNode); shouldDelete { - replicaName := pod.Labels[utils.RayWorkerReplicaIndexKey] + replicaName := pod.Labels[utils.RayWorkerReplicaNameKey] podsToDelete, ok := replicaMap[replicaName] if !ok { continue @@ -873,7 +901,7 @@ func (r *RayClusterReconciler) reconcileMultiHostWorkerGroup(ctx context.Context for _, podName := range worker.ScaleStrategy.WorkersToDelete { for _, pod := range workerPods { if pod.Name == podName { - replicaName := pod.Labels[utils.RayWorkerReplicaIndexKey] + replicaName := pod.Labels[utils.RayWorkerReplicaNameKey] for _, p := range replicaMap[replicaName] { podsToDeleteFromStrategy[p.Name] = p } @@ -899,28 +927,61 @@ func (r *RayClusterReconciler) reconcileMultiHostWorkerGroup(ctx context.Context } // 5. Calculate Pod diff for scaling up or down by NumOfHosts. - runningPodsCount := len(workerPods) - len(deletedPods) + + validReplicaGroups := make(map[string]struct{}) + for replicaName, podList := range replicaMap { + isHealthyAndComplete := true + for _, pod := range podList { + if _, isDeleted := deletedPods[pod.Name]; isDeleted { + isHealthyAndComplete = false + break + } + } + if isHealthyAndComplete { + validReplicaGroups[replicaName] = struct{}{} + } + } + numRunningReplicas := len(validReplicaGroups) numExpectedWorkerPods := int(utils.GetWorkerGroupDesiredReplicas(ctx, *worker)) - diff := numExpectedWorkerPods - runningPodsCount - logger.Info("Reconciling multi-host group", "group", worker.GroupName, "expectedPods", numExpectedWorkerPods, "runningPods", runningPodsCount, "diff", diff) - // Scale up NumOfHost workers per replica. - if diff > 0 { - logger.Info("reconcileMultiHostWorkerGroup", "Number workers to add", diff, "Worker group", worker.GroupName) - if diff%int(worker.NumOfHosts) != 0 { - return fmt.Errorf("cannot scale up multi-host group %s: required %d pods, which is not a multiple of NumOfHosts (%d)", worker.GroupName, diff, worker.NumOfHosts) + // Ensure that if numExpectedWorkerPods is not a multiple of NumOfHosts, we log an error. + if numExpectedWorkerPods%int(worker.NumOfHosts) != 0 { + return fmt.Errorf("desired worker pods (%d) is not a multiple of NumOfHosts (%d) for group %s", + numExpectedWorkerPods, worker.NumOfHosts, worker.GroupName) + } + numExpectedReplicas := numExpectedWorkerPods / int(worker.NumOfHosts) + replicasToCreate := numExpectedReplicas - numRunningReplicas + + // Track full replica groups to determine next replica index to assign to. + validReplicaIndices := make(map[int]bool) + for replicaName := range validReplicaGroups { + if len(replicaMap[replicaName]) > 0 { + pod := replicaMap[replicaName][0] + if indexStr, ok := pod.Labels[utils.RayWorkerReplicaIndexKey]; ok { + if index, err := strconv.Atoi(indexStr); err == nil { + validReplicaIndices[index] = true + } + } } - replicasToCreate := diff / int(worker.NumOfHosts) + } + logger.Info("Reconciling multi-host group", "group", worker.GroupName, "expectedReplicas", numExpectedReplicas, "runningReplicas", numRunningReplicas, "replicasToCreate", replicasToCreate, "inUseIndices", validReplicaIndices) + if replicasToCreate > 0 { logger.Info("Scaling up multi-host group", "group", worker.GroupName, "replicasToCreate", replicasToCreate) + newReplicaIndex := 0 // Find the next available index starting from 0 for i := 0; i < replicasToCreate; i++ { + for validReplicaIndices[newReplicaIndex] { + newReplicaIndex++ + } + validReplicaIndices[newReplicaIndex] = true replicaName := utils.GenerateRayWorkerReplicaGroupName(worker.GroupName) + logger.Info("Creating new replica group", "group", worker.GroupName, "replicaName", replicaName, "replicaIndex", newReplicaIndex) for j := 0; j < int(worker.NumOfHosts); j++ { - if err := r.createWorkerPodWithIndex(ctx, *instance, *worker.DeepCopy(), replicaName, j); err != nil { + if err := r.createWorkerPodWithIndex(ctx, *instance, *worker.DeepCopy(), replicaName, newReplicaIndex, j); err != nil { return errstd.Join(utils.ErrFailedCreateWorkerPod, err) } } } - } else if diff < 0 { + } else if replicasToCreate < 0 { // Scale down NumOfHost workers per replica. enableInTreeAutoscaling := utils.IsAutoscalingEnabled(&instance.Spec) enableRandomPodDelete := false @@ -930,19 +991,19 @@ func (r *RayClusterReconciler) reconcileMultiHostWorkerGroup(ctx context.Context } } if !enableInTreeAutoscaling || enableRandomPodDelete { - workersToRemove := -diff - groupsToRemove := (workersToRemove + int(worker.NumOfHosts) - 1) / int(worker.NumOfHosts) - logger.Info("Scaling down multi-host group by randomly deleting replica groups", "group", worker.GroupName, "groupsToRemove", groupsToRemove) - - groupsDeleted := 0 - for _, podList := range replicaMap { - if groupsDeleted >= groupsToRemove { + replicasToRemove := -replicasToCreate + logger.Info("Scaling down multi-host group by randomly deleting replica groups", "group", worker.GroupName, "groupsToRemove", replicasToRemove) + replicasDeleted := 0 + // Iterate over validReplicaGroups which contains the IDs of replica groups with NumOfHosts running Pods. + for replicaID := range validReplicaGroups { + if replicasDeleted >= replicasToRemove { break } + podList := replicaMap[replicaID] if err := r.deletePods(ctx, instance, podList, worker.GroupName, "scaling down"); err != nil { return err } - groupsDeleted++ + replicasDeleted++ } } else { logger.Info("Random replica group deletion is disabled for the cluster. The only decision-maker for Pod deletions is the Ray Autoscaler.") @@ -1115,7 +1176,7 @@ func (r *RayClusterReconciler) createWorkerPod(ctx context.Context, instance ray logger := ctrl.LoggerFrom(ctx) // build the pod then create it - pod := r.buildWorkerPod(ctx, instance, worker, "", 0) + pod := r.buildWorkerPod(ctx, instance, worker, "", 0, 0) if r.options.BatchSchedulerManager != nil { if scheduler, err := r.options.BatchSchedulerManager.GetScheduler(); err == nil { scheduler.AddMetadataToChildResource(ctx, &instance, &pod, worker.GroupName) @@ -1135,11 +1196,11 @@ func (r *RayClusterReconciler) createWorkerPod(ctx context.Context, instance ray return nil } -func (r *RayClusterReconciler) createWorkerPodWithIndex(ctx context.Context, instance rayv1.RayCluster, worker rayv1.WorkerGroupSpec, replicaGrpName string, hostIndex int) error { +func (r *RayClusterReconciler) createWorkerPodWithIndex(ctx context.Context, instance rayv1.RayCluster, worker rayv1.WorkerGroupSpec, replicaGrpName string, replicaIndex int, hostIndex int) error { logger := ctrl.LoggerFrom(ctx) // build the pod then create it - pod := r.buildWorkerPod(ctx, instance, worker, replicaGrpName, hostIndex) + pod := r.buildWorkerPod(ctx, instance, worker, replicaGrpName, replicaIndex, hostIndex) if r.options.BatchSchedulerManager != nil { if scheduler, err := r.options.BatchSchedulerManager.GetScheduler(); err == nil { scheduler.AddMetadataToChildResource(ctx, &instance, &pod, worker.GroupName) @@ -1188,7 +1249,7 @@ func getCreatorCRDType(instance rayv1.RayCluster) utils.CRDType { } // Build worker instance pods. -func (r *RayClusterReconciler) buildWorkerPod(ctx context.Context, instance rayv1.RayCluster, worker rayv1.WorkerGroupSpec, replicaGrpName string, hostIndex int) corev1.Pod { +func (r *RayClusterReconciler) buildWorkerPod(ctx context.Context, instance rayv1.RayCluster, worker rayv1.WorkerGroupSpec, replicaGrpName string, replicaIndex int, hostIndex int) corev1.Pod { logger := ctrl.LoggerFrom(ctx) podName := utils.PodName(fmt.Sprintf("%s-%s", instance.Name, worker.GroupName), rayv1.WorkerNode, true) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, instance, instance.Namespace) // Fully Qualified Domain Name @@ -1196,7 +1257,7 @@ func (r *RayClusterReconciler) buildWorkerPod(ctx context.Context, instance rayv // The Ray head port used by workers to connect to the cluster (GCS server port for Ray >= 1.11.0, Redis port for older Ray.) headPort := common.GetHeadPort(instance.Spec.HeadGroupSpec.RayStartParams) autoscalingEnabled := utils.IsAutoscalingEnabled(&instance.Spec) - podTemplateSpec := common.DefaultWorkerPodTemplate(ctx, instance, worker, podName, fqdnRayIP, headPort, replicaGrpName, hostIndex) + podTemplateSpec := common.DefaultWorkerPodTemplate(ctx, instance, worker, podName, fqdnRayIP, headPort, replicaGrpName, replicaIndex, hostIndex) if len(r.options.WorkerSidecarContainers) > 0 { podTemplateSpec.Spec.Containers = append(podTemplateSpec.Spec.Containers, r.options.WorkerSidecarContainers...) } diff --git a/ray-operator/controllers/ray/raycluster_controller_test.go b/ray-operator/controllers/ray/raycluster_controller_test.go index 56290797589..a11efb2de56 100644 --- a/ray-operator/controllers/ray/raycluster_controller_test.go +++ b/ray-operator/controllers/ray/raycluster_controller_test.go @@ -19,7 +19,6 @@ import ( "context" "errors" "fmt" - "slices" "strconv" "strings" "time" @@ -981,21 +980,56 @@ var _ = Context("Inside the default namespace", func() { }) It("All multi-host pods are properly labeled", func() { - workerGrpReplicaMap := make(map[string][]string) + type ReplicaInfo struct { + HostIndices map[string]bool + ReplicaIndex string + } + replicaGroups := make(map[string]ReplicaInfo) + // Track replicas to ensure unique indices are applied. + seenReplicaIndices := make(map[string]bool) + for _, pod := range workerPods.Items { + // Get all the labels hostIndex := pod.Labels[utils.RayHostIndexKey] - hostGrpId := pod.Labels[utils.RayWorkerReplicaIndexKey] - - grpReplicaIndexList, grpIdExists := workerGrpReplicaMap[hostGrpId] - if grpIdExists { - Expect(strconv.Atoi(hostIndex)).Should(BeNumerically("<", numOfHosts)) - Expect(strconv.Atoi(hostIndex)).Should(BeNumerically(">=", 0)) - Expect(slices.Contains(grpReplicaIndexList, hostIndex)).To(BeFalse()) - workerGrpReplicaMap[hostGrpId] = append(grpReplicaIndexList, hostIndex) + replicaID := pod.Labels[utils.RayWorkerReplicaNameKey] + replicaIndex := pod.Labels[utils.RayWorkerReplicaIndexKey] + + Expect(replicaIndex).NotTo(BeEmpty(), "Pod %s is missing label %s", pod.Name, utils.RayWorkerReplicaIndexKey) + seenReplicaIndices[replicaIndex] = true + + if info, ok := replicaGroups[replicaID]; ok { + // Validate replicaIndex is the same for all pods in this group. + Expect(replicaIndex).To(Equal(info.ReplicaIndex), "Pod %s in group %s has replicaIndex %s, but expected %s", pod.Name, replicaID, replicaIndex, info.ReplicaIndex) + + // Ensure hostIndex is unique within this replica group. + Expect(info.HostIndices[hostIndex]).To(BeFalse(), "Pod %s in group %s has duplicate hostIndex %s", pod.Name, replicaID, hostIndex) + info.HostIndices[hostIndex] = true } else { - workerGrpReplicaMap[hostGrpId] = []string{} - Expect(len(workerGrpReplicaMap)).Should(BeNumerically("<=", int(replicas))) + replicaGroups[replicaID] = ReplicaInfo{ + ReplicaIndex: replicaIndex, + HostIndices: map[string]bool{hostIndex: true}, + } } + + // Check hostIndex correctly set in range 0 to numOfHosts-1. + hostIndexInt, err := strconv.Atoi(hostIndex) + Expect(err).NotTo(HaveOccurred()) + Expect(hostIndexInt).To(BeNumerically("<", numOfHosts)) + Expect(hostIndexInt).To(BeNumerically(">=", 0)) + } + + // Validate we created 'replicas' number of groups. + Expect(replicaGroups).To(HaveLen(int(replicas)), "Expected %d replica groups, but found %d", replicas, len(replicaGroups)) + + // Validate replica indices are unique and indexed from 0 to replicas-1. + Expect(seenReplicaIndices).To(HaveLen(int(replicas)), "Expected %d unique replica indices, but found %d", replicas, len(seenReplicaIndices)) + Expect(seenReplicaIndices["0"]).To(BeTrue()) + Expect(seenReplicaIndices["1"]).To(BeTrue()) + Expect(seenReplicaIndices["2"]).To(BeTrue()) + + // Validate each replica group has 'numOfHosts' Pods. + for replicaID, info := range replicaGroups { + Expect(info.HostIndices).To(HaveLen(int(numOfHosts)), "Replica group %s expected %d hosts, but found %d", replicaID, numOfHosts, len(info.HostIndices)) } }) diff --git a/ray-operator/controllers/ray/utils/constant.go b/ray-operator/controllers/ray/utils/constant.go index f721464e015..e707cb71023 100644 --- a/ray-operator/controllers/ray/utils/constant.go +++ b/ray-operator/controllers/ray/utils/constant.go @@ -28,12 +28,18 @@ const ( KubeRayVersion = "ray.io/kuberay-version" // Labels for feature RayMultihostIndexing - // RayWorkerReplicaIndexKey label is the unique label for the replica in a specific worker group. It is made up of the worker group name - // and a unique identifier. e.g. multi-host-worker-group-xh3hf // - // RayHostIndexKey label represents the index of the host within the RayWorkerReplicaIndexKey. + // RayWorkerReplicaNameKey label is the unique name for the replica in a specific worker group. It is made up + // of the worker group name and a unique identifier (e.g. multi-host-worker-group-xh3hf). This label is unique + // across RayClusters. + RayWorkerReplicaNameKey = "ray.io/worker-group-replica-name" + + // RayWorkerReplicaIndexKey label is the integer index for the replica in it's worker group (0 to replicas-1). + // The value for this label is unique within its worker group, but not across worker groups or RayClusters. RayWorkerReplicaIndexKey = "ray.io/worker-group-replica-index" - RayHostIndexKey = "ray.io/replica-host-index" + + // RayHostIndexKey label represents the index of the host within the replica group. + RayHostIndexKey = "ray.io/replica-host-index" // In KubeRay, the Ray container must be the first application container in a head or worker Pod. RayContainerIndex = 0 diff --git a/ray-operator/test/e2e/raycluster_multi_host_test.go b/ray-operator/test/e2e/raycluster_multi_host_test.go index 3a234a7dba2..ed97c6fe365 100644 --- a/ray-operator/test/e2e/raycluster_multi_host_test.go +++ b/ray-operator/test/e2e/raycluster_multi_host_test.go @@ -1,13 +1,12 @@ package e2e import ( + "strconv" "testing" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - corev1ac "k8s.io/client-go/applyconfigurations/core/v1" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" @@ -16,7 +15,198 @@ import ( . "github.com/ray-project/kuberay/ray-operator/test/support" ) -func TestRayClusterMultiHost(t *testing.T) { +// verifyWorkerGroupIndices is a helper function to check that pods in a worker group +// have the correct and unique replica/host index labels. +func verifyWorkerGroupIndices(t *testing.T, rayCluster *rayv1.RayCluster, workerGroupName string, expectedHosts int, expectedReplicas int, expectedIndices []int) { + test := With(t) + g := NewWithT(t) + + allWorkerPods, err := GetWorkerPods(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + groupPods := []corev1.Pod{} + for _, pod := range allWorkerPods { + if pod.Labels[utils.RayNodeGroupLabelKey] == workerGroupName { + groupPods = append(groupPods, pod) + } + } + + // Validate total number of pods for this group. + expectedPodCount := expectedReplicas * expectedHosts + g.Expect(groupPods).To(HaveLen(expectedPodCount), + "Expected %d pods for group %s (%d replicas with %d hosts each), but found %d", + expectedPodCount, workerGroupName, expectedReplicas, expectedHosts, len(groupPods)) + + // Track the indices seen when parsing the worker Pods. + seenReplicaIndices := make(map[int]bool) + expectedIndicesMap := make(map[int]bool) + for _, idx := range expectedIndices { + expectedIndicesMap[idx] = true + } + + if expectedHosts > 1 { + // For multi-host, all three labels should be set. + type ReplicaInfo struct { + HostIndices map[int]bool + ReplicaIndex int + } + replicaGroups := make(map[string]ReplicaInfo) + + for _, pod := range groupPods { + replicaID, ok := pod.Labels[utils.RayWorkerReplicaNameKey] + g.Expect(ok).To(BeTrue(), "Pod %s should have a replica ID label (%s)", pod.Name, utils.RayWorkerReplicaNameKey) + + replicaIndexStr, ok := pod.Labels[utils.RayWorkerReplicaIndexKey] + g.Expect(ok).To(BeTrue(), "Pod %s should have a replica index label (%s)", pod.Name, utils.RayWorkerReplicaIndexKey) + replicaIndex, err := strconv.Atoi(replicaIndexStr) + g.Expect(err).NotTo(HaveOccurred()) + seenReplicaIndices[replicaIndex] = true + + hostIndexStr, ok := pod.Labels[utils.RayHostIndexKey] + g.Expect(ok).To(BeTrue(), "Pod %s should have a host index label (%s)", pod.Name, utils.RayHostIndexKey) + hostIndex, err := strconv.Atoi(hostIndexStr) + g.Expect(err).NotTo(HaveOccurred()) + + // Check for duplicate host index values per replica group. + if info, exists := replicaGroups[replicaID]; exists { + g.Expect(replicaIndex).To(Equal(info.ReplicaIndex), + "Pod %s in group %s has inconsistent replicaIndex. Expected %d, got %d", pod.Name, replicaID, info.ReplicaIndex, replicaIndex) + + g.Expect(info.HostIndices[hostIndex]).To(BeFalse(), + "Pod %s in group %s has duplicate hostIndex %d", pod.Name, replicaID, hostIndex) + info.HostIndices[hostIndex] = true + } else { + replicaGroups[replicaID] = ReplicaInfo{ + ReplicaIndex: replicaIndex, + HostIndices: map[int]bool{hostIndex: true}, + } + } + } + + g.Expect(replicaGroups).To(HaveLen(expectedReplicas), "Should have %d replica groups, but found %d", expectedReplicas, len(replicaGroups)) + for replicaID, info := range replicaGroups { + g.Expect(info.HostIndices).To(HaveLen(expectedHosts), "Replica group %s should have %d hosts, but found %d", replicaID, expectedHosts, len(info.HostIndices)) + } + + } else { + // Single-host case, only replica index is set. + for _, pod := range groupPods { + g.Expect(pod.Labels).NotTo(HaveKey(utils.RayWorkerReplicaNameKey), "Pod %s should not have replica ID label for single-host group", pod.Name) + g.Expect(pod.Labels).NotTo(HaveKey(utils.RayHostIndexKey), "Pod %s should not have host index label for single-host group", pod.Name) + + // Check for unique replica index label + indexStr, ok := pod.Labels[utils.RayWorkerReplicaIndexKey] + g.Expect(ok).To(BeTrue(), "Pod %s should have a replica index label (%s)", pod.Name, utils.RayWorkerReplicaIndexKey) + + index, err := strconv.Atoi(indexStr) + g.Expect(err).NotTo(HaveOccurred(), "Failed to parse replica index '%s' for pod %s", indexStr, pod.Name) + + g.Expect(seenReplicaIndices[index]).To(BeFalse(), "Found duplicate replica index %d for pod %s", index, pod.Name) + seenReplicaIndices[index] = true + } + } + + if expectedIndices != nil { + expectedIndicesMap := make(map[int]bool) + for _, idx := range expectedIndices { + expectedIndicesMap[idx] = true + } + g.Expect(seenReplicaIndices).To(Equal(expectedIndicesMap), + "Expected replica indices %v for group %s, but found %v", expectedIndicesMap, workerGroupName, seenReplicaIndices) + } +} + +func TestRayClusterSingleHostMultiSlice(t *testing.T) { + test := With(t) + g := NewWithT(t) + + features.SetFeatureGateDuringTest(t, features.RayMultiHostIndexing, true) + + namespace := test.NewTestNamespace() + const ( + initialReplicas = 3 + clusterName = "raycluster-singlehost" + workerGroupName = "single-host-group" + ) + + // Define the RayCluster spec with a single-host worker group (NumOfHosts = 1). + rayClusterAC := rayv1ac.RayCluster(clusterName, namespace.Name). + WithSpec(rayv1ac.RayClusterSpec(). + WithRayVersion(GetRayVersion()). + WithHeadGroupSpec(rayv1ac.HeadGroupSpec(). + WithRayStartParams(map[string]string{"num-cpus": "0"}). + WithTemplate(HeadPodTemplateApplyConfiguration())). + WithWorkerGroupSpecs(rayv1ac.WorkerGroupSpec(). + WithReplicas(initialReplicas). + WithMinReplicas(0). + WithMaxReplicas(5). + WithNumOfHosts(1). + WithGroupName("single-host-group"). + WithRayStartParams(map[string]string{"num-cpus": "1"}). + WithTemplate(WorkerPodTemplateApplyConfiguration()))) + + // Create the RayCluster. + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(t, "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + + // Wait for the cluster to become Ready and verify the initial Pod count. + LogWithTimestamp(t, "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, rayCluster.Namespace, rayCluster.Name), TestTimeoutLong). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + g.Eventually(func() ([]corev1.Pod, error) { + return GetWorkerPods(test, rayCluster) + }, TestTimeoutShort).Should(HaveLen(initialReplicas)) + + // Verify that all pods are correctly labeled with indices 0 to replicas-1. + LogWithTimestamp(t, "Verifying initial labels on single-host pods for %s/%s", rayCluster.Namespace, rayCluster.Name) + verifyWorkerGroupIndices(t, rayCluster, workerGroupName, 1, initialReplicas, []int{0, 1, 2}) + + // Manually delete the pod with replica index 1. + LogWithTimestamp(t, "Deleting pod with replica index 1 for %s/%s", rayCluster.Namespace, rayCluster.Name) + workerPods, err := GetWorkerPods(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + + var podToDelete *corev1.Pod + for _, pod := range workerPods { + if pod.Labels[utils.RayWorkerReplicaIndexKey] == "1" { + podToDelete = &pod + break + } + } + g.Expect(podToDelete).NotTo(BeNil(), "Could not find pod with replica index 1 to delete") + + err = test.Client().Core().CoreV1().Pods(namespace.Name).Delete(test.Ctx(), podToDelete.Name, metav1.DeleteOptions{}) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(t, "Deleted pod %s", podToDelete.Name) + + // Wait for the controller to reconcile. The pod count should return to 3. + LogWithTimestamp(t, "Waiting for controller to reconcile and fill the gap") + g.Eventually(func() ([]corev1.Pod, error) { + return GetWorkerPods(test, rayCluster) + }, TestTimeoutShort).Should(HaveLen(initialReplicas), "Controller should restore pod count to %d", initialReplicas) + + // Verify that the controller replaced the missing index by creating a new pod with index 1. + LogWithTimestamp(t, "Verifying labels after pod deletion and reconciliation") + verifyWorkerGroupIndices(t, rayCluster, workerGroupName, 1, initialReplicas, []int{0, 1, 2}) + + // Scale up replicas from 3 to 4. + const scaleUpReplicas = 4 + LogWithTimestamp(t, "Scaling up RayCluster %s/%s from %d to %d replicas", rayCluster.Namespace, rayCluster.Name, initialReplicas, scaleUpReplicas) + rayClusterAC.Spec.WorkerGroupSpecs[0].WithReplicas(scaleUpReplicas) + _, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + + g.Eventually(func() ([]corev1.Pod, error) { + return GetWorkerPods(test, rayCluster) + }, TestTimeoutShort).Should(HaveLen(scaleUpReplicas), "Should scale up to %d pods", scaleUpReplicas) + + // Verify the new pod got the next available index, 3. + LogWithTimestamp(t, "Verifying labels after scale-up") + verifyWorkerGroupIndices(t, rayCluster, workerGroupName, 1, scaleUpReplicas, []int{0, 1, 2, 3}) +} + +func TestRayClusterMultiHostMultiSlice(t *testing.T) { test := With(t) g := NewWithT(t) @@ -30,12 +220,6 @@ func TestRayClusterMultiHost(t *testing.T) { initialReplicas = 2 clusterName = "raycluster-multihost" ) - sharedMemVolumeAC := corev1ac.Volume(). - WithName("shared-mem"). - WithEmptyDir(corev1ac.EmptyDirVolumeSource(). - WithMedium(corev1.StorageMediumMemory). - WithSizeLimit(resource.MustParse("1Gi")), - ) // Define the RayCluster spec with a multi-host worker group. rayClusterAC := rayv1ac.RayCluster(clusterName, namespace.Name). @@ -43,68 +227,23 @@ func TestRayClusterMultiHost(t *testing.T) { WithRayVersion(GetRayVersion()). WithHeadGroupSpec(rayv1ac.HeadGroupSpec(). WithRayStartParams(map[string]string{"dashboard-host": "0.0.0.0"}). - WithTemplate(HeadPodTemplateApplyConfiguration(). - // All PodSpec configurations go inside WithSpec. - WithSpec(corev1ac.PodSpec(). - WithVolumes(sharedMemVolumeAC). - WithRestartPolicy(corev1.RestartPolicyNever). - WithContainers(corev1ac.Container(). - WithName("ray-head"). - WithImage(GetRayImage()). - WithPorts( - corev1ac.ContainerPort().WithName(utils.GcsServerPortName).WithContainerPort(utils.DefaultGcsServerPort), - corev1ac.ContainerPort().WithName(utils.ServingPortName).WithContainerPort(utils.DefaultServingPort), - corev1ac.ContainerPort().WithName(utils.DashboardPortName).WithContainerPort(utils.DefaultDashboardPort), - corev1ac.ContainerPort().WithName(utils.ClientPortName).WithContainerPort(utils.DefaultClientPort), - ). - WithResources(corev1ac.ResourceRequirements(). - WithRequests(corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2"), - corev1.ResourceMemory: resource.MustParse("3Gi"), - }). - WithLimits(corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2"), - corev1.ResourceMemory: resource.MustParse("3Gi"), - })), - ), - ), - ), - ). + WithTemplate(HeadPodTemplateApplyConfiguration())). WithWorkerGroupSpecs(rayv1ac.WorkerGroupSpec(). - WithGroupName("multi-host-group"). WithReplicas(initialReplicas). WithMinReplicas(0). WithMaxReplicas(5). WithNumOfHosts(numOfHosts). - WithTemplate(WorkerPodTemplateApplyConfiguration(). - WithSpec(corev1ac.PodSpec(). - WithVolumes(sharedMemVolumeAC). - WithRestartPolicy(corev1.RestartPolicyNever). - WithContainers(corev1ac.Container(). - WithName("ray-worker"). - WithImage(GetRayImage()). - WithResources(corev1ac.ResourceRequirements(). - WithRequests(corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("300m"), - corev1.ResourceMemory: resource.MustParse("1G"), - }). - WithLimits(corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("500m"), - corev1.ResourceMemory: resource.MustParse("1G"), - })), - ), - ), - ), - ), - ) + WithGroupName("multi-host-group"). + WithRayStartParams(map[string]string{"num-cpus": "1"}). + WithTemplate(WorkerPodTemplateApplyConfiguration()))) // Create the RayCluster. rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) g.Expect(err).NotTo(HaveOccurred()) - LogWithTimestamp(test.T(), "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + LogWithTimestamp(t, "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) // Wait for the cluster to become Ready and verify the initial Pod count. - LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + LogWithTimestamp(t, "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) g.Eventually(RayCluster(test, rayCluster.Namespace, rayCluster.Name), TestTimeoutLong). Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) @@ -113,58 +252,59 @@ func TestRayClusterMultiHost(t *testing.T) { return GetWorkerPods(test, rayCluster) }, TestTimeoutShort).Should(HaveLen(expectedPodCount)) - // Verify that all pods are correctly labeled. - LogWithTimestamp(test.T(), "Verifying labels on multi-host pods for %s/%s", rayCluster.Namespace, rayCluster.Name) - workerPods, err := GetWorkerPods(test, rayCluster) - g.Expect(err).NotTo(HaveOccurred()) - replicaMap := make(map[string][]string) - for _, pod := range workerPods { - replicaName, ok := pod.Labels[utils.RayWorkerReplicaIndexKey] - g.Expect(ok).To(BeTrue(), "Pod %s should have a replica index label", pod.Name) - hostIndex, ok := pod.Labels[utils.RayHostIndexKey] - g.Expect(ok).To(BeTrue(), "Pod %s should have a host index label", pod.Name) - replicaMap[replicaName] = append(replicaMap[replicaName], hostIndex) - } - g.Expect(replicaMap).To(HaveLen(initialReplicas), "Should have the correct number of replica groups") - for replicaName, hostIndices := range replicaMap { - g.Expect(hostIndices).To(HaveLen(numOfHosts), "Replica group %s should be complete", replicaName) - } + // Verify that all pods are correctly labeled during replica group scale up. + LogWithTimestamp(t, "Verifying labels on multi-host pods for %s/%s", rayCluster.Namespace, rayCluster.Name) + verifyWorkerGroupIndices(t, rayCluster, "multi-host-group", numOfHosts, initialReplicas, []int{0, 1}) // Scale down replicas from 2 to 1. Verify we scale by a multiple of NumOfHosts. - LogWithTimestamp(test.T(), "Scaling down RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) + const scaleDownReplicas = 1 + LogWithTimestamp(t, "Scaling down RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) rayClusterAC.Spec.WorkerGroupSpecs[0].WithReplicas(1) _, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) g.Expect(err).NotTo(HaveOccurred()) - expectedPodCount = 1 * numOfHosts + expectedPodCount = scaleDownReplicas * numOfHosts g.Eventually(func() ([]corev1.Pod, error) { return GetWorkerPods(test, rayCluster) }, TestTimeoutShort).Should(HaveLen(expectedPodCount), "Should scale down to 1 multi-host group (4 pods)") + // Verify labels again after replica group scale down. + LogWithTimestamp(t, "Verifying labels after scale-down for %s/%s", rayCluster.Namespace, rayCluster.Name) + verifyWorkerGroupIndices(t, rayCluster, "multi-host-group", numOfHosts, scaleDownReplicas, nil) + // Test scale up: Increase replicas from 1 to 3. - LogWithTimestamp(test.T(), "Scaling up RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) - rayClusterAC.Spec.WorkerGroupSpecs[0].WithReplicas(3) + const scaleUpReplicas = 3 + LogWithTimestamp(t, "Scaling up RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) + rayClusterAC.Spec.WorkerGroupSpecs[0].WithReplicas(scaleUpReplicas) _, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) g.Expect(err).NotTo(HaveOccurred()) - expectedPodCount = 3 * numOfHosts + expectedPodCount = scaleUpReplicas * numOfHosts g.Eventually(func() ([]corev1.Pod, error) { return GetWorkerPods(test, rayCluster) }, TestTimeoutShort).Should(HaveLen(expectedPodCount), "Should scale up to 3 multi-host groups (12 pods)") + // Verify labels are set with expected values after scale up again. + LogWithTimestamp(t, "Verifying labels after scale-up for %s/%s", rayCluster.Namespace, rayCluster.Name) + verifyWorkerGroupIndices(t, rayCluster, "multi-host-group", numOfHosts, scaleUpReplicas, []int{0, 1, 2}) + // Manually delete a single pod and verify the controller atomically re-creates the slice. - LogWithTimestamp(test.T(), "Testing atomic multi-host group recreation for RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) - workerPods, err = GetWorkerPods(test, rayCluster) + LogWithTimestamp(t, "Testing atomic multi-host group recreation for RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) + workerPods, err := GetWorkerPods(test, rayCluster) g.Expect(err).NotTo(HaveOccurred()) podToDelete := workerPods[0] err = test.Client().Core().CoreV1().Pods(namespace.Name).Delete(test.Ctx(), podToDelete.Name, metav1.DeleteOptions{}) g.Expect(err).NotTo(HaveOccurred()) // The controller should first clean up the broken multi-host group (-4 pods), and then re-scale it up (+4 pods). - LogWithTimestamp(test.T(), "Waiting for controller to reconcile multi-host group.") + LogWithTimestamp(t, "Waiting for controller to reconcile multi-host group.") // Reconciliation happens too quickly to catch the state where expectedPodCount-NumOfHosts, but we can test // that externally deleted Pods will be re-created to satisfy the expected number. g.Eventually(func() ([]corev1.Pod, error) { return GetWorkerPods(test, rayCluster) }, TestTimeoutShort).Should(HaveLen(expectedPodCount), "Controller restored cluster to the correct number of pods.") + + // Verify labels are still set correctly after atomic re-creation due to unhealthy Pod. + LogWithTimestamp(t, "Verifying labels after atomic recreation for %s/%s", rayCluster.Namespace, rayCluster.Name) + verifyWorkerGroupIndices(t, rayCluster, "multi-host-group", numOfHosts, scaleUpReplicas, []int{0, 1, 2}) }