Skip to content

Commit

Permalink
Updated support for external ZenML Pro server enrollment
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Dec 20, 2024
1 parent b9096af commit 33af3ce
Show file tree
Hide file tree
Showing 19 changed files with 355 additions and 118 deletions.
13 changes: 7 additions & 6 deletions examples/e2e/pipelines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,13 @@ def e2e_use_case_training(
target=target,
)
########## Promotion stage ##########
latest_metric, current_metric = (
compute_performance_metrics_on_current_data(
dataset_tst=dataset_tst,
target_env=target_env,
after=["model_evaluator"],
)
(
latest_metric,
current_metric,
) = compute_performance_metrics_on_current_data(
dataset_tst=dataset_tst,
target_env=target_env,
after=["model_evaluator"],
)

promote_with_metric_compare(
Expand Down
8 changes: 5 additions & 3 deletions examples/quickstart/pipelines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def english_translation_pipeline(
tokenized_dataset, tokenizer = tokenize_data(
dataset=full_dataset, model_type=model_type
)
tokenized_train_dataset, tokenized_eval_dataset, tokenized_test_dataset = (
split_dataset(tokenized_dataset)
)
(
tokenized_train_dataset,
tokenized_eval_dataset,
tokenized_test_dataset,
) = split_dataset(tokenized_dataset)
model = train_model(
tokenized_dataset=tokenized_train_dataset,
model_type=model_type,
Expand Down
71 changes: 69 additions & 2 deletions src/zenml/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Functionality to support ZenML GlobalConfiguration."""
"""Functionality to support ZenML Server Configuration."""

import json
import os
from secrets import token_hex
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
PositiveInt,
model_validator,
field_validator,
)

from zenml.constants import (
DEFAULT_ZENML_JWT_TOKEN_ALGORITHM,
Expand All @@ -44,6 +51,7 @@
DEFAULT_ZENML_SERVER_SECURE_HEADERS_XXP,
DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE,
ENV_ZENML_SERVER_PREFIX,
ENV_ZENML_SERVER_PRO_PREFIX,
)
from zenml.enums import AuthScheme
from zenml.logger import get_logger
Expand Down Expand Up @@ -539,6 +547,9 @@ def get_server_config(cls) -> "ServerConfiguration":
for k, v in os.environ.items():
if v == "":
continue
if k.startswith(ENV_ZENML_SERVER_PRO_PREFIX):
# Skip Pro configuration
continue
if k.startswith(ENV_ZENML_SERVER_PREFIX):
env_server_config[
k[len(ENV_ZENML_SERVER_PREFIX) :].lower()
Expand All @@ -551,3 +562,59 @@ def get_server_config(cls) -> "ServerConfiguration":
# permit downgrading
extra="allow",
)


class ServerProConfiguration(BaseModel):
"""ZenML Server Pro configuration attributes.
All these attributes can be set through the environment with the
`ZENML_SERVER_PRO_`-Prefix. E.g. the value of the `ZENML_SERVER_PRO_API_URL`
environment variable will be extracted to api_url.
Attributes:
api_url: The ZenML Pro API URL.
oauth2_client_secret: The ZenML Pro OAuth2 client secret used to
authenticate the ZenML server with the ZenML Pro API.
oauth2_audience: The OAuth2 audience.
"""

api_url: str
oauth2_client_secret: str
oauth2_audience: str

@field_validator("api_url")
@classmethod
def _strip_trailing_slashes_url(cls, url: str) -> str:
"""Strip any trailing slashes on the API URL.
Args:
url: The API URL.
Returns:
The API URL with potential trailing slashes removed.
"""
return url.rstrip("/")

@classmethod
def get_server_config(cls) -> "ServerProConfiguration":
"""Get the server Pro configuration.
Returns:
The server Pro configuration.
"""
env_server_config: Dict[str, Any] = {}
for k, v in os.environ.items():
if v == "":
continue
if k.startswith(ENV_ZENML_SERVER_PRO_PREFIX):
env_server_config[
k[len(ENV_ZENML_SERVER_PRO_PREFIX) :].lower()
] = v

return ServerProConfiguration(**env_server_config)

model_config = ConfigDict(
# Allow extra attributes from configs of previous ZenML versions to
# permit downgrading
extra="allow",
)
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:

# ZenML Server environment variables
ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_"
ENV_ZENML_SERVER_PRO_PREFIX = "ZENML_SERVER_PRO_"
ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE"
ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME"
ENV_ZENML_SERVER_REPORTABLE_RESOURCES = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,11 @@ def _compute_orchestrator_url(
the URL to the dashboard view in SageMaker.
"""
try:
region_name, pipeline_name, execution_id = (
dissect_pipeline_execution_arn(pipeline_execution.arn)
)
(
region_name,
pipeline_name,
execution_id,
) = dissect_pipeline_execution_arn(pipeline_execution.arn)

# Get the Sagemaker session
session = pipeline_execution.sagemaker_session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
cast,
)

from pydantic import field_validator, BaseModel
from pydantic import BaseModel, field_validator

from zenml.config.base_settings import BaseSettings
from zenml.experiment_trackers.base_experiment_tracker import (
Expand Down Expand Up @@ -69,8 +69,8 @@ def _convert_settings(cls, value: Any) -> Any:
import wandb

if isinstance(value, wandb.Settings):
# Depending on the wandb version, either `model_dump`,
# `make_static` or `to_dict` is available to convert the settings
# Depending on the wandb version, either `model_dump`,
# `make_static` or `to_dict` is available to convert the settings
# to a dictionary
if isinstance(value, BaseModel):
return value.model_dump()
Expand Down
10 changes: 10 additions & 0 deletions src/zenml/zen_server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from zenml.zen_server.exceptions import http_exception_from_error
from zenml.zen_server.jwt import JWTToken
from zenml.zen_server.utils import (
get_zenml_headers,
is_same_or_subdomain,
server_config,
zen_store,
Expand Down Expand Up @@ -307,6 +308,14 @@ def authenticate_credentials(

device_model: Optional[OAuthDeviceInternalResponse] = None
if decoded_token.device_id:
if server_config().auth_scheme in [
AuthScheme.NO_AUTH,
AuthScheme.EXTERNAL,
]:
error = "Authentication error: device authorization is not supported."
logger.error(error)
raise CredentialsNotValid(error)

# Access tokens that have been issued for a device are only valid
# for that device, so we need to check if the device ID matches any
# of the valid devices in the database.
Expand Down Expand Up @@ -685,6 +694,7 @@ def authenticate_external_user(
# Get the user information from the external authenticator
user_info_url = config.external_user_info_url
headers = {"Authorization": "Bearer " + external_access_token}
headers.update(get_zenml_headers())
query_params = dict(server_id=str(config.get_external_server_id()))

try:
Expand Down
61 changes: 9 additions & 52 deletions src/zenml/zen_server/cloud_utils.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,23 @@
"""Utils concerning anything concerning the cloud control plane backend."""

import os
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional

import requests
from pydantic import BaseModel, ConfigDict, field_validator
from requests.adapters import HTTPAdapter, Retry

from zenml.config.server_config import ServerProConfiguration
from zenml.exceptions import SubscriptionUpgradeRequiredError
from zenml.zen_server.utils import server_config

ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_"
from zenml.zen_server.utils import get_zenml_headers, server_config

_cloud_connection: Optional["ZenMLCloudConnection"] = None


class ZenMLCloudConfiguration(BaseModel):
"""ZenML Pro RBAC configuration."""

api_url: str
oauth2_client_id: str
oauth2_client_secret: str
oauth2_audience: str

@field_validator("api_url")
@classmethod
def _strip_trailing_slashes_url(cls, url: str) -> str:
"""Strip any trailing slashes on the API URL.
Args:
url: The API URL.
Returns:
The API URL with potential trailing slashes removed.
"""
return url.rstrip("/")

@classmethod
def from_environment(cls) -> "ZenMLCloudConfiguration":
"""Get the RBAC configuration from environment variables.
Returns:
The RBAC configuration.
"""
env_config: Dict[str, Any] = {}
for k, v in os.environ.items():
if v == "":
continue
if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX):
env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v

return ZenMLCloudConfiguration(**env_config)

model_config = ConfigDict(
# Allow extra attributes from configs of previous ZenML versions to
# permit downgrading
extra="allow"
)


class ZenMLCloudConnection:
"""Class to use for communication between server and control plane."""

def __init__(self) -> None:
"""Initialize the RBAC component."""
self._config = ZenMLCloudConfiguration.from_environment()
self._config = ServerProConfiguration.get_server_config()
self._session: Optional[requests.Session] = None
self._token: Optional[str] = None
self._token_expires_at: Optional[datetime] = None
Expand Down Expand Up @@ -169,6 +121,8 @@ def session(self) -> requests.Session:
self._session = requests.Session()
token = self._fetch_auth_token()
self._session.headers.update({"Authorization": "Bearer " + token})
# Add the ZenML specific headers
self._session.headers.update(get_zenml_headers())

retries = Retry(
total=5, backoff_factor=0.1, status_forcelist=[502, 504]
Expand Down Expand Up @@ -213,8 +167,11 @@ def _fetch_auth_token(self) -> str:
# Get an auth token from the Cloud API
login_url = f"{self._config.api_url}/auth/login"
headers = {"content-type": "application/x-www-form-urlencoded"}
# Add zenml specific headers to the request
headers.update(get_zenml_headers())
payload = {
"client_id": self._config.oauth2_client_id,
# The client ID is the external server ID
"client_id": str(server_config().get_external_server_id()),
"client_secret": self._config.oauth2_client_secret,
"audience": self._config.oauth2_audience,
"grant_type": "client_credentials",
Expand Down
Loading

0 comments on commit 33af3ce

Please sign in to comment.