Skip to content

Commit

Permalink
feat(lmeval): Filter protected env vars set from CR directly
Browse files Browse the repository at this point in the history
  • Loading branch information
ruivieira committed Dec 11, 2024
1 parent 6f88955 commit 31d279d
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 1 deletion.
24 changes: 23 additions & 1 deletion controllers/lmes/lmevaljob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ func (r *LMEvalJobReconciler) validateCustomCard(job *lmesv1alpha1.LMEvalJob, lo

func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod {

var envVars = job.Spec.Pod.GetContainer().GetEnv()
var envVars = removeProtectedEnvVars(job.Spec.Pod.GetContainer().GetEnv())

var volumeMounts = []corev1.VolumeMount{
{
Expand Down Expand Up @@ -1069,3 +1069,25 @@ func getContainerByName(status *corev1.PodStatus, name string) int {
return s.Name == name
})
}

var ProtectedEnvVarNames = []string{"TRUST_REMOTE_CODE", "HF_DATASETS_TRUST_REMOTE_CODE", "HF_DATASETS_OFFLINE", "HF_HUB_OFFLINE", "TRANSFORMERS_OFFLINE", "HF_EVALUATE_OFFLINE"}

// removeProtectedEnvVars removes protected EnvVars from a list of EnvVars
func removeProtectedEnvVars(envVars []corev1.EnvVar) []corev1.EnvVar {
var allowedEnvVars []corev1.EnvVar

for _, env := range envVars {
exclude := false
for _, name := range ProtectedEnvVarNames {
if env.Name == name {
exclude = true
break
}
}
if !exclude {
allowedEnvVars = append(allowedEnvVars, env)
}
}

return allowedEnvVars
}
187 changes: 187 additions & 0 deletions controllers/lmes/lmevaljob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1683,6 +1683,193 @@ func Test_OfflineMode(t *testing.T) {
assert.Equal(t, expect, newPod)
}

// Test_ProtectedVars tests that if the protected env vars are set from spec.pod mode
// they will not be changed in the pod
func Test_ProtectedVars(t *testing.T) {
log := log.FromContext(context.Background())
svcOpts := &serviceOptions{
PodImage: "podimage:latest",
DriverImage: "driver:latest",
ImagePullPolicy: corev1.PullAlways,
}

jobName := "test"
pvcName := "my-pvc"
var job = &lmesv1alpha1.LMEvalJob{
ObjectMeta: metav1.ObjectMeta{
Name: jobName,
Namespace: "default",
UID: "for-testing",
},
TypeMeta: metav1.TypeMeta{
Kind: lmesv1alpha1.KindName,
APIVersion: lmesv1alpha1.Version,
},
Spec: lmesv1alpha1.LMEvalJobSpec{
Model: "test",
ModelArgs: []lmesv1alpha1.Arg{
{Name: "arg1", Value: "value1"},
},
TaskList: lmesv1alpha1.TaskList{
TaskNames: []string{"task1", "task2"},
},
Offline: &lmesv1alpha1.OfflineSpec{
StorageSpec: lmesv1alpha1.OfflineStorageSpec{
PersistentVolumeClaimName: pvcName,
},
},
Pod: &lmesv1alpha1.LMEvalPodSpec{
Container: &lmesv1alpha1.LMEvalContainer{
Env: []corev1.EnvVar{
{
Name: "HF_HUB_OFFLINE",
Value: "0",
},
{
Name: "NOT_PROTECTED",
Value: "True",
},
{
Name: "TRUST_REMOTE_CODE",
Value: "1",
},
},
},
},
},
}

expect := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "default",
Labels: map[string]string{
"app.kubernetes.io/name": "ta-lmes",
},
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: lmesv1alpha1.Version,
Kind: lmesv1alpha1.KindName,
Name: "test",
Controller: &isController,
UID: "for-testing",
},
},
},
TypeMeta: metav1.TypeMeta{
Kind: "Pod",
APIVersion: "v1",
},
Spec: corev1.PodSpec{
InitContainers: []corev1.Container{
{
Name: "driver",
Image: svcOpts.DriverImage,
ImagePullPolicy: svcOpts.ImagePullPolicy,
Command: []string{DriverPath, "--copy", DestDriverPath},
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: &allowPrivilegeEscalation,
Capabilities: &corev1.Capabilities{
Drop: []corev1.Capability{
"ALL",
},
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: "shared",
MountPath: "/opt/app-root/src/bin",
},
},
},
},
Containers: []corev1.Container{
{
Name: "main",
Image: svcOpts.PodImage,
ImagePullPolicy: svcOpts.ImagePullPolicy,
Command: generateCmd(svcOpts, job),
Args: generateArgs(svcOpts, job, log),
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: &allowPrivilegeEscalation,
Capabilities: &corev1.Capabilities{
Drop: []corev1.Capability{
"ALL",
},
},
},
Env: []corev1.EnvVar{
{
Name: "NOT_PROTECTED",
Value: "True",
},
{
Name: "TRUST_REMOTE_CODE",
Value: "0",
},
{
Name: "HF_DATASETS_TRUST_REMOTE_CODE",
Value: "0",
},
{
Name: "HF_DATASETS_OFFLINE",
Value: "1",
},
{
Name: "HF_HUB_OFFLINE",
Value: "1",
},
{
Name: "TRANSFORMERS_OFFLINE",
Value: "1",
},
{
Name: "HF_EVALUATE_OFFLINE",
Value: "1",
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: "shared",
MountPath: "/opt/app-root/src/bin",
},
{
Name: "offline",
MountPath: "/opt/app-root/src/hf_home",
},
},
},
},
SecurityContext: &corev1.PodSecurityContext{
RunAsNonRoot: &runAsNonRootUser,
SeccompProfile: &corev1.SeccompProfile{
Type: corev1.SeccompProfileTypeRuntimeDefault,
},
},
Volumes: []corev1.Volume{
{
Name: "shared", VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
},
{
Name: "offline", VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: pvcName,
ReadOnly: false,
},
},
},
},
RestartPolicy: corev1.RestartPolicyNever,
},
}

newPod := CreatePod(svcOpts, job, log)

assert.Equal(t, expect, newPod)
}

// Test_OnlineMode tests that if the online mode is set the configuration is correct
func Test_OnlineMode(t *testing.T) {
log := log.FromContext(context.Background())
Expand Down

0 comments on commit 31d279d

Please sign in to comment.