Skip to content

Commit

Permalink
Copy Literals to input Parameter enum field
Browse files Browse the repository at this point in the history
Signed-off-by: Alice Purcell <[email protected]>
  • Loading branch information
alicederyn committed Oct 28, 2024
1 parent e9e774f commit 3aa8505
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,19 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par
For a function parameter, python_name should be the parameter name.
For a Pydantic Input or Output class, python_name should be the field name.
"""
if annotation := get_workflow_annotation(annotation):
if workflow_annotation := get_workflow_annotation(annotation):
# Copy so as to not modify the fields themselves
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or python_name
return annotation_copy
io = workflow_annotation.copy()
else:
io = Parameter()

return Parameter(name=python_name)
io.name = io.name or python_name
if isinstance(io, Parameter) and not io.enum:
type_ = unwrap_annotation(annotation)
if get_origin(type_) is Literal:
io.enum = list(get_args(type_))

return io


def get_unsubscripted_type(t: Any) -> Any:
Expand Down
18 changes: 18 additions & 0 deletions tests/script_annotations/annotated_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Annotated, Literal

from hera.shared import global_config
from hera.workflows import Parameter, Steps, Workflow, script

global_config.experimental_features["script_annotations"] = True


@script(constructor="runner")
def literal_str(
my_str: Annotated[Literal["foo", "bar"], Parameter(name="my-str")],
) -> Annotated[Literal[1, 2], Parameter(name="index")]:
return {"foo": 1, "bar": 2}[my_str]


with Workflow(name="my-workflow", entrypoint="steps") as w:
with Steps(name="steps"):
literal_str()
13 changes: 13 additions & 0 deletions tests/script_annotations/literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Literal

from hera.workflows import Steps, Workflow, script


@script(constructor="runner")
def literal_str(my_str: Literal["foo", "bar"]) -> Literal[1, 2]:
return {"foo": 1, "bar": 2}[my_str]


with Workflow(name="my-workflow", entrypoint="steps") as w:
with Steps(name="steps"):
literal_str()
24 changes: 24 additions & 0 deletions tests/script_annotations/pydantic_io_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Literal

from hera.shared import global_config
from hera.workflows import Input, Output, Steps, Workflow, script

global_config.experimental_features["script_pydantic_io"] = True


class ExampleInput(Input):
my_str: Literal["foo", "bar"]


class ExampleOutput(Output):
index: Literal[1, 2]


@script(constructor="runner")
def literal_str(input: ExampleInput) -> ExampleOutput:
return ExampleOutput(index={"foo": 1, "bar": 2}[input.my_str])


with Workflow(name="my-workflow", entrypoint="steps") as w:
with Steps(name="steps"):
literal_str()
28 changes: 28 additions & 0 deletions tests/test_script_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,31 @@ def test_script_with_param(global_config_fixture, module_name):
}
]
assert consume_task["withParam"] == "{{tasks.generate.outputs.parameters.some-values}}"


@pytest.mark.parametrize(
("module_name", "input_name"),
[
pytest.param("tests.script_annotations.literals", "my_str", id="bare-type-annotation"),
pytest.param("tests.script_annotations.annotated_literals", "my-str", id="annotated"),
pytest.param("tests.script_annotations.pydantic_io_literals", "my_str", id="pydantic-io"),
],
)
def test_script_literals(global_config_fixture, module_name, input_name):
"""Test that Literals work correctly as direct type annotations."""
# GIVEN
# Force a reload of the test module, as the runner performs "importlib.import_module", which
# may fetch a cached version
module = importlib.import_module(module_name)
importlib.reload(module)
workflow: Workflow = importlib.import_module(module.__name__).w

# WHEN
workflow_dict = workflow.to_dict()
assert workflow == Workflow.from_dict(workflow_dict)
assert workflow == Workflow.from_yaml(workflow.to_yaml())

# THEN
(literal_str,) = (t for t in workflow_dict["spec"]["templates"] if t["name"] == "literal-str")

assert literal_str["inputs"]["parameters"] == [{"name": input_name, "enum": ["foo", "bar"]}]

0 comments on commit 3aa8505

Please sign in to comment.