Skip to content

Commit

Permalink
Rename new function to set_enum_based_on_type
Browse files Browse the repository at this point in the history
This function only support setting the enum field, and it doesn't appear
likely that we will want to set other fields in future, so make the name
more specific. Additionally, add a docstring, and refactor the function
to use the early-return pattern now adding other metadata is ruled out.

Signed-off-by: Alice Purcell <[email protected]>
  • Loading branch information
alicederyn committed Nov 6, 2024
1 parent fa6a37f commit dbb4a66
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
17 changes: 11 additions & 6 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,16 @@ def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Paramet
return metadata[0]


def add_metadata_from_type(parameter: Parameter, annotation: Any) -> None:
if not parameter.enum:
type_ = unwrap_annotation(annotation)
if get_origin(type_) is Literal:
parameter.enum = list(get_args(type_))
def set_enum_based_on_type(parameter: Parameter, annotation: Any) -> None:
"""Sets the enum field of a Parameter based on its type annotation.
Currently, only supports Literals.
"""
if parameter.enum:
return
type_ = unwrap_annotation(annotation)
if get_origin(type_) is Literal:
parameter.enum = list(get_args(type_))


def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]:
Expand All @@ -107,7 +112,7 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par

io.name = io.name or python_name
if isinstance(io, Parameter):
add_metadata_from_type(io, annotation)
set_enum_based_on_type(io, annotation)

return io

Expand Down
4 changes: 2 additions & 2 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@
)
from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator
from hera.shared._type_util import (
add_metadata_from_type,
construct_io_from_annotation,
get_workflow_annotation,
is_subscripted,
origin_type_issupertype,
set_enum_based_on_type,
)
from hera.shared.serialization import serialize
from hera.workflows._context import _context
Expand Down Expand Up @@ -380,7 +380,7 @@ def _get_parameters_from_callable(source: Callable) -> List[Parameter]:
default = MISSING

param = Parameter(name=p.name, default=default)
add_metadata_from_type(param, p.annotation)
set_enum_based_on_type(param, p.annotation)
parameters.append(param)

return parameters
Expand Down

0 comments on commit dbb4a66

Please sign in to comment.