diff --git a/README.md b/README.md index 86a15019683..b7398b89ba3 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,7 @@ Or, through our CLI command: zenml stack deploy --provider aws ``` -Alternatively, if the necessary pieces of infrastructure is already deployed, you can register a cloud stack seamlessly through the stack wizard: +Alternatively, if the necessary pieces of infrastructure are already deployed, you can register a cloud stack seamlessly through the stack wizard: ```bash zenml stack register --provider aws @@ -195,9 +195,9 @@ def trainer(training_df: pd.DataFrame) -> Annotated["model", torch.nn.Module]: ![Exploring ZenML Models](/docs/book/.gitbook/assets/readme_mcp.gif) -### Purpose built for machine learning with integration to you favorite tools +### Purpose built for machine learning with integrations to your favorite tools -While ZenML brings a lot of value of the box, it also integrates into your existing tooling and infrastructure without you having to be locked in. +While ZenML brings a lot of value out of the box, it also integrates into your existing tooling and infrastructure without you having to be locked in. ```python from bentoml._internal.bento import bento diff --git a/docs/book/component-guide/step-operators/modal.md b/docs/book/component-guide/step-operators/modal.md index b7f46227c85..4492152050d 100644 --- a/docs/book/component-guide/step-operators/modal.md +++ b/docs/book/component-guide/step-operators/modal.md @@ -86,6 +86,12 @@ def my_modal_step(): ... ``` +{% hint style="info" %} +Note that the `cpu` parameter in `ResourceSettings` currently only accepts a single integer value. This specifies a soft minimum limit - Modal will guarantee at least this many physical cores, but the actual usage could be higher. The CPU cores/hour will also determine the minimum price paid for the compute resources. + +For example, with the configuration above (2 CPUs and 32GB memory), the minimum cost would be approximately $1.03 per hour ((0.135 * 2) + (0.024 * 32) = $1.03). +{% endhint %} + This will run `my_modal_step` on a Modal instance with 1 A100 GPU, 2 CPUs, and 32GB of CPU memory. diff --git a/docs/book/how-to/pipeline-development/build-pipelines/hyper-parameter-tuning.md b/docs/book/how-to/pipeline-development/build-pipelines/hyper-parameter-tuning.md index 35ce3c93c2b..49f8ae72a3e 100644 --- a/docs/book/how-to/pipeline-development/build-pipelines/hyper-parameter-tuning.md +++ b/docs/book/how-to/pipeline-development/build-pipelines/hyper-parameter-tuning.md @@ -16,7 +16,7 @@ def my_pipeline(step_count: int) -> None: data = load_data_step() after = [] for i in range(step_count): - train_step(data, learning_rate=i * 0.0001, name=f"train_step_{i}") + train_step(data, learning_rate=i * 0.0001, id=f"train_step_{i}") after.append(f"train_step_{i}") model = select_model_step(..., after=after) ``` diff --git a/docs/book/how-to/project-setup-and-management/setting-up-a-project-repository/connect-your-git-repository.md b/docs/book/how-to/project-setup-and-management/setting-up-a-project-repository/connect-your-git-repository.md index 063f316fef4..d2e82e82a58 100644 --- a/docs/book/how-to/project-setup-and-management/setting-up-a-project-repository/connect-your-git-repository.md +++ b/docs/book/how-to/project-setup-and-management/setting-up-a-project-repository/connect-your-git-repository.md @@ -54,6 +54,21 @@ zenml code-repository register --type=github \ where \ is the name of the code repository you are registering, \ is the owner of the repository, \ is the name of the repository, \ is your GitHub Personal Access Token and \ is the URL of the GitHub instance which defaults to `https://github.com.` You will need to set a URL if you are using GitHub Enterprise. +{% hint style="warning" %} +Please refer to the section on using secrets for stack configuration in order to securely store your GitHub +Personal Access Token. + +```shell +# Using central secrets management +zenml secret create github_secret \ + --pa_token= + +# Then reference the username and password +zenml code-repository register ... --token={{github_secret.pa_token}} + ... +``` +{% endhint %} + After registering the GitHub code repository, ZenML will automatically detect if your source files are being tracked by GitHub and store the commit hash for each pipeline run.
@@ -96,6 +111,21 @@ zenml code-repository register --type=gitlab \ where `` is the name of the code repository you are registering, `` is the group of the project, `` is the name of the project, \ is your GitLab Personal Access Token, and \ is the URL of the GitLab instance which defaults to `https://gitlab.com.` You will need to set a URL if you have a self-hosted GitLab instance. +{% hint style="warning" %} +Please refer to the section on using secrets for stack configuration in order to securely store your GitLab +Personal Access Token. + +```shell +# Using central secrets management +zenml secret create gitlab_secret \ + --pa_token= + +# Then reference the username and password +zenml code-repository register ... --token={{gitlab_secret.pa_token}} + ... +``` +{% endhint %} + After registering the GitLab code repository, ZenML will automatically detect if your source files are being tracked by GitLab and store the commit hash for each pipeline run.
diff --git a/scripts/test-migrations.sh b/scripts/test-migrations.sh index 145c0df7b27..170ff03e2e3 100755 --- a/scripts/test-migrations.sh +++ b/scripts/test-migrations.sh @@ -23,7 +23,7 @@ else fi # List of versions to test -VERSIONS=("0.40.3" "0.43.0" "0.44.3" "0.45.6" "0.47.0" "0.50.0" "0.51.0" "0.52.0" "0.53.1" "0.54.1" "0.55.5" "0.56.4" "0.57.1" "0.60.0" "0.61.0" "0.62.0" "0.63.0" "0.64.0" "0.65.0" "0.68.0" "0.70.0") +VERSIONS=("0.40.3" "0.43.0" "0.44.3" "0.45.6" "0.47.0" "0.50.0" "0.51.0" "0.52.0" "0.53.1" "0.54.1" "0.55.5" "0.56.4" "0.57.1" "0.60.0" "0.61.0" "0.62.0" "0.63.0" "0.64.0" "0.65.0" "0.68.0" "0.70.0" "0.71.0") # Try to get the latest version using pip index version=$(pip index versions zenml 2>/dev/null | grep -v YANKED | head -n1 | awk '{print $2}' | tr -d '()') diff --git a/src/zenml/artifacts/artifact_config.py b/src/zenml/artifacts/artifact_config.py index 236d88ec9e5..9aa578207d6 100644 --- a/src/zenml/artifacts/artifact_config.py +++ b/src/zenml/artifacts/artifact_config.py @@ -104,15 +104,18 @@ def _remove_old_attributes(cls, data: Dict[str, Any]) -> Dict[str, Any]: ) elif is_model_artifact: logger.warning( - "`ArtifactConfig.is_model_artifact` is deprecated and will be " - "removed soon. Use `ArtifactConfig.artifact_type` instead." + "`ArtifactConfig(..., is_model_artifact=True)` is deprecated " + "and will be removed soon. Use `ArtifactConfig(..., " + "artifact_type=ArtifactType.MODEL)` instead. For more info: " + "https://docs.zenml.io/user-guide/starter-guide/manage-artifacts" ) data.setdefault("artifact_type", ArtifactType.MODEL) elif is_deployment_artifact: logger.warning( - "`ArtifactConfig.is_deployment_artifact` is deprecated and " - "will be removed soon. Use `ArtifactConfig.artifact_type` " - "instead." + "`ArtifactConfig(..., is_deployment_artifact=True)` is " + "deprecated and will be removed soon. Use `ArtifactConfig(..., " + "artifact_type=ArtifactType.SERVICE)` instead. For more info: " + "https://docs.zenml.io/user-guide/starter-guide/manage-artifacts" ) data.setdefault("artifact_type", ArtifactType.SERVICE) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 2573964aa77..7930acc5c77 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -414,7 +414,9 @@ def log_artifact_metadata( """ logger.warning( "The `log_artifact_metadata` function is deprecated and will soon be " - "removed. Please use `log_metadata` instead." + "removed. Instead, you can consider using: " + "`log_metadata(metadata={...}, infer_artifact=True, ...)` instead. For more " + "info: https://docs.zenml.io/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact" ) from zenml import log_metadata diff --git a/src/zenml/cli/server.py b/src/zenml/cli/server.py index e5c0b135f5d..39c241c377b 100644 --- a/src/zenml/cli/server.py +++ b/src/zenml/cli/server.py @@ -587,25 +587,6 @@ def server_list(verbose: bool = False, all: bool = False) -> None: accessible_pro_servers = client.tenant.list(member_only=not all) except AuthorizationException as e: cli_utils.warning(f"ZenML Pro authorization error: {e}") - else: - if not all: - accessible_pro_servers = [ - s - for s in accessible_pro_servers - if s.status == TenantStatus.AVAILABLE - ] - - if not accessible_pro_servers: - cli_utils.declare( - "No ZenML Pro servers that are accessible to the current " - "user could be found." - ) - if not all: - cli_utils.declare( - "Hint: use the `--all` flag to show all ZenML servers, " - "including those that the client is not currently " - "authorized to access or are not running." - ) # We update the list of stored ZenML Pro servers with the ones that the # client is a member of @@ -633,6 +614,25 @@ def server_list(verbose: bool = False, all: bool = False) -> None: stored_server.update_server_info(accessible_server) pro_servers.append(stored_server) + if not all: + accessible_pro_servers = [ + s + for s in accessible_pro_servers + if s.status == TenantStatus.AVAILABLE + ] + + if not accessible_pro_servers: + cli_utils.declare( + "No ZenML Pro servers that are accessible to the current " + "user could be found." + ) + if not all: + cli_utils.declare( + "Hint: use the `--all` flag to show all ZenML servers, " + "including those that the client is not currently " + "authorized to access or are not running." + ) + elif pro_servers: cli_utils.warning( "The ZenML Pro authentication has expired. Please re-login " diff --git a/src/zenml/client.py b/src/zenml/client.py index 995f2d8bdb3..b3dddd3e1c9 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -109,7 +109,6 @@ EventSourceResponse, EventSourceUpdate, FlavorFilter, - FlavorRequest, FlavorResponse, ModelFilter, ModelRequest, @@ -1702,6 +1701,7 @@ def list_services( updated: Optional[datetime] = None, type: Optional[str] = None, flavor: Optional[str] = None, + user: Optional[Union[UUID, str]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, hydrate: bool = False, @@ -1727,6 +1727,7 @@ def list_services( flavor: Use the service flavor for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. running: Use the running status for filtering @@ -1753,6 +1754,7 @@ def list_services( flavor=flavor, workspace_id=workspace_id, user_id=user_id, + user=user, running=running, name=service_name, pipeline_name=pipeline_name, @@ -2198,17 +2200,8 @@ def create_flavor( "configuration class' docstring." ) - create_flavor_request = FlavorRequest( - source=source, - type=flavor.type, - name=flavor.name, - config_schema=flavor.config_schema, - integration="custom", - user=self.active_user.id, - workspace=self.active_workspace.id, - ) - - return self.zen_store.create_flavor(flavor=create_flavor_request) + flavor_request = flavor.to_model(integration="custom", is_custom=True) + return self.zen_store.create_flavor(flavor=flavor_request) def get_flavor( self, @@ -2249,6 +2242,7 @@ def list_flavors( type: Optional[str] = None, integration: Optional[str] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[FlavorResponse]: """Fetches all the flavor models. @@ -2262,6 +2256,7 @@ def list_flavors( created: Use to flavors by time of creation updated: Use the last updated date for filtering user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the flavor to filter by. type: The type of the flavor to filter by. integration: The integration of the flavor to filter by. @@ -2277,6 +2272,7 @@ def list_flavors( sort_by=sort_by, logical_operator=logical_operator, user_id=user_id, + user=user, name=name, type=type, integration=integration, @@ -2661,13 +2657,16 @@ def list_builds( updated: Optional[Union[datetime, str]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, pipeline_id: Optional[Union[str, UUID]] = None, stack_id: Optional[Union[str, UUID]] = None, + container_registry_id: Optional[Union[UUID, str]] = None, is_local: Optional[bool] = None, contains_code: Optional[bool] = None, zenml_version: Optional[str] = None, python_version: Optional[str] = None, checksum: Optional[str] = None, + stack_checksum: Optional[str] = None, hydrate: bool = False, ) -> Page[PipelineBuildResponse]: """List all builds. @@ -2682,13 +2681,17 @@ def list_builds( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_id: The id of the pipeline to filter by. stack_id: The id of the stack to filter by. + container_registry_id: The id of the container registry to + filter by. is_local: Use to filter local builds. contains_code: Use to filter builds that contain code. zenml_version: The version of ZenML to filter by. python_version: The Python version to filter by. checksum: The build checksum to filter by. + stack_checksum: The stack checksum to filter by. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -2705,13 +2708,16 @@ def list_builds( updated=updated, workspace_id=workspace_id, user_id=user_id, + user=user, pipeline_id=pipeline_id, stack_id=stack_id, + container_registry_id=container_registry_id, is_local=is_local, contains_code=contains_code, zenml_version=zenml_version, python_version=python_version, checksum=checksum, + stack_checksum=stack_checksum, ) build_filter_model.set_scope_workspace(self.active_workspace.id) return self.zen_store.list_builds( @@ -2771,7 +2777,7 @@ def get_event_source( allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> EventSourceResponse: - """Get a event source by name, ID or prefix. + """Get an event source by name, ID or prefix. Args: name_id_or_prefix: The name, ID or prefix of the stack. @@ -2804,6 +2810,7 @@ def list_event_sources( event_source_type: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[EventSourceResponse]: """Lists all event_sources. @@ -2818,6 +2825,7 @@ def list_event_sources( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the event_source to filter by. flavor: The flavor of the event_source to filter by. event_source_type: The subtype of the event_source to filter by. @@ -2834,6 +2842,7 @@ def list_event_sources( logical_operator=logical_operator, workspace_id=workspace_id, user_id=user_id, + user=user, name=name, flavor=flavor, plugin_subtype=event_source_type, @@ -3001,6 +3010,7 @@ def list_actions( action_type: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[ActionResponse]: """List actions. @@ -3015,6 +3025,7 @@ def list_actions( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the action to filter by. flavor: The flavor of the action to filter by. action_type: The type of the action to filter by. @@ -3031,6 +3042,7 @@ def list_actions( logical_operator=logical_operator, workspace_id=workspace_id, user_id=user_id, + user=user, name=name, id=id, flavor=flavor, @@ -3179,6 +3191,7 @@ def list_triggers( action_subtype: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[TriggerResponse]: """Lists all triggers. @@ -3193,6 +3206,7 @@ def list_triggers( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the trigger to filter by. event_source_id: The event source associated with the trigger. action_id: The action associated with the trigger. @@ -3215,6 +3229,7 @@ def list_triggers( logical_operator=logical_operator, workspace_id=workspace_id, user_id=user_id, + user=user, name=name, event_source_id=event_source_id, action_id=action_id, @@ -3365,6 +3380,7 @@ def list_deployments( updated: Optional[Union[datetime, str]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, pipeline_id: Optional[Union[str, UUID]] = None, stack_id: Optional[Union[str, UUID]] = None, build_id: Optional[Union[str, UUID]] = None, @@ -3383,6 +3399,7 @@ def list_deployments( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_id: The id of the pipeline to filter by. stack_id: The id of the stack to filter by. build_id: The id of the build to filter by. @@ -3403,6 +3420,7 @@ def list_deployments( updated=updated, workspace_id=workspace_id, user_id=user_id, + user=user, pipeline_id=pipeline_id, stack_id=stack_id, build_id=build_id, @@ -3488,6 +3506,7 @@ def list_run_templates( logical_operator: LogicalOperators = LogicalOperators.AND, created: Optional[Union[datetime, str]] = None, updated: Optional[Union[datetime, str]] = None, + id: Optional[Union[UUID, str]] = None, name: Optional[str] = None, tag: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, @@ -3510,6 +3529,7 @@ def list_run_templates( logical_operator: Which logical operator to use [and, or]. created: Filter by the creation date. updated: Filter by the last updated date. + id: Filter by run template ID. name: Filter by run template name. tag: Filter by run template tags. workspace_id: Filter by workspace ID. @@ -3534,6 +3554,7 @@ def list_run_templates( logical_operator=logical_operator, created=created, updated=updated, + id=id, name=name, tag=tag, workspace_id=workspace_id, @@ -3650,6 +3671,7 @@ def list_schedules( name: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, pipeline_id: Optional[Union[str, UUID]] = None, orchestrator_id: Optional[Union[str, UUID]] = None, active: Optional[Union[str, bool]] = None, @@ -3674,6 +3696,7 @@ def list_schedules( name: The name of the stack to filter by. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_id: The id of the pipeline to filter by. orchestrator_id: The id of the orchestrator to filter by. active: Use to filter by active status. @@ -3700,6 +3723,7 @@ def list_schedules( name=name, workspace_id=workspace_id, user_id=user_id, + user=user, pipeline_id=pipeline_id, orchestrator_id=orchestrator_id, active=active, @@ -3940,6 +3964,7 @@ def list_run_steps( original_step_run_id: Optional[Union[str, UUID]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, model_version_id: Optional[Union[str, UUID]] = None, model: Optional[Union[UUID, str]] = None, hydrate: bool = False, @@ -3958,6 +3983,7 @@ def list_run_steps( end_time: Use to filter by the time when the step finished running workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_run_id: The id of the pipeline run to filter by. deployment_id: The id of the deployment to filter by. original_step_run_id: The id of the original step run to filter by. @@ -3992,6 +4018,7 @@ def list_run_steps( name=name, workspace_id=workspace_id, user_id=user_id, + user=user, model_version_id=model_version_id, model=model, ) @@ -4664,6 +4691,7 @@ def list_secrets( scope: Optional[SecretScope] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[SecretResponse]: """Fetches all the secret models. @@ -4683,6 +4711,7 @@ def list_secrets( scope: The scope of the secret to filter by. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -4699,6 +4728,7 @@ def list_secrets( sort_by=sort_by, logical_operator=logical_operator, user_id=user_id, + user=user, workspace_id=workspace_id, name=name, scope=scope, @@ -5013,6 +5043,7 @@ def list_code_repositories( name: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[CodeRepositoryResponse]: """List all code repositories. @@ -5028,6 +5059,7 @@ def list_code_repositories( name: The name of the code repository to filter by. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -5045,6 +5077,7 @@ def list_code_repositories( name=name, workspace_id=workspace_id, user_id=user_id, + user=user, ) filter_model.set_scope_workspace(self.active_workspace.id) return self.zen_store.list_code_repositories( @@ -5405,6 +5438,7 @@ def list_service_connectors( resource_id: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, labels: Optional[Dict[str, Optional[str]]] = None, secret_id: Optional[Union[str, UUID]] = None, hydrate: bool = False, @@ -5427,6 +5461,7 @@ def list_service_connectors( they can give access to. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the service connector to filter by. labels: The labels of the service connector to filter by. secret_id: Filter by the id of the secret that is referenced by the @@ -5444,6 +5479,7 @@ def list_service_connectors( logical_operator=logical_operator, workspace_id=workspace_id or self.active_workspace.id, user_id=user_id, + user=user, name=name, connector_type=connector_type, auth_method=auth_method, @@ -6596,6 +6632,7 @@ def list_authorized_devices( client_id: Union[UUID, str, None] = None, status: Union[OAuthDeviceStatus, str, None] = None, trusted_device: Union[bool, str, None] = None, + user: Optional[Union[UUID, str]] = None, failed_auth_attempts: Union[int, str, None] = None, last_login: Optional[Union[datetime, str, None]] = None, hydrate: bool = False, @@ -6613,6 +6650,7 @@ def list_authorized_devices( expires: Use the expiration date for filtering. client_id: Use the client id for filtering. status: Use the status for filtering. + user: Filter by user name/ID. trusted_device: Use the trusted device flag for filtering. failed_auth_attempts: Use the failed auth attempts for filtering. last_login: Use the last login date for filtering. @@ -6632,6 +6670,7 @@ def list_authorized_devices( updated=updated, expires=expires, client_id=client_id, + user=user, status=status, trusted_device=trusted_device, failed_auth_attempts=failed_auth_attempts, @@ -6730,7 +6769,7 @@ def get_trigger_execution( trigger_execution_id: UUID, hydrate: bool = True, ) -> TriggerExecutionResponse: - """Get an trigger execution by ID. + """Get a trigger execution by ID. Args: trigger_execution_id: The ID of the trigger execution to get. @@ -6751,6 +6790,7 @@ def list_trigger_executions( size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, trigger_id: Optional[UUID] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[TriggerExecutionResponse]: """List all trigger executions matching the given filter criteria. @@ -6761,6 +6801,7 @@ def list_trigger_executions( size: The maximum size of all pages. logical_operator: Which logical operator to use [and, or]. trigger_id: ID of the trigger to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -6772,6 +6813,7 @@ def list_trigger_executions( sort_by=sort_by, page=page, size=size, + user=user, logical_operator=logical_operator, ) filter_model.set_scope_workspace(self.active_workspace.id) diff --git a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py index 0b7b01b546d..52b19af2afe 100644 --- a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +++ b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py @@ -33,7 +33,6 @@ from zenml.step_operators import BaseStepOperator if TYPE_CHECKING: - from zenml.config.base_settings import BaseSettings from zenml.config.step_run_info import StepRunInfo from zenml.models import PipelineDeploymentBase diff --git a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py index 11e84ca2d38..7a1a7321704 100644 --- a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +++ b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py @@ -23,7 +23,7 @@ cast, ) -from pydantic import field_validator +from pydantic import field_validator, BaseModel from zenml.config.base_settings import BaseSettings from zenml.experiment_trackers.base_experiment_tracker import ( @@ -60,18 +60,26 @@ def _convert_settings(cls, value: Any) -> Any: Args: value: The settings. + Raises: + ValueError: If converting the settings failed. + Returns: Dict representation of the settings. """ import wandb if isinstance(value, wandb.Settings): - # Depending on the wandb version, either `make_static` or `to_dict` - # is available to convert the settings to a dictionary - if hasattr(value, "make_static"): + # Depending on the wandb version, either `model_dump`, + # `make_static` or `to_dict` is available to convert the settings + # to a dictionary + if isinstance(value, BaseModel): + return value.model_dump() + elif hasattr(value, "make_static"): return cast(Dict[str, Any], value.make_static()) - else: + elif hasattr(value, "to_dict"): return value.to_dict() + else: + raise ValueError("Unable to convert wandb settings to dict.") else: return value diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index b7f3c591518..f7987fc1b58 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -509,22 +509,6 @@ def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]: raise ValueError( "`model_version_id` field is for internal use only" ) - - version = data.get("version", None) - - if ( - version in [stage.value for stage in ModelStages] - and not suppress_class_validation_warnings - ): - logger.info( - f"Version `{version}` matches one of the possible " - "`ModelStages` and will be fetched using stage." - ) - if str(version).isnumeric() and not suppress_class_validation_warnings: - logger.info( - f"`version` `{version}` is numeric and will be fetched " - "using version number." - ) data["suppress_class_validation_warnings"] = True return data @@ -603,6 +587,18 @@ def _get_model_version( hydrate=hydrate, ) else: + if self.version in ModelStages.values(): + logger.info( + f"Version `{self.version}` for model {self.name} matches " + "one of the possible `ModelStages` and will be fetched " + "using stage." + ) + if str(self.version).isnumeric(): + logger.info( + f"Version `{self.version}` for model {self.name} is " + "numeric and will be fetched using version number." + ) + mv = zenml_client.get_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=self.version, diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index a3612fc2c12..2c87d83b6d0 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -56,7 +56,9 @@ def log_model_metadata( """ logger.warning( "The `log_model_metadata` function is deprecated and will soon be " - "removed. Please use `log_metadata` instead." + "removed. Instead, you can consider using: " + "`log_metadata(metadata={...}, infer_model=True)` instead. For more " + "info: https://docs.zenml.io/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model" ) from zenml import log_metadata diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 1c4d2cccfb5..d2aa8380be4 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -113,7 +113,7 @@ def validate_operation(cls, value: Any) -> Any: def generate_query_conditions( self, table: Type[SQLModel], - ) -> Union["ColumnElement[bool]"]: + ) -> "ColumnElement[bool]": """Generate the query conditions for the database. This method converts the Filter class into an appropriate SQLModel @@ -291,11 +291,19 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: import sqlalchemy from sqlalchemy_utils.functions import cast_if + from zenml.utils import uuid_utils + # For equality checks, compare the UUID directly if self.operation == GenericFilterOps.EQUALS: + if not uuid_utils.is_valid_uuid(self.value): + return False + return column == self.value if self.operation == GenericFilterOps.NOT_EQUALS: + if not uuid_utils.is_valid_uuid(self.value): + return True + return column != self.value # For all other operations, cast and handle the column as string @@ -436,7 +444,6 @@ class BaseFilter(BaseModel): le=PAGE_SIZE_MAXIMUM, description="Page size", ) - id: Optional[Union[UUID, str]] = Field( default=None, description="Id for this resource", @@ -491,13 +498,13 @@ def validate_sort_by(cls, value: Any) -> Any: ) value = column - if column in cls.FILTER_EXCLUDE_FIELDS: + if column in cls.CUSTOM_SORTING_OPTIONS: + return value + elif column in cls.FILTER_EXCLUDE_FIELDS: raise ValueError( f"This resource can not be sorted by this field: '{value}'" ) - elif column in cls.model_fields: - return value - elif column in cls.CUSTOM_SORTING_OPTIONS: + if column in cls.model_fields: return value else: raise ValueError( @@ -703,16 +710,10 @@ def generate_name_or_id_query_conditions( conditions = [] - try: - filter_ = FilterGenerator(table).define_filter( - column="id", value=value, operator=operator - ) - conditions.append(filter_.generate_query_conditions(table=table)) - except ValueError: - # UUID filter with equal operators and no full UUID fail with - # a ValueError. In this case, we already know that the filter - # will not produce any result and can simply ignore it. - pass + filter_ = FilterGenerator(table).define_filter( + column="id", value=value, operator=operator + ) + conditions.append(filter_.generate_query_conditions(table=table)) filter_ = FilterGenerator(table).define_filter( column="name", value=value, operator=operator @@ -759,7 +760,7 @@ def offset(self) -> int: return self.size * (self.page - 1) def generate_filter( - self, table: Type[SQLModel] + self, table: Type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. @@ -779,7 +780,7 @@ def generate_filter( filters.append( column_filter.generate_query_conditions(table=table) ) - for custom_filter in self.get_custom_filters(): + for custom_filter in self.get_custom_filters(table): filters.append(custom_filter) if self.logical_operator == LogicalOperators.OR: return or_(False, *filters) @@ -788,12 +789,17 @@ def generate_filter( else: raise RuntimeError("No valid logical operator was supplied.") - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom filters. This can be overridden by subclasses to define custom filters that are not based on the columns of the underlying table. + Args: + table: The query table. + Returns: A list of custom filters. """ @@ -1101,18 +1107,8 @@ def _define_uuid_filter( A Filter object. Raises: - ValueError: If the value is not a valid UUID. + ValueError: If the value for a oneof filter is not a list. """ - # For equality checks, ensure that the value is a valid UUID. - if operator == GenericFilterOps.EQUALS and not isinstance(value, UUID): - try: - UUID(value) - except ValueError as e: - raise ValueError( - "Invalid value passed as UUID query parameter." - ) from e - - # For equality checks, ensure that the value is a valid UUID. if operator == GenericFilterOps.ONEOF and not isinstance(value, list): raise ValueError(ONEOF_ERROR) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index f563b6dc81c..f5267f4840d 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -23,6 +23,7 @@ Optional, Type, TypeVar, + Union, ) from uuid import UUID @@ -151,16 +152,32 @@ class UserScopedFilter(BaseFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *BaseFilter.FILTER_EXCLUDE_FIELDS, + "user", "scope_user", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *BaseFilter.CLI_EXCLUDE_FIELDS, + "user_id", "scope_user", ] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *BaseFilter.CUSTOM_SORTING_OPTIONS, + "user", + ] + scope_user: Optional[UUID] = Field( default=None, description="The user to scope this query to.", ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, + description="UUID of the user that created the entity.", + union_mode="left_to_right", + ) + user: Optional[Union[UUID, str]] = Field( + default=None, + description="Name/ID of the user that created the entity.", + ) def set_scope_user(self, user_id: UUID) -> None: """Set the user that is performing the filtering to scope the response. @@ -170,6 +187,73 @@ def set_scope_user(self, user_id: UUID) -> None: """ self.scope_user = user_id + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: + """Get custom filters. + + Args: + table: The query table. + + Returns: + A list of custom filters. + """ + custom_filters = super().get_custom_filters(table) + + from sqlmodel import and_ + + from zenml.zen_stores.schemas import UserSchema + + if self.user: + user_filter = and_( + getattr(table, "user_id") == UserSchema.id, + self.generate_name_or_id_query_conditions( + value=self.user, + table=UserSchema, + additional_columns=["full_name"], + ), + ) + custom_filters.append(user_filter) + + return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import UserSchema + + sort_by, operand = self.sorting_params + + if sort_by == "user": + column = UserSchema.name + + query = query.join( + UserSchema, getattr(table, "user_id") == UserSchema.id + ) + + if operand == SorterOps.ASCENDING: + query = query.order_by(asc(column)) + else: + query = query.order_by(desc(column)) + + return query + + return super().apply_sorting(query=query, table=table) + def apply_filter( self, query: AnyQuery, @@ -240,21 +324,37 @@ def workspace(self) -> "WorkspaceResponse": return self.get_metadata().workspace -class WorkspaceScopedFilter(BaseFilter): +class WorkspaceScopedFilter(UserScopedFilter): """Model to enable advanced scoping with workspace.""" FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *BaseFilter.FILTER_EXCLUDE_FIELDS, + *UserScopedFilter.FILTER_EXCLUDE_FIELDS, + "workspace", "scope_workspace", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *BaseFilter.CLI_EXCLUDE_FIELDS, + *UserScopedFilter.CLI_EXCLUDE_FIELDS, + "workspace_id", + "workspace", "scope_workspace", ] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *UserScopedFilter.CUSTOM_SORTING_OPTIONS, + "workspace", + ] scope_workspace: Optional[UUID] = Field( default=None, description="The workspace to scope this query to.", ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, + description="UUID of the workspace that this entity belongs to.", + union_mode="left_to_right", + ) + workspace: Optional[Union[UUID, str]] = Field( + default=None, + description="Name/ID of the workspace that this entity belongs to.", + ) def set_scope_workspace(self, workspace_id: UUID) -> None: """Set the workspace to scope this response. @@ -264,6 +364,35 @@ def set_scope_workspace(self, workspace_id: UUID) -> None: """ self.scope_workspace = workspace_id + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: + """Get custom filters. + + Args: + table: The query table. + + Returns: + A list of custom filters. + """ + custom_filters = super().get_custom_filters(table) + + from sqlmodel import and_ + + from zenml.zen_stores.schemas import WorkspaceSchema + + if self.workspace: + workspace_filter = and_( + getattr(table, "workspace_id") == WorkspaceSchema.id, + self.generate_name_or_id_query_conditions( + value=self.workspace, + table=WorkspaceSchema, + ), + ) + custom_filters.append(workspace_filter) + + return custom_filters + def apply_filter( self, query: AnyQuery, @@ -291,6 +420,44 @@ def apply_filter( return query + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import WorkspaceSchema + + sort_by, operand = self.sorting_params + + if sort_by == "workspace": + column = WorkspaceSchema.name + + query = query.join( + WorkspaceSchema, + getattr(table, "workspace_id") == WorkspaceSchema.id, + ) + + if operand == SorterOps.ASCENDING: + query = query.order_by(asc(column)) + else: + query = query.order_by(desc(column)) + + return query + + return super().apply_sorting(query=query, table=table) + class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): """Model to enable advanced scoping with workspace and tagging.""" @@ -304,6 +471,11 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): "tag", ] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS, + "tag", + ] + def apply_filter( self, query: AnyQuery, @@ -330,15 +502,20 @@ def apply_filter( return query - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom tag filters. + Args: + table: The query table. + Returns: A list of custom filters. """ from zenml.zen_stores.schemas import TagSchema - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) if self.tag: custom_filters.append( self.generate_custom_query_conditions_for_column( @@ -347,3 +524,79 @@ def get_custom_filters(self) -> List["ColumnElement[bool]"]: ) return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + sort_by, operand = self.sorting_params + + if sort_by == "tag": + from sqlmodel import and_, asc, desc, func + + from zenml.enums import SorterOps, TaggableResourceTypes + from zenml.zen_stores.schemas import ( + ArtifactSchema, + ArtifactVersionSchema, + ModelSchema, + ModelVersionSchema, + PipelineRunSchema, + PipelineSchema, + RunTemplateSchema, + TagResourceSchema, + TagSchema, + ) + + resource_type_mapping = { + ArtifactSchema: TaggableResourceTypes.ARTIFACT, + ArtifactVersionSchema: TaggableResourceTypes.ARTIFACT_VERSION, + ModelSchema: TaggableResourceTypes.MODEL, + ModelVersionSchema: TaggableResourceTypes.MODEL_VERSION, + PipelineSchema: TaggableResourceTypes.PIPELINE, + PipelineRunSchema: TaggableResourceTypes.PIPELINE_RUN, + RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE, + } + + query = ( + query.outerjoin( + TagResourceSchema, + and_( + table.id == TagResourceSchema.resource_id, + TagResourceSchema.resource_type + == resource_type_mapping[table], + ), + ) + .outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id) + .group_by(table.id) + ) + + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc( + func.group_concat(TagSchema.name, ",").label( + "tags_list" + ) + ) + ) + else: + query = query.order_by( + desc( + func.group_concat(TagSchema.name, ",").label( + "tags_list" + ) + ) + ) + + return query + + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index cd5089a3db4..a6998b92b3c 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -20,6 +20,8 @@ Dict, List, Optional, + Type, + TypeVar, Union, ) from uuid import UUID @@ -58,6 +60,10 @@ ) from zenml.models.v2.core.pipeline_run import PipelineRunResponse from zenml.models.v2.core.step_run import StepRunResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + logger = get_logger(__name__) @@ -471,7 +477,6 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): "name", "only_unused", "has_custom_name", - "user", "model", "pipeline_run", "model_version_id", @@ -516,19 +521,10 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): description="Artifact store for this artifact", union_mode="left_to_right", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace for this artifact", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that produced this artifact", - union_mode="left_to_right", - ) model_version_id: Optional[Union[UUID, str]] = Field( default=None, - description="ID of the model version that is associated with this artifact version.", + description="ID of the model version that is associated with this " + "artifact version.", union_mode="left_to_right", ) only_unused: Optional[bool] = Field( @@ -559,13 +555,18 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): model_config = ConfigDict(protected_namespaces=()) - def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List[Union["ColumnElement[bool]"]]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_, or_, select @@ -581,7 +582,6 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: StepRunInputArtifactSchema, StepRunOutputArtifactSchema, StepRunSchema, - UserSchema, ) if self.name: @@ -629,17 +629,6 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ) custom_filters.append(custom_name_filter) - if self.user: - user_filter = and_( - ArtifactVersionSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.model: model_filter = and_( ArtifactVersionSchema.id diff --git a/src/zenml/models/v2/core/code_repository.py b/src/zenml/models/v2/core/code_repository.py index c0a5430468b..485f710b7de 100644 --- a/src/zenml/models/v2/core/code_repository.py +++ b/src/zenml/models/v2/core/code_repository.py @@ -13,8 +13,7 @@ # permissions and limitations under the License. """Models representing code repositories.""" -from typing import Any, Dict, Optional, Union -from uuid import UUID +from typing import Any, Dict, Optional from pydantic import Field @@ -189,13 +188,3 @@ class CodeRepositoryFilter(WorkspaceScopedFilter): description="Name of the code repository.", default=None, ) - workspace_id: Optional[Union[UUID, str]] = Field( - description="Workspace of the code repository.", - default=None, - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - description="User that created the code repository.", - default=None, - union_mode="left_to_right", - ) diff --git a/src/zenml/models/v2/core/component.py b/src/zenml/models/v2/core/component.py index a4f52be884c..98418589222 100644 --- a/src/zenml/models/v2/core/component.py +++ b/src/zenml/models/v2/core/component.py @@ -21,6 +21,7 @@ List, Optional, Type, + TypeVar, Union, ) from uuid import UUID @@ -42,9 +43,11 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement - from sqlmodel import SQLModel from zenml.models import FlavorResponse, ServiceConnectorResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Base Model ------------------ @@ -356,7 +359,6 @@ class ComponentFilter(WorkspaceScopedFilter): *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "scope_type", "stack_id", - "user", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS, @@ -366,7 +368,6 @@ class ComponentFilter(WorkspaceScopedFilter): default=None, description="The type to scope this query to.", ) - name: Optional[str] = Field( default=None, description="Name of the stack component", @@ -379,16 +380,6 @@ class ComponentFilter(WorkspaceScopedFilter): default=None, description="Type of the stack component", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the stack component", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the stack component", - union_mode="left_to_right", - ) connector_id: Optional[Union[UUID, str]] = Field( default=None, description="Connector linked to the stack component", @@ -399,10 +390,6 @@ class ComponentFilter(WorkspaceScopedFilter): description="Stack of the stack component", union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the component.", - ) def set_scope_type(self, component_type: str) -> None: """Set the type of component on which to perform the filtering to scope the response. @@ -413,7 +400,7 @@ def set_scope_type(self, component_type: str) -> None: self.scope_type = component_type def generate_filter( - self, table: Type["SQLModel"] + self, table: Type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. @@ -449,31 +436,3 @@ def generate_filter( base_filter = operator(base_filter, stack_filter) return base_filter - - def get_custom_filters(self) -> List["ColumnElement[bool]"]: - """Get custom filters. - - Returns: - A list of custom filters. - """ - from sqlmodel import and_ - - from zenml.zen_stores.schemas import ( - StackComponentSchema, - UserSchema, - ) - - custom_filters = super().get_custom_filters() - - if self.user: - user_filter = and_( - StackComponentSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - - return custom_filters diff --git a/src/zenml/models/v2/core/flavor.py b/src/zenml/models/v2/core/flavor.py index fd4110300c3..77fe774c073 100644 --- a/src/zenml/models/v2/core/flavor.py +++ b/src/zenml/models/v2/core/flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing flavors.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional from uuid import UUID from pydantic import Field @@ -428,13 +428,3 @@ class FlavorFilter(WorkspaceScopedFilter): default=None, description="Integration associated with the flavor", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the stack", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the stack", - union_mode="left_to_right", - ) diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index 0eb3b749c88..0b5272ab7e6 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing models.""" -from typing import TYPE_CHECKING, ClassVar, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional from uuid import UUID from pydantic import BaseModel, Field @@ -30,8 +30,6 @@ from zenml.utils.pagination_utils import depaginate if TYPE_CHECKING: - from sqlalchemy.sql.elements import ColumnElement - from zenml.model.model import Model from zenml.models.v2.core.tag import TagResponse @@ -318,61 +316,7 @@ def versions(self) -> List["Model"]: class ModelFilter(WorkspaceScopedTaggableFilter): """Model to enable advanced filtering of all Workspaces.""" - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS, - "workspace_id", - "user_id", - ] - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, - "user", - ] - name: Optional[str] = Field( default=None, description="Name of the Model", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Model", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the Model", - union_mode="left_to_right", - ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the model.", - ) - - def get_custom_filters( - self, - ) -> List["ColumnElement[bool]"]: - """Get custom filters. - - Returns: - A list of custom filters. - """ - custom_filters = super().get_custom_filters() - - from sqlmodel import and_ - - from zenml.zen_stores.schemas import ( - ModelSchema, - UserSchema, - ) - - if self.user: - user_filter = and_( - ModelSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - - return custom_filters diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index d1a7a951978..80880f1e70e 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -77,10 +77,6 @@ class ModelVersionRequest(WorkspaceScopedRequest): default=None, ) - number: Optional[int] = Field( - description="The number of the model version", - default=None, - ) model: UUID = Field( description="The ID of the model containing version", ) @@ -585,7 +581,6 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, - "user", "run_metadata", ] @@ -597,25 +592,11 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): default=None, description="The number of the Model Version", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="The workspace of the Model Version", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="The user of the Model Version", - union_mode="left_to_right", - ) stage: Optional[Union[str, ModelStages]] = Field( description="The model version stage", default=None, union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the model version.", - ) run_metadata: Optional[Dict[str, str]] = Field( default=None, description="The run_metadata to filter the model versions by.", @@ -639,14 +620,17 @@ def set_scope_model(self, model_name_or_id: Union[str, UUID]) -> None: self._model_id = model_id def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ @@ -654,20 +638,8 @@ def get_custom_filters( ModelVersionSchema, RunMetadataResourceSchema, RunMetadataSchema, - UserSchema, ) - if self.user: - user_filter = and_( - ModelVersionSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.run_metadata is not None: from zenml.enums import MetadataResourceTypes diff --git a/src/zenml/models/v2/core/model_version_artifact.py b/src/zenml/models/v2/core/model_version_artifact.py index f3a677a86e9..6c9514b9735 100644 --- a/src/zenml/models/v2/core/model_version_artifact.py +++ b/src/zenml/models/v2/core/model_version_artifact.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing the link between model versions and artifacts.""" -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Union from uuid import UUID from pydantic import ConfigDict, Field @@ -32,6 +32,9 @@ from sqlalchemy.sql.elements import ColumnElement from zenml.models.v2.core.artifact_version import ArtifactVersionResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -164,13 +167,18 @@ class ModelVersionArtifactFilter(BaseFilter): # careful we might overwrite some fields protected by pydantic. model_config = ConfigDict(protected_namespaces=()) - def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List[Union["ColumnElement[bool]"]]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_, col diff --git a/src/zenml/models/v2/core/model_version_pipeline_run.py b/src/zenml/models/v2/core/model_version_pipeline_run.py index 6181c2ffbb1..40e7f823d9c 100644 --- a/src/zenml/models/v2/core/model_version_pipeline_run.py +++ b/src/zenml/models/v2/core/model_version_pipeline_run.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing the link between model versions and pipeline runs.""" -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Union from uuid import UUID from pydantic import ConfigDict, Field @@ -30,6 +30,12 @@ from zenml.models.v2.base.filter import BaseFilter, StrFilter from zenml.models.v2.core.pipeline_run import PipelineRunResponse +if TYPE_CHECKING: + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + + # ------------------ Request Model ------------------ @@ -147,13 +153,18 @@ class ModelVersionPipelineRunFilter(BaseFilter): # careful we might overwrite some fields protected by pydantic. model_config = ConfigDict(protected_namespaces=()) - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 5166e0abb9c..199e9cce959 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -21,7 +21,6 @@ Optional, Type, TypeVar, - Union, ) from uuid import UUID @@ -45,9 +44,7 @@ from zenml.models.v2.core.tag import TagResponse if TYPE_CHECKING: - from sqlalchemy.sql.elements import ColumnElement - - from zenml.models.v2.core.pipeline_run import PipelineRunResponse + from zenml.models import PipelineRunResponse, UserResponse from zenml.zen_stores.schemas import BaseSchema AnySchema = TypeVar("AnySchema", bound=BaseSchema) @@ -122,6 +119,10 @@ class PipelineResponseMetadata(WorkspaceScopedResponseMetadata): class PipelineResponseResources(WorkspaceScopedResponseResources): """Class for all resource models associated with the pipeline entity.""" + latest_run_user: Optional["UserResponse"] = Field( + default=None, + title="The user that created the latest run of this pipeline.", + ) tags: List[TagResponse] = Field( title="Tags associated with the pipeline.", ) @@ -258,10 +259,12 @@ def tags(self) -> List[TagResponse]: class PipelineFilter(WorkspaceScopedTaggableFilter): """Pipeline filter model.""" - CUSTOM_SORTING_OPTIONS = [SORT_PIPELINES_BY_LATEST_RUN_KEY] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + SORT_PIPELINES_BY_LATEST_RUN_KEY, + ] FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, - "user", "latest_run_status", ] @@ -274,20 +277,6 @@ class PipelineFilter(WorkspaceScopedTaggableFilter): description="Filter by the status of the latest run of a pipeline. " "This will always be applied as an `AND` filter for now.", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Pipeline", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the Pipeline", - union_mode="left_to_right", - ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the pipeline.", - ) def apply_filter( self, query: AnyQuery, table: Type["AnySchema"] @@ -343,36 +332,6 @@ def apply_filter( return query - def get_custom_filters( - self, - ) -> List["ColumnElement[bool]"]: - """Get custom filters. - - Returns: - A list of custom filters. - """ - custom_filters = super().get_custom_filters() - - from sqlmodel import and_ - - from zenml.zen_stores.schemas import ( - PipelineSchema, - UserSchema, - ) - - if self.user: - user_filter = and_( - PipelineSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - - return custom_filters - def apply_sorting( self, query: AnyQuery, @@ -387,12 +346,45 @@ def apply_sorting( Returns: The query with sorting applied. """ - column, _ = self.sorting_params + from sqlmodel import asc, case, col, desc, func, select + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema + + sort_by, operand = self.sorting_params + + if sort_by == SORT_PIPELINES_BY_LATEST_RUN_KEY: + # Subquery to find the latest run per pipeline + latest_run_subquery = ( + select( + PipelineRunSchema.pipeline_id, + case( + ( + func.max(PipelineRunSchema.created).is_(None), + PipelineSchema.created, + ), + else_=func.max(PipelineRunSchema.created), + ).label("latest_run"), + ) + .group_by(col(PipelineRunSchema.pipeline_id)) + .subquery() + ) + + # Join the subquery with the pipelines + query = query.outerjoin( + latest_run_subquery, + PipelineSchema.id == latest_run_subquery.c.pipeline_id, + ) + + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc(latest_run_subquery.c.latest_run) + ).order_by(col(PipelineSchema.id)) + else: + query = query.order_by( + desc(latest_run_subquery.c.latest_run) + ).order_by(col(PipelineSchema.id)) - if column == SORT_PIPELINES_BY_LATEST_RUN_KEY: - # If sorting by the latest run, the sorting is already done in the - # base query in `SqlZenStore.list_pipelines(...)` and we don't need - # to to anything here return query else: return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline_build.py b/src/zenml/models/v2/core/pipeline_build.py index 3cb6dcb4e47..19dc89ccbf0 100644 --- a/src/zenml/models/v2/core/pipeline_build.py +++ b/src/zenml/models/v2/core/pipeline_build.py @@ -14,7 +14,17 @@ """Models representing pipeline builds.""" import json -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field @@ -31,8 +41,13 @@ from zenml.models.v2.misc.build_item import BuildItem if TYPE_CHECKING: + from sqlalchemy.sql.elements import ColumnElement + from zenml.models.v2.core.pipeline import PipelineResponse from zenml.models.v2.core.stack import StackResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -446,23 +461,23 @@ def contains_code(self) -> bool: class PipelineBuildFilter(WorkspaceScopedFilter): """Model to enable advanced filtering of all pipeline builds.""" - workspace_id: Optional[Union[UUID, str]] = Field( - description="Workspace for this pipeline build.", - default=None, - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - description="User that produced this pipeline build.", - default=None, - union_mode="left_to_right", - ) + FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, + "container_registry_id", + ] + pipeline_id: Optional[Union[UUID, str]] = Field( description="Pipeline associated with the pipeline build.", default=None, union_mode="left_to_right", ) stack_id: Optional[Union[UUID, str]] = Field( - description="Stack used for the Pipeline Run", + description="Stack associated with the pipeline build.", + default=None, + union_mode="left_to_right", + ) + container_registry_id: Optional[Union[UUID, str]] = Field( + description="Container registry associated with the pipeline build.", default=None, union_mode="left_to_right", ) @@ -484,3 +499,43 @@ class PipelineBuildFilter(WorkspaceScopedFilter): checksum: Optional[str] = Field( description="The build checksum.", default=None ) + stack_checksum: Optional[str] = Field( + description="The stack checksum.", default=None + ) + + def get_custom_filters( + self, + table: Type["AnySchema"], + ) -> List["ColumnElement[bool]"]: + """Get custom filters. + + Args: + table: The query table. + + Returns: + A list of custom filters. + """ + custom_filters = super().get_custom_filters(table) + + from sqlmodel import and_ + + from zenml.enums import StackComponentType + from zenml.zen_stores.schemas import ( + PipelineBuildSchema, + StackComponentSchema, + StackCompositionSchema, + StackSchema, + ) + + if self.container_registry_id: + container_registry_filter = and_( + PipelineBuildSchema.stack_id == StackSchema.id, + StackSchema.id == StackCompositionSchema.stack_id, + StackCompositionSchema.component_id == StackComponentSchema.id, + StackComponentSchema.type + == StackComponentType.CONTAINER_REGISTRY.value, + StackComponentSchema.id == self.container_registry_id, + ) + custom_filters.append(container_registry_filter) + + return custom_filters diff --git a/src/zenml/models/v2/core/pipeline_deployment.py b/src/zenml/models/v2/core/pipeline_deployment.py index 760f65f1a35..94dbc431507 100644 --- a/src/zenml/models/v2/core/pipeline_deployment.py +++ b/src/zenml/models/v2/core/pipeline_deployment.py @@ -358,16 +358,6 @@ def template_id(self) -> Optional[UUID]: class PipelineDeploymentFilter(WorkspaceScopedFilter): """Model to enable advanced filtering of all pipeline deployments.""" - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace for this deployment.", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created this deployment.", - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( default=None, description="Pipeline associated with the deployment.", diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 958d662a515..3a22f642953 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -16,10 +16,13 @@ from datetime import datetime from typing import ( TYPE_CHECKING, + Any, ClassVar, Dict, List, Optional, + Type, + TypeVar, Union, cast, ) @@ -55,6 +58,11 @@ from zenml.models.v2.core.schedule import ScheduleResponse from zenml.models.v2.core.stack import StackResponse from zenml.models.v2.core.step_run import StepRunResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + +AnyQuery = TypeVar("AnyQuery", bound=Any) # ------------------ Request Model ------------------ @@ -584,6 +592,15 @@ def tags(self) -> List[TagResponse]: class PipelineRunFilter(WorkspaceScopedTaggableFilter): """Model to enable advanced filtering of all Workspaces.""" + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + "tag", + "stack", + "pipeline", + "model", + "model_version", + ] + FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, "unlisted", @@ -592,7 +609,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): "schedule_id", "stack_id", "template_id", - "user", "pipeline", "stack", "code_repository", @@ -615,16 +631,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): description="Pipeline associated with the Pipeline Run", union_mode="left_to_right", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Pipeline Run", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the Pipeline Run", - union_mode="left_to_right", - ) stack_id: Optional[Union[UUID, str]] = Field( default=None, description="Stack used for the Pipeline Run", @@ -675,16 +681,12 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): union_mode="left_to_right", ) unlisted: Optional[bool] = None - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the run.", - ) run_metadata: Optional[Dict[str, str]] = Field( default=None, description="The run_metadata to filter the pipeline runs by.", ) # TODO: Remove once frontend is ready for it. This is replaced by the more - # generic `pipeline` filter below. + # generic `pipeline` filter below. pipeline_name: Optional[str] = Field( default=None, description="Name of the pipeline associated with the run", @@ -716,13 +718,17 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): def get_custom_filters( self, + table: Type["AnySchema"], ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_, col, or_ @@ -741,7 +747,6 @@ def get_custom_filters( StackComponentSchema, StackCompositionSchema, StackSchema, - UserSchema, ) if self.unlisted is not None: @@ -792,17 +797,6 @@ def get_custom_filters( ) custom_filters.append(run_template_filter) - if self.user: - user_filter = and_( - PipelineRunSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.pipeline: pipeline_filter = and_( PipelineRunSchema.pipeline_id == PipelineSchema.id, @@ -926,3 +920,71 @@ def get_custom_filters( custom_filters.append(additional_filter) return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import ( + ModelSchema, + ModelVersionSchema, + PipelineDeploymentSchema, + PipelineRunSchema, + PipelineSchema, + StackSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == "pipeline": + query = query.join( + PipelineSchema, + PipelineRunSchema.pipeline_id == PipelineSchema.id, + ) + column = PipelineSchema.name + elif sort_by == "stack": + query = query.join( + PipelineDeploymentSchema, + PipelineRunSchema.deployment_id == PipelineDeploymentSchema.id, + ).join( + StackSchema, + PipelineDeploymentSchema.stack_id == StackSchema.id, + ) + column = StackSchema.name + elif sort_by == "model": + query = query.join( + ModelVersionSchema, + PipelineRunSchema.model_version_id == ModelVersionSchema.id, + ).join( + ModelSchema, + ModelVersionSchema.model_id == ModelSchema.id, + ) + column = ModelSchema.name + elif sort_by == "model_version": + query = query.join( + ModelVersionSchema, + PipelineRunSchema.model_version_id == ModelVersionSchema.id, + ) + column = ModelVersionSchema.name + else: + return super().apply_sorting(query=query, table=table) + + if operand == SorterOps.ASCENDING: + query = query.order_by(asc(column)) + else: + query = query.order_by(desc(column)) + + return query diff --git a/src/zenml/models/v2/core/run_template.py b/src/zenml/models/v2/core/run_template.py index b1aae8a325a..2bc177c043e 100644 --- a/src/zenml/models/v2/core/run_template.py +++ b/src/zenml/models/v2/core/run_template.py @@ -13,7 +13,17 @@ # permissions and limitations under the License. """Models representing pipeline templates.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field @@ -45,6 +55,11 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + + # ------------------ Request Model ------------------ @@ -310,16 +325,6 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter): default=None, description="Name of the run template.", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace associated with the template.", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the template.", - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( default=None, description="Pipeline associated with the template.", @@ -340,10 +345,6 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter): description="Code repository associated with the template.", union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the template.", - ) pipeline: Optional[Union[UUID, str]] = Field( default=None, description="Name/ID of the pipeline associated with the template.", @@ -354,14 +355,17 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter): ) def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ @@ -371,7 +375,6 @@ def get_custom_filters( PipelineSchema, RunTemplateSchema, StackSchema, - UserSchema, ) if self.code_repository_id: @@ -409,17 +412,6 @@ def get_custom_filters( ) custom_filters.append(pipeline_filter) - if self.user: - user_filter = and_( - RunTemplateSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.pipeline: pipeline_filter = and_( RunTemplateSchema.source_deployment_id diff --git a/src/zenml/models/v2/core/schedule.py b/src/zenml/models/v2/core/schedule.py index af838f17ccc..0e7dc01c421 100644 --- a/src/zenml/models/v2/core/schedule.py +++ b/src/zenml/models/v2/core/schedule.py @@ -279,16 +279,6 @@ def pipeline_id(self) -> Optional[UUID]: class ScheduleFilter(WorkspaceScopedFilter): """Model to enable advanced filtering of all Users.""" - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace scope of the schedule.", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the schedule", - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( default=None, description="Pipeline that the schedule is attached to.", diff --git a/src/zenml/models/v2/core/secret.py b/src/zenml/models/v2/core/secret.py index 79e50cd1841..3f29b57de22 100644 --- a/src/zenml/models/v2/core/secret.py +++ b/src/zenml/models/v2/core/secret.py @@ -15,7 +15,6 @@ from datetime import datetime from typing import Any, ClassVar, Dict, List, Optional, Union -from uuid import UUID from pydantic import Field, SecretStr @@ -253,25 +252,12 @@ class SecretFilter(WorkspaceScopedFilter): default=None, description="Name of the secret", ) - scope: Optional[Union[SecretScope, str]] = Field( default=None, description="Scope in which to filter secrets", union_mode="left_to_right", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Secret", - union_mode="left_to_right", - ) - - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the Secret", - union_mode="left_to_right", - ) - @staticmethod def _get_filtering_value(value: Optional[Any]) -> str: """Convert the value to a string that can be used for lexicographical filtering and sorting. diff --git a/src/zenml/models/v2/core/service.py b/src/zenml/models/v2/core/service.py index c3dcbd7cfc8..2ad9724b20a 100644 --- a/src/zenml/models/v2/core/service.py +++ b/src/zenml/models/v2/core/service.py @@ -15,19 +15,20 @@ from datetime import datetime from typing import ( + TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Type, + TypeVar, Union, ) from uuid import UUID from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.sql.elements import ColumnElement -from sqlmodel import SQLModel from zenml.constants import STR_FIELD_MAX_LENGTH from zenml.models.v2.base.scoped import ( @@ -37,11 +38,15 @@ WorkspaceScopedResponseBody, WorkspaceScopedResponseMetadata, WorkspaceScopedResponseResources, - WorkspaceScopedTaggableFilter, ) from zenml.services.service_status import ServiceState from zenml.services.service_type import ServiceType +if TYPE_CHECKING: + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + # ------------------ Request Model ------------------ @@ -376,16 +381,6 @@ class ServiceFilter(WorkspaceScopedFilter): description="Name of the service. Use this to filter services by " "their name.", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the service", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the service", - union_mode="left_to_right", - ) type: Optional[str] = Field( default=None, description="Type of the service. Filter services by their type.", @@ -457,9 +452,7 @@ def set_flavor(self, flavor: str) -> None: "config", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS, - "workspace_id", - "user_id", + *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS, "flavor", "type", "pipeline_step_name", @@ -468,7 +461,7 @@ def set_flavor(self, flavor: str) -> None: ] def generate_filter( - self, table: Type["SQLModel"] + self, table: Type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. diff --git a/src/zenml/models/v2/core/service_connector.py b/src/zenml/models/v2/core/service_connector.py index 806e6100072..8c71106ae22 100644 --- a/src/zenml/models/v2/core/service_connector.py +++ b/src/zenml/models/v2/core/service_connector.py @@ -801,7 +801,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter): default=None, description="The type to scope this query to.", ) - name: Optional[str] = Field( default=None, description="The name to filter by", @@ -810,16 +809,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter): default=None, description="The type of service connector to filter by", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace to filter by", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User to filter by", - union_mode="left_to_right", - ) auth_method: Optional[str] = Field( default=None, title="Filter by the authentication method configured for the " diff --git a/src/zenml/models/v2/core/stack.py b/src/zenml/models/v2/core/stack.py index 3d8ad20a2c1..1e49eb1544b 100644 --- a/src/zenml/models/v2/core/stack.py +++ b/src/zenml/models/v2/core/stack.py @@ -14,7 +14,17 @@ """Models representing stacks.""" import json -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field, model_validator @@ -39,6 +49,9 @@ from sqlalchemy.sql.elements import ColumnElement from zenml.models.v2.core.component import ComponentResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -323,7 +336,6 @@ class StackFilter(WorkspaceScopedFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "component_id", - "user", "component", ] @@ -334,42 +346,32 @@ class StackFilter(WorkspaceScopedFilter): description: Optional[str] = Field( default=None, description="Description of the stack" ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the stack", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the stack", - union_mode="left_to_right", - ) component_id: Optional[Union[UUID, str]] = Field( default=None, description="Component in the stack", union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the stack.", - ) component: Optional[Union[UUID, str]] = Field( default=None, description="Name/ID of a component in the stack." ) - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from zenml.zen_stores.schemas import ( StackComponentSchema, StackCompositionSchema, StackSchema, - UserSchema, ) if self.component_id: @@ -379,17 +381,6 @@ def get_custom_filters(self) -> List["ColumnElement[bool]"]: ) custom_filters.append(component_id_filter) - if self.user: - user_filter = and_( - StackSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.component: component_filter = and_( StackCompositionSchema.stack_id == StackSchema.id, diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index d9ac5e0354a..0a505539d07 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -14,7 +14,16 @@ """Models representing steps runs.""" from datetime import datetime -from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import BaseModel, ConfigDict, Field @@ -41,6 +50,9 @@ LogsRequest, LogsResponse, ) + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) class StepRunInputResponse(ArtifactVersionResponse): @@ -553,16 +565,6 @@ class StepRunFilter(WorkspaceScopedFilter): description="Original id for this step run", union_mode="left_to_right", ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that produced this step run", - union_mode="left_to_right", - ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of this step run", - union_mode="left_to_right", - ) model_version_id: Optional[Union[UUID, str]] = Field( default=None, description="Model version associated with the step run.", @@ -576,18 +578,20 @@ class StepRunFilter(WorkspaceScopedFilter): default=None, description="The run_metadata to filter the step runs by.", ) - model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ diff --git a/src/zenml/models/v2/core/trigger.py b/src/zenml/models/v2/core/trigger.py index daef211ed7b..45fc23a501c 100644 --- a/src/zenml/models/v2/core/trigger.py +++ b/src/zenml/models/v2/core/trigger.py @@ -13,7 +13,17 @@ # permissions and limitations under the License. """Collection of all models concerning triggers.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field, model_validator @@ -39,6 +49,9 @@ ActionResponse, ) from zenml.models.v2.core.event_source import EventSourceResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -358,10 +371,13 @@ class TriggerFilter(WorkspaceScopedFilter): ) def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ @@ -373,7 +389,7 @@ def get_custom_filters( TriggerSchema, ) - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) if self.event_source_flavor: event_source_flavor_filter = and_( diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 6db9c085a89..1141172bf31 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -179,12 +179,10 @@ def launch(self) -> None: pipeline_run_id=pipeline_run.id, pipeline_run_metadata=pipeline_run_metadata, ) - - pipeline_model_version, pipeline_run = ( - step_run_utils.prepare_pipeline_run_model_version( - pipeline_run - ) - ) + if model_version := pipeline_run.model_version: + step_run_utils.log_model_version_dashboard_url( + model_version=model_version + ) request_factory = step_run_utils.StepRunRequestFactory( deployment=self._deployment, @@ -209,12 +207,10 @@ def launch(self) -> None: step_run = Client().zen_store.create_run_step( step_run_request ) - - step_model_version, step_run = ( - step_run_utils.prepare_step_run_model_version( - step_run=step_run, pipeline_run=pipeline_run + if model_version := step_run.model_version: + step_run_utils.log_model_version_dashboard_url( + model_version=model_version ) - ) if not step_run.status.is_finished: logger.info(f"Step `{self._step_name}` has started.") @@ -289,8 +285,8 @@ def _bypass() -> None: f"Using cached version of step `{self._step_name}`." ) if ( - model_version := step_model_version - or pipeline_model_version + model_version := step_run.model_version + or pipeline_run.model_version ): step_run_utils.link_output_artifacts_to_model_version( artifacts=step_run.outputs, diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index e371b4c509a..6451a4cc0a4 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -14,7 +14,7 @@ """Utilities for creating step runs.""" from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from zenml.client import Client from zenml.config.step_configurations import Step @@ -24,21 +24,13 @@ from zenml.model.utils import link_artifact_version_to_model_version from zenml.models import ( ArtifactVersionResponse, - ModelVersionPipelineRunRequest, ModelVersionResponse, PipelineDeploymentResponse, PipelineRunResponse, - PipelineRunUpdate, StepRunRequest, - StepRunResponse, - StepRunUpdate, ) from zenml.orchestrators import cache_utils, input_utils, utils from zenml.stack import Stack -from zenml.utils import pagination_utils, string_utils - -if TYPE_CHECKING: - from zenml.model.model import Model logger = get_logger(__name__) @@ -293,10 +285,6 @@ def create_cached_step_runs( deployment=deployment, pipeline_run=pipeline_run, stack=stack ) - pipeline_model_version, pipeline_run = prepare_pipeline_run_model_version( - pipeline_run=pipeline_run - ) - while ( cache_candidates := find_cacheable_invocation_candidates( deployment=deployment, @@ -311,7 +299,9 @@ def create_cached_step_runs( # Make sure the request factory has the most up to date pipeline # run to avoid hydration calls - request_factory.pipeline_run = pipeline_run + request_factory.pipeline_run = Client().get_pipeline_run( + pipeline_run.id + ) try: step_run_request = request_factory.create_request( invocation_id @@ -336,15 +326,10 @@ def create_cached_step_runs( step_run = Client().zen_store.create_run_step(step_run_request) - # Refresh the pipeline run here to make sure we have the latest - # state - pipeline_run = Client().get_pipeline_run(pipeline_run.id) - - step_model_version, step_run = prepare_step_run_model_version( - step_run=step_run, pipeline_run=pipeline_run - ) - - if model_version := step_model_version or pipeline_model_version: + if ( + model_version := step_run.model_version + or pipeline_run.model_version + ): link_output_artifacts_to_model_version( artifacts=step_run.outputs, model_version=model_version, @@ -356,169 +341,6 @@ def create_cached_step_runs( return cached_invocations -def get_or_create_model_version_for_pipeline_run( - model: "Model", - pipeline_run: PipelineRunResponse, - substitutions: Dict[str, str], -) -> Tuple[ModelVersionResponse, bool]: - """Get or create a model version as part of a pipeline run. - - Args: - model: The model to get or create. - pipeline_run: The pipeline run for which the model should be created. - substitutions: Substitutions to apply to the model version name. - - Returns: - The model version and a boolean indicating whether it was newly created - or not. - """ - # Copy the model before modifying it so we don't accidently modify - # configurations in which the model object is potentially referenced - model = model.model_copy() - - if model.model_version_id: - return model._get_model_version(), False - elif model.version: - if isinstance(model.version, str): - model.version = string_utils.format_name_template( - model.version, - substitutions=substitutions, - ) - model.name = string_utils.format_name_template( - model.name, - substitutions=substitutions, - ) - - return ( - model._get_or_create_model_version(), - model._created_model_version, - ) - - # The model version should be created as part of this run - # -> We first check if it was already created as part of this run, and if - # not we do create it. If this is running in two parallel steps, we might - # run into issues that this will create two versions. Ideally, all model - # versions required for a pipeline run and its steps could be created - # server-side at run creation time before the first step starts. - if model_version := get_model_version_created_by_pipeline_run( - model_name=model.name, pipeline_run=pipeline_run - ): - return model_version, False - else: - return model._get_or_create_model_version(), True - - -def get_model_version_created_by_pipeline_run( - model_name: str, pipeline_run: PipelineRunResponse -) -> Optional[ModelVersionResponse]: - """Get a model version that was created by a specific pipeline run. - - This function does not refresh the pipeline run, so it will only try to - fetch the model version from existing steps if they're already part of the - response. - - Args: - model_name: The model name for which to get the version. - pipeline_run: The pipeline run for which to get the version. - - Returns: - A model version with the given name created by the run, or None if such - a model version does not exist. - """ - if pipeline_run.config.model and pipeline_run.model_version: - if ( - pipeline_run.config.model.name == model_name - and pipeline_run.config.model.version is None - ): - return pipeline_run.model_version - - # We fetch a list of hydrated step runs here in order to avoid hydration - # calls for each step separately. - candidate_step_runs = pagination_utils.depaginate( - Client().list_run_steps, - pipeline_run_id=pipeline_run.id, - model=model_name, - hydrate=True, - ) - for step_run in candidate_step_runs: - if step_run.config.model and step_run.model_version: - if ( - step_run.config.model.name == model_name - and step_run.config.model.version is None - ): - return step_run.model_version - - return None - - -def prepare_pipeline_run_model_version( - pipeline_run: PipelineRunResponse, -) -> Tuple[Optional[ModelVersionResponse], PipelineRunResponse]: - """Prepare the model version for a pipeline run. - - Args: - pipeline_run: The pipeline run for which to prepare the model version. - - Returns: - The prepared model version and the updated pipeline run. - """ - model_version = None - - if pipeline_run.model_version: - model_version = pipeline_run.model_version - elif config_model := pipeline_run.config.model: - model_version, _ = get_or_create_model_version_for_pipeline_run( - model=config_model, - pipeline_run=pipeline_run, - substitutions=pipeline_run.config.substitutions, - ) - pipeline_run = Client().zen_store.update_run( - run_id=pipeline_run.id, - run_update=PipelineRunUpdate(model_version_id=model_version.id), - ) - link_pipeline_run_to_model_version( - pipeline_run=pipeline_run, model_version=model_version - ) - log_model_version_dashboard_url(model_version) - - return model_version, pipeline_run - - -def prepare_step_run_model_version( - step_run: StepRunResponse, pipeline_run: PipelineRunResponse -) -> Tuple[Optional[ModelVersionResponse], StepRunResponse]: - """Prepare the model version for a step run. - - Args: - step_run: The step run for which to prepare the model version. - pipeline_run: The pipeline run of the step. - - Returns: - The prepared model version and the updated step run. - """ - model_version = None - - if step_run.model_version: - model_version = step_run.model_version - elif config_model := step_run.config.model: - model_version, created = get_or_create_model_version_for_pipeline_run( - model=config_model, - pipeline_run=pipeline_run, - substitutions=step_run.config.substitutions, - ) - step_run = Client().zen_store.update_run_step( - step_run_id=step_run.id, - step_run_update=StepRunUpdate(model_version_id=model_version.id), - ) - link_pipeline_run_to_model_version( - pipeline_run=pipeline_run, model_version=model_version - ) - if created: - log_model_version_dashboard_url(model_version) - - return model_version, step_run - - def log_model_version_dashboard_url( model_version: ModelVersionResponse, ) -> None: @@ -546,24 +368,6 @@ def log_model_version_dashboard_url( ) -def link_pipeline_run_to_model_version( - pipeline_run: PipelineRunResponse, model_version: ModelVersionResponse -) -> None: - """Link a pipeline run to a model version. - - Args: - pipeline_run: The pipeline run to link. - model_version: The model version to link. - """ - client = Client() - client.zen_store.create_model_version_pipeline_run_link( - ModelVersionPipelineRunRequest( - pipeline_run=pipeline_run.id, - model_version=model_version.id, - ) - ) - - def link_output_artifacts_to_model_version( artifacts: Dict[str, List[ArtifactVersionResponse]], model_version: ModelVersionResponse, diff --git a/src/zenml/pipelines/build_utils.py b/src/zenml/pipelines/build_utils.py index eacbd1d07da..810f8d5f177 100644 --- a/src/zenml/pipelines/build_utils.py +++ b/src/zenml/pipelines/build_utils.py @@ -249,6 +249,11 @@ def find_existing_build( client = Client() stack = client.active_stack + if not stack.container_registry: + # There can be no non-local builds that we can reuse if there is no + # container registry in the stack. + return None + python_version_prefix = ".".join(platform.python_version_tuple()[:2]) required_builds = stack.get_docker_builds(deployment=deployment) @@ -263,6 +268,13 @@ def find_existing_build( sort_by="desc:created", size=1, stack_id=stack.id, + # Until we implement stack versioning, users can still update their + # stack to update/remove the container registry. In that case, we might + # try to pull an image from a container registry that we don't have + # access to. This is why we add an additional check for the container + # registry ID here. (This is still not perfect as users can update the + # container registry URI or config, but the best we can do) + container_registry_id=stack.container_registry.id, # The build is local and it's not clear whether the images # exist on the current machine or if they've been overwritten. # TODO: Should we support this by storing the unique Docker ID for diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index a7918f3e81c..80290f091cc 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -413,10 +413,7 @@ def get_pipeline_run_status( logger.error(error) raise CredentialsNotValid(error) - if pipeline_run_status in [ - ExecutionStatus.FAILED, - ExecutionStatus.COMPLETED, - ]: + if pipeline_run_status.is_finished: error = ( f"The execution of pipeline run " f"{decoded_token.pipeline_run_id} has already concluded and " @@ -461,10 +458,7 @@ def get_step_run_status( logger.error(error) raise CredentialsNotValid(error) - if step_run_status in [ - ExecutionStatus.FAILED, - ExecutionStatus.COMPLETED, - ]: + if step_run_status.is_finished: error = ( f"The execution of step run " f"{decoded_token.step_run_id} has already concluded and " diff --git a/src/zenml/zen_server/rbac/rbac_sql_zen_store.py b/src/zenml/zen_server/rbac/rbac_sql_zen_store.py new file mode 100644 index 00000000000..1d6082a9e75 --- /dev/null +++ b/src/zenml/zen_server/rbac/rbac_sql_zen_store.py @@ -0,0 +1,173 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""RBAC SQL Zen Store implementation.""" + +from typing import ( + Optional, + Tuple, +) +from uuid import UUID + +from zenml.logger import get_logger +from zenml.models import ( + ModelRequest, + ModelResponse, + ModelVersionRequest, + ModelVersionResponse, +) +from zenml.zen_server.feature_gate.endpoint_utils import ( + check_entitlement, + report_usage, +) +from zenml.zen_server.rbac.models import Action, ResourceType +from zenml.zen_server.rbac.utils import ( + verify_permission, + verify_permission_for_model, +) +from zenml.zen_stores.sql_zen_store import SqlZenStore + +logger = get_logger(__name__) + + +class RBACSqlZenStore(SqlZenStore): + """Wrapper around the SQLZenStore that implements RBAC functionality.""" + + def _get_or_create_model( + self, model_request: ModelRequest + ) -> Tuple[bool, ModelResponse]: + """Get or create a model. + + Args: + model_request: The model request. + + # noqa: DAR401 + Raises: + Exception: If the user is not allowed to create a model. + + Returns: + A boolean whether the model was created or not, and the model. + """ + allow_model_creation = True + error = None + + try: + verify_permission( + resource_type=ResourceType.MODEL, action=Action.CREATE + ) + check_entitlement(resource_type=ResourceType.MODEL) + except Exception as e: + allow_model_creation = False + error = e + + if allow_model_creation: + created, model_response = super()._get_or_create_model( + model_request + ) + else: + try: + model_response = self.get_model(model_request.name) + created = False + except KeyError: + # The model does not exist. We now raise the error that + # explains why the model could not be created, instead of just + # the KeyError that it doesn't exist + assert error + raise error from None + + if created: + report_usage( + resource_type=ResourceType.MODEL, resource_id=model_response.id + ) + else: + verify_permission_for_model(model_response, action=Action.READ) + + return created, model_response + + def _get_model_version( + self, + model_id: UUID, + version_name: Optional[str] = None, + producer_run_id: Optional[UUID] = None, + ) -> ModelVersionResponse: + """Get a model version. + + Args: + model_id: The ID of the model. + version_name: The name of the model version. + producer_run_id: The ID of the producer pipeline run. If this is + set, only numeric versions created as part of the pipeline run + will be returned. + + Returns: + The model version. + """ + model_version = super()._get_model_version( + model_id=model_id, + version_name=version_name, + producer_run_id=producer_run_id, + ) + verify_permission_for_model(model_version, action=Action.READ) + return model_version + + def _get_or_create_model_version( + self, + model_version_request: ModelVersionRequest, + producer_run_id: Optional[UUID] = None, + ) -> Tuple[bool, ModelVersionResponse]: + """Get or create a model version. + + Args: + model_version_request: The model version request. + producer_run_id: ID of the producer pipeline run. + + # noqa: DAR401 + Raises: + Exception: If the authenticated user is not allowed to + create a model version. + + Returns: + A boolean whether the model version was created or not, and the + model version. + """ + allow_creation = True + error = None + + try: + verify_permission( + resource_type=ResourceType.MODEL_VERSION, action=Action.CREATE + ) + except Exception as e: + allow_creation = False + error = e + + if allow_creation: + created, model_version_response = ( + super()._get_or_create_model_version(model_version_request, producer_run_id=producer_run_id) + ) + else: + try: + model_version_response = self._get_model_version( + model_id=model_version_request.model, + version_name=model_version_request.name, + producer_run_id=producer_run_id, + ) + created = False + except KeyError: + # The model version does not exist. We now raise the error that + # explains why the version could not be created, instead of just + # the KeyError that it doesn't exist + assert error + raise error from None + + return created, model_version_response diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index e970ba535e7..a1339c10bf3 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -41,7 +41,6 @@ from zenml.enums import ( APITokenType, AuthScheme, - ExecutionStatus, OAuthDeviceStatus, OAuthGrantTypes, ) @@ -589,10 +588,7 @@ def api_token( "security reasons." ) - if pipeline_run.status in [ - ExecutionStatus.FAILED, - ExecutionStatus.COMPLETED, - ]: + if pipeline_run.status.is_finished: raise ValueError( f"The execution of pipeline run {pipeline_run_id} has already " "concluded and API tokens can no longer be generated for it " @@ -609,10 +605,7 @@ def api_token( "be generated for non-existent step runs for security reasons." ) - if step_run.status in [ - ExecutionStatus.FAILED, - ExecutionStatus.COMPLETED, - ]: + if step_run.status.is_finished: raise ValueError( f"The execution of step run {step_run_id} has already " "concluded and API tokens can no longer be generated for it " diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index ff96c7a640c..86414385ff1 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -421,6 +421,8 @@ def f(model: Model = Depends(make_dependable(Model))): """ from fastapi import Query + from zenml.zen_server.exceptions import error_detail + def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel: from fastapi import HTTPException @@ -428,9 +430,8 @@ def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel: inspect.signature(init_cls_and_handle_errors).bind(*args, **kwargs) return cls(*args, **kwargs) except ValidationError as e: - for error in e.errors(): - error["loc"] = tuple(["query"] + list(error["loc"])) - raise HTTPException(422, detail=e.errors()) + detail = error_detail(e, exception_type=ValueError) + raise HTTPException(422, detail=detail) params = {v.name: v for v in inspect.signature(cls).parameters.values()} query_params = getattr(cls, "API_MULTI_INPUT_PARAMS", []) diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 210f6b8b1ed..11467c44814 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -36,6 +36,7 @@ DEFAULT_STACK_AND_COMPONENT_NAME, DEFAULT_WORKSPACE_NAME, ENV_ZENML_DEFAULT_WORKSPACE_NAME, + ENV_ZENML_SERVER, IS_DEBUG_ENV, ) from zenml.enums import ( @@ -155,9 +156,16 @@ def get_store_class(store_type: StoreType) -> Type["BaseZenStore"]: TypeError: If the store type is unsupported. """ if store_type == StoreType.SQL: - from zenml.zen_stores.sql_zen_store import SqlZenStore + if os.environ.get(ENV_ZENML_SERVER): + from zenml.zen_server.rbac.rbac_sql_zen_store import ( + RBACSqlZenStore, + ) + + return RBACSqlZenStore + else: + from zenml.zen_stores.sql_zen_store import SqlZenStore - return SqlZenStore + return SqlZenStore elif store_type == StoreType.REST: from zenml.zen_stores.rest_zen_store import RestZenStore diff --git a/src/zenml/zen_stores/migrations/versions/26351d482b9e_add_step_run_unique_constraint.py b/src/zenml/zen_stores/migrations/versions/26351d482b9e_add_step_run_unique_constraint.py new file mode 100644 index 00000000000..a9f1b31563a --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/26351d482b9e_add_step_run_unique_constraint.py @@ -0,0 +1,37 @@ +"""Add step run unique constraint [26351d482b9e]. + +Revision ID: 26351d482b9e +Revises: 0.71.0 +Create Date: 2024-12-03 11:46:57.541578 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "26351d482b9e" +down_revision = "0.71.0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("step_run", schema=None) as batch_op: + batch_op.create_unique_constraint( + "unique_step_name_for_pipeline_run", ["name", "pipeline_run_id"] + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("step_run", schema=None) as batch_op: + batch_op.drop_constraint( + "unique_step_name_for_pipeline_run", type_="unique" + ) + + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py b/src/zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py new file mode 100644 index 00000000000..007b5ddbb8a --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py @@ -0,0 +1,68 @@ +"""Add model version producer run unique constraint [a1237ba94fd8]. + +Revision ID: a1237ba94fd8 +Revises: 26351d482b9e +Create Date: 2024-12-13 10:28:55.432414 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a1237ba94fd8" +down_revision = "26351d482b9e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("model_version", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "producer_run_id_if_numeric", + sqlmodel.sql.sqltypes.GUID(), + nullable=True, + ) + ) + + # Set the producer_run_id_if_numeric column to the model version ID for + # existing rows + connection = op.get_bind() + metadata = sa.MetaData() + metadata.reflect(only=("model_version",), bind=connection) + model_version_table = sa.Table("model_version", metadata) + + connection.execute( + model_version_table.update().values( + producer_run_id_if_numeric=model_version_table.c.id + ) + ) + + with op.batch_alter_table("model_version", schema=None) as batch_op: + batch_op.alter_column( + "producer_run_id_if_numeric", + existing_type=sqlmodel.sql.sqltypes.GUID(), + nullable=False, + ) + batch_op.create_unique_constraint( + "unique_numeric_version_for_pipeline_run", + ["model_id", "producer_run_id_if_numeric"], + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("model_version", schema=None) as batch_op: + batch_op.drop_constraint( + "unique_numeric_version_for_pipeline_run", type_="unique" + ) + batch_op.drop_column("producer_run_id_if_numeric") + + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 974875beafe..e3e29759b21 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -4349,46 +4349,74 @@ def _request( {source_context.name: source_context.get().value} ) - try: - return self._handle_response( - self.session.request( - method, - url, - params=params, - verify=self.config.verify_ssl, - timeout=timeout or self.config.http_timeout, - **kwargs, - ) - ) - except CredentialsNotValid: - # NOTE: CredentialsNotValid is raised only when the server - # explicitly indicates that the credentials are not valid and they - # can be thrown away. - - # We authenticate or re-authenticate here and then try the request - # again, this time with a valid API token in the header. - self.authenticate( - # If the last request was authenticated with an API token, - # we force a re-authentication to get a fresh token. - force=self._api_token is not None - ) - - try: - return self._handle_response( - self.session.request( - method, - url, - params=params, - verify=self.config.verify_ssl, - timeout=self.config.http_timeout, - **kwargs, + # If the server replies with a credentials validation (401 Unauthorized) + # error, we (re-)authenticate and retry the request here in the + # following cases: + # + # 1. initial authentication: the last request was not authenticated + # with an API token. + # 2. re-authentication: the last request was authenticated with an API + # token that was rejected by the server. This is to cover the case + # of expired tokens that can be refreshed by the client automatically + # without user intervention from other sources (e.g. API keys). + # + # NOTE: it can happen that the same request is retried here for up to + # two times: once after initial authentication and once after + # re-authentication. + re_authenticated = False + while True: + try: + return self._handle_response( + self.session.request( + method, + url, + params=params, + verify=self.config.verify_ssl, + timeout=timeout or self.config.http_timeout, + **kwargs, + ) ) - ) - except CredentialsNotValid as e: - raise CredentialsNotValid( - "The current credentials are no longer valid. Please log in " - "again using 'zenml login'." - ) from e + except CredentialsNotValid as e: + # NOTE: CredentialsNotValid is raised only when the server + # explicitly indicates that the credentials are not valid and + # they can be thrown away or when the request is not + # authenticated at all. + + if self._api_token is None: + # The last request was not authenticated with an API + # token at all. We authenticate here and then try the + # request again, this time with a valid API token in the + # header. + logger.debug( + f"The last request was not authenticated: {e}\n" + "Re-authenticating and retrying..." + ) + self.authenticate() + elif not re_authenticated: + # The last request was authenticated with an API token + # that was rejected by the server. We attempt a + # re-authentication here and then retry the request. + logger.debug( + "The last request was authenticated with an API token " + f"that was rejected by the server: {e}\n" + "Re-authenticating and retrying..." + ) + re_authenticated = True + self.authenticate( + # Ignore the current token and force a re-authentication + force=True + ) + else: + # The last request was made after re-authenticating but + # still failed. Bailing out. + logger.debug( + f"The last request failed after re-authenticating: {e}\n" + "Bailing out..." + ) + raise CredentialsNotValid( + "The current credentials are no longer valid. Please " + "log in again using 'zenml login'." + ) from e def get( self, diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index feb4a93dc80..41c186c75ca 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -15,10 +15,16 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast -from uuid import UUID +from uuid import UUID, uuid4 from pydantic import ConfigDict -from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column, UniqueConstraint +from sqlalchemy import ( + BOOLEAN, + INTEGER, + TEXT, + Column, + UniqueConstraint, +) from sqlmodel import Field, Relationship from zenml.enums import ( @@ -228,11 +234,13 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): __tablename__ = MODEL_VERSION_TABLENAME __table_args__ = ( - # We need two unique constraints here: + # We need three unique constraints here: # - The first to ensure that each model version for a # model has a unique version number # - The second one to ensure that explicit names given by # users are unique + # - The third one to ensure that a pipeline run only produces a single + # auto-incremented version per model UniqueConstraint( "number", "model_id", @@ -243,6 +251,11 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): "model_id", name="unique_version_for_model_id", ), + UniqueConstraint( + "model_id", + "producer_run_id_if_numeric", + name="unique_numeric_version_for_pipeline_run", + ), ) workspace_id: UUID = build_foreign_key_field( @@ -312,12 +325,23 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): ), ) pipeline_runs: List["PipelineRunSchema"] = Relationship( - back_populates="model_version" + back_populates="model_version", ) step_runs: List["StepRunSchema"] = Relationship( back_populates="model_version" ) + # We want to make sure each pipeline run only creates a single numeric + # version for each model. To solve this, we need to add a unique constraint. + # If a value of a unique constraint is NULL it is ignored and the + # remaining values in the unique constraint have to be unique. In + # our case however, we only want the unique constraint applied in + # case there is a producer run and only for numeric versions. To solve this, + # we fall back to the model version ID (which is the primary key and + # therefore unique) in case there is no producer run or the version is not + # numeric. + producer_run_id_if_numeric: UUID + # TODO: In Pydantic v2, the `model_` is a protected namespaces for all # fields defined under base models. If not handled, this raises a warning. # It is possible to suppress this warning message with the following @@ -328,24 +352,36 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): @classmethod def from_request( - cls, model_version_request: ModelVersionRequest + cls, + model_version_request: ModelVersionRequest, + model_version_number: int, + producer_run_id: Optional[UUID] = None, ) -> "ModelVersionSchema": """Convert an `ModelVersionRequest` to an `ModelVersionSchema`. Args: model_version_request: The request model version to convert. + model_version_number: The model version number. + producer_run_id: The ID of the producer run. Returns: The converted schema. """ + id_ = uuid4() + is_numeric = str(model_version_number) == model_version_request.name + return cls( + id=id_, workspace_id=model_version_request.workspace, user_id=model_version_request.user, model_id=model_version_request.model, name=model_version_request.name, - number=model_version_request.number, + number=model_version_number, description=model_version_request.description, stage=model_version_request.stage, + producer_run_id_if_numeric=producer_run_id + if (producer_run_id and is_numeric) + else id_, ) def to_model( diff --git a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py index ae2fe609bcf..409bd2eebfc 100644 --- a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py @@ -228,13 +228,6 @@ def to_model( Returns: The created `PipelineDeploymentResponse`. """ - pipeline_configuration = PipelineConfiguration.model_validate_json( - self.pipeline_configuration - ) - step_configurations = json.loads(self.step_configurations) - for s, c in step_configurations.items(): - step_configurations[s] = Step.model_validate(c) - body = PipelineDeploymentResponseBody( user=self.user.to_model() if self.user else None, created=self.created, @@ -242,6 +235,13 @@ def to_model( ) metadata = None if include_metadata: + pipeline_configuration = PipelineConfiguration.model_validate_json( + self.pipeline_configuration + ) + step_configurations = json.loads(self.step_configurations) + for s, c in step_configurations.items(): + step_configurations[s] = Step.model_validate(c) + metadata = PipelineDeploymentResponseMetadata( workspace=self.workspace.to_model(), run_name_template=self.run_name_template, diff --git a/src/zenml/zen_stores/schemas/pipeline_schemas.py b/src/zenml/zen_stores/schemas/pipeline_schemas.py index 1f287720ee6..3719a64b207 100644 --- a/src/zenml/zen_stores/schemas/pipeline_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_schemas.py @@ -156,7 +156,12 @@ def to_model( resources = None if include_resources: + latest_run_user = self.runs[-1].user if self.runs else None + resources = PipelineResponseResources( + latest_run_user=latest_run_user.to_model() + if latest_run_user + else None, tags=[t.tag.to_model() for t in self.tags], ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index f8788505156..ea01de1ab24 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -19,7 +19,7 @@ from uuid import UUID from pydantic import ConfigDict -from sqlalchemy import TEXT, Column, String +from sqlalchemy import TEXT, Column, String, UniqueConstraint from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlmodel import Field, Relationship, SQLModel @@ -67,6 +67,13 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): """SQL Model for steps of pipeline runs.""" __tablename__ = "step_run" + __table_args__ = ( + UniqueConstraint( + "name", + "pipeline_run_id", + name="unique_step_name_for_pipeline_run", + ), + ) # Fields start_time: Optional[datetime] = Field(nullable=True) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index bb3a77befbd..19bdda8b28f 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -55,7 +55,7 @@ field_validator, model_validator, ) -from sqlalchemy import asc, case, desc, func +from sqlalchemy import func from sqlalchemy.engine import URL, Engine, make_url from sqlalchemy.exc import ( ArgumentError, @@ -71,6 +71,7 @@ col, create_engine, delete, + desc, or_, select, ) @@ -100,7 +101,6 @@ ENV_ZENML_SERVER, FINISHED_ONBOARDING_SURVEY_KEY, MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION, - SORT_PIPELINES_BY_LATEST_RUN_KEY, SQL_STORE_BACKUP_DIRECTORY_NAME, TEXT_FIELD_MAX_LENGTH, handle_bool_env_var, @@ -117,7 +117,6 @@ OnboardingStep, SecretScope, SecretsStoreType, - SorterOps, StackComponentType, StackDeploymentProvider, StepRunInputArtifactType, @@ -298,7 +297,11 @@ replace_localhost_with_internal_hostname, ) from zenml.utils.pydantic_utils import before_validator_handler -from zenml.utils.string_utils import random_str, validate_name +from zenml.utils.string_utils import ( + format_name_template, + random_str, + validate_name, +) from zenml.zen_stores import template_utils from zenml.zen_stores.base_zen_store import ( BaseZenStore, @@ -4358,69 +4361,14 @@ def list_pipelines( Returns: A list of all pipelines matching the filter criteria. """ - query: Union[Select[Any], SelectOfScalar[Any]] = select(PipelineSchema) - _custom_conversion: Optional[Callable[[Any], PipelineResponse]] = None - - column, operand = pipeline_filter_model.sorting_params - if column == SORT_PIPELINES_BY_LATEST_RUN_KEY: - with Session(self.engine) as session: - max_date_subquery = ( - # If no run exists for the pipeline yet, we use the pipeline - # creation date as a fallback, otherwise newly created - # pipeline would always be at the top/bottom - select( - PipelineSchema.id, - case( - ( - func.max(PipelineRunSchema.created).is_(None), - PipelineSchema.created, - ), - else_=func.max(PipelineRunSchema.created), - ).label("run_or_created"), - ) - .outerjoin( - PipelineRunSchema, - PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type] - ) - .group_by(col(PipelineSchema.id)) - .subquery() - ) - - if operand == SorterOps.DESCENDING: - sort_clause = desc - else: - sort_clause = asc - - query = ( - # We need to include the subquery in the select here to - # make this query work with the distinct statement. This - # result will be removed in the custom conversion function - # applied later - select(PipelineSchema, max_date_subquery.c.run_or_created) - .where(PipelineSchema.id == max_date_subquery.c.id) - .order_by(sort_clause(max_date_subquery.c.run_or_created)) - # We always add the `id` column as a tiebreaker to ensure a - # stable, repeatable order of items, otherwise subsequent - # pages might contain the same items. - .order_by(col(PipelineSchema.id)) - ) - - def _custom_conversion(row: Any) -> PipelineResponse: - return cast( - PipelineResponse, - row[0].to_model( - include_metadata=hydrate, include_resources=True - ), - ) - with Session(self.engine) as session: + query = select(PipelineSchema) return self.filter_and_paginate( session=session, query=query, table=PipelineSchema, filter_model=pipeline_filter_model, hydrate=hydrate, - custom_schema_to_model_conversion=_custom_conversion, ) def count_pipelines(self, filter_model: Optional[PipelineFilter]) -> int: @@ -5211,6 +5159,20 @@ def create_run( "already exists." ) + if model_version_id := self._get_or_create_model_version_for_run( + new_run + ): + new_run.model_version_id = model_version_id + session.add(new_run) + session.commit() + + self.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequest( + model_version=model_version_id, pipeline_run=new_run.id + ) + ) + session.refresh(new_run) + return new_run.to_model( include_metadata=True, include_resources=True ) @@ -8167,25 +8129,17 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: f"with ID '{step_run.pipeline_run_id}' found." ) - # Check if the step name already exists in the pipeline run - existing_step_run = session.exec( - select(StepRunSchema) - .where(StepRunSchema.name == step_run.name) - .where( - StepRunSchema.pipeline_run_id == step_run.pipeline_run_id - ) - ).first() - if existing_step_run is not None: + step_schema = StepRunSchema.from_request(step_run) + session.add(step_schema) + try: + session.commit() + except IntegrityError: raise EntityExistsError( f"Unable to create step `{step_run.name}`: A step with " f"this name already exists in the pipeline run with ID " f"'{step_run.pipeline_run_id}'." ) - # Create the step - step_schema = StepRunSchema.from_request(step_run) - session.add(step_schema) - # Add logs entry for the step if exists if step_run.logs is not None: log_entry = LogsSchema( @@ -8281,6 +8235,21 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: session.commit() session.refresh(step_schema) + if model_version_id := self._get_or_create_model_version_for_run( + step_schema + ): + step_schema.model_version_id = model_version_id + session.add(step_schema) + session.commit() + + self.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequest( + model_version=model_version_id, + pipeline_run=step_schema.pipeline_run_id, + ) + ) + session.refresh(step_schema) + return step_schema.to_model( include_metadata=True, include_resources=True ) @@ -10283,6 +10252,22 @@ def update_model( # ----------------------------- Model Versions ----------------------------- + def _get_or_create_model( + self, model_request: ModelRequest + ) -> Tuple[bool, ModelResponse]: + """Get or create a model. + + Args: + model_request: The model request. + + Returns: + A boolean whether the model was created or not, and the model. + """ + try: + return True, self.create_model(model_request) + except EntityExistsError: + return False, self.get_model(model_request.name) + def _get_next_numeric_version_for_model( self, session: Session, model_id: UUID ) -> int: @@ -10307,55 +10292,276 @@ def _get_next_numeric_version_for_model( else: return int(current_max_version) + 1 - def _model_version_exists(self, model_id: UUID, version: str) -> bool: + def _model_version_exists( + self, + model_id: UUID, + version: Optional[str] = None, + producer_run_id: Optional[UUID] = None, + ) -> bool: """Check if a model version with a certain version exists. Args: model_id: The model ID of the version. version: The version name. + producer_run_id: The producer run ID. If given, checks if a numeric + version for the producer run exists. Returns: - If a model version with the given version name exists. + If a model version for the given arguments exists. """ + query = select(ModelVersionSchema.id).where( + ModelVersionSchema.model_id == model_id + ) + + if version: + query = query.where(ModelVersionSchema.name == version) + + if producer_run_id: + query = query.where( + ModelVersionSchema.producer_run_id_if_numeric + == producer_run_id, + ) + with Session(self.engine) as session: - return ( - session.exec( - select(ModelVersionSchema.id) - .where(ModelVersionSchema.model_id == model_id) - .where(ModelVersionSchema.name == version) - ).first() - is not None + return session.exec(query).first() is not None + + def _get_model_version( + self, + model_id: UUID, + version_name: Optional[str] = None, + producer_run_id: Optional[UUID] = None, + ) -> ModelVersionResponse: + """Get a model version. + + Args: + model_id: The ID of the model. + version_name: The name of the model version. + producer_run_id: The ID of the producer pipeline run. If this is + set, only numeric versions created as part of the pipeline run + will be returned. + + Raises: + ValueError: If no version name or producer run ID was provided. + KeyError: If no model version was found. + + Returns: + The model version. + """ + query = select(ModelVersionSchema).where( + ModelVersionSchema.model_id == model_id + ) + + if version_name: + if version_name.isnumeric(): + query = query.where( + ModelVersionSchema.number == int(version_name) + ) + error_text = ( + f"No version with number {version_name} found " + f"for model {model_id}." + ) + elif version_name in ModelStages.values(): + if version_name == ModelStages.LATEST: + query = query.order_by( + desc(col(ModelVersionSchema.number)) + ).limit(1) + else: + query = query.where( + ModelVersionSchema.stage == version_name + ) + error_text = ( + f"No {version_name} stage version found for " + f"model {model_id}." + ) + else: + query = query.where(ModelVersionSchema.name == version_name) + error_text = ( + f"No {version_name} version found for model {model_id}." + ) + + elif producer_run_id: + query = query.where( + ModelVersionSchema.producer_run_id_if_numeric + == producer_run_id, + ) + error_text = ( + f"No numeric model version found for model {model_id} " + f"and producer run {producer_run_id}." + ) + else: + raise ValueError( + "Version name or producer run id need to be specified." ) - @track_decorator(AnalyticsEvent.CREATED_MODEL_VERSION) - def create_model_version( - self, model_version: ModelVersionRequest + with Session(self.engine) as session: + schema = session.exec(query).one_or_none() + + if not schema: + raise KeyError(error_text) + + return schema.to_model( + include_metadata=True, include_resources=True + ) + + def _get_or_create_model_version( + self, + model_version_request: ModelVersionRequest, + producer_run_id: Optional[UUID] = None, + ) -> Tuple[bool, ModelVersionResponse]: + """Get or create a model version. + + Args: + model_version_request: The model version request. + producer_run_id: ID of the producer pipeline run. + + Raises: + EntityCreationError: If the model version creation failed. + + Returns: + A boolean whether the model version was created or not, and the + model version. + """ + try: + model_version = self._create_model_version( + model_version=model_version_request, + producer_run_id=producer_run_id, + ) + track(event=AnalyticsEvent.CREATED_MODEL_VERSION) + return True, model_version + except EntityCreationError: + # Need to explicitly re-raise this here as otherwise the catching + # of the RuntimeError would include this + raise + except RuntimeError: + return False, self._get_model_version( + model_id=model_version_request.model, + producer_run_id=producer_run_id, + ) + except EntityExistsError: + return False, self._get_model_version( + model_id=model_version_request.model, + version_name=model_version_request.name, + ) + + def _get_or_create_model_version_for_run( + self, pipeline_or_step_run: Union[PipelineRunSchema, StepRunSchema] + ) -> Optional[UUID]: + """Get or create a model version for a pipeline or step run. + + Args: + pipeline_or_step_run: The pipeline or step run for which to create + the model version. + + Returns: + The model version. + """ + if isinstance(pipeline_or_step_run, PipelineRunSchema): + producer_run_id = pipeline_or_step_run.id + pipeline_run = pipeline_or_step_run.to_model(include_metadata=True) + configured_model = pipeline_run.config.model + substitutions = pipeline_run.config.substitutions + else: + producer_run_id = pipeline_or_step_run.pipeline_run_id + step_run = pipeline_or_step_run.to_model(include_metadata=True) + configured_model = step_run.config.model + substitutions = step_run.config.substitutions + + if not configured_model: + return None + + model_request = ModelRequest( + name=format_name_template( + configured_model.name, substitutions=substitutions + ), + license=configured_model.license, + description=configured_model.description, + audience=configured_model.audience, + use_cases=configured_model.use_cases, + limitations=configured_model.limitations, + trade_offs=configured_model.trade_offs, + ethics=configured_model.ethics, + save_models_to_registry=configured_model.save_models_to_registry, + user=pipeline_or_step_run.user_id, + workspace=pipeline_or_step_run.workspace_id, + ) + + _, model_response = self._get_or_create_model( + model_request=model_request + ) + + version_name = None + if configured_model.version is not None: + version_name = format_name_template( + str(configured_model.version), substitutions=substitutions + ) + + # If the model version was specified to be a numeric version or + # stage we don't try to create it (which will fail because it is not + # allowed) but try to fetch it immediately + if ( + version_name.isnumeric() + or version_name in ModelStages.values() + ): + return self._get_model_version( + model_id=model_response.id, version_name=version_name + ).id + + model_version_request = ModelVersionRequest( + model=model_response.id, + name=version_name, + description=configured_model.description, + tags=configured_model.tags, + user=pipeline_or_step_run.user_id, + workspace=pipeline_or_step_run.workspace_id, + ) + + _, model_version_response = self._get_or_create_model_version( + model_version_request=model_version_request, + producer_run_id=producer_run_id, + ) + return model_version_response.id + + def _create_model_version( + self, + model_version: ModelVersionRequest, + producer_run_id: Optional[UUID] = None, ) -> ModelVersionResponse: """Creates a new model version. Args: model_version: the Model Version to be created. + producer_run_id: ID of the pipeline run that produced this model + version. Returns: The newly created model version. Raises: - ValueError: If `number` is not None during model version creation. + ValueError: If the requested version name is invalid. EntityExistsError: If a model version with the given name already exists. EntityCreationError: If the model version creation failed. + RuntimeError: If an auto-incremented model version already exists + for the producer run. """ - if model_version.number is not None: - raise ValueError( - "`number` field must be None during model version creation." - ) + has_custom_name = False + if model_version.name: + has_custom_name = True + validate_name(model_version) - model = self.get_model(model_version.model) + if model_version.name.isnumeric(): + raise ValueError( + "Can't create model version with custom numeric model " + "version name." + ) - has_custom_name = model_version.name is not None - if has_custom_name: - validate_name(model_version) + if str(model_version.name).lower() in ModelStages.values(): + raise ValueError( + "Can't create model version with a name that is used as a " + f"model version stage ({ModelStages.values()})." + ) + model = self.get_model(model_version.model) model_version_id = None remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION @@ -10363,17 +10569,19 @@ def create_model_version( remaining_tries -= 1 try: with Session(self.engine) as session: - model_version.number = ( + model_version_number = ( self._get_next_numeric_version_for_model( session=session, model_id=model.id, ) ) if not has_custom_name: - model_version.name = str(model_version.number) + model_version.name = str(model_version_number) model_version_schema = ModelVersionSchema.from_request( - model_version + model_version, + model_version_number=model_version_number, + producer_run_id=producer_run_id, ) session.add(model_version_schema) session.commit() @@ -10394,6 +10602,13 @@ def create_model_version( f"{model_version.name}): A model with the " "same name and version already exists." ) + elif producer_run_id and self._model_version_exists( + model_id=model.id, producer_run_id=producer_run_id + ): + raise RuntimeError( + "Auto-incremented model version already exists for " + f"producer run {producer_run_id}." + ) elif remaining_tries == 0: raise EntityCreationError( f"Failed to create version for model " @@ -10412,10 +10627,9 @@ def create_model_version( ) logger.debug( "Failed to create model version %s " - "(version %s) due to an integrity error. " + "due to an integrity error. " "Retrying in %f seconds.", model.name, - model_version.number, sleep_duration, ) time.sleep(sleep_duration) @@ -10430,6 +10644,20 @@ def create_model_version( return self.get_model_version(model_version_id) + @track_decorator(AnalyticsEvent.CREATED_MODEL_VERSION) + def create_model_version( + self, model_version: ModelVersionRequest + ) -> ModelVersionResponse: + """Creates a new model version. + + Args: + model_version: the Model Version to be created. + + Returns: + The newly created model version. + """ + return self._create_model_version(model_version=model_version) + def get_model_version( self, model_version_id: UUID, hydrate: bool = True ) -> ModelVersionResponse: diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index d16b9dc31bd..9d91e1eb76e 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -268,13 +268,11 @@ def test_model_fetch_model_and_version_latest(self): def test_init_stage_logic(self): """Test that if version is set to string contained in ModelStages user is informed about it.""" with ModelContext(create_model=False) as (mdl_name, _, _): - with mock.patch("zenml.model.model.logger.info") as logger: - mv = Model( - name=mdl_name, - version=ModelStages.PRODUCTION.value, - ) - logger.assert_called_once() - assert mv.version == ModelStages.PRODUCTION.value + mv = Model( + name=mdl_name, + version=ModelStages.PRODUCTION.value, + ) + assert mv.version == ModelStages.PRODUCTION.value mv = Model(name=mdl_name, version=ModelStages.PRODUCTION) assert mv.version == ModelStages.PRODUCTION diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index f070ca0272f..e3049d6cb59 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -93,7 +93,7 @@ def test_that_argument_as_get_artifact_of_model_in_pipeline_context_fails_if_not clean_client: "Client", ): producer_pipe(False) - with pytest.raises(RuntimeError): + with pytest.raises(KeyError): consumer_pipe() diff --git a/tests/integration/functional/steps/test_model_version.py b/tests/integration/functional/steps/test_model_version.py index 2100890bd8a..3990494a7f2 100644 --- a/tests/integration/functional/steps/test_model_version.py +++ b/tests/integration/functional/steps/test_model_version.py @@ -22,7 +22,7 @@ from zenml import get_pipeline_context, get_step_context, pipeline, step from zenml.artifacts.artifact_config import ArtifactConfig from zenml.client import Client -from zenml.enums import ModelStages +from zenml.enums import ExecutionStatus, ModelStages from zenml.model.model import Model @@ -571,7 +571,7 @@ def _inner_pipeline(): # this will run all steps, including one requesting new version run_1 = f"run_{uuid4()}" # model is configured with latest stage, so a warm-up needed - with pytest.raises(RuntimeError): + with pytest.raises(KeyError): _inner_pipeline.with_options(run_name=run_1)() run_2 = f"run_{uuid4()}" Model(name="step")._get_or_create_model_version() @@ -812,3 +812,147 @@ def _inner_pipeline(): assert "{time}" not in versions[1].version assert len(versions[1]._get_model_version().data_artifact_ids["data"]) == 2 assert versions[1].version != first_version_name + + +@step +def noop() -> None: + pass + + +def test_model_version_creation(clean_client: "Client"): + """Tests that model versions get created correctly for a pipeline run.""" + shared_model_name = random_resource_name() + custom_model_name = random_resource_name() + + @pipeline(model=Model(name=shared_model_name), enable_cache=False) + def _inner_pipeline(): + noop.with_options(model=Model(name=shared_model_name))(id="shared") + noop.with_options( + model=Model(name=shared_model_name, version="custom") + )(id="custom_version") + noop.with_options(model=Model(name=custom_model_name))( + id="custom_model" + ) + + run_1 = _inner_pipeline() + shared_versions = clean_client.list_model_versions(shared_model_name) + assert len(shared_versions) == 2 + implicit_version = shared_versions[-2] + explicit_version = shared_versions[-1] + + custom_versions = clean_client.list_model_versions(custom_model_name) + assert len(custom_versions) == 1 + custom_version = custom_versions[-1] + + assert run_1.model_version_id == implicit_version.id + for name, step_ in run_1.steps.items(): + if name == "shared": + assert step_.model_version_id == implicit_version.id + elif name == "custom_version": + assert step_.model_version_id == explicit_version.id + else: + assert step_.model_version_id == custom_version.id + links = clean_client.list_model_version_pipeline_run_links( + pipeline_run_id=run_1.id + ) + assert len(links) == 3 + + run_2 = _inner_pipeline() + shared_versions = clean_client.list_model_versions(shared_model_name) + assert len(shared_versions) == 3 + implicit_version = shared_versions[-1] + explicit_version = shared_versions[-2] + + custom_versions = clean_client.list_model_versions(custom_model_name) + assert len(custom_versions) == 2 + custom_version = custom_versions[-1] + + assert run_2.model_version_id == implicit_version.id + for name, step_ in run_2.steps.items(): + if name == "shared": + assert step_.model_version_id == implicit_version.id + elif name == "custom_version": + assert step_.model_version_id == explicit_version.id + else: + assert step_.model_version_id == custom_version.id + links = clean_client.list_model_version_pipeline_run_links( + pipeline_run_id=run_2.id + ) + assert len(links) == 3 + + # Run with caching enabled to see if everything still works + run_3 = _inner_pipeline.with_options(enable_cache=True)() + shared_versions = clean_client.list_model_versions(shared_model_name) + assert len(shared_versions) == 4 + implicit_version = shared_versions[-1] + explicit_version = shared_versions[-3] + + custom_versions = clean_client.list_model_versions(custom_model_name) + assert len(custom_versions) == 3 + custom_version = custom_versions[-1] + + assert run_3.model_version_id == implicit_version.id + for name, step_ in run_3.steps.items(): + assert step_.status == ExecutionStatus.CACHED + + if name == "shared": + assert step_.model_version_id == implicit_version.id + elif name == "custom_version": + assert step_.model_version_id == explicit_version.id + else: + assert step_.model_version_id == custom_version.id + links = clean_client.list_model_version_pipeline_run_links( + pipeline_run_id=run_3.id + ) + assert len(links) == 3 + + +def test_model_version_fetching_by_stage(clean_client: "Client"): + """Tests that model versions can be fetched by number or stage.""" + model_name = random_resource_name() + + @pipeline(model=Model(name=model_name), enable_cache=False) + def _creator_pipeline(): + noop() + + @pipeline(model=Model(name=model_name, version=1), enable_cache=False) + def _fetch_by_version_number_pipeline(): + noop() + + @pipeline( + model=Model(name=model_name, version="latest"), enable_cache=False + ) + def _fetch_latest_version_pipeline(): + noop() + + @pipeline( + model=Model(name=model_name, version="production"), enable_cache=False + ) + def _fetch_prod_version_pipeline(): + noop() + + with pytest.raises(KeyError): + _fetch_by_version_number_pipeline() + + with pytest.raises(KeyError): + _fetch_latest_version_pipeline() + + with pytest.raises(KeyError): + _fetch_prod_version_pipeline() + + _creator_pipeline() + _creator_pipeline() + + versions = clean_client.list_model_versions(model_name) + assert len(versions) == 2 + mv_1, mv_2 = versions + mv_1.set_stage("production") + + run = _fetch_by_version_number_pipeline() + assert run.model_version_id == mv_1.id + + run = _fetch_latest_version_pipeline() + assert run.model_version_id == mv_2.id + + run = _fetch_prod_version_pipeline() + assert run.model_version_id == mv_1.id diff --git a/tests/integration/integrations/gcp/orchestrators/test_vertex_orchestrator.py b/tests/integration/integrations/gcp/orchestrators/test_vertex_orchestrator.py index 6ffde7bddac..8e5c41f0d32 100644 --- a/tests/integration/integrations/gcp/orchestrators/test_vertex_orchestrator.py +++ b/tests/integration/integrations/gcp/orchestrators/test_vertex_orchestrator.py @@ -141,9 +141,13 @@ def test_vertex_orchestrator_stack_validation( {"cpu_limit": "4", "gpu_limit": 4, "memory_limit": "1G"}, { "accelerator": { + "count": "1", + "type": "NVIDIA_TESLA_K80", "resourceCount": "1", "resourceType": "NVIDIA_TESLA_K80", }, + "cpuLimit": 1.0, + "memoryLimit": 1.0, "resourceCpuLimit": "1.0", "resourceMemoryLimit": "1G", }, @@ -154,9 +158,13 @@ def test_vertex_orchestrator_stack_validation( {"cpu_limit": "1.0", "gpu_limit": 1, "memory_limit": "1G"}, { "accelerator": { + "count": "1", + "type": "NVIDIA_TESLA_K80", "resourceCount": "1", "resourceType": "NVIDIA_TESLA_K80", }, + "cpuLimit": 1.0, + "memoryLimit": 1.0, "resourceCpuLimit": "1.0", "resourceMemoryLimit": "1G", }, @@ -166,6 +174,8 @@ def test_vertex_orchestrator_stack_validation( ResourceSettings(cpu_count=1, gpu_count=None, memory="1GB"), {"cpu_limit": None, "gpu_limit": None, "memory_limit": None}, { + "cpuLimit": 1.0, + "memoryLimit": 1.0, "resourceCpuLimit": "1.0", "resourceMemoryLimit": "1G", }, @@ -174,7 +184,12 @@ def test_vertex_orchestrator_stack_validation( ( ResourceSettings(cpu_count=1, gpu_count=0, memory="1GB"), {"cpu_limit": None, "gpu_limit": None, "memory_limit": None}, - {"resourceCpuLimit": "1.0", "resourceMemoryLimit": "1G"}, + { + "cpuLimit": 1.0, + "memoryLimit": 1.0, + "resourceCpuLimit": "1.0", + "resourceMemoryLimit": "1G", + }, ), ], ) @@ -233,13 +248,16 @@ def _build_kfp_pipeline() -> None: job_spec = pipeline_json["deploymentSpec"]["executors"][ f"exec-{step_name}" ]["container"] + if "accelerator" in job_spec["resources"]: - if "count" in job_spec["resources"]["accelerator"]: - expected_resources["accelerator"]["count"] = expected_resources[ - "accelerator" - ]["resourceCount"] - if "type" in job_spec["resources"]["accelerator"]: - expected_resources["accelerator"]["type"] = expected_resources[ - "accelerator" - ]["resourceType"] + if "resourceCount" not in job_spec["resources"]["accelerator"]: + expected_resources["accelerator"].pop("resourceCount", None) + if "resourceType" not in job_spec["resources"]["accelerator"]: + expected_resources["accelerator"].pop("resourceType", None) + + if "resourceCpuLimit" not in job_spec["resources"]: + expected_resources.pop("resourceCpuLimit", None) + if "resourceMemoryLimit" not in job_spec["resources"]: + expected_resources.pop("resourceMemoryLimit", None) + assert job_spec["resources"] == expected_resources diff --git a/tests/unit/model/test_model_version_init.py b/tests/unit/model/test_model_version_init.py deleted file mode 100644 index 21009b96d5c..00000000000 --- a/tests/unit/model/test_model_version_init.py +++ /dev/null @@ -1,27 +0,0 @@ -from unittest.mock import patch - -import pytest - -from zenml.model.model import Model - - -@pytest.mark.parametrize( - "version_name,logger", - [ - ["staging", "info"], - ["1", "info"], - [1, "info"], - ], - ids=[ - "Pick model by text stage", - "Pick model by text version number", - "Pick model by integer version number", - ], -) -def test_init_warns(version_name, logger): - with patch(f"zenml.model.model.logger.{logger}") as logger: - Model( - name="foo", - version=version_name, - ) - logger.assert_called_once() diff --git a/tests/unit/models/test_filter_models.py b/tests/unit/models/test_filter_models.py index 46b711bb7cc..c0d69ea4d23 100644 --- a/tests/unit/models/test_filter_models.py +++ b/tests/unit/models/test_filter_models.py @@ -235,21 +235,11 @@ def test_uuid_filter_model(): ) -def test_uuid_filter_model_fails_for_invalid_uuids_on_equality(): - """Test filtering for equality with invalid UUID fails.""" - with pytest.raises(ValueError): - uuid_value = "a92k34" - SomeFilterModel(uuid_field=f"{GenericFilterOps.EQUALS}:{uuid_value}") - - def test_uuid_filter_model_succeeds_for_invalid_uuid_on_non_equality(): """Test filtering with other UUID operations is possible with non-UUIDs.""" filter_value = "a92k34" for filter_op in UUIDFilter.ALLOWED_OPS: - if ( - filter_op == GenericFilterOps.EQUALS - or filter_op == GenericFilterOps.ONEOF - ): + if filter_op == GenericFilterOps.ONEOF: continue filter_model = SomeFilterModel( uuid_field=f"{filter_op}:{filter_value}" diff --git a/tests/unit/pipelines/test_build_utils.py b/tests/unit/pipelines/test_build_utils.py index 73684af306d..de278fac778 100644 --- a/tests/unit/pipelines/test_build_utils.py +++ b/tests/unit/pipelines/test_build_utils.py @@ -518,7 +518,9 @@ def test_local_repo_verification( assert isinstance(code_repo, StubCodeRepository) -def test_finding_existing_build(mocker, sample_deployment_response_model): +def test_finding_existing_build( + mocker, sample_deployment_response_model, remote_container_registry +): """Tests finding an existing build.""" mock_list_builds = mocker.patch( "zenml.client.Client.list_builds", @@ -551,14 +553,30 @@ def test_finding_existing_build(mocker, sample_deployment_response_model): ], ) + build_utils.find_existing_build( + deployment=sample_deployment_response_model, + code_repository=StubCodeRepository(), + ) + # No container registry -> no non-local build to pull + mock_list_builds.assert_not_called() + + mocker.patch.object( + Stack, + "container_registry", + new_callable=mocker.PropertyMock, + return_value=remote_container_registry, + ) + build = build_utils.find_existing_build( deployment=sample_deployment_response_model, code_repository=StubCodeRepository(), ) + mock_list_builds.assert_called_once_with( sort_by="desc:created", size=1, stack_id=Client().active_stack.id, + container_registry_id=remote_container_registry.id, is_local=False, contains_code=False, zenml_version=zenml.__version__,