Skip to content

Commit

Permalink
Pipeline run API token fixes and improvements (#3242)
Browse files Browse the repository at this point in the history
* Remove the workload tokens expiry

* Allow clients to use generic API tokens instead of workload API tokens with pipeline runs

* Actually use the set expiration time in the endpoint
  • Loading branch information
stefannica authored Dec 4, 2024
1 parent e72aef7 commit 61988c0
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 35 deletions.
4 changes: 4 additions & 0 deletions src/zenml/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING,
DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT,
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME,
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_MAX_LIFETIME,
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY,
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE,
DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS,
Expand Down Expand Up @@ -269,6 +270,9 @@ class ServerConfiguration(BaseModel):
generic_api_token_lifetime: PositiveInt = (
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME
)
generic_api_token_max_lifetime: PositiveInt = (
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_MAX_LIFETIME
)

external_login_url: Optional[str] = None
external_user_info_url: Optional[str] = None
Expand Down
10 changes: 10 additions & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK"
ENV_ZENML_CUSTOM_SOURCE_ROOT = "ZENML_CUSTOM_SOURCE_ROOT"
ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME"
ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = (
"ZENML_PIPELINE_API_TOKEN_EXPIRATION"
)

# ZenML Server environment variables
ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_"
Expand Down Expand Up @@ -268,6 +271,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY = 1000
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME = 60 * 60 # 1 hour
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_MAX_LIFETIME = (
60 * 60 * 24 * 7
) # 7 days

DEFAULT_ZENML_SERVER_SECURE_HEADERS_HSTS = (
"max-age=63072000; includeSubdomains"
Expand Down Expand Up @@ -466,3 +472,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:


STACK_DEPLOYMENT_API_TOKEN_EXPIRATION = 60 * 6 # 6 hours

ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = handle_int_env_var(
ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION, default=0
)
82 changes: 55 additions & 27 deletions src/zenml/orchestrators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
ENV_ZENML_ACTIVE_STACK_ID,
ENV_ZENML_ACTIVE_WORKSPACE_ID,
ENV_ZENML_DISABLE_CREDENTIALS_DISK_CACHING,
ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION,
ENV_ZENML_SERVER,
ENV_ZENML_STORE_PREFIX,
ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION,
)
from zenml.enums import AuthScheme, StackComponentType, StoreType
from zenml.enums import APITokenType, AuthScheme, StackComponentType, StoreType
from zenml.logger import get_logger
from zenml.stack import StackComponent

Expand Down Expand Up @@ -137,37 +139,63 @@ def get_config_environment_vars(
url = global_config.store_configuration.url
api_token = credentials_store.get_token(url, allow_expired=False)
if schedule_id or pipeline_run_id or step_run_id:
# When connected to an authenticated ZenML server, if a schedule ID,
# pipeline run ID or step run ID is supplied, we need to fetch a new
# workload API token scoped to the schedule, pipeline run or step
# run.
assert isinstance(global_config.zen_store, RestZenStore)

# If only a schedule is given, the pipeline run credentials will
# be valid for the entire duration of the schedule.
api_key = credentials_store.get_api_key(url)
if not api_key and not pipeline_run_id and not step_run_id:
# The user has the option to manually set an expiration for the API
# token generated for a pipeline run. In this case, we generate a new
# generic API token that will be valid for the indicated duration.
if (
pipeline_run_id
and ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION != 0
):
logger.warning(
"An API token without an expiration time will be generated "
"and used to run this pipeline on a schedule. This is very "
"insecure because the API token will be valid for the "
"entire lifetime of the schedule and can be used to access "
"your user account if accidentally leaked. When deploying "
"a pipeline on a schedule, it is strongly advised to use a "
"service account API key to authenticate to the ZenML "
"server instead of your regular user account. For more "
"information, see "
"https://docs.zenml.io/how-to/connecting-to-zenml/connect-with-a-service-account"
f"An unscoped API token will be generated for this pipeline "
f"run that will expire after "
f"{ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION} "
f"seconds instead of being scoped to the pipeline run "
f"and not having an expiration time. This is more insecure "
f"because the API token will remain valid even after the "
f"pipeline run completes its execution. This option has "
"been explicitly enabled by setting the "
f"{ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION} environment "
f"variable"
)
new_api_token = global_config.zen_store.get_api_token(
token_type=APITokenType.GENERIC,
expires_in=ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION,
)

# The schedule, pipeline run or step run credentials are scoped to
# the schedule, pipeline run or step run and will only be valid for
# the duration of the schedule/pipeline run/step run.
new_api_token = global_config.zen_store.get_api_token(
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
)
else:
# If a schedule ID, pipeline run ID or step run ID is supplied,
# we need to fetch a new workload API token scoped to the
# schedule, pipeline run or step run.

# If only a schedule is given, the pipeline run credentials will
# be valid for the entire duration of the schedule.
api_key = credentials_store.get_api_key(url)
if not api_key and not pipeline_run_id and not step_run_id:
logger.warning(
"An API token without an expiration time will be generated "
"and used to run this pipeline on a schedule. This is very "
"insecure because the API token will be valid for the "
"entire lifetime of the schedule and can be used to access "
"your user account if accidentally leaked. When deploying "
"a pipeline on a schedule, it is strongly advised to use a "
"service account API key to authenticate to the ZenML "
"server instead of your regular user account. For more "
"information, see "
"https://docs.zenml.io/how-to/connecting-to-zenml/connect-with-a-service-account"
)

# The schedule, pipeline run or step run credentials are scoped to
# the schedule, pipeline run or step run and will only be valid for
# the duration of the schedule/pipeline run/step run.
new_api_token = global_config.zen_store.get_api_token(
token_type=APITokenType.WORKLOAD,
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
)

environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = (
new_api_token
Expand Down
9 changes: 7 additions & 2 deletions src/zenml/zen_server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,10 @@ def generate_access_token(
response: The FastAPI response object.
device: The device used for authentication.
api_key: The service account API key used for authentication.
expires_in: The number of seconds until the token expires.
expires_in: The number of seconds until the token expires. If not set,
the default value is determined automatically based on the server
configuration and type of token. If set to 0, the token will not
expire.
schedule_id: The ID of the schedule to scope the token to.
pipeline_run_id: The ID of the pipeline run to scope the token to.
step_run_id: The ID of the step run to scope the token to.
Expand All @@ -835,7 +838,9 @@ def generate_access_token(
# according to the values configured in the server config. Device tokens are
# handled separately from regular user tokens.
expires: Optional[datetime] = None
if expires_in:
if expires_in == 0:
expires_in = None
elif expires_in is not None:
expires = datetime.utcnow() + timedelta(seconds=expires_in)
elif device:
# If a device was used for authentication, the token will expire
Expand Down
22 changes: 20 additions & 2 deletions src/zenml/zen_server/routers/auth_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def device_authorization(
@handle_exceptions
def api_token(
token_type: APITokenType = APITokenType.GENERIC,
expires_in: Optional[int] = None,
schedule_id: Optional[UUID] = None,
pipeline_run_id: Optional[UUID] = None,
step_run_id: Optional[UUID] = None,
Expand All @@ -463,7 +464,8 @@ def api_token(
of API tokens are supported:
* Generic API token: This token is short-lived and can be used for
generic automation tasks.
generic automation tasks. The expiration can be set by the user, but the
server will impose a maximum expiration time.
* Workload API token: This token is scoped to a specific pipeline run, step
run or schedule and is used by pipeline workloads to authenticate with the
server. A pipeline run ID, step run ID or schedule ID must be provided and
Expand All @@ -475,6 +477,10 @@ def api_token(
Args:
token_type: The type of API token to generate.
expires_in: The expiration time of the generic API token in seconds.
If not set, the server will use the default expiration time for
generic API tokens. The server also imposes a maximum expiration
time.
schedule_id: The ID of the schedule to scope the workload API token to.
pipeline_run_id: The ID of the pipeline run to scope the workload API
token to.
Expand Down Expand Up @@ -502,9 +508,19 @@ def api_token(

config = server_config()

if not expires_in:
expires_in = config.generic_api_token_lifetime

if expires_in > config.generic_api_token_max_lifetime:
raise ValueError(
f"The maximum expiration time for generic API tokens allowed "
f"by this server is {config.generic_api_token_max_lifetime} "
"seconds."
)

return generate_access_token(
user_id=token.user_id,
expires_in=config.generic_api_token_lifetime,
expires_in=expires_in,
).access_token

verify_permission(
Expand Down Expand Up @@ -611,4 +627,6 @@ def api_token(
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
# Never expire the token
expires_in=0,
).access_token
4 changes: 3 additions & 1 deletion src/zenml/zen_server/template_execution/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ def run_template(
placeholder_run = create_placeholder_run(deployment=new_deployment)
assert placeholder_run

# We create an API token scoped to the pipeline run
# We create an API token scoped to the pipeline run that never expires
api_token = generate_access_token(
user_id=auth_context.user.id,
pipeline_run_id=placeholder_run.id,
# Keep the original API key or device scopes, if any
api_key=auth_context.api_key,
device=auth_context.device,
# Never expire the token
expires_in=0,
).access_token

environment = {
Expand Down
11 changes: 8 additions & 3 deletions src/zenml/zen_stores/rest_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3873,13 +3873,17 @@ def delete_authorized_device(self, device_id: UUID) -> None:

def get_api_token(
self,
token_type: APITokenType = APITokenType.WORKLOAD,
expires_in: Optional[int] = None,
schedule_id: Optional[UUID] = None,
pipeline_run_id: Optional[UUID] = None,
step_run_id: Optional[UUID] = None,
) -> str:
"""Get an API token for a workload.
"""Get an API token.
Args:
token_type: The type of the token to get.
expires_in: The time in seconds until the token expires.
schedule_id: The ID of the schedule to get a token for.
pipeline_run_id: The ID of the pipeline run to get a token for.
step_run_id: The ID of the step run to get a token for.
Expand All @@ -3891,9 +3895,10 @@ def get_api_token(
ValueError: if the server response is not valid.
"""
params: Dict[str, Any] = {
# Python clients may only request workload tokens.
"token_type": APITokenType.WORKLOAD.value,
"token_type": token_type.value,
}
if expires_in:
params["expires_in"] = expires_in
if schedule_id:
params["schedule_id"] = schedule_id
if pipeline_run_id:
Expand Down

0 comments on commit 61988c0

Please sign in to comment.