Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Implement Spark pod template tolerations #411

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,5 @@ require (
)

replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d

replace github.com/flyteorg/flyteidl => /Users/andrew/dev/forks/flyteidl
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,6 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/flyteorg/flyteidl v1.5.13 h1:IQ2Cw+u36ew3BPyRDAcHdzc/GyNEOXOxhKy9jbS4hbo=
github.com/flyteorg/flyteidl v1.5.13/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og=
github.com/flyteorg/flytestdlib v1.0.24 h1:jDvymcjlsTRCwOtxPapro0WZBe3isTz+T3Tiq+mZUuk=
github.com/flyteorg/flytestdlib v1.0.24/go.mod h1:6nXa5g00qFIsgdvQ7jKQMJmDniqO0hG6Z5X5olfduqQ=
github.com/flyteorg/stow v0.3.7 h1:Cx7j8/Ux6+toD5hp5fy++927V+yAcAttDeQAlUD/864=
Expand Down
37 changes: 33 additions & 4 deletions go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"google.golang.org/protobuf/types/known/structpb"
"sigs.k8s.io/controller-runtime/pkg/client"

"strconv"
Expand All @@ -20,14 +21,17 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"

v1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes/scheme"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"k8s.io/client-go/kubernetes/scheme"

sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"regexp"
"strings"
Expand Down Expand Up @@ -61,6 +65,21 @@ func (sparkResourceHandler) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

func getTolerations(podSpecPb *structpb.Struct) ([]v1.Toleration, error) {
tolerations := make([]v1.Toleration, 0)
tolerations = append(tolerations, config.GetK8sPluginConfig().DefaultTolerations...)
if podSpecPb != nil {
var podSpec v1.PodSpec
err := utils.UnmarshalStruct(podSpecPb, &podSpec)
if err != nil {
return nil, errors.Wrapf(errors.BadTaskSpecification, err,
"invalid pod spec [%v], failed to unmarshal", podSpec)
}
tolerations = append(tolerations, podSpec.Tolerations...)
}
return tolerations, nil
}

// Creates a new Job that will execute the main container as well as any generated types the result from the execution.
func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
Expand Down Expand Up @@ -99,6 +118,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
if len(serviceAccountName) == 0 {
serviceAccountName = sparkTaskType
}

tolerations, err := getTolerations(sparkJob.GetDriverPod().GetPodSpec())
if err != nil {
return nil, err
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
Expand All @@ -108,14 +132,19 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
Image: &container.Image,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
Tolerations: tolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
ServiceAccount: &serviceAccountName,
}

tolerations, err = getTolerations(sparkJob.GetExecutorPod().GetPodSpec())
if err != nil {
return nil, err
}

executorSpec := sparkOp.ExecutorSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity.DeepCopy(),
Expand All @@ -125,7 +154,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
EnvVars: sparkEnvVars,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
Tolerations: tolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
Expand Down
92 changes: 82 additions & 10 deletions go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ import (
pluginIOMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"

sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
)

