Skip to content

Commit

Permalink
Add support for environment variables for Training images
Browse files Browse the repository at this point in the history
  • Loading branch information
ChughShilpa authored and openshift-merge-bot[bot] committed Nov 21, 2024
1 parent 2a7c1fc commit e99e941
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
2 changes: 2 additions & 0 deletions support/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ const (
RayROCmImage = "quay.io/modh/ray:2.35.0-py311-rocm61"
RayTorchCudaImage = "quay.io/rhoai/ray:2.35.0-py311-cu121-torch24-fa26"
RayTorchROCmImage = "quay.io/rhoai/ray:2.35.0-py311-rocm61-torch24-fa26"
TrainingCudaImage = "quay.io/modh/training:py311-cuda121-torch241"
TrainingROCmImage = "quay.io/modh/training:py311-rocm61-torch241"
)
15 changes: 12 additions & 3 deletions support/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ const (
// The environment variables hereafter can be used to change the components
// used for testing.

CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
CodeFlareTestTrainingImage = "CODEFLARE_TEST_TRAINING_IMAGE"

// The testing output directory, to write output files into.
CodeFlareTestOutputDir = "CODEFLARE_TEST_OUTPUT_DIR"
Expand Down Expand Up @@ -97,6 +98,14 @@ func GetPyTorchImage() string {
return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime")
}

func GetCudaTrainingImage() string {
return lookupEnvOrDefault(CodeFlareTestTrainingImage, TrainingCudaImage)
}

func GetROCmTrainingImage() string {
return lookupEnvOrDefault(CodeFlareTestTrainingImage, TrainingROCmImage)
}

func GetInstascaleOcmSecret() (string, string) {
res := strings.SplitN(lookupEnvOrDefault(InstaScaleOcmSecret, "default/instascale-ocm-secret"), "/", 2)
return res[0], res[1]
Expand Down
15 changes: 15 additions & 0 deletions support/environment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ func TestGetPyTorchImage(t *testing.T) {

}

func TestGetTrainingImage(t *testing.T) {

g := gomega.NewGomegaWithT(t)
// Set the environment variable.
os.Setenv(CodeFlareTestTrainingImage, "training/training:latest")

// Get the image.
image := GetCudaTrainingImage()

// Assert that the image is correct.

g.Expect(image).To(gomega.Equal("training/training:latest"), "Expected image training/training:latest, but got %s", image)

}

func TestGetClusterID(t *testing.T) {

g := gomega.NewGomegaWithT(t)
Expand Down

0 comments on commit e99e941

Please sign in to comment.