diff --git a/src/zenml/__init__.py b/src/zenml/__init__.py index 299a090d189..a20957a33a3 100644 --- a/src/zenml/__init__.py +++ b/src/zenml/__init__.py @@ -48,6 +48,7 @@ from zenml.pipelines import get_pipeline_context, pipeline from zenml.steps import step, get_step_context from zenml.steps.utils import log_step_metadata +from zenml.utils.metadata_utils import log_metadata from zenml.entrypoints import entrypoint __all__ = [ @@ -56,6 +57,7 @@ "get_pipeline_context", "get_step_context", "load_artifact", + "log_metadata", "log_artifact_metadata", "log_model_metadata", "log_step_metadata", diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 83b7693eee6..e18485d42ab 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -408,7 +408,7 @@ def log_artifact_metadata( not provided, when being called inside a step that produces an artifact named `artifact_name`, the metadata will be associated to the corresponding newly created artifact. Or, if not provided when - being called outside of a step, or in a step that does not produce + being called outside a step, or in a step that does not produce any artifact named `artifact_name`, the metadata will be associated to the latest version of that artifact. @@ -417,6 +417,10 @@ def log_artifact_metadata( called inside a step with a single output, or, if neither an artifact nor an output with the given name exists. """ + logger.warning( + "The `log_artifact_metadata` function is deprecated and will soon be " + "removed. Please use `log_metadata` instead." + ) try: step_context = get_step_context() in_step_outputs = (artifact_name in step_context._outputs) or ( diff --git a/src/zenml/client.py b/src/zenml/client.py index 6cc5318b202..154700bce8c 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -3796,6 +3796,7 @@ def list_pipeline_runs( templatable: Optional[bool] = None, tag: Optional[str] = None, user: Optional[Union[UUID, str]] = None, + run_metadata: Optional[Dict[str, str]] = None, pipeline: Optional[Union[UUID, str]] = None, code_repository: Optional[Union[UUID, str]] = None, model: Optional[Union[UUID, str]] = None, @@ -3835,6 +3836,7 @@ def list_pipeline_runs( templatable: If the runs should be templatable or not. tag: Tag to filter by. user: The name/ID of the user to filter by. + run_metadata: The run_metadata of the run to filter by. pipeline: The name/ID of the pipeline to filter by. code_repository: Filter by code repository name/ID. model: Filter by model name/ID. @@ -3874,6 +3876,7 @@ def list_pipeline_runs( tag=tag, unlisted=unlisted, user=user, + run_metadata=run_metadata, pipeline=pipeline, code_repository=code_repository, stack=stack, @@ -4194,7 +4197,7 @@ def get_artifact_version( ), ) except RuntimeError: - pass # Cannot link to step run if called outside of a step + pass # Cannot link to step run if called outside a step return artifact def list_artifact_versions( @@ -4222,6 +4225,7 @@ def list_artifact_versions( user: Optional[Union[UUID, str]] = None, model: Optional[Union[UUID, str]] = None, pipeline_run: Optional[Union[UUID, str]] = None, + run_metadata: Optional[Dict[str, str]] = None, tag: Optional[str] = None, hydrate: bool = False, ) -> Page[ArtifactVersionResponse]: @@ -4253,6 +4257,7 @@ def list_artifact_versions( user: Filter by user name or ID. model: Filter by model name or ID. pipeline_run: Filter by pipeline run name or ID. + run_metadata: Filter by run metadata. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. diff --git a/src/zenml/enums.py b/src/zenml/enums.py index c39b39c43ea..a3a1a7fec56 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -253,6 +253,7 @@ class GenericFilterOps(StrEnum): CONTAINS = "contains" STARTSWITH = "startswith" ENDSWITH = "endswith" + ONEOF = "oneof" GTE = "gte" GT = "gt" LTE = "lte" diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 2593b606d17..5ec2123098f 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -56,6 +56,11 @@ def log_model_metadata( ValueError: If no model name/version is provided and the function is not called inside a step with configured `model` in decorator. """ + logger.warning( + "The `log_model_metadata` function is deprecated and will soon be " + "removed. Please use `log_metadata` instead." + ) + if model_name and model_version: from zenml import Model diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 486226a3a5f..1c4d2cccfb5 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Base filter model definitions.""" +import json from abc import ABC, abstractmethod from datetime import datetime from typing import ( @@ -36,7 +37,7 @@ field_validator, model_validator, ) -from sqlalchemy import asc, desc +from sqlalchemy import Float, and_, asc, cast, desc from sqlmodel import SQLModel from zenml.constants import ( @@ -63,6 +64,11 @@ AnyQuery = TypeVar("AnyQuery", bound=Any) +ONEOF_ERROR = ( + "When you are using the 'oneof:' filtering make sure that the " + "provided value is a json formatted list." +) + class Filter(BaseModel, ABC): """Filter for all fields. @@ -171,8 +177,28 @@ class StrFilter(Filter): GenericFilterOps.STARTSWITH, GenericFilterOps.CONTAINS, GenericFilterOps.ENDSWITH, + GenericFilterOps.ONEOF, + GenericFilterOps.GT, + GenericFilterOps.GTE, + GenericFilterOps.LT, + GenericFilterOps.LTE, ] + @model_validator(mode="after") + def check_value_if_operation_oneof(self) -> "StrFilter": + """Validator to check if value is a list if oneof operation is used. + + Raises: + ValueError: If the value is not a list + + Returns: + self + """ + if self.operation == GenericFilterOps.ONEOF: + if not isinstance(self.value, list): + raise ValueError(ONEOF_ERROR) + return self + def generate_query_conditions_from_column(self, column: Any) -> Any: """Generate query conditions for a string column. @@ -181,6 +207,9 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: Returns: A list of query conditions. + + Raises: + ValueError: the comparison of the column to a numeric value fails. """ if self.operation == GenericFilterOps.CONTAINS: return column.like(f"%{self.value}%") @@ -190,6 +219,40 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: return column.endswith(f"{self.value}") if self.operation == GenericFilterOps.NOT_EQUALS: return column != self.value + if self.operation == GenericFilterOps.ONEOF: + return column.in_(self.value) + if self.operation in { + GenericFilterOps.GT, + GenericFilterOps.LT, + GenericFilterOps.GTE, + GenericFilterOps.LTE, + }: + try: + numeric_column = cast(column, Float) + + assert self.value is not None + + if self.operation == GenericFilterOps.GT: + return and_( + numeric_column, numeric_column > float(self.value) + ) + if self.operation == GenericFilterOps.LT: + return and_( + numeric_column, numeric_column < float(self.value) + ) + if self.operation == GenericFilterOps.GTE: + return and_( + numeric_column, numeric_column >= float(self.value) + ) + if self.operation == GenericFilterOps.LTE: + return and_( + numeric_column, numeric_column <= float(self.value) + ) + except Exception as e: + raise ValueError( + f"Failed to compare the column '{column}' to the " + f"value '{self.value}' (must be numeric): {e}" + ) return column == self.value @@ -211,6 +274,9 @@ def _remove_hyphens_from_value(cls, value: Any) -> Any: if isinstance(value, str): return value.replace("-", "") + if isinstance(value, list): + return [str(v).replace("-", "") for v in value] + return value def generate_query_conditions_from_column(self, column: Any) -> Any: @@ -588,6 +654,10 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]: Returns: A tuple of the filter value and the operator. + + Raises: + ValueError: when we try to use the `oneof` operator with the wrong + value. """ operator = GenericFilterOps.EQUALS # Default operator if isinstance(value, str): @@ -598,6 +668,15 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]: ): value = split_value[1] operator = GenericFilterOps(split_value[0]) + + if operator == operator.ONEOF: + try: + value = json.loads(value) + if not isinstance(value, list): + raise ValueError + except ValueError: + raise ValueError(ONEOF_ERROR) + return value, operator def generate_name_or_id_query_conditions( @@ -648,8 +727,8 @@ def generate_name_or_id_query_conditions( return or_(*conditions) + @staticmethod def generate_custom_query_conditions_for_column( - self, value: Any, table: Type[SQLModel], column: str, @@ -833,16 +912,17 @@ def define_filter( # Create str filters if self.is_str_field(column): - return StrFilter( - operation=GenericFilterOps(operator), + return self._define_str_filter( + operator=GenericFilterOps(operator), column=column, value=value, ) # Handle unsupported datatypes logger.warning( - f"The Datatype {self._model_class.model_fields[column].annotation} might " - "not be supported for filtering. Defaulting to a string filter." + f"The Datatype {self._model_class.model_fields[column].annotation} " + "might not be supported for filtering. Defaulting to a string " + "filter." ) return StrFilter( operation=GenericFilterOps(operator), @@ -1032,8 +1112,9 @@ def _define_uuid_filter( "Invalid value passed as UUID query parameter." ) from e - # Cast the value to string for further comparisons. - value = str(value) + # For equality checks, ensure that the value is a valid UUID. + if operator == GenericFilterOps.ONEOF and not isinstance(value, list): + raise ValueError(ONEOF_ERROR) # Generate the filter. uuid_filter = UUIDFilter( @@ -1043,6 +1124,38 @@ def _define_uuid_filter( ) return uuid_filter + @staticmethod + def _define_str_filter( + column: str, value: Any, operator: GenericFilterOps + ) -> StrFilter: + """Define a str filter for a given column. + + Args: + column: The column to filter on. + value: The UUID value by which to filter. + operator: The operator to use for filtering. + + Returns: + A Filter object. + + Raises: + ValueError: If the value is not a proper value. + """ + # For equality checks, ensure that the value is a valid UUID. + if operator == GenericFilterOps.ONEOF and not isinstance(value, list): + raise ValueError( + "If you are using `oneof:` as a filtering op, the value needs " + "to be a json formatted list string." + ) + + # Generate the filter. + str_filter = StrFilter( + operation=GenericFilterOps(operator), + column=column, + value=value, + ) + return str_filter + @staticmethod def _define_bool_filter( column: str, value: Any, operator: GenericFilterOps diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index b5d80e32990..d26b3bceef4 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -474,6 +474,7 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): "user", "model", "pipeline_run", + "run_metadata", ] artifact_id: Optional[Union[UUID, str]] = Field( default=None, @@ -545,6 +546,10 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): description="Name/ID of a pipeline run that is associated with this " "artifact version.", ) + run_metadata: Optional[Dict[str, str]] = Field( + default=None, + description="The run_metadata to filter the artifact versions by.", + ) model_config = ConfigDict(protected_namespaces=()) @@ -564,6 +569,7 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ModelSchema, ModelVersionArtifactSchema, PipelineRunSchema, + RunMetadataSchema, StepRunInputArtifactSchema, StepRunOutputArtifactSchema, StepRunSchema, @@ -645,6 +651,23 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ) custom_filters.append(pipeline_run_filter) + if self.run_metadata is not None: + from zenml.enums import MetadataResourceTypes + + for key, value in self.run_metadata.items(): + additional_filter = and_( + RunMetadataSchema.resource_id == ArtifactVersionSchema.id, + RunMetadataSchema.resource_type + == MetadataResourceTypes.ARTIFACT_VERSION, + RunMetadataSchema.key == key, + self.generate_custom_query_conditions_for_column( + value=value, + table=RunMetadataSchema, + column="value", + ), + ) + custom_filters.append(additional_filter) + return custom_filters diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index dbc0a0f214f..f2e3a7aa911 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -590,6 +590,7 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, "user", + "run_metadata", ] name: Optional[str] = Field( @@ -619,6 +620,10 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): default=None, description="Name/ID of the user that created the model version.", ) + run_metadata: Optional[Dict[str, str]] = Field( + default=None, + description="The run_metadata to filter the model versions by.", + ) _model_id: UUID = PrivateAttr(None) @@ -651,6 +656,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelVersionSchema, + RunMetadataSchema, UserSchema, ) @@ -665,6 +671,23 @@ def get_custom_filters( ) custom_filters.append(user_filter) + if self.run_metadata is not None: + from zenml.enums import MetadataResourceTypes + + for key, value in self.run_metadata.items(): + additional_filter = and_( + RunMetadataSchema.resource_id == ModelVersionSchema.id, + RunMetadataSchema.resource_type + == MetadataResourceTypes.MODEL_VERSION, + RunMetadataSchema.key == key, + self.generate_custom_query_conditions_for_column( + value=value, + table=RunMetadataSchema, + column="value", + ), + ) + custom_filters.append(additional_filter) + return custom_filters def apply_filter( diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 26f517acdd3..8468c105bee 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -587,6 +587,7 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): "stack_component", "pipeline_name", "templatable", + "run_metadata", ] name: Optional[str] = Field( default=None, @@ -665,6 +666,10 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): default=None, description="Name/ID of the user that created the run.", ) + run_metadata: Optional[Dict[str, str]] = Field( + default=None, + description="The run_metadata to filter the pipeline runs by.", + ) # TODO: Remove once frontend is ready for it. This is replaced by the more # generic `pipeline` filter below. pipeline_name: Optional[str] = Field( @@ -694,7 +699,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): templatable: Optional[bool] = Field( default=None, description="Whether the run is templatable." ) - model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( @@ -718,6 +722,7 @@ def get_custom_filters( PipelineDeploymentSchema, PipelineRunSchema, PipelineSchema, + RunMetadataSchema, ScheduleSchema, StackComponentSchema, StackCompositionSchema, @@ -887,5 +892,21 @@ def get_custom_filters( ) custom_filters.append(templatable_filter) + if self.run_metadata is not None: + from zenml.enums import MetadataResourceTypes + + for key, value in self.run_metadata.items(): + additional_filter = and_( + RunMetadataSchema.resource_id == PipelineRunSchema.id, + RunMetadataSchema.resource_type + == MetadataResourceTypes.PIPELINE_RUN, + RunMetadataSchema.key == key, + self.generate_custom_query_conditions_for_column( + value=value, + table=RunMetadataSchema, + column="value", + ), + ) + custom_filters.append(additional_filter) return custom_filters diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index f4103f433d3..7052a1b42d7 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -509,6 +509,7 @@ class StepRunFilter(WorkspaceScopedFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "model", + "run_metadata", ] name: Optional[str] = Field( @@ -571,6 +572,10 @@ class StepRunFilter(WorkspaceScopedFilter): default=None, description="Name/ID of the model associated with the step run.", ) + run_metadata: Optional[Dict[str, str]] = Field( + default=None, + description="The run_metadata to filter the step runs by.", + ) model_config = ConfigDict(protected_namespaces=()) @@ -589,6 +594,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelSchema, ModelVersionSchema, + RunMetadataSchema, StepRunSchema, ) @@ -601,5 +607,21 @@ def get_custom_filters( ), ) custom_filters.append(model_filter) + if self.run_metadata is not None: + from zenml.enums import MetadataResourceTypes + + for key, value in self.run_metadata.items(): + additional_filter = and_( + RunMetadataSchema.resource_id == StepRunSchema.id, + RunMetadataSchema.resource_type + == MetadataResourceTypes.STEP_RUN, + RunMetadataSchema.key == key, + self.generate_custom_query_conditions_for_column( + value=value, + table=RunMetadataSchema, + column="value", + ), + ) + custom_filters.append(additional_filter) return custom_filters diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index c45523d8d77..e237d12f9ff 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -442,6 +442,11 @@ def log_step_metadata( from within a step or if no pipeline name or ID is provided and the function is not called from within a step. """ + logger.warning( + "The `log_step_metadata` function is deprecated and will soon be " + "removed. Please use `log_metadata` instead." + ) + step_context = None if not step_name: with contextlib.suppress(RuntimeError): diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py new file mode 100644 index 00000000000..47bd4f06e38 --- /dev/null +++ b/src/zenml/utils/metadata_utils.py @@ -0,0 +1,335 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utility functions to handle metadata for ZenML entities.""" + +import contextlib +from typing import Dict, Optional, Union, overload +from uuid import UUID + +from zenml.client import Client +from zenml.enums import MetadataResourceTypes +from zenml.logger import get_logger +from zenml.metadata.metadata_types import MetadataType +from zenml.steps.step_context import get_step_context + +logger = get_logger(__name__) + + +@overload +def log_metadata(metadata: Dict[str, MetadataType]) -> None: ... + + +@overload +def log_metadata( + *, + metadata: Dict[str, MetadataType], + artifact_version_id: UUID, +) -> None: ... + + +@overload +def log_metadata( + *, + metadata: Dict[str, MetadataType], + artifact_name: str, + artifact_version: Optional[str] = None, +) -> None: ... + + +@overload +def log_metadata( + *, + metadata: Dict[str, MetadataType], + model_version_id: UUID, +) -> None: ... + + +@overload +def log_metadata( + *, + metadata: Dict[str, MetadataType], + model_name: str, + model_version: str, +) -> None: ... + + +@overload +def log_metadata( + *, + metadata: Dict[str, MetadataType], + step_id: UUID, +) -> None: ... + + +@overload +def log_metadata( + *, + metadata: Dict[str, MetadataType], + run_id_name_or_prefix: Union[UUID, str], +) -> None: ... + + +@overload +def log_metadata( + *, + metadata: Dict[str, MetadataType], + step_name: str, + run_id_name_or_prefix: Union[UUID, str], +) -> None: ... + + +def log_metadata( + metadata: Dict[str, MetadataType], + # Parameters to manually log metadata for steps and runs + step_id: Optional[UUID] = None, + step_name: Optional[str] = None, + run_id_name_or_prefix: Optional[Union[UUID, str]] = None, + # Parameters to manually log metadata for artifacts + artifact_version_id: Optional[UUID] = None, + artifact_name: Optional[str] = None, + artifact_version: Optional[str] = None, + # Parameters to manually log metadata for models + model_version_id: Optional[UUID] = None, + model_name: Optional[str] = None, + model_version: Optional[str] = None, +) -> None: + """Logs metadata for various resource types in a generalized way. + + Args: + metadata: The metadata to log. + step_id: The ID of the step. + step_name: The name of the step. + run_id_name_or_prefix: The id, name or prefix of the run + artifact_version_id: The ID of the artifact version + artifact_name: The name of the artifact. + artifact_version: The version of the artifact. + model_version_id: The ID of the model version. + model_name: The name of the model. + model_version: The version of the model + + Raises: + ValueError: If no identifiers are provided and the function is not + called from within a step. + """ + client = Client() + + # If a step name is provided, we need a run_id_name_or_prefix and will log + # metadata for the steps pipeline and model accordingly. + if step_name is not None and run_id_name_or_prefix is not None: + run_model = client.get_pipeline_run( + name_id_or_prefix=run_id_name_or_prefix + ) + step_model = run_model.steps[step_name] + + client.create_run_metadata( + metadata=metadata, + resource_id=run_model.id, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step_model.id, + resource_type=MetadataResourceTypes.STEP_RUN, + ) + if step_model.model_version: + client.create_run_metadata( + metadata=metadata, + resource_id=step_model.model_version.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) + + # If a step is identified by id, fetch it directly through the client, + # follow a similar procedure and log metadata for its pipeline and model + # as well. + elif step_id is not None: + step_model = client.get_run_step(step_run_id=step_id) + client.create_run_metadata( + metadata=metadata, + resource_id=step_model.pipeline_run_id, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step_model.id, + resource_type=MetadataResourceTypes.STEP_RUN, + ) + if step_model.model_version: + client.create_run_metadata( + metadata=metadata, + resource_id=step_model.model_version.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) + + # If a pipeline run id is identified, we need to log metadata to it and its + # model as well. + elif run_id_name_or_prefix is not None: + run_model = client.get_pipeline_run( + name_id_or_prefix=run_id_name_or_prefix + ) + client.create_run_metadata( + metadata=metadata, + resource_id=run_model.id, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) + if run_model.model_version: + client.create_run_metadata( + metadata=metadata, + resource_id=run_model.model_version.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) + + # If the user provides a model name and version, we use to model abstraction + # to fetch the model version and attach the corresponding metadata to it. + elif model_name is not None and model_version is not None: + from zenml import Model + + mv = Model(name=model_name, version=model_version) + client.create_run_metadata( + metadata=metadata, + resource_id=mv.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) + + # If the user provides a model version id, we use the client to fetch it and + # attach the metadata to it. + elif model_version_id is not None: + model_version_id = client.get_model_version( + model_version_name_or_number_or_id=model_version_id + ).id + client.create_run_metadata( + metadata=metadata, + resource_id=model_version_id, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) + + # If the user provides an artifact name, there are three possibilities. If + # an artifact version is also provided with the name, we use both to fetch + # the artifact version and use it to log the metadata. If no version is + # provided, if the function is called within a step we search the artifacts + # of the step if not we fetch the latest version and attach the metadata + # to the latest version. + elif artifact_name is not None: + if artifact_version: + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_name, version=artifact_version + ) + client.create_run_metadata( + metadata=metadata, + resource_id=artifact_version_model.id, + resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + else: + step_context = None + with contextlib.suppress(RuntimeError): + step_context = get_step_context() + + if step_context: + step_context.add_output_metadata( + metadata=metadata, output_name=artifact_name + ) + else: + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_name + ) + client.create_run_metadata( + metadata=metadata, + resource_id=artifact_version_model.id, + resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + + # If the user directly provides an artifact_version_id, we use the client to + # fetch is and attach the metadata accordingly. + elif artifact_version_id is not None: + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_version_id, + ) + client.create_run_metadata( + metadata=metadata, + resource_id=artifact_version_model.id, + resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + + # If every additional value is None, that means we are calling it bare bones + # and this call needs to happen during a step execution. We will use the + # step context to fetch the step, run and possibly the model version and + # attach the metadata accordingly. + elif all( + v is None + for v in [ + step_id, + step_name, + run_id_name_or_prefix, + artifact_version_id, + artifact_name, + artifact_version, + model_version_id, + model_name, + model_version, + ] + ): + try: + step_context = get_step_context() + except RuntimeError: + raise ValueError( + "You are calling 'log_metadata()' outside of a step execution. " + "If you would like to add metadata to a ZenML entity outside " + "of the step execution, please provide the required " + "identifiers." + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step_context.pipeline_run.id, + resource_type=MetadataResourceTypes.PIPELINE_RUN, + ) + client.create_run_metadata( + metadata=metadata, + resource_id=step_context.step_run.id, + resource_type=MetadataResourceTypes.STEP_RUN, + ) + if step_context.model_version: + client.create_run_metadata( + metadata=metadata, + resource_id=step_context.model_version.id, + resource_type=MetadataResourceTypes.MODEL_VERSION, + ) + + else: + raise ValueError( + """ + Unsupported way to call the `log_metadata`. Possible combinations " + include: + + # Inside a step + # Logs the metadata to the step, its run and possibly its model + log_metadata(metadata={}) + + # Manually logging for a step + # Logs the metadata to the step, its run and possibly its model + log_metadata(metadata={}, step_name=..., run_id_name_or_prefix=...) + log_metadata(metadata={}, step_id=...) + + # Manually logging for a run + # Logs the metadata to the run, possibly its model + log_metadata(metadata={}, run_id_name_or_prefix=...) + + # Manually logging for a model + log_metadata(metadata={}, model_name=..., model_version=...) + log_metadata(metadata={}, model_version_id=...) + + # Manually logging for an artifact + log_metadata(metadata={}, artifact_name=...) # inside a step + log_metadata(metadata={}, artifact_name=..., artifact_version=...) + log_metadata(metadata={}, artifact_version_id=...) + """ + ) diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index 2421f55fe38..5c091e61fb0 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -15,6 +15,7 @@ from zenml import ( load_artifact, log_artifact_metadata, + log_metadata, pipeline, save_artifact, step, @@ -120,23 +121,25 @@ def _load_pipeline(expected_value, name, version): ) -def test_log_artifact_metadata_existing(clean_client): +def test_log_metadata_existing(clean_client): """Test logging artifact metadata for existing artifacts.""" save_artifact(42, "meaning_of_life") - log_artifact_metadata( - {"description": "Aria is great!"}, artifact_name="meaning_of_life" + log_metadata( + metadata={"description": "Aria is great!"}, + artifact_name="meaning_of_life", ) save_artifact(43, "meaning_of_life", version="43") - log_artifact_metadata( - {"description_2": "Blupus is great!"}, artifact_name="meaning_of_life" + log_metadata( + metadata={"description_2": "Blupus is great!"}, + artifact_name="meaning_of_life", ) - log_artifact_metadata( - {"description_3": "Axl is great!"}, + log_metadata( + metadata={"description_3": "Axl is great!"}, artifact_name="meaning_of_life", artifact_version="1", ) - log_artifact_metadata( - { + log_metadata( + metadata={ "float": 1.0, "int": 1, "str": "1.0", @@ -183,11 +186,11 @@ def artifact_metadata_logging_step() -> str: "description": "Aria is great!", "metrics": {"accuracy": 0.9}, } - log_artifact_metadata(output_metadata) + log_artifact_metadata(metadata=output_metadata) return "42" -def test_log_artifact_metadata_single_output(clean_client): +def test_log_metadata_single_output(clean_client): """Test logging artifact metadata for a single output.""" @pipeline @@ -212,11 +215,11 @@ def artifact_multi_output_metadata_logging_step() -> ( "description": "Blupus is great!", "metrics": {"accuracy": 0.9}, } - log_artifact_metadata(metadata=output_metadata, artifact_name="int_output") + log_metadata(metadata=output_metadata, artifact_name="int_output") return "42", 42 -def test_log_artifact_metadata_multi_output(clean_client): +def test_log_metadata_multi_output(clean_client): """Test logging artifact metadata for multiple outputs.""" @pipeline @@ -249,10 +252,10 @@ def wrong_artifact_multi_output_metadata_logging_step() -> ( return "42", 42 -def test_log_artifact_metadata_raises_error_if_output_name_unclear( +def test_log_metadata_raises_error_if_output_name_unclear( clean_client, ): - """Test that `log_artifact_metadata` raises an error if the output name is unclear.""" + """Test that `log_metadata` raises an error if the output name is unclear.""" @pipeline def artifact_metadata_logging_pipeline(): diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index b8cbf95e738..cdf98ac9301 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -19,12 +19,12 @@ from typing_extensions import Annotated from tests.integration.functional.utils import random_str -from zenml import get_step_context, pipeline, step +from zenml import get_step_context, log_metadata, pipeline, step from zenml.artifacts.utils import save_artifact from zenml.client import Client from zenml.enums import ModelStages from zenml.model.model import Model -from zenml.model.utils import link_artifact_to_model, log_model_metadata +from zenml.model.utils import link_artifact_to_model from zenml.models import TagRequest @@ -107,10 +107,10 @@ def __exit__(self, exc_type, exc_value, exc_traceback): @step def step_metadata_logging_functional(mdl_name: str): """Functional logging using implicit Model from context.""" - log_model_metadata({"foo": "bar"}) + log_metadata({"foo": "bar"}) assert get_step_context().model.run_metadata["foo"] == "bar" - log_model_metadata( - {"foo": "bar"}, model_name=mdl_name, model_version="other" + log_metadata( + metadata={"foo": "bar"}, model_name=mdl_name, model_version="other" ) @@ -409,18 +409,22 @@ def test_metadata_logging_functional(self): ) mv._get_or_create_model_version() - log_model_metadata( - {"foo": "bar"}, model_name=mv.name, model_version=mv.number + log_metadata( + metadata={"foo": "bar"}, + model_name=mv.name, + model_version=str(mv.number), ) assert len(mv.run_metadata) == 1 assert mv.run_metadata["foo"] == "bar" with pytest.raises(ValueError): - log_model_metadata({"foo": "bar"}) + log_metadata({"foo": "bar"}) - log_model_metadata( - {"bar": "foo"}, model_name=mv.name, model_version="latest" + log_metadata( + metadata={"bar": "foo"}, + model_name=mv.name, + model_version="latest", ) assert len(mv.run_metadata) == 2 diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index 978b0c81844..70e7608f7a8 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -6,11 +6,10 @@ Model, get_pipeline_context, get_step_context, - log_model_metadata, + log_metadata, pipeline, step, ) -from zenml.artifacts.utils import log_artifact_metadata from zenml.client import Client @@ -101,17 +100,25 @@ def test_that_argument_as_get_artifact_of_model_in_pipeline_context_fails_if_not @step def producer() -> Annotated[str, "bar"]: """Produce artifact with metadata and attach metadata to model version.""" - ver = get_step_context().model.version - log_model_metadata(metadata={"foobar": "model_meta_" + ver}) - log_artifact_metadata(metadata={"foobar": "artifact_meta_" + ver}) - return "artifact_data_" + ver + model = get_step_context().model + + log_metadata( + metadata={"foobar": "model_meta_" + model.version}, + model_name=model.name, + model_version=model.version, + ) + log_metadata( + metadata={"foobar": "artifact_meta_" + model.version}, + artifact_name="bar", + ) + return "artifact_data_" + model.version @step def asserter(artifact: str, artifact_metadata: str, model_metadata: str): """Assert that passed in values are loaded in lazy mode. - They do not exists before actual run of the pipeline. + They do not exist before actual run of the pipeline. """ ver = get_step_context().model.version assert artifact == "artifact_data_" + ver diff --git a/tests/integration/functional/steps/test_step_context.py b/tests/integration/functional/steps/test_step_context.py index 4442f84b08e..d520cfd83a4 100644 --- a/tests/integration/functional/steps/test_step_context.py +++ b/tests/integration/functional/steps/test_step_context.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from typing_extensions import Annotated -from zenml import get_step_context, log_artifact_metadata, pipeline, step +from zenml import get_step_context, log_metadata, pipeline, step from zenml.artifacts.artifact_config import ArtifactConfig from zenml.client import Client from zenml.enums import ArtifactType @@ -92,7 +92,9 @@ def _simple_step_pipeline(): @step def output_metadata_logging_step() -> Annotated[int, "my_output"]: - log_artifact_metadata(metadata={"some_key": "some_value"}) + log_metadata( + metadata={"some_key": "some_value"}, artifact_name="my_output" + ) return 42 diff --git a/tests/integration/functional/steps/test_utils.py b/tests/integration/functional/steps/test_utils.py index 7bdff4867e9..ed983547a2f 100644 --- a/tests/integration/functional/steps/test_utils.py +++ b/tests/integration/functional/steps/test_utils.py @@ -14,8 +14,7 @@ """Tests for utility functions and classes to run ZenML steps.""" -from zenml import pipeline, step -from zenml.steps.utils import log_step_metadata +from zenml import log_metadata, pipeline, step @step @@ -31,11 +30,11 @@ def step_metadata_logging_step_inside_run() -> str: "description": "Aria is great!", "metrics": {"accuracy": 0.9}, } - log_step_metadata(metadata=step_metadata) + log_metadata(metadata=step_metadata) return "42" -def test_log_step_metadata_within_step(clean_client): +def test_log_metadata_within_step(clean_client): """Test logging step metadata for the latest run.""" @pipeline @@ -54,7 +53,7 @@ def step_metadata_logging_pipeline(): assert run_metadata["metrics"] == {"accuracy": 0.9} -def test_log_step_metadata_using_latest_run(clean_client): +def test_log_metadata_using_latest_run(clean_client): """Test logging step metadata for the latest run.""" @pipeline @@ -74,10 +73,10 @@ def step_metadata_logging_pipeline(): "description": "Axl is great!", "metrics": {"accuracy": 0.9}, } - log_step_metadata( + log_metadata( metadata=step_metadata, step_name="step_metadata_logging_step", - pipeline_name_id_or_prefix="step_metadata_logging_pipeline", + run_id_name_or_prefix="step_metadata_logging_pipeline", ) run_after_log = step_metadata_logging_pipeline.model.last_run run_metadata_after_log = run_after_log.steps[ @@ -89,7 +88,7 @@ def step_metadata_logging_pipeline(): assert run_metadata_after_log["metrics"] == {"accuracy": 0.9} -def test_log_step_metadata_using_specific_params(clean_client): +def test_log_metadata_using_specific_params(clean_client): """Test logging step metadata for a specific step.""" @pipeline @@ -114,10 +113,9 @@ def step_metadata_logging_pipeline(): "description": "Blupus is great!", "metrics": {"accuracy": 0.9}, } - log_step_metadata( + log_metadata( metadata=step_metadata, - step_name="step_metadata_logging_step", - run_id=step_run_id, + step_id=step_run_id, ) run_after_log = step_metadata_logging_pipeline.model.last_run run_metadata_after_log = run_after_log.steps[ diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index 9daa823a1a7..bd9583a8a1d 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -32,8 +32,7 @@ ExternalArtifact, get_pipeline_context, get_step_context, - log_artifact_metadata, - log_model_metadata, + log_metadata, pipeline, save_artifact, step, @@ -968,16 +967,20 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: """Produce artifact with metadata.""" from zenml.client import Client - log_artifact_metadata(metadata={"some_meta": "meta_new_one"}) + log_metadata( + metadata={"some_meta": "meta_new_one"}, artifact_name="new_one" + ) client = Client() - log_model_metadata( + model = get_step_context().model + + log_metadata( metadata={"some_meta": "meta_new_one"}, + model_name=model.name, + model_version=model.version, ) - model = get_step_context().model - mv = client.create_model_version( model_name_or_id=model.name, name="model_version2", @@ -1132,12 +1135,12 @@ def dummy(): save_artifact( data="body_preexisting", name="preexisting", version="1.2.3" ) - log_artifact_metadata( + log_metadata( metadata={"some_meta": "meta_preexisting"}, artifact_name="preexisting", artifact_version="1.2.3", ) - log_model_metadata( + log_metadata( metadata={"some_meta": "meta_preexisting"}, model_name="aria", model_version="model_version", diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 57a477dea7b..ae4a4011108 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +import json import os import random import time @@ -19,12 +20,12 @@ from datetime import datetime from string import ascii_lowercase from threading import Thread -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from unittest.mock import patch from uuid import UUID, uuid4 import pytest -from pydantic import SecretStr +from pydantic import SecretStr, ValidationError from sqlalchemy.exc import IntegrityError from tests.integration.functional.utils import sample_name @@ -46,6 +47,7 @@ from tests.unit.pipelines.test_build_utils import ( StubLocalRepositoryContext, ) +from zenml import Model, log_metadata, pipeline, step from zenml.artifacts.utils import ( _load_artifact_store, ) @@ -2905,6 +2907,64 @@ def test_deleting_run_deletes_steps(): assert store.list_run_steps(filter_model).total == 0 +@step +def step_to_log_metadata(metadata: Union[str, int, bool]) -> int: + log_metadata({"blupus": metadata}) + return 42 + + +@pipeline(name="aria", model=Model(name="axl"), tags=["cats", "squirrels"]) +def pipeline_to_log_metadata(metadata): + step_to_log_metadata(metadata) + + +def test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client): + store = clean_client.zen_store + + metadata_values = [3, 25, 100, "random_string", True] + + runs = [] + for v in metadata_values: + runs.append(pipeline_to_log_metadata(v)) + + # Test oneof: name filtering + runs_filter = PipelineRunFilter( + name=f"oneof:{json.dumps([r.name for r in runs[:2]])}" + ) + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == 2 # The first two runs + + # Test oneof: UUID filtering + runs_filter = PipelineRunFilter( + id=f"oneof:{json.dumps([str(r.id) for r in runs[:2]])}" + ) + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == 2 # The first two runs + + # Test oneof: tags filtering + runs_filter = PipelineRunFilter(tag='oneof:["cats", "dogs"]') + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == len(metadata_values) # All runs + + runs_filter = PipelineRunFilter(tag='oneof:["dogs"]') + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == 0 # No runs + + # Test oneof: formatting + with pytest.raises(ValidationError): + PipelineRunFilter(name="oneof:random_value") + + # Test metadata filtering + runs_filter = PipelineRunFilter(run_metadata={"blupus": "lt:30"}) + runs = store.list_runs(runs_filter_model=runs_filter) + assert len(runs) == 2 # The run with 3 and 25 + + for r in runs: + assert "blupus" in r.run_metadata + assert isinstance(r.run_metadata["blupus"], int) + assert r.run_metadata["blupus"] < 30 + + # .--------------------. # | Pipeline run steps | # '--------------------' diff --git a/tests/unit/models/test_filter_models.py b/tests/unit/models/test_filter_models.py index 779d6120b01..46b711bb7cc 100644 --- a/tests/unit/models/test_filter_models.py +++ b/tests/unit/models/test_filter_models.py @@ -182,7 +182,7 @@ def test_datetime_filter_model(): filter_class=DatetimeFilter, filter_value=filter_value, expected_value=expected_value, - ignore_operators=[GenericFilterOps.IN], + ignore_operators=[GenericFilterOps.IN, GenericFilterOps.ONEOF], ) @@ -231,6 +231,7 @@ def test_uuid_filter_model(): filter_class=UUIDFilter, filter_value=filter_value, expected_value=str(filter_value).replace("-", ""), + ignore_operators=[GenericFilterOps.ONEOF], ) @@ -245,7 +246,10 @@ def test_uuid_filter_model_succeeds_for_invalid_uuid_on_non_equality(): """Test filtering with other UUID operations is possible with non-UUIDs.""" filter_value = "a92k34" for filter_op in UUIDFilter.ALLOWED_OPS: - if filter_op == GenericFilterOps.EQUALS: + if ( + filter_op == GenericFilterOps.EQUALS + or filter_op == GenericFilterOps.ONEOF + ): continue filter_model = SomeFilterModel( uuid_field=f"{filter_op}:{filter_value}" @@ -264,4 +268,5 @@ def test_string_filter_model(): filter_field="str_field", filter_class=StrFilter, filter_value="a_random_string", + ignore_operators=[GenericFilterOps.ONEOF], )