diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index 5f82c501d05..963b83e3163 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -24,8 +24,8 @@ ConfigDict, Field, PositiveInt, - model_validator, field_validator, + model_validator, ) from zenml.constants import ( diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py index 6a52cc4877c..b604636146c 100644 --- a/src/zenml/zen_server/cloud_utils.py +++ b/src/zenml/zen_server/cloud_utils.py @@ -12,6 +12,7 @@ _cloud_connection: Optional["ZenMLCloudConnection"] = None + class ZenMLCloudConnection: """Class to use for communication between server and control plane.""" @@ -103,6 +104,48 @@ def post( return response + def patch( + self, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + ) -> requests.Response: + """Send a PATCH request using the active session. + + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + data: Data to include in the request. + + Raises: + RuntimeError: If the request failed. + + Returns: + The response. + """ + url = self._config.api_url + endpoint + + response = self.session.post( + url=url, params=params, json=data, timeout=7 + ) + if response.status_code == 401: + # Refresh the auth token and try again + self._clear_session() + response = self.session.patch( + url=url, params=params, json=data, timeout=7 + ) + + try: + response.raise_for_status() + except requests.HTTPError as e: + raise RuntimeError( + f"Failed while trying to contact the central zenml pro " + f"service: {e}" + ) + + return response + @property def session(self) -> requests.Session: """Authenticate to the ZenML Pro Management Plane. @@ -220,3 +263,8 @@ def cloud_connection() -> ZenMLCloudConnection: _cloud_connection = ZenMLCloudConnection() return _cloud_connection + + +def send_pro_tenant_status_update() -> None: + """Send a tenant status update to the Cloud API.""" + cloud_connection().patch("/tenants/status_updates") diff --git a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl index f100f87cc0a..b0f724bceb8 100644 --- a/src/zenml/zen_server/deploy/helm/templates/_environment.tpl +++ b/src/zenml/zen_server/deploy/helm/templates/_environment.tpl @@ -143,7 +143,7 @@ Returns: {{- if .ZenML.pro.enabled }} -auth_scheme: "external" +auth_scheme: EXTERNAL deployment_type: cloud cors_allow_origins: "{{ .ZenML.pro.dashboardURL }},{{ .ZenML.pro.serverURL }}" external_login_url: "{{ .ZenML.pro.dashboardURL }}/api/auth/login" diff --git a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml index fcf24841773..4d3e0f27065 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-secret.yaml @@ -13,7 +13,7 @@ data: {{- range $k, $v := include "zenml.storeSecretEnvVariables" . | fromYaml}} {{ $k }}: {{ $v | b64enc | quote }} {{- end }} - {{- range $k, $v := include "zenml.serverSecretConfigurationAttrs" . | fromYaml}} + {{- range $k, $v := include "zenml.serverSecretEnvVariables" . | fromYaml}} {{ $k }}: {{ $v | b64enc | quote }} {{- end }} {{- range $k, $v := include "zenml.secretsStoreSecretEnvVariables" . | fromYaml}} diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 3c4b20638ad..060e9db59cb 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -54,6 +54,7 @@ ) from zenml.enums import AuthScheme, SourceContextTypes from zenml.models import ServerDeploymentType +from zenml.zen_server.cloud_utils import send_pro_tenant_status_update from zenml.zen_server.exceptions import error_detail from zenml.zen_server.routers import ( actions_endpoints, @@ -389,6 +390,10 @@ def initialize() -> None: initialize_plugins() initialize_secure_headers() initialize_memcache(cfg.memcache_max_capacity, cfg.memcache_default_expiry) + if cfg.deployment_type == ServerDeploymentType.CLOUD: + # Send a tenant status update to the Cloud API to indicate that the + # ZenML server is running or to update the version and server URL. + send_pro_tenant_status_update() DASHBOARD_REDIRECT_URL = None