Skip to content

Commit daebac4

Browse files
committed
Check Pod Labels first and add unit tests
Signed-off-by: Ryan O'Leary <[email protected]>
1 parent cc581c2 commit daebac4

File tree

2 files changed

+263
-59
lines changed

2 files changed

+263
-59
lines changed

ray-operator/controllers/ray/common/pod.go

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,34 +1066,68 @@ func addDefaultRayNodeLabels(pod *corev1.Pod) {
10661066
})
10671067
}
10681068
if !containsEnvVar(*rayContainer, utils.RayNodeZone) {
1069-
// uses downward api to set the ray.io/availability-zone node label
1070-
// Ref: https://kubernetes.io/docs/reference/labels-annotations-taints/#topologykubernetesiozone
1071-
envVars = append(envVars, corev1.EnvVar{
1072-
Name: utils.RayNodeZone,
1073-
ValueFrom: &corev1.EnvVarSource{
1074-
FieldRef: &corev1.ObjectFieldSelector{
1075-
FieldPath: fmt.Sprintf("metadata.labels['%s']", utils.K8sTopologyZoneLabel),
1076-
},
1077-
},
1078-
})
1069+
envVars = append(envVars, getPodZoneEnvVar(pod))
10791070
}
10801071
if !containsEnvVar(*rayContainer, utils.RayNodeRegion) {
1081-
// uses downward api to set the ray.io/availability-region node label
1082-
// Ref: https://kubernetes.io/docs/reference/labels-annotations-taints/#topologykubernetesioregion
1083-
envVars = append(envVars, corev1.EnvVar{
1084-
Name: utils.RayNodeRegion,
1085-
ValueFrom: &corev1.EnvVarSource{
1086-
FieldRef: &corev1.ObjectFieldSelector{
1087-
FieldPath: fmt.Sprintf("metadata.labels['%s']", utils.K8sTopologyRegionLabel),
1088-
},
1089-
},
1090-
})
1072+
envVars = append(envVars, getPodRegionEnvVar(pod))
10911073
}
10921074
rayContainer.Env = envVars
10931075
}
10941076

1095-
// getPodMarketTypeFromNodeSelector is a helper function to determine the ray.io/market-type label
1096-
// based on a Kubernetes Pod spec.
1077+
// getPodZoneEnvVar is a helper function to determine the ray.io/availability-zone label value
1078+
// based on a Pod spec - checking labels, nodeSelectors, and then falling back to downward API.
1079+
func getPodZoneEnvVar(pod *corev1.Pod) corev1.EnvVar {
1080+
if podZone, ok := pod.Labels[utils.K8sTopologyZoneLabel]; ok && podZone != "" {
1081+
return corev1.EnvVar{
1082+
Name: utils.RayNodeZone,
1083+
Value: podZone,
1084+
}
1085+
} else if podZone, ok := pod.Spec.NodeSelector[utils.K8sTopologyZoneLabel]; ok && podZone != "" {
1086+
return corev1.EnvVar{
1087+
Name: utils.RayNodeZone,
1088+
Value: podZone,
1089+
}
1090+
}
1091+
// uses downward api to set the ray.io/availability-zone node label
1092+
// Ref: https://kubernetes.io/docs/reference/labels-annotations-taints/#topologykubernetesiozone
1093+
return corev1.EnvVar{
1094+
Name: utils.RayNodeZone,
1095+
ValueFrom: &corev1.EnvVarSource{
1096+
FieldRef: &corev1.ObjectFieldSelector{
1097+
FieldPath: fmt.Sprintf("metadata.labels['%s']", utils.K8sTopologyZoneLabel),
1098+
},
1099+
},
1100+
}
1101+
}
1102+
1103+
// getPodRegionEnvVar is a helper function to determine the ray.io/availability-region label value
1104+
// based on a Pod spec - checking labels, nodeSelectors, and then falling back to downward API.
1105+
func getPodRegionEnvVar(pod *corev1.Pod) corev1.EnvVar {
1106+
if podRegion, ok := pod.Labels[utils.K8sTopologyRegionLabel]; ok && podRegion != "" {
1107+
return corev1.EnvVar{
1108+
Name: utils.RayNodeRegion,
1109+
Value: podRegion,
1110+
}
1111+
} else if podRegion, ok := pod.Spec.NodeSelector[utils.K8sTopologyRegionLabel]; ok && podRegion != "" {
1112+
return corev1.EnvVar{
1113+
Name: utils.RayNodeRegion,
1114+
Value: podRegion,
1115+
}
1116+
}
1117+
// uses downward api to set the ray.io/availability-region node label
1118+
// Ref: https://kubernetes.io/docs/reference/labels-annotations-taints/#topologykubernetesioregion
1119+
return corev1.EnvVar{
1120+
Name: utils.RayNodeRegion,
1121+
ValueFrom: &corev1.EnvVarSource{
1122+
FieldRef: &corev1.ObjectFieldSelector{
1123+
FieldPath: fmt.Sprintf("metadata.labels['%s']", utils.K8sTopologyRegionLabel),
1124+
},
1125+
},
1126+
}
1127+
}
1128+
1129+
// getPodMarketTypeFromNodeSelector is a helper function to determine the ray.io/market-type
1130+
// label value based on a Kubernetes Pod spec - checking labels, nodeSelector, and nodeAffinity.
10971131
func getPodMarketType(pod *corev1.Pod) utils.PodMarketType {
10981132
marketType := getPodMarketTypeFromNodeSelector(pod.Spec.NodeSelector)
10991133

ray-operator/controllers/ray/common/pod_test.go

Lines changed: 207 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,54 +2046,224 @@ func TestGetPodMarketType(t *testing.T) {
20462046
}
20472047
}
20482048