const sparkMainClass = "MainClass"
Expand Down Expand Up @@ -87,7 +88,8 @@ func TestGetEventInfo(t *testing.T) {
},
},
}))
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("blah-1", dummySparkConf), false)
sparkJob := dummySparkCustomObj(dummySparkConf)
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("blah-1", sparkJob), false)
info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState))
assert.NoError(t, err)
assert.Len(t, info.Logs, 6)
Expand Down Expand Up @@ -157,7 +159,8 @@ func TestGetTaskPhase(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

ctx := context.TODO()
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("", dummySparkConf), false)
sparkJob := dummySparkCustomObj(dummySparkConf)
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("", sparkJob), false)
taskPhase, err := sparkResourceHandler.GetTaskPhase(ctx, taskCtx, dummySparkApplication(sj.NewState))
assert.NoError(t, err)
assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseQueued)
Expand Down Expand Up @@ -242,17 +245,14 @@ func dummySparkApplication(state sj.ApplicationStateType) *sj.SparkApplication {

func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob {
sparkJob := plugins.SparkJob{}

sparkJob.MainClass = sparkMainClass
sparkJob.MainApplicationFile = sparkApplicationFile
sparkJob.SparkConf = sparkConf
sparkJob.ApplicationType = plugins.SparkApplication_PYTHON
return &sparkJob
}

func dummySparkTaskTemplate(id string, sparkConf map[string]string) *core.TaskTemplate {

sparkJob := dummySparkCustomObj(sparkConf)
func dummySparkTaskTemplate(id string, sparkJob *plugins.SparkJob) *core.TaskTemplate {
sparkJobJSON, err := utils.MarshalToString(sparkJob)
if err != nil {
panic(err)
Expand Down Expand Up @@ -335,7 +335,8 @@ func TestBuildResourceSpark(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

// Case1: Valid Spark Task-Template
taskTemplate := dummySparkTaskTemplate("blah-1", dummySparkConf)
sparkJob := dummySparkCustomObj(dummySparkConf)
taskTemplate := dummySparkTaskTemplate("blah-1", sparkJob)

// Set spark custom feature config.
assert.NoError(t, setSparkConfig(&Config{
Expand Down Expand Up @@ -619,7 +620,8 @@ func TestBuildResourceSpark(t *testing.T) {
dummyConfWithRequest["spark.kubernetes.driver.request.cores"] = "3"
dummyConfWithRequest["spark.kubernetes.executor.request.cores"] = "4"

taskTemplate = dummySparkTaskTemplate("blah-1", dummyConfWithRequest)
sparkJob = dummySparkCustomObj(dummyConfWithRequest)
taskTemplate = dummySparkTaskTemplate("blah-1", sparkJob)
resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.Nil(t, err)
assert.NotNil(t, resource)
Expand Down Expand Up @@ -678,6 +680,76 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Nil(t, resource)
}

func TestBuildResourcePodTemplate(t *testing.T) {
defaultToleration := corev1.Toleration{

Key: "x/flyte",
Value: "default",
Operator: "Equal",
}
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultTolerations: []corev1.Toleration{
defaultToleration,
},
})
assert.NoError(t, err)
sparkJob := dummySparkCustomObj(dummySparkConf)
extraDriverToleration := corev1.Toleration{
Key: "x/flyte",
Value: "extra-driver",
Operator: "Equal",
}
podSpec := corev1.PodSpec{
Tolerations: []corev1.Toleration{
extraDriverToleration,
},
}
driverPodSpecPb := structpb.Struct{}
err = utils.MarshalStruct(&podSpec, &driverPodSpecPb)
assert.NoError(t, err)
sparkJob.DriverPodValue = &plugins.SparkJob_DriverPod{
DriverPod: &core.K8SPod{
PodSpec: &driverPodSpecPb,
},
}
extraExecutorToleration := corev1.Toleration{
Key: "x/flyte",
Value: "extra-executor",
Operator: "Equal",
}
podSpec = corev1.PodSpec{
Tolerations: []corev1.Toleration{
extraExecutorToleration,
},
}
execPodSpecPb := structpb.Struct{}
err = utils.MarshalStruct(&podSpec, &execPodSpecPb)
assert.NoError(t, err)
sparkJob.ExecutorPodValue = &plugins.SparkJob_ExecutorPod{
ExecutorPod: &core.K8SPod{
PodSpec: &execPodSpecPb,
},
}
taskTemplate := dummySparkTaskTemplate("blah-1", sparkJob)
sparkResourceHandler := sparkResourceHandler{}
resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.Nil(t, err)

assert.NotNil(t, resource)
sparkApp, ok := resource.(*sj.SparkApplication)
assert.True(t, ok)
assert.Equal(t, 2, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, sparkApp.Spec.Driver.Tolerations, []corev1.Toleration{
defaultToleration,
extraDriverToleration,
})
assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, sparkApp.Spec.Executor.Tolerations, []corev1.Toleration{
defaultToleration,
extraExecutorToleration,
})
}

func TestGetPropertiesSpark(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}
expected := k8s.PluginProperties{}
Expand Down
Loading