Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions ray-operator/controllers/ray/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func getEnableProbesInjection() bool {
}

// DefaultWorkerPodTemplate sets the config values
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, numHostIndex int) corev1.PodTemplateSpec {
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, replicaIndex int, numHostIndex int) corev1.PodTemplateSpec {
podTemplate := workerSpec.Template
podTemplate.GenerateName = podName
// Pods created by RayCluster should be restricted to the namespace of the RayCluster.
Expand Down Expand Up @@ -329,11 +329,15 @@ func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, wo
mergedLabels := mergeLabels(workerSpec.Template.ObjectMeta.Labels, workerSpec.Labels)
podTemplate.Labels = labelPod(rayv1.WorkerNode, instance.Name, workerSpec.GroupName, mergedLabels)

// Add additional labels for RayMultihostIndexing
multihostIndexingEnabled := features.Enabled(features.RayMultiHostIndexing) && workerSpec.NumOfHosts > 1
if multihostIndexingEnabled {
podTemplate.Labels[utils.RayWorkerReplicaIndexKey] = replicaGrpName
podTemplate.Labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex)
// Add additional labels when RayMultihostIndexing is enabled.
if features.Enabled(features.RayMultiHostIndexing) {
// The ordered replica index can be used for the single-host, multi-slice case.
podTemplate.Labels[utils.RayWorkerReplicaIndexKey] = strconv.Itoa(replicaIndex)
if workerSpec.NumOfHosts > 1 {
// These labels are specific to multi-host group setup and reconciliation.
podTemplate.Labels[utils.RayWorkerReplicaNameKey] = replicaGrpName
podTemplate.Labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex)
}
}
workerSpec.RayStartParams = setMissingRayStartParams(ctx, workerSpec.RayStartParams, rayv1.WorkerNode, headPort, fqdnRayIP)

Expand Down
27 changes: 14 additions & 13 deletions ray-operator/controllers/ray/common/pod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ func TestBuildPod(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, defaultContainerEnvs)

// Check resources
Expand Down Expand Up @@ -761,7 +761,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, nil)
expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080")
actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0])
Expand Down Expand Up @@ -792,7 +792,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP, nil)
workerContainer := workerPod.Spec.Containers[utils.RayContainerIndex]
assert.Equal(t, []string{"I am worker"}, workerContainer.Command)
Expand Down Expand Up @@ -847,7 +847,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP, nil)

val, ok = pod.Labels[utils.RayClusterServingServiceLabelKey]
Expand Down Expand Up @@ -903,7 +903,7 @@ func TestBuildPod_WithLoginBash(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP, nil)

// Verify worker container command
Expand Down Expand Up @@ -1166,7 +1166,7 @@ func TestDefaultWorkerPodTemplateWithName(t *testing.T) {
expectedWorker := *worker.DeepCopy()

// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0, 0)
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
assert.Equal(t, expectedWorker, worker)
}
Expand All @@ -1187,9 +1187,10 @@ func TestDeafultWorkerPodTemplateWithReplicaGrpAndIndex(t *testing.T) {
groupReplicaName := utils.GenerateRayWorkerReplicaGroupName(worker.GroupName)

// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", groupReplicaName, 2)
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", groupReplicaName, 0, 2)
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
assert.Equal(t, podTemplateSpec.Labels[utils.RayWorkerReplicaIndexKey], groupReplicaName)
assert.Equal(t, podTemplateSpec.Labels[utils.RayWorkerReplicaNameKey], groupReplicaName)
assert.Equal(t, "0", podTemplateSpec.Labels[utils.RayWorkerReplicaIndexKey])
assert.Equal(t, "2", podTemplateSpec.Labels[utils.RayHostIndexKey])
}

Expand Down Expand Up @@ -1235,7 +1236,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
worker := cluster.Spec.WorkerGroupSpecs[0]
podName := cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
// DefaultWorkerPodTemplate will add the default metrics port if user doesn't specify it.
// Verify the default metrics port exists.
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, int32(utils.DefaultMetricsPort)))
Expand All @@ -1245,7 +1246,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
ContainerPort: customMetricsPort,
}
cluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Ports = []corev1.ContainerPort{metricsPort}
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0, 0)
// Verify the custom metrics port exists.
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, customMetricsPort))
}
Expand Down Expand Up @@ -1284,7 +1285,7 @@ func TestDefaultWorkerPodTemplate_Autoscaling(t *testing.T) {

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0, 0)
assert.Equal(t, tc.expectedRestartPolicy, podTemplateSpec.Spec.RestartPolicy)
})
}
Expand All @@ -1300,7 +1301,7 @@ func TestDefaultInitContainer(t *testing.T) {
expectedResult := len(cluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers) + 1

// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0, 0)
numInitContainers := len(podTemplateSpec.Spec.InitContainers)
assert.Equal(t, expectedResult, numInitContainers, "A default init container is expected to be added.")

Expand Down Expand Up @@ -1359,7 +1360,7 @@ func TestDefaultInitContainerImagePullPolicy(t *testing.T) {
// set ray container imagePullPolicy
worker.Template.Spec.Containers[utils.RayContainerIndex].ImagePullPolicy = tc.imagePullPolicy

podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0, 0)

healthCheckContainer := podTemplateSpec.Spec.InitContainers[len(podTemplateSpec.Spec.InitContainers)-1]
assert.Equal(t, tc.expectedPullPolicy, healthCheckContainer.ImagePullPolicy, "The ImagePullPolicy of the init container should be the same as the Ray container.")
Expand Down
Loading
Loading