From 654b666ee7d1b7c81711c748f9cd8c1498b7c7bb Mon Sep 17 00:00:00 2001
From: Jiaxin Shan <seedjeffwan@gmail.com>
Date: Mon, 30 Aug 2021 14:01:25 -0700
Subject: [PATCH] Revert #135: change rtype to commonv1.ReplicaType (#158)

https://github.com/kubeflow/common/pull/135 This change brings some extra side effects and make training operator dependency upgrade fail. Since this change is kind of refactor, we determine to revert it at this moment.
---
 pkg/apis/common/v1/interface.go               |  2 +-
 pkg/controller.v1/common/job.go               |  9 ++++--
 pkg/controller.v1/common/pod.go               | 28 ++++++++++---------
 pkg/controller.v1/common/service.go           | 28 +++++++++++--------
 pkg/controller.v1/common/util.go              |  4 +--
 pkg/controller.v1/common/util_test.go         |  4 +--
 pkg/controller.v1/expectation/util.go         |  9 +++---
 pkg/util/logger.go                            |  3 +-
 .../test_job/test_job_controller.go           |  2 +-
 9 files changed, 48 insertions(+), 41 deletions(-)

diff --git a/pkg/apis/common/v1/interface.go b/pkg/apis/common/v1/interface.go
index 661a96c6..255d7b6e 100644
--- a/pkg/apis/common/v1/interface.go
+++ b/pkg/apis/common/v1/interface.go
@@ -44,7 +44,7 @@ type ControllerInterface interface {
 	UpdateJobStatusInApiServer(job interface{}, jobStatus *JobStatus) error
 
 	// SetClusterSpec sets the cluster spec for the pod
-	SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype ReplicaType, index string) error
+	SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype, index string) error
 
 	// Returns the default container name in pod
 	GetDefaultContainerName() string
diff --git a/pkg/controller.v1/common/job.go b/pkg/controller.v1/common/job.go
index 606f0dc6..ad03d690 100644
--- a/pkg/controller.v1/common/job.go
+++ b/pkg/controller.v1/common/job.go
@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"reflect"
 	"sort"
+	"strings"
 	"time"
 
 	apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
@@ -326,9 +327,10 @@ func (jc *JobController) ReconcileJobs(
 // ResetExpectations reset the expectation for creates and deletes of pod/service to zero.
 func (jc *JobController) ResetExpectations(jobKey string, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec)  {
 	for rtype := range replicas {
-		expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
+		rt := strings.ToLower(string(rtype))
+		expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rt)
 		jc.Expectations.SetExpectations(expectationPodsKey, 0, 0)
-		expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype)
+		expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt)
 		jc.Expectations.SetExpectations(expectationServicesKey, 0, 0)
 	}
 }
@@ -359,7 +361,8 @@ func (jc *JobController) PastBackoffLimit(jobName string, runPolicy *apiv1.RunPo
 			continue
 		}
 		// Convert ReplicaType to lower string.
-		pods, err := jc.FilterPodsForReplicaType(pods, rtype)
+		rt := strings.ToLower(string(rtype))
+		pods, err := jc.FilterPodsForReplicaType(pods, rt)
 		if err != nil {
 			return false, err
 		}
diff --git a/pkg/controller.v1/common/pod.go b/pkg/controller.v1/common/pod.go
index 507b33f7..f5243a82 100644
--- a/pkg/controller.v1/common/pod.go
+++ b/pkg/controller.v1/common/pod.go
@@ -18,6 +18,7 @@ import (
 	"fmt"
 	"reflect"
 	"strconv"
+	"strings"
 
 	apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
 	"github.com/kubeflow/common/pkg/controller.v1/control"
@@ -103,7 +104,7 @@ func (jc *JobController) AddPod(obj interface{}) {
 		}
 
 		rtype := pod.Labels[apiv1.ReplicaTypeLabel]
-		expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype))
+		expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
 
 		jc.Expectations.CreationObserved(expectationPodsKey)
 		// TODO: we may need add backoff here
@@ -204,7 +205,7 @@ func (jc *JobController) DeletePod(obj interface{}) {
 	}
 
 	rtype := pod.Labels[apiv1.ReplicaTypeLabel]
-	expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype))
+	expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
 
 	jc.Expectations.DeletionObserved(expectationPodsKey)
 	deletedPodsCount.Inc()
@@ -253,14 +254,14 @@ func (jc *JobController) GetPodsForJob(jobObject interface{}) ([]*v1.Pod, error)
 }
 
 // FilterPodsForReplicaType returns pods belong to a replicaType.
-func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) {
+func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) {
 	var result []*v1.Pod
 
 	replicaSelector := &metav1.LabelSelector{
 		MatchLabels: make(map[string]string),
 	}
 
-	replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType)
+	replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType
 
 	for _, pod := range pods {
 		selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
@@ -339,12 +340,13 @@ func (jc *JobController) ReconcilePods(
 		utilruntime.HandleError(fmt.Errorf("couldn't get key for job object %#v: %v", job, err))
 		return err
 	}
-	expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
-
 	// Convert ReplicaType to lower string.
-	logger := commonutil.LoggerForReplica(metaObject, rtype)
+	rt := strings.ToLower(string(rtype))
+	expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rt)
+
+	logger := commonutil.LoggerForReplica(metaObject, rt)
 	// Get all pods for the type rt.
