diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index cfb9c51e..f1c27f40 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -83,11 +83,16 @@ def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Paramet return metadata[0] -def add_metadata_from_type(parameter: Parameter, annotation: Any) -> None: - if not parameter.enum: - type_ = unwrap_annotation(annotation) - if get_origin(type_) is Literal: - parameter.enum = list(get_args(type_)) +def set_enum_based_on_type(parameter: Parameter, annotation: Any) -> None: + """Sets the enum field of a Parameter based on its type annotation. + + Currently, only supports Literals. + """ + if parameter.enum: + return + type_ = unwrap_annotation(annotation) + if get_origin(type_) is Literal: + parameter.enum = list(get_args(type_)) def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]: @@ -107,7 +112,7 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par io.name = io.name or python_name if isinstance(io, Parameter): - add_metadata_from_type(io, annotation) + set_enum_based_on_type(io, annotation) return io diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index 0974e488..2065c438 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -48,11 +48,11 @@ ) from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator from hera.shared._type_util import ( - add_metadata_from_type, construct_io_from_annotation, get_workflow_annotation, is_subscripted, origin_type_issupertype, + set_enum_based_on_type, ) from hera.shared.serialization import serialize from hera.workflows._context import _context @@ -380,7 +380,7 @@ def _get_parameters_from_callable(source: Callable) -> List[Parameter]: default = MISSING param = Parameter(name=p.name, default=default) - add_metadata_from_type(param, p.annotation) + set_enum_based_on_type(param, p.annotation) parameters.append(param) return parameters