Skip to content
2 changes: 1 addition & 1 deletion apiserver/pkg/server/ray_job_submission_service_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type RayJobSubmissionServiceServer struct {
// Create RayJobSubmissionServiceServer
func NewRayJobSubmissionServiceServer(clusterServer *ClusterServer, options *RayJobSubmissionServiceServerOptions) *RayJobSubmissionServiceServer {
zl := zerolog.New(os.Stdout).Level(zerolog.DebugLevel)
return &RayJobSubmissionServiceServer{clusterServer: clusterServer, options: options, log: zerologr.New(&zl).WithName("jobsubmissionservice"), dashboardClientFunc: utils.GetRayDashboardClientFunc(nil, false)}
return &RayJobSubmissionServiceServer{clusterServer: clusterServer, options: options, log: zerologr.New(&zl).WithName("jobsubmissionservice"), dashboardClientFunc: utils.GetRayDashboardClientFunc(nil, false, nil, nil)}
}

// Submit Ray job
Expand Down
6 changes: 4 additions & 2 deletions ray-operator/apis/config/v1alpha1/configuration_types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package v1alpha1

import (
"sync"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/manager"
Expand Down Expand Up @@ -85,8 +87,8 @@ type Configuration struct {
EnableMetrics bool `json:"enableMetrics,omitempty"`
}

func (config Configuration) GetDashboardClient(mgr manager.Manager) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
return utils.GetRayDashboardClientFunc(mgr, config.UseKubernetesProxy)
func (config Configuration) GetDashboardClient(mgr manager.Manager, taskQueue chan func(), jobInfoMap *sync.Map) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
return utils.GetRayDashboardClientFunc(mgr, config.UseKubernetesProxy, taskQueue, jobInfoMap)
}

func (config Configuration) GetHttpProxyClient(mgr manager.Manager) func(hostIp, podNamespace, podName string, port int) utils.RayHttpProxyClientInterface {
Expand Down
31 changes: 23 additions & 8 deletions ray-operator/controllers/ray/rayjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"

"github.com/go-logr/logr"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/ray-project/kuberay/ray-operator/controllers/ray/metrics"
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils/dashboardclient"
utiltypes "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils/types"
"github.com/ray-project/kuberay/ray-operator/pkg/features"
)

Expand All @@ -40,11 +42,12 @@ const (
// RayJobReconciler reconciles a RayJob object
type RayJobReconciler struct {
client.Client
Scheme *runtime.Scheme
Recorder record.EventRecorder

Scheme *runtime.Scheme
Recorder record.EventRecorder
JobInfoMap *sync.Map
dashboardClientFunc func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error)
options RayJobReconcilerOptions
workerPool *dashboardclient.WorkerPool
}

type RayJobReconcilerOptions struct {
Expand All @@ -53,13 +56,18 @@ type RayJobReconcilerOptions struct {

// NewRayJobReconciler returns a new reconcile.Reconciler
func NewRayJobReconciler(_ context.Context, mgr manager.Manager, options RayJobReconcilerOptions, provider utils.ClientProvider) *RayJobReconciler {
dashboardClientFunc := provider.GetDashboardClient(mgr)
taskQueue := make(chan func(), 1000)
JobInfoMap := &sync.Map{}
workerPool := dashboardclient.NewWorkerPool(taskQueue)
dashboardClientFunc := provider.GetDashboardClient(mgr, taskQueue, JobInfoMap)
return &RayJobReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Recorder: mgr.GetEventRecorderFor("rayjob-controller"),
JobInfoMap: JobInfoMap,
dashboardClientFunc: dashboardClientFunc,
options: options,
workerPool: workerPool,
}
}

Expand Down Expand Up @@ -263,9 +271,10 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request)
if err != nil {
return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err
}

jobInfo, err := rayDashboardClient.GetJobInfo(ctx, rayJobInstance.Status.JobId)
if err != nil {
var jobInfo *utiltypes.RayJobInfo
if loadedJobInfo, ok := r.JobInfoMap.Load(rayJobInstance.Status.JobId); ok {
jobInfo = loadedJobInfo.(*utiltypes.RayJobInfo)
} else {
// If the Ray job was not found, GetJobInfo returns a BadRequest error.
if rayJobInstance.Spec.SubmissionMode == rayv1.HTTPMode && errors.IsBadRequest(err) {
logger.Info("The Ray job was not found. Submit a Ray job via an HTTP request.", "JobId", rayJobInstance.Status.JobId)
Expand All @@ -275,10 +284,16 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request)
}
return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, nil
}
logger.Error(err, "Failed to get job info", "JobId", rayJobInstance.Status.JobId)
logger.Info("Job info not found in map", "JobId", rayJobInstance.Status.JobId)
rayDashboardClient.AsyncGetJobInfo(ctx, rayJobInstance.Status.JobId)
return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err
}

