Skip to content

Commit 55ac657

Browse files
committed
KEP-2170: Implement TrainJob Reconciler to manage objects
Signed-off-by: Yuki Iwai <[email protected]>
1 parent 0149eb0 commit 55ac657

File tree

13 files changed

+401
-73
lines changed

13 files changed

+401
-73
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ bin/
44
/tf-operator
55
vendor/
66
testbin/*
7+
dep-crds/
78
cover.out
89

910
# IDEs

Makefile

+19-3
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,16 @@ HAS_SETUP_ENVTEST := $(shell command -v setup-envtest;)
7474
testall: manifests generate fmt vet golangci-lint test ## Run tests.
7575

7676
test: envtest
77-
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./... -coverprofile cover.out
77+
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" \
78+
go test ./pkg/apis/kubeflow.org/v1/... ./pkg/cert/... ./pkg/common/... ./pkg/config/... ./pkg/controller.v1/... ./pkg/core/... ./pkg/util/... ./pkg/webhooks/... -coverprofile cover.out
7879

7980
.PHONY: test-integrationv2
80-
test-integrationv2: envtest
81+
test-integrationv2: envtest jobset-operator-crd scheduler-plugins-crd
8182
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./test/... -coverprofile cover.out
8283

8384
.PHONY: testv2
8485
testv2:
85-
go test ./pkg/controller.v2/... ./pkg/runtime.v2/... ./pkg/webhook.v2/... ./pkg/util.v2/... -coverprofile cover.out
86+
go test ./pkg/apis/kubeflow.org/v2alpha1/... ./pkg/controller.v2/... ./pkg/runtime.v2/... ./pkg/webhook.v2/... ./pkg/util.v2/... -coverprofile cover.out
8687

8788
envtest:
8889
ifndef HAS_SETUP_ENVTEST
@@ -127,3 +128,18 @@ controller-gen: ## Download controller-gen locally if necessary.
127128
KUSTOMIZE = $(shell pwd)/bin/kustomize
128129
kustomize: ## Download kustomize locally if necessary.
129130
GOBIN=$(PROJECT_DIR)/bin go install sigs.k8s.io/kustomize/kustomize/[email protected]
131+
132+
## Download external CRDs for the integration testings.
133+
EXTERNAL_CRDS_DIR ?= $(PROJECT_DIR)/dep-crds
134+
135+
JOBSET_ROOT = $(shell go list -m -mod=readonly -f "{{.Dir}}" sigs.k8s.io/jobset)
136+
.PHONY: jobset-operator-crd
137+
jobset-operator-crd: ## Copy the CRDs from the jobset-operator to the dep-crds directory.
138+
mkdir -p $(EXTERNAL_CRDS_DIR)/jobset-operator/
139+
cp -f $(JOBSET_ROOT)/config/components/crd/bases/* $(EXTERNAL_CRDS_DIR)/jobset-operator/
140+
141+
SCHEDULER_PLUGINS_ROOT = $(shell go list -m -f "{{.Dir}}" sigs.k8s.io/scheduler-plugins)
142+
.PHONY: scheduler-plugins-crd
143+
scheduler-plugins-crd:
144+
mkdir -p $(EXTERNAL_CRDS_DIR)/scheduler-plugins/
145+
cp -f $(SCHEDULER_PLUGINS_ROOT)/manifests/coscheduling/* $(PROJECT_DIR)/dep-crds/scheduler-plugins

pkg/controller.v2/setup.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ func SetupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (st
2626
if err := NewTrainJobReconciler(
2727
mgr.GetClient(),
2828
mgr.GetEventRecorderFor("training-operator-trainjob-controller"),
29-
).SetupWithManager(mgr, runtimes); err != nil {
29+
runtimes,
30+
).SetupWithManager(mgr); err != nil {
3031
return "TrainJob", err
3132
}
3233
return "", nil

pkg/controller.v2/trainjob_controller.go

+64-4
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ import (
2020
"context"
2121

2222
"github.com/go-logr/logr"
23+
"k8s.io/apimachinery/pkg/runtime/schema"
2324
"k8s.io/client-go/tools/record"
2425
"k8s.io/klog/v2"
26+
"k8s.io/utils/ptr"
2527
ctrl "sigs.k8s.io/controller-runtime"
2628
"sigs.k8s.io/controller-runtime/pkg/client"
29+
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
2730

2831
kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
2932
runtime "github.com/kubeflow/training-operator/pkg/runtime.v2"
@@ -33,13 +36,15 @@ type TrainJobReconciler struct {
3336
log logr.Logger
3437
client client.Client
3538
recorder record.EventRecorder
39+
runtimes map[string]runtime.Runtime
3640
}
3741

38-
func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder) *TrainJobReconciler {
42+
func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder, runs map[string]runtime.Runtime) *TrainJobReconciler {
3943
return &TrainJobReconciler{
4044
log: ctrl.Log.WithName("trainjob-controller"),
4145
client: client,
4246
recorder: recorder,
47+
runtimes: runs,
4348
}
4449
}
4550

@@ -49,15 +54,70 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
4954
return ctrl.Result{}, client.IgnoreNotFound(err)
5055
}
5156
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
52-
ctrl.LoggerInto(ctx, log)
57+
ctx = ctrl.LoggerInto(ctx, log)
5358
log.V(2).Info("Reconciling TrainJob")
59+
if err := r.createOrUpdateObjs(ctx, &trainJob); err != nil {
60+
return ctrl.Result{}, err
61+
}
62+
// TODO (tenzen-y): Do update the status.
5463
return ctrl.Result{}, nil
5564
}
5665

57-
func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) error {
66+
func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
67+
log := ctrl.LoggerFrom(ctx)
68+
69+
// Controller assumes the runtime existence has already verified in the webhook on TrainJob creation.
70+
run := r.runtimes[runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()]
71+
objs, err := run.NewObjects(ctx, trainJob)
72+
if err != nil {
73+
return err
74+
}
75+
for _, obj := range objs {
76+
var gvk schema.GroupVersionKind
77+
if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil {
78+
return err
79+
}
80+
logKeysAndValues := []any{
81+
"groupVersionKind", gvk.String(),
82+
"namespace", obj.GetNamespace(),
83+
"name", obj.GetName(),
84+
}
85+
// TODO (tenzen-y): Ideally, we should use the SSA instead of checking existence.
86+
// Non-empty resourceVersion indicates UPDATE operation.
87+
var creationErr error
88+
var created bool
89+
if obj.GetResourceVersion() == "" {
90+
creationErr = r.client.Create(ctx, obj)
91+
created = creationErr == nil
92+
}
93+
switch {
94+
case created:
95+
log.V(5).Info("Succeeded to create object", logKeysAndValues)
96+
continue
97+
case client.IgnoreAlreadyExists(creationErr) != nil:
98+
return creationErr
99+
default:
100+
// This indicates CREATE operation has not been performed or the object has already existed in the cluster.
101+
if err = r.client.Update(ctx, obj); err != nil {
102+
return err
103+
}
104+
log.V(5).Info("Succeeded to update object", logKeysAndValues)
105+
}
106+
}
107+
return nil
108+
}
109+
110+
func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
111+
return schema.GroupKind{
112+
Group: ptr.Deref(runtimeRef.APIGroup, ""),
113+
Kind: ptr.Deref(runtimeRef.Kind, ""),
114+
}
115+
}
116+
117+
func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
58118
b := ctrl.NewControllerManagedBy(mgr).
59119
For(&kubeflowv2.TrainJob{})
60-
for _, run := range runtimes {
120+
for _, run := range r.runtimes {
61121
for _, registrar := range run.EventHandlerRegistrars() {
62122
if registrar != nil {
63123
b = registrar(b, mgr.GetClient())

pkg/runtime.v2/core/clustertrainingruntime_test.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
4646
}{
4747
"succeeded to build JobSet and PodGroup": {
4848
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
49+
Suspend(true).
4950
UID("uid").
5051
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
5152
Trainer(
@@ -57,7 +58,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
5758
clusterTrainingRuntime: baseRuntime.RuntimeSpec(
5859
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
5960
ContainerImage("test:runtime").
60-
PodGroupPolicySchedulingTimeout(120).
61+
PodGroupPolicyCoschedulingSchedulingTimeout(120).
6162
MLPolicyNumNodes(20).
6263
ResourceRequests(0, corev1.ResourceList{
6364
corev1.ResourceCPU: resource.MustParse("1"),
@@ -69,6 +70,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
6970
).Obj(),
7071
wantObjs: []client.Object{
7172
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
73+
Suspend(true).
7274
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
7375
ContainerImage(ptr.To("test:trainjob")).
7476
JobCompletionMode(batchv1.IndexedCompletion).

pkg/runtime.v2/core/trainingruntime_test.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
4646
}{
4747
"succeeded to build JobSet and PodGroup": {
4848
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
49+
Suspend(true).
4950
UID("uid").
5051
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
5152
SpecLabel("conflictLabel", "override").
@@ -62,7 +63,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
6263
RuntimeSpec(
6364
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
6465
ContainerImage("test:runtime").
65-
PodGroupPolicySchedulingTimeout(120).
66+
PodGroupPolicyCoschedulingSchedulingTimeout(120).
6667
MLPolicyNumNodes(20).
6768
ResourceRequests(0, corev1.ResourceList{
6869
corev1.ResourceCPU: resource.MustParse("1"),
@@ -74,6 +75,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
7475
).Obj(),
7576
wantObjs: []client.Object{
7677
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
78+
Suspend(true).
7779
Label("conflictLabel", "override").
7880
Annotation("conflictAnnotation", "override").
7981
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").

pkg/runtime.v2/framework/core/framework_test.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
334334
ResourceRequests(1, corev1.ResourceList{
335335
corev1.ResourceCPU: resource.MustParse("1"),
336336
corev1.ResourceMemory: resource.MustParse("2Gi"),
337-
}).
338-
Clone()
337+
})
339338
jobSetWithPropagatedTrainJobParams := jobSetBase.
339+
Clone().
340340
JobCompletionMode(batchv1.IndexedCompletion).
341341
ContainerImage(ptr.To("foo:bar")).
342-
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
343-
Clone()
342+
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid")
344343

345344
cases := map[string]struct {
346345
runtimeInfo *runtime.Info
@@ -361,6 +360,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
361360
Obj(),
362361
runtimeInfo: &runtime.Info{
363362
Obj: jobSetBase.
363+
Clone().
364364
Obj(),
365365
Policy: runtime.Policy{
366366
MLPolicy: &kubeflowv2.MLPolicy{
@@ -403,10 +403,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
403403
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
404404
Obj(),
405405
jobSetWithPropagatedTrainJobParams.
406+
Clone().
406407
Obj(),
407408
},
408409
wantRuntimeInfo: &runtime.Info{
409410
Obj: jobSetWithPropagatedTrainJobParams.
411+
Clone().
410412
Obj(),
411413
Policy: runtime.Policy{
412414
MLPolicy: &kubeflowv2.MLPolicy{

pkg/runtime.v2/framework/plugins/jobset/builder.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ import (
2828
)
2929

3030
type Builder struct {
31-
*jobsetv1alpha2.JobSet
31+
jobsetv1alpha2.JobSet
3232
}
3333

3434
func NewBuilder(objectKey client.ObjectKey, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec) *Builder {
3535
return &Builder{
36-
JobSet: &jobsetv1alpha2.JobSet{
36+
JobSet: jobsetv1alpha2.JobSet{
3737
TypeMeta: metav1.TypeMeta{
3838
APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(),
3939
Kind: "JobSet",
@@ -76,8 +76,13 @@ func (b *Builder) PodLabels(labels map[string]string) *Builder {
7676
return b
7777
}
7878

79+
func (b *Builder) Suspend(suspend *bool) *Builder {
80+
b.Spec.Suspend = suspend
81+
return b
82+
}
83+
7984
// TODO: Need to support all TrainJob fields.
8085

8186
func (b *Builder) Build() *jobsetv1alpha2.JobSet {
82-
return b.JobSet
87+
return &b.JobSet
8388
}

pkg/runtime.v2/framework/plugins/jobset/jobset.go

+29-16
Original file line numberDiff line numberDiff line change
@@ -74,29 +74,37 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl
7474
if !ok {
7575
return nil, nil
7676
}
77-
jobSetBuilder := NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{
78-
ObjectMeta: metav1.ObjectMeta{
79-
Labels: info.Labels,
80-
Annotations: info.Annotations,
81-
},
82-
Spec: raw.Spec,
83-
})
77+
78+
var jobSetBuilder *Builder
79+
oldJobSet := &jobsetv1alpha2.JobSet{}
80+
if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), oldJobSet); err != nil {
81+
if !apierrors.IsNotFound(err) {
82+
return nil, err
83+
}
84+
jobSetBuilder = NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{
85+
ObjectMeta: metav1.ObjectMeta{
86+
Labels: info.Labels,
87+
Annotations: info.Annotations,
88+
},
89+
Spec: raw.Spec,
90+
})
91+
oldJobSet = nil
92+
} else {
93+
jobSetBuilder = &Builder{
94+
JobSet: *oldJobSet.DeepCopy(),
95+
}
96+
}
97+
8498
// TODO (tenzen-y): We should support all field propagation in builder.
8599
jobSet := jobSetBuilder.
100+
Suspend(trainJob.Spec.Suspend).
86101
ContainerImage(trainJob.Spec.Trainer.Image).
87102
JobCompletionMode(batchv1.IndexedCompletion).
88103
PodLabels(info.PodLabels).
89104
Build()
90105
if err := ctrlutil.SetControllerReference(trainJob, jobSet, j.scheme); err != nil {
91106
return nil, err
92107
}
93-
oldJobSet := &jobsetv1alpha2.JobSet{}
94-
if err := j.client.Get(ctx, client.ObjectKeyFromObject(jobSet), oldJobSet); err != nil {
95-
if !apierrors.IsNotFound(err) {
96-
return nil, err
97-
}
98-
oldJobSet = nil
99-
}
100108
if err := info.Update(jobSet); err != nil {
101109
return nil, err
102110
}
@@ -106,9 +114,14 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl
106114
return nil, nil
107115
}
108116

109-
func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, suspended bool) bool {
117+
func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, trainJobIsSuspended bool) bool {
110118
return old == nil ||
111-
suspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations))
119+
(!trainJobIsSuspended && !ptr.Equal(old.Spec.Suspend, new.Spec.Suspend)) ||
120+
(trainJobIsSuspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)))
121+
}
122+
123+
func jobSetIsSuspended(jobSet *jobsetv1alpha2.JobSet) bool {
124+
return ptr.Deref(jobSet.Spec.Suspend, false)
112125
}
113126

114127
func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder {

0 commit comments

Comments
 (0)