Skip to content

Commit

Permalink
Validate pytorchjob workers are configured when elasticpolicy is conf…
Browse files Browse the repository at this point in the history
…igured (#2320)

* Update pytorchjob webhook to validate elastic policy without worker

Signed-off-by: tarat44 <[email protected]>
Co-authored-by: ricardov1 <[email protected]>
Co-authored-by: alenawang <[email protected]>

* Update pkg/webhooks/pytorch/pytorchjob_webhook.go

Co-authored-by: Andrey Velichkevich <[email protected]>
Signed-off-by: Tara Tufano <[email protected]>

* Make replica minimum separate validation, fix field path

Signed-off-by: tarat44 <[email protected]>
Co-authored-by: ricardov1 <[email protected]>

---------

Signed-off-by: tarat44 <[email protected]>
Signed-off-by: Tara Tufano <[email protected]>
Co-authored-by: ricardov1 <[email protected]>
Co-authored-by: alenawang <[email protected]>
Co-authored-by: Andrey Velichkevich <[email protected]>
  • Loading branch information
4 people authored Dec 5, 2024
1 parent 265d4c7 commit 2392c36
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 7 deletions.
22 changes: 15 additions & 7 deletions pkg/webhooks/pytorch/pytorchjob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,19 @@ func validatePyTorchJob(oldJob, newJob *trainingoperator.PyTorchJob) (admission.
func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
var warnings admission.Warnings

if spec.ElasticPolicy != nil && spec.ElasticPolicy.NProcPerNode != nil {
elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode")
nprocPerNodePath := specPath.Child("nprocPerNode")
warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String()))
if spec.NprocPerNode != nil {
allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath)))
if spec.ElasticPolicy != nil {
_, ok := spec.PyTorchReplicaSpecs[trainingoperator.PyTorchJobReplicaTypeWorker]
workerPath := pytorchReplicaSpecPath.Key(string(trainingoperator.PyTorchJobReplicaTypeWorker))
if !ok {
allErrs = append(allErrs, field.Required(workerPath, "must be configured if elastic policy is used"))
}
if spec.ElasticPolicy.NProcPerNode != nil {
elasticNProcPerNodePath := specPath.Child("elasticPolicy").Child("nProcPerNode")
nprocPerNodePath := specPath.Child("nprocPerNode")
warnings = append(warnings, fmt.Sprintf("%s is deprecated, use %s instead", elasticNProcPerNodePath.String(), nprocPerNodePath.String()))
if spec.NprocPerNode != nil {
allErrs = append(allErrs, field.Forbidden(elasticNProcPerNodePath, fmt.Sprintf("must not be used with %s", nprocPerNodePath)))
}
}
}
allErrs = append(allErrs, validatePyTorchReplicaSpecs(spec.PyTorchReplicaSpecs)...)
Expand Down Expand Up @@ -147,6 +153,8 @@ func validatePyTorchReplicaSpecs(rSpecs map[trainingoperator.ReplicaType]*traini
if rSpec.Replicas == nil || int(*rSpec.Replicas) != 1 {
allErrs = append(allErrs, field.Forbidden(rolePath.Child("replicas"), "must be 1"))
}
} else if rSpec.Replicas != nil && int(*rSpec.Replicas) < 1 {
allErrs = append(allErrs, field.Forbidden(rolePath.Child("replicas"), "must be at least 1"))
}
}
return allErrs
Expand Down
85 changes: 85 additions & 0 deletions pkg/webhooks/pytorch/pytorchjob_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ func TestValidateV1PyTorchJob(t *testing.T) {
RunPolicy: trainingoperator.RunPolicy{
ManagedBy: ptr.To(trainingoperator.KubeflowJobsController),
},
ElasticPolicy: &trainingoperator.ElasticPolicy{
RDZVBackend: ptr.To(trainingoperator.BackendC10D),
},
PyTorchReplicaSpecs: validPyTorchReplicaSpecs,
},
},
Expand Down Expand Up @@ -247,6 +250,19 @@ func TestValidateV1PyTorchJob(t *testing.T) {
},
},
},
trainingoperator.PyTorchJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -279,6 +295,19 @@ func TestValidateV1PyTorchJob(t *testing.T) {
},
},
},
trainingoperator.PyTorchJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -335,6 +364,62 @@ func TestValidateV1PyTorchJob(t *testing.T) {
field.Invalid(field.NewPath("spec", "runPolicy", "managedBy"), trainingoperator.MultiKueueController, apivalidation.FieldImmutableErrorMsg),
},
},
"attempt to configure elasticPolicy when no worker is configured": {
pytorchJob: &trainingoperator.PyTorchJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.PyTorchJobSpec{
ElasticPolicy: &trainingoperator.ElasticPolicy{},
PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.PyTorchJobReplicaTypeMaster: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
wantErr: field.ErrorList{
field.Required(pytorchReplicaSpecPath.Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)), ""),
},
},
"attempt to configure worker with 0 replicas": {
pytorchJob: &trainingoperator.PyTorchJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.PyTorchJobSpec{
ElasticPolicy: &trainingoperator.ElasticPolicy{},
PyTorchReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.PyTorchJobReplicaTypeWorker: {
Replicas: ptr.To[int32](0),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "pytorch",
Image: "gcr.io/kubeflow-ci/pytorch-dist-mnist_test:1.0",
},
},
},
},
},
},
},
},
wantErr: field.ErrorList{
field.Forbidden(pytorchReplicaSpecPath.Key(string(trainingoperator.PyTorchJobReplicaTypeWorker)).Child("replicas"), ""),
},
},
}

for name, tc := range testCases {
Expand Down

0 comments on commit 2392c36

Please sign in to comment.