-	pods, err = jc.FilterPodsForReplicaType(pods, rtype)
+	pods, err = jc.FilterPodsForReplicaType(pods, rt)
 	if err != nil {
 		return err
 	}
@@ -362,13 +364,13 @@ func (jc *JobController) ReconcilePods(
 	podSlices := jc.GetPodSlices(pods, numReplicas, logger)
 	for index, podSlice := range podSlices {
 		if len(podSlice) > 1 {
-			logger.Warningf("We have too many pods for %s %d", rtype, index)
+			logger.Warningf("We have too many pods for %s %d", rt, index)
 		} else if len(podSlice) == 0 {
-			logger.Infof("Need to create new pod: %s-%d", rtype, index)
+			logger.Infof("Need to create new pod: %s-%d", rt, index)
 
 			// check if this replica is the master role
 			masterRole = jc.Controller.IsMasterRole(replicas, rtype, index)
-			err = jc.createNewPod(job, rtype, strconv.Itoa(index), spec, masterRole, replicas)
+			err = jc.createNewPod(job, rt, strconv.Itoa(index), spec, masterRole, replicas)
 			if err != nil {
 				return err
 			}
@@ -416,7 +418,7 @@ func (jc *JobController) ReconcilePods(
 }
 
 // createNewPod creates a new pod for the given index and type.
-func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, index string, spec *apiv1.ReplicaSpec, masterRole bool,
+func (jc *JobController) createNewPod(job interface{}, rt, index string, spec *apiv1.ReplicaSpec, masterRole bool,
 	replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error {
 
 	metaObject, ok := job.(metav1.Object)
@@ -436,7 +438,7 @@ func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, ind
 
 	// Set type and index for the worker.
 	labels := jc.GenLabels(metaObject.GetName())
-	labels[apiv1.ReplicaTypeLabel] = string(rt)
+	labels[apiv1.ReplicaTypeLabel] = rt
 	labels[apiv1.ReplicaIndexLabel] = index
 
 	if masterRole {
diff --git a/pkg/controller.v1/common/service.go b/pkg/controller.v1/common/service.go
index 9d993f34..ccf5aca5 100644
--- a/pkg/controller.v1/common/service.go
+++ b/pkg/controller.v1/common/service.go
@@ -16,6 +16,7 @@ package common
 import (
 	"fmt"
 	"strconv"
+	"strings"
 
 	apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
 	"github.com/kubeflow/common/pkg/controller.v1/control"
@@ -71,8 +72,8 @@ func (jc *JobController) AddService(obj interface{}) {
 			return
 		}
 
-		rtypeValue := service.Labels[apiv1.ReplicaTypeLabel]
-		expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, apiv1.ReplicaType(rtypeValue))
+		rtype := service.Labels[apiv1.ReplicaTypeLabel]
+		expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype)
 
 		jc.Expectations.CreationObserved(expectationServicesKey)
 		// TODO: we may need add backoff here
@@ -137,14 +138,14 @@ func (jc *JobController) GetServicesForJob(jobObject interface{}) ([]*v1.Service
 }
 
 // FilterServicesForReplicaType returns service belong to a replicaType.
-func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) {
+func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) {
 	var result []*v1.Service
 
 	replicaSelector := &metav1.LabelSelector{
 		MatchLabels: make(map[string]string),
 	}
 
-	replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType)
+	replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType
 
 	for _, service := range services {
 		selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
@@ -209,9 +210,11 @@ func (jc *JobController) ReconcileServices(
 	rtype apiv1.ReplicaType,
 	spec *apiv1.ReplicaSpec) error {
 
+	// Convert ReplicaType to lower string.
+	rt := strings.ToLower(string(rtype))
 	replicas := int(*spec.Replicas)
 	// Get all services for the type rt.
-	services, err := jc.FilterServicesForReplicaType(services, rtype)
+	services, err := jc.FilterServicesForReplicaType(services, rt)
 	if err != nil {
 		return err
 	}
@@ -222,13 +225,13 @@ func (jc *JobController) ReconcileServices(
 	// If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created.
 	//
 	// If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted.
-	serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype))
+	serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt))
 
 	for index, serviceSlice := range serviceSlices {
 		if len(serviceSlice) > 1 {
-			commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index)
+			commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rt, index)
 		} else if len(serviceSlice) == 0 {
-			commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index)
+			commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rt, index)
 			err = jc.CreateNewService(job, rtype, spec, strconv.Itoa(index))
 			if err != nil {
 				return err
@@ -279,9 +282,12 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica
 		return err
 	}
 
+	// Convert ReplicaType to lower string.
+	rt := strings.ToLower(string(rtype))
+
 	// Append ReplicaTypeLabel and ReplicaIndexLabel labels.
 	labels := jc.GenLabels(job.GetName())
-	labels[apiv1.ReplicaTypeLabel] = string(rtype)
+	labels[apiv1.ReplicaTypeLabel] = rt
 	labels[apiv1.ReplicaIndexLabel] = index
 
 	ports, err := jc.GetPortsFromJob(spec)
@@ -303,13 +309,13 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica
 		service.Spec.Ports = append(service.Spec.Ports, svcPort)
 	}
 