rayDashboardClient.AsyncGetJobInfo(ctx, rayJobInstance.Status.JobId)
if jobInfo == nil {
logger.Error(err, "Failed to get job info", "JobId", rayJobInstance.Status.JobId)
return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err
}
// If the JobStatus is in a terminal status, such as SUCCEEDED, FAILED, or STOPPED, it is impossible for the Ray job
// to transition to any other. Additionally, RayJob does not currently support retries. Hence, we can mark the RayJob
// as "Complete" or "Failed" to avoid unnecessary reconciliation.
Expand Down
2 changes: 1 addition & 1 deletion ray-operator/controllers/ray/rayservice_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type RayServiceReconciler struct {

// NewRayServiceReconciler returns a new reconcile.Reconciler
func NewRayServiceReconciler(_ context.Context, mgr manager.Manager, provider utils.ClientProvider) *RayServiceReconciler {
dashboardClientFunc := provider.GetDashboardClient(mgr)
dashboardClientFunc := provider.GetDashboardClient(mgr, nil, nil)
httpProxyClientFunc := provider.GetHttpProxyClient(mgr)
return &RayServiceReconciler{
Client: mgr.GetClient(),
Expand Down
3 changes: 2 additions & 1 deletion ray-operator/controllers/ray/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ray
import (
"os"
"path/filepath"
"sync"
"testing"

. "github.com/onsi/ginkgo/v2"
Expand Down Expand Up @@ -52,7 +53,7 @@ var (

type TestClientProvider struct{}

func (testProvider TestClientProvider) GetDashboardClient(_ manager.Manager) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
func (testProvider TestClientProvider) GetDashboardClient(_ manager.Manager, _ chan func(), _ *sync.Map) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
return func(_ *rayv1.RayCluster, _ string) (dashboardclient.RayDashboardClientInterface, error) {
return fakeRayDashboardClient, nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"strings"
"sync"

"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/util/json"
Expand All @@ -25,12 +26,13 @@ var (
)

type RayDashboardClientInterface interface {
InitClient(client *http.Client, dashboardURL string)
InitClient(client *http.Client, dashboardURL string, taskQueue chan func(), jobInfoMap *sync.Map)
UpdateDeployments(ctx context.Context, configJson []byte) error
// V2/multi-app Rest API
GetServeDetails(ctx context.Context) (*utiltypes.ServeDetails, error)
GetMultiApplicationStatus(context.Context) (map[string]*utiltypes.ServeApplicationStatus, error)
GetJobInfo(ctx context.Context, jobId string) (*utiltypes.RayJobInfo, error)
AsyncGetJobInfo(ctx context.Context, jobId string)
ListJobs(ctx context.Context) (*[]utiltypes.RayJobInfo, error)
SubmitJob(ctx context.Context, rayJob *rayv1.RayJob) (string, error)
SubmitJobReq(ctx context.Context, request *utiltypes.RayJobRequest) (string, error)
Expand All @@ -41,12 +43,16 @@ type RayDashboardClientInterface interface {

type RayDashboardClient struct {
client *http.Client
taskQueue chan func()
jobInfoMap *sync.Map
dashboardURL string
}

func (r *RayDashboardClient) InitClient(client *http.Client, dashboardURL string) {
func (r *RayDashboardClient) InitClient(client *http.Client, dashboardURL string, taskQueue chan func(), jobInfoMap *sync.Map) {
r.client = client
r.dashboardURL = dashboardURL
r.taskQueue = taskQueue
r.jobInfoMap = jobInfoMap
}

// UpdateDeployments update the deployments in the Ray cluster.
Expand Down Expand Up @@ -161,6 +167,18 @@ func (r *RayDashboardClient) GetJobInfo(ctx context.Context, jobId string) (*uti
return &jobInfo, nil
}

func (r *RayDashboardClient) AsyncGetJobInfo(ctx context.Context, jobId string) {
r.taskQueue <- func() {
jobInfo, err := r.GetJobInfo(ctx, jobId)
if err != nil {
fmt.Printf("AsyncGetJobInfo: error: %v\n", err)
}
if jobInfo != nil {
r.jobInfoMap.Store(jobId, jobInfo)
}
}
}

func (r *RayDashboardClient) ListJobs(ctx context.Context) (*[]utiltypes.RayJobInfo, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.dashboardURL+JobPath, nil)
if err != nil {
Expand Down Expand Up @@ -211,6 +229,7 @@ func (r *RayDashboardClient) SubmitJobReq(ctx context.Context, request *utiltype
}

req.Header.Set("Content-Type", "application/json")

resp, err := r.client.Do(req)
if err != nil {
return
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package dashboardclient

import (
"sync"
)

type WorkerPool struct {
taskQueue chan func()
stop chan struct{}
wg sync.WaitGroup
workers int
}

func NewWorkerPool(taskQueue chan func()) *WorkerPool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func NewWorkerPool(taskQueue chan func()) *WorkerPool {
func NewWorkerPool(workers int) *WorkerPool {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing a task queue channel is weird. Specifying a worker count is more understandable. You can also make a buffered channel based on the worker count internally.

wp := &WorkerPool{
taskQueue: taskQueue,
workers: 10,
stop: make(chan struct{}),
}

// Start workers immediately
wp.Start()
return wp
}

// Start launches worker goroutines to consume from queue
func (wp *WorkerPool) Start() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this private.

for i := 0; i < wp.workers; i++ {
wp.wg.Add(1)
go wp.worker()
}
}

// worker consumes and executes tasks from the queue
func (wp *WorkerPool) worker() {
defer wp.wg.Done()

for {
select {
case <-wp.stop:
return
case task := <-wp.taskQueue:
if task != nil {
task() // Execute the job
}
}
}
}

// Stop shuts down all workers
func (wp *WorkerPool) Stop() {
close(wp.stop)
wp.wg.Wait()
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"sync"
"sync/atomic"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
Expand All @@ -19,7 +20,7 @@ type FakeRayDashboardClient struct {

var _ dashboardclient.RayDashboardClientInterface = (*FakeRayDashboardClient)(nil)

func (r *FakeRayDashboardClient) InitClient(_ *http.Client, _ string) {
func (r *FakeRayDashboardClient) InitClient(_ *http.Client, _ string, _ chan func(), _ *sync.Map) {
}

func (r *FakeRayDashboardClient) UpdateDeployments(_ context.Context, _ []byte) error {
Expand All @@ -46,6 +47,9 @@ func (r *FakeRayDashboardClient) GetJobInfo(ctx context.Context, jobId string) (
return &utiltypes.RayJobInfo{JobStatus: rayv1.JobStatusRunning}, nil
}

func (r *FakeRayDashboardClient) AsyncGetJobInfo(_ context.Context, _ string) {
}

func (r *FakeRayDashboardClient) ListJobs(ctx context.Context) (*[]utiltypes.RayJobInfo, error) {
if mock := r.GetJobInfoMock.Load(); mock != nil {
info, err := (*mock)(ctx, "job_id")
Expand Down
9 changes: 6 additions & 3 deletions ray-operator/controllers/ray/utils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"time"
"unicode"

Expand Down Expand Up @@ -641,7 +642,7 @@ func EnvVarByName(envName string, envVars []corev1.EnvVar) (corev1.EnvVar, bool)
}

type ClientProvider interface {
GetDashboardClient(mgr manager.Manager) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error)
GetDashboardClient(mgr manager.Manager, taskQueue chan func(), jobInfoMap *sync.Map) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error)
GetHttpProxyClient(mgr manager.Manager) func(hostIp, podNamespace, podName string, port int) RayHttpProxyClientInterface
}

Expand Down Expand Up @@ -758,7 +759,7 @@ func FetchHeadServiceURL(ctx context.Context, cli client.Client, rayCluster *ray
return headServiceURL, nil
}

func GetRayDashboardClientFunc(mgr manager.Manager, useKubernetesProxy bool) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
func GetRayDashboardClientFunc(mgr manager.Manager, useKubernetesProxy bool, taskQueue chan func(), jobInfoMap *sync.Map) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
return func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
dashboardClient := &dashboardclient.RayDashboardClient{}
if useKubernetesProxy {
Expand All @@ -777,13 +778,15 @@ func GetRayDashboardClientFunc(mgr manager.Manager, useKubernetesProxy bool) fun
// configured to communicate with the Kubernetes API server.
mgr.GetHTTPClient(),
fmt.Sprintf("%s/api/v1/namespaces/%s/services/%s:dashboard/proxy", mgr.GetConfig().Host, rayCluster.Namespace, headSvcName),
taskQueue,
jobInfoMap,
)
return dashboardClient, nil
}

dashboardClient.InitClient(&http.Client{
Timeout: 2 * time.Second,
}, "http://"+url)
}, "http://"+url, taskQueue, jobInfoMap)
return dashboardClient, nil
}
}
Expand Down
2 changes: 1 addition & 1 deletion ray-operator/rayjob-submitter/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func main() {
}
rayDashboardClient := &dashboardclient.RayDashboardClient{}
address = rayjobsubmitter.JobSubmissionURL(address)
rayDashboardClient.InitClient(&http.Client{Timeout: time.Second * 10}, address)
rayDashboardClient.InitClient(&http.Client{Timeout: time.Second * 10}, address, nil, nil)
submissionId, err := rayDashboardClient.SubmitJobReq(context.Background(), &req)
if err != nil {
if strings.Contains(err.Error(), "Please use a different submission_id") {
Expand Down
2 changes: 1 addition & 1 deletion ray-operator/test/sampleyaml/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func QueryDashboardGetAppStatus(t Test, rayCluster *rayv1.RayCluster) func(Gomeg

g.Expect(err).ToNot(HaveOccurred())
url := fmt.Sprintf("127.0.0.1:%d", localPort)
rayDashboardClientFunc := utils.GetRayDashboardClientFunc(nil, false)
rayDashboardClientFunc := utils.GetRayDashboardClientFunc(nil, false, nil, nil)
rayDashboardClient, err := rayDashboardClientFunc(rayCluster, url)
g.Expect(err).ToNot(HaveOccurred())
serveDetails, err := rayDashboardClient.GetServeDetails(t.Ctx())
Expand Down
Loading