2049-
func TestAddDefaultRayNodeLabels_GKESpot(t *testing.T) {
2050-
pod := &corev1.Pod{
2051-
ObjectMeta: metav1.ObjectMeta{
2052-
Labels: map[string]string{
2053-
"ray.io/group": "test-worker-group-1",
2054-
"topology.kubernetes.io/region": "us-central2",
2055-
"topology.kubernetes.io/zone": "us-central2-b",
2049+
func TestAddDefaultRayNodeLabels(t *testing.T) {
2050+
tests := []struct {
2051+
labels map[string]string
2052+
nodeSelector map[string]string
2053+
nodeAffinity *corev1.NodeAffinity
2054+
expectedEnv map[string]string
2055+
name string
2056+
}{
2057+
{
2058+
name: "Availability zone vars set from region and zone topology labels",
2059+
labels: map[string]string{
2060+
utils.K8sTopologyRegionLabel: "us-west4",
2061+
utils.K8sTopologyZoneLabel: "us-west4-a",
2062+
},
2063+
expectedEnv: map[string]string{
2064+
utils.RayNodeRegion: "us-west4",
2065+
utils.RayNodeZone: "us-west4-a",
20562066
},
20572067
},
2058-
Spec: corev1.PodSpec{
2059-
Containers: []corev1.Container{
2060-
{Name: "ray"},
2068+
{
2069+
name: "Availability zone vars set from region and zone topology nodeSelectors",
2070+
nodeSelector: map[string]string{
2071+
utils.K8sTopologyRegionLabel: "us-central2",
2072+
utils.K8sTopologyZoneLabel: "us-central2-b",
20612073
},
2062-
NodeSelector: map[string]string{
2063-
"cloud.google.com/gke-spot": "true",
2074+
expectedEnv: map[string]string{
2075+
utils.RayNodeRegion: "us-central2",
2076+
utils.RayNodeZone: "us-central2-b",
2077+
},
2078+
},
2079+
{
2080+
name: "Availability zone vars set from downward API",
2081+
expectedEnv: map[string]string{
2082+
utils.RayNodeRegion: "metadata.labels['topology.kubernetes.io/region']",
2083+
utils.RayNodeZone: "metadata.labels['topology.kubernetes.io/zone']",
2084+
},
2085+
},
2086+
{
2087+
name: "Market type env var set from GKE Spot nodeSelector",
2088+
nodeSelector: map[string]string{
2089+
utils.GKESpotLabel: "true",
2090+
utils.K8sTopologyRegionLabel: "me-central1",
2091+
utils.K8sTopologyZoneLabel: "me-central1-a",
2092+
},
2093+
expectedEnv: map[string]string{
2094+
utils.RayNodeMarketType: string(utils.SpotMarketType),
2095+
utils.RayNodeRegion: "me-central1",
2096+
utils.RayNodeZone: "me-central1-a",
2097+
},
2098+
},
2099+
{
2100+
name: "Market type env var set from EKS Spot nodeSelector",
2101+
nodeSelector: map[string]string{
2102+
utils.EKSCapacityTypeLabel: "SPOT",
2103+
utils.K8sTopologyRegionLabel: "us-central1",
2104+
utils.K8sTopologyZoneLabel: "us-central1-c",
2105+
},
2106+
expectedEnv: map[string]string{
2107+
utils.RayNodeMarketType: string(utils.SpotMarketType),
2108+
utils.RayNodeRegion: "us-central1",
2109+
utils.RayNodeZone: "us-central1-c",
2110+
},
2111+
},
2112+
{
2113+
name: "Market type env var set from nodeAffinity",
2114+
nodeAffinity: &corev1.NodeAffinity{
2115+
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
2116+
NodeSelectorTerms: []corev1.NodeSelectorTerm{
2117+
{
2118+
MatchExpressions: []corev1.NodeSelectorRequirement{
2119+
{
2120+
Key: utils.EKSCapacityTypeLabel,
2121+
Operator: corev1.NodeSelectorOpIn,
2122+
Values: []string{"SPOT"},
2123+
},
2124+
},
2125+
},
2126+
},
2127+
},
2128+
},
2129+
expectedEnv: map[string]string{
2130+
utils.RayNodeMarketType: string(utils.SpotMarketType),
2131+
utils.RayNodeRegion: "metadata.labels['topology.kubernetes.io/region']",
2132+
utils.RayNodeZone: "metadata.labels['topology.kubernetes.io/zone']",
20642133
},
20652134
},
20662135
}
20672136

2068-
addDefaultRayNodeLabels(pod)
2069-
rayContainer := pod.Spec.Containers[utils.RayContainerIndex]
2070-
checkContainerEnv(t, rayContainer, "RAY_NODE_MARKET_TYPE", "spot")
2071-
checkContainerEnv(t, rayContainer, "RAY_NODE_REGION", "metadata.labels['topology.kubernetes.io/region']")
2072-
checkContainerEnv(t, rayContainer, "RAY_NODE_ZONE", "metadata.labels['topology.kubernetes.io/zone']")
2137+
for _, tt := range tests {
2138+
t.Run(tt.name, func(t *testing.T) {
2139+
pod := &corev1.Pod{
2140+
ObjectMeta: metav1.ObjectMeta{
2141+
Labels: tt.labels,
2142+
},
2143+
Spec: corev1.PodSpec{
2144+
Containers: []corev1.Container{{Name: "ray"}},
2145+
NodeSelector: tt.nodeSelector,
2146+
},
2147+
}
2148+
if tt.nodeAffinity != nil {
2149+
pod.Spec.Affinity = &corev1.Affinity{NodeAffinity: tt.nodeAffinity}
2150+
}
2151+
// validate default labels are set correctly from Pod spec as env vars
2152+
addDefaultRayNodeLabels(pod)
2153+
rayContainer := pod.Spec.Containers[utils.RayContainerIndex]
2154+
for key, expectedVar := range tt.expectedEnv {
2155+
foundVar := false
2156+
for _, env := range rayContainer.Env {
2157+
if env.Name == key {
2158+
if env.Value != "" {
2159+
if env.Value != expectedVar {
2160+
t.Errorf("%s: got value %q, but expected %q", key, env.Value, expectedVar)
2161+
}
2162+
} else if env.ValueFrom != nil && env.ValueFrom.FieldRef != nil {
2163+
if env.ValueFrom.FieldRef.FieldPath != expectedVar {
2164+
t.Errorf("%s: got FieldPath %q, but expected %q", key, env.ValueFrom.FieldRef.FieldPath, expectedVar)
2165+
}
2166+
} else {
2167+
t.Errorf("%s: environment var not set as expected", key)
2168+
}
2169+
foundVar = true
2170+
break
2171+
}
2172+
}
2173+
if !foundVar {
2174+
t.Errorf("%s: not found in container env", key)
2175+
}
2176+
}
2177+
})
2178+
}
20732179
}
20742180

2075-
func TestAddDefaultRayNodeLabels_EKSSpot(t *testing.T) {
2076-
pod := &corev1.Pod{
2077-
ObjectMeta: metav1.ObjectMeta{
2078-
Labels: map[string]string{
2079-
"ray.io/group": "test-worker-group-2",
2080-
"topology.kubernetes.io/region": "us-west4",
2081-
"topology.kubernetes.io/zone": "us-west4-a",
2082-
},
2181+
func TestGetPodZoneEnvVar(t *testing.T) {
2182+
tests := []struct {
2183+
name string
2184+
labels map[string]string
2185+
nodeSelector map[string]string
2186+
expectedVar string
2187+
}{
2188+
{
2189+
name: "Retrieve topology zone from labels",
2190+
labels: map[string]string{utils.K8sTopologyZoneLabel: "us-west4-a"},
2191+
expectedVar: "us-west4-a",
20832192
},
2084-
Spec: corev1.PodSpec{
2085-
Containers: []corev1.Container{
2086-
{Name: "ray"},
2087-
},
2088-
NodeSelector: map[string]string{
2089-
"eks.amazonaws.com/capacityType": "SPOT",
2090-
},
2193+
{
2194+
name: "Retrieve topology zone from nodeSelector",
2195+
nodeSelector: map[string]string{utils.K8sTopologyZoneLabel: "us-central2-b"},
2196+
expectedVar: "us-central2-b",
20912197
},
2198+
{
2199+
name: "Zone set using downward API",
2200+
expectedVar: "metadata.labels['topology.kubernetes.io/zone']",
2201+
},
2202+
}
2203+
for _, tt := range tests {
2204+
t.Run(tt.name, func(t *testing.T) {
2205+
pod := &corev1.Pod{
2206+
ObjectMeta: metav1.ObjectMeta{Labels: tt.labels},
2207+
Spec: corev1.PodSpec{NodeSelector: tt.nodeSelector},
2208+
}
2209+
// validate expected zone env var is parsed from Pod spec
2210+
result := getPodZoneEnvVar(pod)
2211+
if result.Value != "" {
2212+
if result.Value != tt.expectedVar {
2213+
t.Errorf("got env var %q, but expected %q", result.Value, tt.expectedVar)
2214+
}
2215+
} else if result.ValueFrom != nil {
2216+
if result.ValueFrom.FieldRef.FieldPath != tt.expectedVar {
2217+
t.Errorf("got FieldPath %q, but expected %q", result.ValueFrom.FieldRef.FieldPath, tt.expectedVar)
2218+
}
2219+
} else {
2220+
t.Errorf("getPodZoneEnvVar did not return expected env value")
2221+
}
2222+
})
20922223
}
2224+
}
20932225

2094-
addDefaultRayNodeLabels(pod)
2095-
rayContainer := pod.Spec.Containers[utils.RayContainerIndex]
2096-
checkContainerEnv(t, rayContainer, utils.RayNodeMarketType, "spot")
2097-
checkContainerEnv(t, rayContainer, utils.RayNodeRegion, "metadata.labels['topology.kubernetes.io/region']")
2098-
checkContainerEnv(t, rayContainer, utils.RayNodeZone, "metadata.labels['topology.kubernetes.io/zone']")
2226+
func TestGetPodRegionEnvVar(t *testing.T) {
2227+
tests := []struct {
2228+
name string
2229+
labels map[string]string
2230+
nodeSelector map[string]string
2231+
expectedVar string
2232+
}{
2233+
{
2234+
name: "Retrieve topology region from labels",
2235+
labels: map[string]string{utils.K8sTopologyRegionLabel: "us-central1"},
2236+
expectedVar: "us-central1",
2237+
},
2238+
{
2239+
name: "Retrieve topology region from nodeSelector",
2240+
nodeSelector: map[string]string{utils.K8sTopologyRegionLabel: "us-central2"},
2241+
expectedVar: "us-central2",
2242+
},
2243+
{
2244+
name: "Region set using downward API",
2245+
expectedVar: "metadata.labels['topology.kubernetes.io/region']",
2246+
},
2247+
}
2248+
for _, tt := range tests {
2249+
t.Run(tt.name, func(t *testing.T) {
2250+
pod := &corev1.Pod{
2251+
ObjectMeta: metav1.ObjectMeta{Labels: tt.labels},
2252+
Spec: corev1.PodSpec{NodeSelector: tt.nodeSelector},
2253+
}
2254+
// validate expected region env var is parsed from Pod spec
2255+
result := getPodRegionEnvVar(pod)
2256+
if result.Value != "" {
2257+
if result.Value != tt.expectedVar {
2258+
t.Errorf("got env var %q, but expected %q", result.Value, tt.expectedVar)
2259+
}
2260+
} else if result.ValueFrom != nil {
2261+
if result.ValueFrom.FieldRef.FieldPath != tt.expectedVar {
2262+
t.Errorf("got FieldPath %q, but expected %q", result.ValueFrom.FieldRef.FieldPath, tt.expectedVar)
2263+
}
2264+
} else {
2265+
t.Errorf("getPodRegionEnvVar did not return expected env value")
2266+
}
2267+
})
2268+
}
20992269
}

0 commit comments

Comments
 (0)