diff --git a/tests/kfto/kfto_mnist_training_test.go b/tests/kfto/kfto_mnist_training_test.go index 63596c07..ed947433 100644 --- a/tests/kfto/kfto_mnist_training_test.go +++ b/tests/kfto/kfto_mnist_training_test.go @@ -31,7 +31,7 @@ import ( ) func TestPyTorchJobMnistCpu(t *testing.T) { - runKFTOPyTorchMnistJob(t, 0, 2, "", GetCudaTrainingImage(), "resources/requirements.txt") + runKFTOPyTorchMnistJob(t, 0, 1, "", GetCudaTrainingImage(), "resources/requirements.txt") } func TestPyTorchJobMnistWithCuda(t *testing.T) { @@ -149,6 +149,12 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config MountPath: "/mnt/output", }, }, + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("6Gi"), + }, + }, }, }, Volumes: []corev1.Volume{ @@ -226,6 +232,12 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config MountPath: "/tmp", }, }, + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("6Gi"), + }, + }, }, }, Volumes: []corev1.Volume{ @@ -255,21 +267,9 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config } if useGPU { - // Update resource lists - tuningJob.Spec.PyTorchReplicaSpecs["Master"].Template.Spec.Containers[0].Resources = corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2"), - corev1.ResourceMemory: resource.MustParse("8Gi"), - corev1.ResourceName(gpuLabel): resource.MustParse(fmt.Sprint(numGpus)), - }, - } - tuningJob.Spec.PyTorchReplicaSpecs["Worker"].Template.Spec.Containers[0].Resources = corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2"), - corev1.ResourceMemory: resource.MustParse("8Gi"), - corev1.ResourceName(gpuLabel): resource.MustParse(fmt.Sprint(numGpus)), - }, - } + // Update resource lists for GPU (NVIDIA/ROCm) usecase + tuningJob.Spec.PyTorchReplicaSpecs["Master"].Template.Spec.Containers[0].Resources.Limits[corev1.ResourceName(gpuLabel)] = resource.MustParse(fmt.Sprint(numGpus)) + tuningJob.Spec.PyTorchReplicaSpecs["Worker"].Template.Spec.Containers[0].Resources.Limits[corev1.ResourceName(gpuLabel)] = resource.MustParse(fmt.Sprint(numGpus)) // Update tolerations tuningJob.Spec.PyTorchReplicaSpecs["Master"].Template.Spec.Tolerations = []corev1.Toleration{ diff --git a/tests/kfto/resources/mnist.py b/tests/kfto/resources/mnist.py index 7d8d445d..91b1cbd3 100644 --- a/tests/kfto/resources/mnist.py +++ b/tests/kfto/resources/mnist.py @@ -1,3 +1,17 @@ +# Copyright 2023. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import os diff --git a/tests/kfto/resources/requirements-rocm.txt b/tests/kfto/resources/requirements-rocm.txt index 1880dc8f..6e2f7b93 100644 --- a/tests/kfto/resources/requirements-rocm.txt +++ b/tests/kfto/resources/requirements-rocm.txt @@ -1,3 +1,5 @@ --extra-index-url https://download.pytorch.org/whl/rocm6.1 torchvision==0.19.0 -tensorboard==2.18.0 \ No newline at end of file +tensorboard==2.18.0 +fsspec[http]==2024.6.1 +numpy==2.0.2 \ No newline at end of file diff --git a/tests/kfto/resources/requirements.txt b/tests/kfto/resources/requirements.txt index e3ae7b3e..9352f8b6 100644 --- a/tests/kfto/resources/requirements.txt +++ b/tests/kfto/resources/requirements.txt @@ -1,2 +1,4 @@ torchvision==0.19.0 -tensorboard==2.18.0 \ No newline at end of file +tensorboard==2.18.0 +fsspec[http]==2024.6.1 +numpy==2.0.2 \ No newline at end of file