-	service.Name = GenGeneralName(job.GetName(), rtype, index)
+	service.Name = GenGeneralName(job.GetName(), rt, index)
 	service.Labels = labels
 	// Create OwnerReference.
 	controllerRef := jc.GenOwnerReference(job)
 
 	// Creation is expected when there is no error returned
-	expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype)
+	expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt)
 	jc.Expectations.RaiseExpectations(expectationServicesKey, 1, 0)
 
 	err = jc.ServiceControl.CreateServicesWithControllerRef(job.GetNamespace(), service, job.(runtime.Object), controllerRef)
diff --git a/pkg/controller.v1/common/util.go b/pkg/controller.v1/common/util.go
index f1800210..55a03739 100644
--- a/pkg/controller.v1/common/util.go
+++ b/pkg/controller.v1/common/util.go
@@ -44,8 +44,8 @@ func (p ReplicasPriority) Swap(i, j int) {
 	p[i], p[j] = p[j], p[i]
 }
 
-func GenGeneralName(jobName string, rtype apiv1.ReplicaType, index string) string {
-	n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + index
+func GenGeneralName(jobName string, rtype, index string) string {
+	n := jobName + "-" + strings.ToLower(rtype) + "-" + index
 	return strings.Replace(n, "/", "-", -1)
 }
 
diff --git a/pkg/controller.v1/common/util_test.go b/pkg/controller.v1/common/util_test.go
index 3b87373d..168b9244 100644
--- a/pkg/controller.v1/common/util_test.go
+++ b/pkg/controller.v1/common/util_test.go
@@ -18,15 +18,13 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
-
-	apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
 )
 
 func TestGenGeneralName(t *testing.T) {
 	tcs := []struct {
 		index        string
 		key          string
-		replicaType  apiv1.ReplicaType
+		replicaType  string
 		expectedName string
 	}{
 		{
diff --git a/pkg/controller.v1/expectation/util.go b/pkg/controller.v1/expectation/util.go
index 9061000d..4c57c40f 100644
--- a/pkg/controller.v1/expectation/util.go
+++ b/pkg/controller.v1/expectation/util.go
@@ -1,16 +1,15 @@
 package expectation
 
 import (
-	apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
 	"strings"
 )
 
 // GenExpectationPodsKey generates an expectation key for pods of a job
-func GenExpectationPodsKey(jobKey string, replicaType apiv1.ReplicaType) string {
-	return jobKey + "/" + strings.ToLower(string(replicaType)) + "/pods"
+func GenExpectationPodsKey(jobKey, replicaType string) string {
+	return jobKey + "/" + strings.ToLower(replicaType) + "/pods"
 }
 
 // GenExpectationPodsKey generates an expectation key for services of a job
-func GenExpectationServicesKey(jobKey string, replicaType apiv1.ReplicaType) string {
-	return jobKey + "/" + strings.ToLower(string(replicaType)) + "/services"
+func GenExpectationServicesKey(jobKey, replicaType string) string {
+	return jobKey + "/" + strings.ToLower(replicaType) + "/services"
 }
diff --git a/pkg/util/logger.go b/pkg/util/logger.go
index a9719fce..d3ce95e6 100644
--- a/pkg/util/logger.go
+++ b/pkg/util/logger.go
@@ -15,7 +15,6 @@
 package util
 
 import (
-	apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
 	"strings"
 
 	log "github.com/sirupsen/logrus"
@@ -24,7 +23,7 @@ import (
 	metav1unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
 )
 
-func LoggerForReplica(job metav1.Object, rtype apiv1.ReplicaType) *log.Entry {
+func LoggerForReplica(job metav1.Object, rtype string) *log.Entry {
 	return log.WithFields(log.Fields{
 		// We use job to match the key used in controller.go
 		// Its more common in K8s to use a period to indicate namespace.name. So that's what we use.
diff --git a/test_job/controller.v1/test_job/test_job_controller.go b/test_job/controller.v1/test_job/test_job_controller.go
index 80d37cd9..2e6dd210 100644
--- a/test_job/controller.v1/test_job/test_job_controller.go
+++ b/test_job/controller.v1/test_job/test_job_controller.go
@@ -65,7 +65,7 @@ func (t *TestJobController) UpdateJobStatusInApiServer(job interface{}, jobStatu
 	return nil
 }
 
-func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype commonv1.ReplicaType, index string) error {
+func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
 	return nil
 }