diff --git a/go/tasks/config/config.go b/go/tasks/config/config.go index 0cb5c6b92..01cfa0b37 100755 --- a/go/tasks/config/config.go +++ b/go/tasks/config/config.go @@ -12,11 +12,11 @@ var ( rootSection = config.MustRegisterSection(configSectionKey, &Config{}) ) -// Top level plugins config. +// Config is the top level plugins config. type Config struct { } -// Retrieves the current config value or default. +// GetConfig retrieves the current config value or default. func GetConfig() *Config { return rootSection.GetConfig().(*Config) } @@ -24,3 +24,7 @@ func GetConfig() *Config { func MustRegisterSubSection(subSectionKey string, section config.Config) config.Section { return rootSection.MustRegisterSection(subSectionKey, section) } + +func MustRegisterSubSectionWithUpdates(subSectionKey string, section config.Config, sectionUpdatedFn config.SectionUpdated) config.Section { + return rootSection.MustRegisterSectionWithUpdates(subSectionKey, section, sectionUpdatedFn) +} diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index 51d3e4e81..93fd3067d 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -184,7 +184,7 @@ func phaseInfo(p Phase, v uint32, err *core.ExecutionError, info *TaskInfo, clea } } -// Return in the case the plugin is not ready to start +// PhaseInfoNotReady represents the case the plugin is not ready to start func PhaseInfoNotReady(t time.Time, version uint32, reason string) PhaseInfo { pi := phaseInfo(PhaseNotReady, version, nil, &TaskInfo{OccurredAt: &t}, false) pi.reason = reason @@ -198,7 +198,7 @@ func PhaseInfoWaitingForResources(t time.Time, version uint32, reason string) Ph return pi } -// Return in the case the plugin is not ready to start +// PhaseInfoWaitingForResourcesInfo represents the case the plugin is not ready to start func PhaseInfoWaitingForResourcesInfo(t time.Time, version uint32, reason string, info *TaskInfo) PhaseInfo { pi := phaseInfo(PhaseWaitingForResources, version, nil, info, false) pi.reason = reason diff --git a/go/tasks/plugins/k8s/ray/config.go b/go/tasks/plugins/k8s/ray/config.go index 8b699cd79..2c30621bf 100644 --- a/go/tasks/plugins/k8s/ray/config.go +++ b/go/tasks/plugins/k8s/ray/config.go @@ -1,8 +1,12 @@ package ray import ( + "context" + pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" + "github.com/flyteorg/flyteplugins/go/tasks/logs" pluginmachinery "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flytestdlib/config" ) //go:generate pflags Config --default-var=defaultConfig @@ -14,10 +18,39 @@ var ( ServiceType: "NodePort", IncludeDashboard: true, DashboardHost: "0.0.0.0", - NodeIPAddress: "$MY_POD_IP", + EnableUsageStats: false, + Defaults: DefaultConfig{ + HeadNode: NodeConfig{ + StartParameters: map[string]string{ + // Disable usage reporting by default: https://docs.ray.io/en/latest/cluster/usage-stats.html + DisableUsageStatsStartParameter: "true", + }, + IPAddress: "$MY_POD_IP", + }, + WorkerNode: NodeConfig{ + StartParameters: map[string]string{ + // Disable usage reporting by default: https://docs.ray.io/en/latest/cluster/usage-stats.html + DisableUsageStatsStartParameter: "true", + }, + IPAddress: "$MY_POD_IP", + }, + }, } - configSection = pluginsConfig.MustRegisterSubSection("ray", &defaultConfig) + configSection = pluginsConfig.MustRegisterSubSectionWithUpdates("ray", &defaultConfig, + func(ctx context.Context, newValue config.Config) { + if newValue == nil { + return + } + + if len(newValue.(*Config).Defaults.HeadNode.IPAddress) == 0 { + newValue.(*Config).Defaults.HeadNode.IPAddress = newValue.(*Config).DeprecatedNodeIPAddress + } + + if len(newValue.(*Config).Defaults.WorkerNode.IPAddress) == 0 { + newValue.(*Config).Defaults.WorkerNode.IPAddress = newValue.(*Config).DeprecatedNodeIPAddress + } + }) ) // Config is config for 'ray' plugin @@ -39,11 +72,24 @@ type Config struct { // or 0.0.0.0 (available from all interfaces). By default, this is localhost. DashboardHost string `json:"dashboardHost,omitempty"` - // NodeIPAddress the IP address of the head node. By default, this is pod ip address. - NodeIPAddress string `json:"nodeIPAddress,omitempty"` + // DeprecatedNodeIPAddress the IP address of the head node. By default, this is pod ip address. + DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"` // Remote Ray Cluster Config RemoteClusterConfig pluginmachinery.ClusterConfig `json:"remoteClusterConfig" pflag:"Configuration of remote K8s cluster for ray jobs"` + Logs logs.LogConfig `json:"logs" pflag:"-,Log configuration for ray jobs"` + Defaults DefaultConfig `json:"defaults" pflag:"-,Default configuration for ray jobs"` + EnableUsageStats bool `json:"enableUsageStats" pflag:",Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html"` +} + +type DefaultConfig struct { + HeadNode NodeConfig `json:"headNode,omitempty" pflag:"-,Default configuration for head node of ray jobs"` + WorkerNode NodeConfig `json:"workerNode,omitempty" pflag:"-,Default configuration for worker node of ray jobs"` +} + +type NodeConfig struct { + StartParameters map[string]string `json:"startParameters,omitempty" pflag:"-,Start parameters for the node"` + IPAddress string `json:"ipAddress,omitempty" pflag:"-,IP address of the node"` } func GetConfig() *Config { diff --git a/go/tasks/plugins/k8s/ray/config_flags.go b/go/tasks/plugins/k8s/ray/config_flags.go index f8e983056..8113a2627 100755 --- a/go/tasks/plugins/k8s/ray/config_flags.go +++ b/go/tasks/plugins/k8s/ray/config_flags.go @@ -55,9 +55,9 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "serviceType"), defaultConfig.ServiceType, "") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "includeDashboard"), defaultConfig.IncludeDashboard, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dashboardHost"), defaultConfig.DashboardHost, "") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "nodeIPAddress"), defaultConfig.NodeIPAddress, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.name"), defaultConfig.RemoteClusterConfig.Name, "Friendly name of the remote cluster") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.endpoint"), defaultConfig.RemoteClusterConfig.Endpoint, " Remote K8s cluster endpoint") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.enabled"), defaultConfig.RemoteClusterConfig.Enabled, " Boolean flag to enable or disable") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "enableUsageStats"), defaultConfig.EnableUsageStats, "Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html") return cmdFlags } diff --git a/go/tasks/plugins/k8s/ray/config_flags_test.go b/go/tasks/plugins/k8s/ray/config_flags_test.go index 60761b900..f05c62c8e 100755 --- a/go/tasks/plugins/k8s/ray/config_flags_test.go +++ b/go/tasks/plugins/k8s/ray/config_flags_test.go @@ -169,20 +169,6 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_nodeIPAddress", func(t *testing.T) { - - t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("nodeIPAddress", testValue) - if vString, err := cmdFlags.GetString("nodeIPAddress"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.NodeIPAddress) - - } else { - assert.FailNow(t, err.Error()) - } - }) - }) t.Run("Test_remoteClusterConfig.name", func(t *testing.T) { t.Run("Override", func(t *testing.T) { @@ -225,4 +211,18 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_enableUsageStats", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("enableUsageStats", testValue) + if vBool, err := cmdFlags.GetBool("enableUsageStats"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.EnableUsageStats) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/go/tasks/plugins/k8s/ray/ray.go b/go/tasks/plugins/k8s/ray/ray.go index b82e06ddd..0435216e9 100644 --- a/go/tasks/plugins/k8s/ray/ray.go +++ b/go/tasks/plugins/k8s/ray/ray.go @@ -5,9 +5,9 @@ import ( "fmt" "strconv" "strings" - "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/logs" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" @@ -27,11 +27,12 @@ import ( ) const ( - rayTaskType = "ray" - KindRayJob = "RayJob" - IncludeDashboard = "include-dashboard" - NodeIPAddress = "node-ip-address" - DashboardHost = "dashboard-host" + rayTaskType = "ray" + KindRayJob = "RayJob" + IncludeDashboard = "include-dashboard" + NodeIPAddress = "node-ip-address" + DashboardHost = "dashboard-host" + DisableUsageStatsStartParameter = "disable-usage-stats" ) type rayJobResourceHandler struct { @@ -57,7 +58,6 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) - if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } @@ -76,26 +76,36 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to get primary container from the pod: [%v]", err.Error()) } + cfg := GetConfig() headReplicas := int32(1) headNodeRayStartParams := make(map[string]string) if rayJob.RayCluster.HeadGroupSpec != nil && rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil { headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams + } else if headNode := cfg.Defaults.HeadNode; len(headNode.StartParameters) > 0 { + headNodeRayStartParams = headNode.StartParameters } + if _, exist := headNodeRayStartParams[IncludeDashboard]; !exist { headNodeRayStartParams[IncludeDashboard] = strconv.FormatBool(GetConfig().IncludeDashboard) } + if _, exist := headNodeRayStartParams[NodeIPAddress]; !exist { - headNodeRayStartParams[NodeIPAddress] = GetConfig().NodeIPAddress + headNodeRayStartParams[NodeIPAddress] = cfg.Defaults.HeadNode.IPAddress } + if _, exist := headNodeRayStartParams[DashboardHost]; !exist { - headNodeRayStartParams[DashboardHost] = GetConfig().DashboardHost + headNodeRayStartParams[DashboardHost] = cfg.DashboardHost + } + + if _, exists := headNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { + headNodeRayStartParams[DisableUsageStatsStartParameter] = "true" } enableIngress := true rayClusterSpec := rayv1alpha1.RayClusterSpec{ HeadGroupSpec: rayv1alpha1.HeadGroupSpec{ Template: buildHeadPodTemplate(&container, podSpec, objectMeta, taskCtx), - ServiceType: v1.ServiceType(GetConfig().ServiceType), + ServiceType: v1.ServiceType(cfg.ServiceType), Replicas: &headReplicas, EnableIngress: &enableIngress, RayStartParams: headNodeRayStartParams, @@ -111,6 +121,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC if spec.MinReplicas != 0 { minReplicas = spec.MinReplicas } + if spec.MaxReplicas != 0 { maxReplicas = spec.MaxReplicas } @@ -118,9 +129,16 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC workerNodeRayStartParams := make(map[string]string) if spec.RayStartParams != nil { workerNodeRayStartParams = spec.RayStartParams + } else if workerNode := cfg.Defaults.WorkerNode; len(workerNode.StartParameters) > 0 { + workerNodeRayStartParams = workerNode.StartParameters } + if _, exist := workerNodeRayStartParams[NodeIPAddress]; !exist { - workerNodeRayStartParams[NodeIPAddress] = GetConfig().NodeIPAddress + workerNodeRayStartParams[NodeIPAddress] = cfg.Defaults.WorkerNode.IPAddress + } + + if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats { + workerNodeRayStartParams[DisableUsageStatsStartParameter] = "true" } workerNodeSpec := rayv1alpha1.WorkerGroupSpec{ @@ -145,8 +163,8 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC jobSpec := rayv1alpha1.RayJobSpec{ RayClusterSpec: rayClusterSpec, Entrypoint: strings.Join(container.Args, " "), - ShutdownAfterJobFinishes: GetConfig().ShutdownAfterJobFinishes, - TTLSecondsAfterFinished: &GetConfig().TTLSecondsAfterFinished, + ShutdownAfterJobFinishes: cfg.ShutdownAfterJobFinishes, + TTLSecondsAfterFinished: &cfg.TTLSecondsAfterFinished, RuntimeEnv: rayJob.RuntimeEnv, } @@ -347,12 +365,10 @@ func (rayJobResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx }, nil } -func getEventInfoForRayJob() (*pluginsCore.TaskInfo, error) { - taskLogs := make([]*core.TaskLog, 0, 3) - logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig()) - +func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1alpha1.RayJob) (*pluginsCore.TaskInfo, error) { + logPlugin, err := logs.InitializeLogPlugins(&logConfig) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err) } if logPlugin == nil { @@ -362,22 +378,31 @@ func getEventInfoForRayJob() (*pluginsCore.TaskInfo, error) { // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs // RayJob CRD does not include the name of the worker or head pod for now - // TODO: Add ray Dashboard URI to task logs + taskID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + logOutput, err := logPlugin.GetTaskLogs(tasklog.Input{ + Namespace: rayJob.Namespace, + TaskExecutionIdentifier: &taskID, + }) + + if err != nil { + return nil, fmt.Errorf("failed to generate task logs. Error: %w", err) + } return &pluginsCore.TaskInfo{ - Logs: taskLogs, + Logs: logOutput.TaskLogs, }, nil } -func (rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { +func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { rayJob := resource.(*rayv1alpha1.RayJob) - info, err := getEventInfoForRayJob() + info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) if err != nil { return pluginsCore.PhaseInfoUndefined, err } + switch rayJob.Status.JobStatus { case rayv1alpha1.JobStatusPending: - return pluginsCore.PhaseInfoNotReady(time.Now(), pluginsCore.DefaultPhaseVersion, "job is pending"), nil + return pluginsCore.PhaseInfoInitializing(rayJob.Status.StartTime.Time, pluginsCore.DefaultPhaseVersion, "job is pending", info), nil case rayv1alpha1.JobStatusFailed: reason := fmt.Sprintf("Failed to create Ray job: %s", rayJob.Name) return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil @@ -386,7 +411,8 @@ func (rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s case rayv1alpha1.JobStatusRunning: return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil } - return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "JobCreated"), nil + + return pluginsCore.PhaseInfoQueued(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil } func init() { diff --git a/go/tasks/plugins/k8s/ray/ray_test.go b/go/tasks/plugins/k8s/ray/ray_test.go index e3e4b5585..c7b5042c8 100644 --- a/go/tasks/plugins/k8s/ray/ray_test.go +++ b/go/tasks/plugins/k8s/ray/ray_test.go @@ -3,6 +3,12 @@ package ray import ( "context" "testing" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyteplugins/go/tasks/logs" + mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" @@ -169,7 +175,58 @@ func TestBuildResourceRay(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas, &headReplica) assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, - map[string]string{"dashboard-host": "0.0.0.0", "include-dashboard": "true", "node-ip-address": "$MY_POD_IP", "num-cpus": "1"}) + map[string]string{ + "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", + "node-ip-address": "$MY_POD_IP", "num-cpus": "1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) + + workerReplica := int32(3) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, &workerReplica) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, &workerReplica) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, &workerReplica) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) +} + +func TestDefaultStartParameters(t *testing.T) { + rayJobResourceHandler := rayJobResourceHandler{} + rayJob := &plugins.RayJob{ + RayCluster: &plugins.RayCluster{ + HeadGroupSpec: &plugins.HeadGroupSpec{}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3}}, + }, + } + + taskTemplate := dummyRayTaskTemplate("ray-id", rayJob) + toleration := []corev1.Toleration{{ + Key: "storage", + Value: "dedicated", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }} + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) + assert.Nil(t, err) + + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate)) + assert.Nil(t, err) + + assert.NotNil(t, RayResource) + ray, ok := RayResource.(*rayv1alpha1.RayJob) + assert.True(t, ok) + + headReplica := int32(1) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas, &headReplica) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, + map[string]string{ + "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", + "node-ip-address": "$MY_POD_IP"}) assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) @@ -180,12 +237,72 @@ func TestBuildResourceRay(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, &workerReplica) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"node-ip-address": "$MY_POD_IP"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) } +func newPluginContext() k8s.PluginContext { + plg := &mocks2.PluginContext{} + + taskExecID := &mocks.TaskExecutionID{} + taskExecID.OnGetID().Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + + tskCtx := &mocks.TaskExecutionMetadata{} + tskCtx.OnGetTaskExecutionID().Return(taskExecID) + plg.OnTaskExecutionMetadata().Return(tskCtx) + return plg +} + +func init() { + f := defaultConfig + f.Logs = logs.LogConfig{ + IsKubernetesEnabled: true, + } + + if err := SetConfig(&f); err != nil { + panic(err) + } +} + +func TestGetTaskPhase(t *testing.T) { + ctx := context.Background() + rayJobResourceHandler := rayJobResourceHandler{} + pluginCtx := newPluginContext() + + testCases := []struct { + rayJobPhase rayv1alpha1.JobStatus + expectedCorePhase pluginsCore.Phase + }{ + {"", pluginsCore.PhaseQueued}, + {rayv1alpha1.JobStatusPending, pluginsCore.PhaseInitializing}, + {rayv1alpha1.JobStatusRunning, pluginsCore.PhaseRunning}, + {rayv1alpha1.JobStatusSucceeded, pluginsCore.PhaseSuccess}, + {rayv1alpha1.JobStatusFailed, pluginsCore.PhasePermanentFailure}, + } + + for _, tc := range testCases { + t.Run("TestGetTaskPhase_"+string(tc.rayJobPhase), func(t *testing.T) { + rayObject := &rayv1alpha1.RayJob{} + rayObject.Status.JobStatus = tc.rayJobPhase + startTime := metav1.NewTime(time.Now()) + rayObject.Status.StartTime = &startTime + phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) + assert.Nil(t, err) + assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String()) + }) + } +} + func TestGetPropertiesRay(t *testing.T) { rayJobResourceHandler := rayJobResourceHandler{} expected := k8s.PluginProperties{}