diff --git a/apiserver/pkg/server/ray_job_submission_service_server.go b/apiserver/pkg/server/ray_job_submission_service_server.go index d03a30a9346..e35557a8a1c 100644 --- a/apiserver/pkg/server/ray_job_submission_service_server.go +++ b/apiserver/pkg/server/ray_job_submission_service_server.go @@ -18,6 +18,7 @@ import ( "sigs.k8s.io/yaml" api "github.com/ray-project/kuberay/proto/go_client" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" ) @@ -31,7 +32,7 @@ type RayJobSubmissionServiceServer struct { api.UnimplementedRayJobSubmissionServiceServer options *RayJobSubmissionServiceServerOptions clusterServer *ClusterServer - dashboardClientFunc func() utils.RayDashboardClientInterface + dashboardClientFunc func(rayCluster *rayv1.RayCluster, url string) (utils.RayDashboardClientInterface, error) log logr.Logger } @@ -49,9 +50,8 @@ func (s *RayJobSubmissionServiceServer) SubmitRayJob(ctx context.Context, req *a if err != nil { return nil, err } - rayDashboardClient := s.dashboardClientFunc() - // TODO: support proxy subresources in kuberay-apiserver - if err := rayDashboardClient.InitClient(ctx, *url, nil); err != nil { + rayDashboardClient, err := s.dashboardClientFunc(nil, *url) + if err != nil { return nil, err } request := &utils.RayJobRequest{Entrypoint: req.Jobsubmission.Entrypoint} @@ -104,9 +104,8 @@ func (s *RayJobSubmissionServiceServer) GetJobDetails(ctx context.Context, req * if err != nil { return nil, err } - rayDashboardClient := s.dashboardClientFunc() - // TODO: support proxy subresources in kuberay-apiserver - if err := rayDashboardClient.InitClient(ctx, *url, nil); err != nil { + rayDashboardClient, err := s.dashboardClientFunc(nil, *url) + if err != nil { return nil, err } nodeInfo, err := rayDashboardClient.GetJobInfo(ctx, req.Submissionid) @@ -127,9 +126,8 @@ func (s *RayJobSubmissionServiceServer) GetJobLog(ctx context.Context, req *api. if err != nil { return nil, err } - rayDashboardClient := s.dashboardClientFunc() - // TODO: support proxy subresources in kuberay-apiserver - if err := rayDashboardClient.InitClient(ctx, *url, nil); err != nil { + rayDashboardClient, err := s.dashboardClientFunc(nil, *url) + if err != nil { return nil, err } jlog, err := rayDashboardClient.GetJobLog(ctx, req.Submissionid) @@ -150,9 +148,8 @@ func (s *RayJobSubmissionServiceServer) ListJobDetails(ctx context.Context, req if err != nil { return nil, err } - rayDashboardClient := s.dashboardClientFunc() - // TODO: support proxy subresources in kuberay-apiserver - if err := rayDashboardClient.InitClient(ctx, *url, nil); err != nil { + rayDashboardClient, err := s.dashboardClientFunc(nil, *url) + if err != nil { return nil, err } nodesInfo, err := rayDashboardClient.ListJobs(ctx) @@ -174,9 +171,8 @@ func (s *RayJobSubmissionServiceServer) StopRayJob(ctx context.Context, req *api if err != nil { return nil, err } - rayDashboardClient := s.dashboardClientFunc() - // TODO: support proxy subresources in kuberay-apiserver - if err := rayDashboardClient.InitClient(ctx, *url, nil); err != nil { + rayDashboardClient, err := s.dashboardClientFunc(nil, *url) + if err != nil { return nil, err } err = rayDashboardClient.StopJob(ctx, req.Submissionid) @@ -194,9 +190,8 @@ func (s *RayJobSubmissionServiceServer) DeleteRayJob(ctx context.Context, req *a if err != nil { return nil, err } - rayDashboardClient := s.dashboardClientFunc() - // TODO: support proxy subresources in kuberay-apiserver - if err := rayDashboardClient.InitClient(ctx, *url, nil); err != nil { + rayDashboardClient, err := s.dashboardClientFunc(nil, *url) + if err != nil { return nil, err } err = rayDashboardClient.DeleteJob(ctx, req.Submissionid) diff --git a/ray-operator/apis/config/v1alpha1/configuration_types.go b/ray-operator/apis/config/v1alpha1/configuration_types.go index c58864d85ff..7c390715415 100644 --- a/ray-operator/apis/config/v1alpha1/configuration_types.go +++ b/ray-operator/apis/config/v1alpha1/configuration_types.go @@ -5,6 +5,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/manager" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" ) @@ -75,7 +76,7 @@ type Configuration struct { EnableMetrics bool `json:"enableMetrics,omitempty"` } -func (config Configuration) GetDashboardClient(mgr manager.Manager) func() utils.RayDashboardClientInterface { +func (config Configuration) GetDashboardClient(mgr manager.Manager) func(rayCluster *rayv1.RayCluster, url string) (utils.RayDashboardClientInterface, error) { return utils.GetRayDashboardClientFunc(mgr, config.UseKubernetesProxy) } diff --git a/ray-operator/controllers/ray/rayjob_controller.go b/ray-operator/controllers/ray/rayjob_controller.go index e3d4f2a4330..17dab71e345 100644 --- a/ray-operator/controllers/ray/rayjob_controller.go +++ b/ray-operator/controllers/ray/rayjob_controller.go @@ -42,7 +42,7 @@ type RayJobReconciler struct { Scheme *runtime.Scheme Recorder record.EventRecorder - dashboardClientFunc func() utils.RayDashboardClientInterface + dashboardClientFunc func(rayCluster *rayv1.RayCluster, url string) (utils.RayDashboardClientInterface, error) options RayJobReconcilerOptions } @@ -115,9 +115,10 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request) logger.Error(err, "Failed to get RayCluster") } - rayDashboardClient := r.dashboardClientFunc() - if err := rayDashboardClient.InitClient(ctx, rayJobInstance.Status.DashboardURL, rayClusterInstance); err != nil { - logger.Error(err, "Failed to initialize dashboard client") + rayDashboardClient, err := r.dashboardClientFunc(rayClusterInstance, rayJobInstance.Status.DashboardURL) + if err != nil { + logger.Error(err, "Failed to get dashboard client for RayJob") + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err } if err := rayDashboardClient.StopJob(ctx, rayJobInstance.Status.JobId); err != nil { logger.Error(err, "Failed to stop job for RayJob") @@ -260,8 +261,9 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request) } // Check the current status of ray jobs - rayDashboardClient := r.dashboardClientFunc() - if err := rayDashboardClient.InitClient(ctx, rayJobInstance.Status.DashboardURL, rayClusterInstance); err != nil { + rayDashboardClient, err := r.dashboardClientFunc(rayClusterInstance, rayJobInstance.Status.DashboardURL) + if err != nil { + logger.Error(err, "Failed to get dashboard client for RayJob") return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err } diff --git a/ray-operator/controllers/ray/rayservice_controller.go b/ray-operator/controllers/ray/rayservice_controller.go index 498720f73db..31940973c8a 100644 --- a/ray-operator/controllers/ray/rayservice_controller.go +++ b/ray-operator/controllers/ray/rayservice_controller.go @@ -52,7 +52,7 @@ type RayServiceReconciler struct { // Cache value is map of RayCluster name to Serve application config. ServeConfigs *lru.Cache RayClusterDeletionTimestamps cmap.ConcurrentMap[string, time.Time] - dashboardClientFunc func() utils.RayDashboardClientInterface + dashboardClientFunc func(rayCluster *rayv1.RayCluster, url string) (utils.RayDashboardClientInterface, error) httpProxyClientFunc func() utils.RayHttpProxyClientInterface } @@ -943,8 +943,8 @@ func (r *RayServiceReconciler) reconcileServe(ctx context.Context, rayServiceIns return false, serveApplications, err } - rayDashboardClient := r.dashboardClientFunc() - if err := rayDashboardClient.InitClient(ctx, clientURL, rayClusterInstance); err != nil { + rayDashboardClient, err := r.dashboardClientFunc(rayClusterInstance, clientURL) + if err != nil { return false, serveApplications, err } diff --git a/ray-operator/controllers/ray/suite_test.go b/ray-operator/controllers/ray/suite_test.go index 7841d3ac893..1d45dfffcea 100644 --- a/ray-operator/controllers/ray/suite_test.go +++ b/ray-operator/controllers/ray/suite_test.go @@ -51,9 +51,9 @@ var ( type TestClientProvider struct{} -func (testProvider TestClientProvider) GetDashboardClient(_ manager.Manager) func() utils.RayDashboardClientInterface { - return func() utils.RayDashboardClientInterface { - return fakeRayDashboardClient +func (testProvider TestClientProvider) GetDashboardClient(_ manager.Manager) func(rayCluster *rayv1.RayCluster, url string) (utils.RayDashboardClientInterface, error) { + return func(_ *rayv1.RayCluster, _ string) (utils.RayDashboardClientInterface, error) { + return fakeRayDashboardClient, nil } } diff --git a/ray-operator/controllers/ray/utils/dashboard_httpclient.go b/ray-operator/controllers/ray/utils/dashboard_httpclient.go index a6e82775d45..fd058074eb0 100644 --- a/ray-operator/controllers/ray/utils/dashboard_httpclient.go +++ b/ray-operator/controllers/ray/utils/dashboard_httpclient.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "time" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/util/json" @@ -25,7 +24,6 @@ var ( ) type RayDashboardClientInterface interface { - InitClient(ctx context.Context, url string, rayCluster *rayv1.RayCluster) error UpdateDeployments(ctx context.Context, configJson []byte) error // V2/multi-app Rest API GetServeDetails(ctx context.Context) (*ServeDetails, error) @@ -39,53 +37,11 @@ type RayDashboardClientInterface interface { DeleteJob(ctx context.Context, jobName string) error } -type BaseDashboardClient struct { +type RayDashboardClient struct { client *http.Client dashboardURL string } -func GetRayDashboardClientFunc(mgr ctrl.Manager, useKubernetesProxy bool) func() RayDashboardClientInterface { - return func() RayDashboardClientInterface { - return &RayDashboardClient{ - mgr: mgr, - useKubernetesProxy: useKubernetesProxy, - } - } -} - -type RayDashboardClient struct { - mgr ctrl.Manager - BaseDashboardClient - useKubernetesProxy bool -} - -func (r *RayDashboardClient) InitClient(ctx context.Context, url string, rayCluster *rayv1.RayCluster) error { - log := ctrl.LoggerFrom(ctx) - - if r.useKubernetesProxy { - var err error - headSvcName := rayCluster.Status.Head.ServiceName - if headSvcName == "" { - log.Info("RayCluster is missing .status.head.serviceName, calling GenerateHeadServiceName instead...", "RayCluster name", rayCluster.Name, "namespace", rayCluster.Namespace) - headSvcName, err = GenerateHeadServiceName(RayClusterCRD, rayCluster.Spec, rayCluster.Name) - if err != nil { - return err - } - } - - r.client = r.mgr.GetHTTPClient() - r.dashboardURL = fmt.Sprintf("%s/api/v1/namespaces/%s/services/%s:dashboard/proxy", r.mgr.GetConfig().Host, rayCluster.Namespace, headSvcName) - return nil - } - - r.client = &http.Client{ - Timeout: 2 * time.Second, - } - - r.dashboardURL = "http://" + url - return nil -} - // UpdateDeployments update the deployments in the Ray cluster. func (r *RayDashboardClient) UpdateDeployments(ctx context.Context, configJson []byte) error { var req *http.Request diff --git a/ray-operator/controllers/ray/utils/dashboard_httpclient_test.go b/ray-operator/controllers/ray/utils/dashboard_httpclient_test.go index cefd01297b0..51d95cc873c 100644 --- a/ray-operator/controllers/ray/utils/dashboard_httpclient_test.go +++ b/ray-operator/controllers/ray/utils/dashboard_httpclient_test.go @@ -53,8 +53,8 @@ var _ = Describe("RayFrameworkGenerator", func() { } rayDashboardClient = &RayDashboardClient{} - err := rayDashboardClient.InitClient(context.Background(), "127.0.0.1:8090", nil) - Expect(err).ToNot(HaveOccurred()) + rayDashboardClient.dashboardURL = "http://127.0.0.1:8090" + rayDashboardClient.client = &http.Client{} }) It("Test ConvertRayJobToReq", func() { diff --git a/ray-operator/controllers/ray/utils/fake_serve_httpclient.go b/ray-operator/controllers/ray/utils/fake_serve_httpclient.go index 35de7ef866b..2de90cb4510 100644 --- a/ray-operator/controllers/ray/utils/fake_serve_httpclient.go +++ b/ray-operator/controllers/ray/utils/fake_serve_httpclient.go @@ -3,7 +3,6 @@ package utils import ( "context" "fmt" - "net/http" "sync/atomic" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" @@ -12,18 +11,11 @@ import ( type FakeRayDashboardClient struct { multiAppStatuses map[string]*ServeApplicationStatus GetJobInfoMock atomic.Pointer[func(context.Context, string) (*RayJobInfo, error)] - BaseDashboardClient - serveDetails ServeDetails + serveDetails ServeDetails } var _ RayDashboardClientInterface = (*FakeRayDashboardClient)(nil) -func (r *FakeRayDashboardClient) InitClient(_ context.Context, url string, _ *rayv1.RayCluster) error { - r.client = &http.Client{} - r.dashboardURL = "http://" + url - return nil -} - func (r *FakeRayDashboardClient) UpdateDeployments(_ context.Context, _ []byte) error { fmt.Print("UpdateDeployments fake succeeds.") return nil diff --git a/ray-operator/controllers/ray/utils/util.go b/ray-operator/controllers/ray/utils/util.go index df2efe0fc39..f814c8eab22 100644 --- a/ray-operator/controllers/ray/utils/util.go +++ b/ray-operator/controllers/ray/utils/util.go @@ -6,10 +6,12 @@ import ( "encoding/base32" "fmt" "math" + "net/http" "os" "reflect" "strconv" "strings" + "time" "unicode" batchv1 "k8s.io/api/batch/v1" @@ -638,7 +640,7 @@ func EnvVarByName(envName string, envVars []corev1.EnvVar) (corev1.EnvVar, bool) } type ClientProvider interface { - GetDashboardClient(mgr manager.Manager) func() RayDashboardClientInterface + GetDashboardClient(mgr manager.Manager) func(rayCluster *rayv1.RayCluster, url string) (RayDashboardClientInterface, error) GetHttpProxyClient(mgr manager.Manager) func() RayHttpProxyClientInterface } @@ -754,3 +756,29 @@ func FetchHeadServiceURL(ctx context.Context, cli client.Client, rayCluster *ray port) return headServiceURL, nil } + +func GetRayDashboardClientFunc(mgr manager.Manager, useKubernetesProxy bool) func(rayCluster *rayv1.RayCluster, url string) (RayDashboardClientInterface, error) { + return func(rayCluster *rayv1.RayCluster, url string) (RayDashboardClientInterface, error) { + if useKubernetesProxy { + var err error + headSvcName := rayCluster.Status.Head.ServiceName + if headSvcName == "" { + headSvcName, err = GenerateHeadServiceName(RayClusterCRD, rayCluster.Spec, rayCluster.Name) + if err != nil { + return nil, err + } + } + return &RayDashboardClient{ + client: mgr.GetHTTPClient(), + dashboardURL: fmt.Sprintf("%s/api/v1/namespaces/%s/services/%s:dashboard/proxy", mgr.GetConfig().Host, rayCluster.Namespace, headSvcName), + }, nil + } + + return &RayDashboardClient{ + client: &http.Client{ + Timeout: 2 * time.Second, + }, + dashboardURL: "http://" + url, + }, nil + } +} diff --git a/ray-operator/test/sampleyaml/support.go b/ray-operator/test/sampleyaml/support.go index 0ab8de835ae..42ffd19e0a2 100644 --- a/ray-operator/test/sampleyaml/support.go +++ b/ray-operator/test/sampleyaml/support.go @@ -65,7 +65,6 @@ func AllAppsRunning(rayService *rayv1.RayService) bool { func QueryDashboardGetAppStatus(t Test, rayCluster *rayv1.RayCluster) func(Gomega) { return func(g Gomega) { - rayDashboardClient := &utils.RayDashboardClient{} pod, err := GetHeadPod(t, rayCluster) g.Expect(err).ToNot(HaveOccurred()) @@ -76,8 +75,8 @@ func QueryDashboardGetAppStatus(t Test, rayCluster *rayv1.RayCluster) func(Gomeg g.Expect(err).ToNot(HaveOccurred()) url := fmt.Sprintf("127.0.0.1:%d", localPort) - - err = rayDashboardClient.InitClient(t.Ctx(), url, rayCluster) + rayDashboardClientFunc := utils.GetRayDashboardClientFunc(nil, false) + rayDashboardClient, err := rayDashboardClientFunc(rayCluster, url) g.Expect(err).ToNot(HaveOccurred()) serveDetails, err := rayDashboardClient.GetServeDetails(t.Ctx()) g.Expect(err).ToNot(HaveOccurred())