diff --git a/controllers/lmes/lmevaljob_controller.go b/controllers/lmes/lmevaljob_controller.go index b87be6c..2995020 100644 --- a/controllers/lmes/lmevaljob_controller.go +++ b/controllers/lmes/lmevaljob_controller.go @@ -666,7 +666,7 @@ func (r *LMEvalJobReconciler) validateCustomCard(job *lmesv1alpha1.LMEvalJob, lo func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod { - var envVars = job.Spec.Pod.GetContainer().GetEnv() + var envVars = removeProtectedEnvVars(job.Spec.Pod.GetContainer().GetEnv()) var volumeMounts = []corev1.VolumeMount{ { @@ -1069,3 +1069,25 @@ func getContainerByName(status *corev1.PodStatus, name string) int { return s.Name == name }) } + +var ProtectedEnvVarNames = []string{"TRUST_REMOTE_CODE", "HF_DATASETS_TRUST_REMOTE_CODE", "HF_DATASETS_OFFLINE", "HF_HUB_OFFLINE", "TRANSFORMERS_OFFLINE", "HF_EVALUATE_OFFLINE"} + +// removeProtectedEnvVars removes protected EnvVars from a list of EnvVars +func removeProtectedEnvVars(envVars []corev1.EnvVar) []corev1.EnvVar { + var allowedEnvVars []corev1.EnvVar + + for _, env := range envVars { + exclude := false + for _, name := range ProtectedEnvVarNames { + if env.Name == name { + exclude = true + break + } + } + if !exclude { + allowedEnvVars = append(allowedEnvVars, env) + } + } + + return allowedEnvVars +} diff --git a/controllers/lmes/lmevaljob_controller_test.go b/controllers/lmes/lmevaljob_controller_test.go index 2519e13..b29f0d6 100644 --- a/controllers/lmes/lmevaljob_controller_test.go +++ b/controllers/lmes/lmevaljob_controller_test.go @@ -1683,6 +1683,193 @@ func Test_OfflineMode(t *testing.T) { assert.Equal(t, expect, newPod) } +// Test_ProtectedVars tests that if the protected env vars are set from spec.pod mode +// they will not be changed in the pod +func Test_ProtectedVars(t *testing.T) { + log := log.FromContext(context.Background()) + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, + } + + jobName := "test" + pvcName := "my-pvc" + var job = &lmesv1alpha1.LMEvalJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: jobName, + Namespace: "default", + UID: "for-testing", + }, + TypeMeta: metav1.TypeMeta{ + Kind: lmesv1alpha1.KindName, + APIVersion: lmesv1alpha1.Version, + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + ModelArgs: []lmesv1alpha1.Arg{ + {Name: "arg1", Value: "value1"}, + }, + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1", "task2"}, + }, + Offline: &lmesv1alpha1.OfflineSpec{ + StorageSpec: lmesv1alpha1.OfflineStorageSpec{ + PersistentVolumeClaimName: pvcName, + }, + }, + Pod: &lmesv1alpha1.LMEvalPodSpec{ + Container: &lmesv1alpha1.LMEvalContainer{ + Env: []corev1.EnvVar{ + { + Name: "HF_HUB_OFFLINE", + Value: "0", + }, + { + Name: "NOT_PROTECTED", + Value: "True", + }, + { + Name: "TRUST_REMOTE_CODE", + Value: "1", + }, + }, + }, + }, + }, + } + + expect := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + Labels: map[string]string{ + "app.kubernetes.io/name": "ta-lmes", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: lmesv1alpha1.Version, + Kind: lmesv1alpha1.KindName, + Name: "test", + Controller: &isController, + UID: "for-testing", + }, + }, + }, + TypeMeta: metav1.TypeMeta{ + Kind: "Pod", + APIVersion: "v1", + }, + Spec: corev1.PodSpec{ + InitContainers: []corev1.Container{ + { + Name: "driver", + Image: svcOpts.DriverImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, + Command: []string{DriverPath, "--copy", DestDriverPath}, + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: &allowPrivilegeEscalation, + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{ + "ALL", + }, + }, + }, + VolumeMounts: []corev1.VolumeMount{ + { + Name: "shared", + MountPath: "/opt/app-root/src/bin", + }, + }, + }, + }, + Containers: []corev1.Container{ + { + Name: "main", + Image: svcOpts.PodImage, + ImagePullPolicy: svcOpts.ImagePullPolicy, + Command: generateCmd(svcOpts, job), + Args: generateArgs(svcOpts, job, log), + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: &allowPrivilegeEscalation, + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{ + "ALL", + }, + }, + }, + Env: []corev1.EnvVar{ + { + Name: "NOT_PROTECTED", + Value: "True", + }, + { + Name: "TRUST_REMOTE_CODE", + Value: "0", + }, + { + Name: "HF_DATASETS_TRUST_REMOTE_CODE", + Value: "0", + }, + { + Name: "HF_DATASETS_OFFLINE", + Value: "1", + }, + { + Name: "HF_HUB_OFFLINE", + Value: "1", + }, + { + Name: "TRANSFORMERS_OFFLINE", + Value: "1", + }, + { + Name: "HF_EVALUATE_OFFLINE", + Value: "1", + }, + }, + VolumeMounts: []corev1.VolumeMount{ + { + Name: "shared", + MountPath: "/opt/app-root/src/bin", + }, + { + Name: "offline", + MountPath: "/opt/app-root/src/hf_home", + }, + }, + }, + }, + SecurityContext: &corev1.PodSecurityContext{ + RunAsNonRoot: &runAsNonRootUser, + SeccompProfile: &corev1.SeccompProfile{ + Type: corev1.SeccompProfileTypeRuntimeDefault, + }, + }, + Volumes: []corev1.Volume{ + { + Name: "shared", VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + { + Name: "offline", VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: pvcName, + ReadOnly: false, + }, + }, + }, + }, + RestartPolicy: corev1.RestartPolicyNever, + }, + } + + newPod := CreatePod(svcOpts, job, log) + + assert.Equal(t, expect, newPod) +} + // Test_OnlineMode tests that if the online mode is set the configuration is correct func Test_OnlineMode(t *testing.T) { log := log.FromContext(context.Background())