diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index 7a1ff10f..3dc7138c 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -83,6 +83,13 @@ def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Paramet return metadata[0] +def add_metadata_from_type(parameter: Parameter, annotation: Any) -> None: + if not parameter.enum: + type_ = unwrap_annotation(annotation) + if get_origin(type_) is Literal: + parameter.enum = list(get_args(type_)) + + def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]: """Constructs a Parameter or Artifact object based on annotations. diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index faac6296..0974e488 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -48,6 +48,7 @@ ) from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator from hera.shared._type_util import ( + add_metadata_from_type, construct_io_from_annotation, get_workflow_annotation, is_subscripted, @@ -379,6 +380,7 @@ def _get_parameters_from_callable(source: Callable) -> List[Parameter]: default = MISSING param = Parameter(name=p.name, default=default) + add_metadata_from_type(param, p.annotation) parameters.append(param) return parameters diff --git a/tests/test_script_annotations.py b/tests/test_script_annotations.py index 30a87143..93fe49e3 100644 --- a/tests/test_script_annotations.py +++ b/tests/test_script_annotations.py @@ -486,10 +486,12 @@ def test_script_with_param(global_config_fixture, module_name): pytest.param("tests.script_annotations.pydantic_io_literals", "my_str", id="pydantic-io"), ], ) -def test_script_literals(global_config_fixture, module_name, input_name): +@pytest.mark.parametrize("experimental_feature", ["", "script_annotations", "script_pydantic_io"]) +def test_script_literals(global_config_fixture, module_name, input_name, experimental_feature): """Test that Literals work correctly as direct type annotations.""" # GIVEN - global_config_fixture.experimental_features["script_annotations"] = True + if experimental_feature: + global_config_fixture.experimental_features[experimental_feature] = True # Force a reload of the test module, as the runner performs "importlib.import_module", which # may fetch a cached version