Skip to content

Commit

Permalink
Support Literal types in script runner (#1249)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [X] Fixes #1173
- [X] Tests added
- [ ] Documentation/examples added
- [X] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Currently, the script runner does not support `Literal` types, as
`origin_type_issubtype` does not special-case them, so
`_is_str_kwarg_of` does not return True.

This PR adds supports for Literals to `origin_type_issubtype`, and
additionally copies their values from the annotation into the `enum`
field on input Parameters. It also combines two paths in
`_get_inputs_from_callable` into one by using
`construct_io_from_annotation` instead of having a fallback branch if
`get_workflow_annotation` returns `None`; this fixes a bug where the
default was being verified as None for optional strings only if there
was no IO annotation.

---------

Signed-off-by: Alice Purcell <[email protected]>
  • Loading branch information
alicederyn authored Nov 6, 2024
1 parent 52670c4 commit 4328576
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 37 deletions.
29 changes: 24 additions & 5 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Any,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -82,6 +83,18 @@ def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Paramet
return metadata[0]


def set_enum_based_on_type(parameter: Parameter, annotation: Any) -> None:
"""Sets the enum field of a Parameter based on its type annotation.
Currently, only supports Literals.
"""
if parameter.enum:
return
type_ = unwrap_annotation(annotation)
if get_origin(type_) is Literal:
parameter.enum = list(get_args(type_))


def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]:
"""Constructs a Parameter or Artifact object based on annotations.
Expand All @@ -91,13 +104,17 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par
For a function parameter, python_name should be the parameter name.
For a Pydantic Input or Output class, python_name should be the field name.
"""
if annotation := get_workflow_annotation(annotation):
if workflow_annotation := get_workflow_annotation(annotation):
# Copy so as to not modify the fields themselves
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or python_name
return annotation_copy
io = workflow_annotation.copy()
else:
io = Parameter()

io.name = io.name or python_name
if isinstance(io, Parameter):
set_enum_based_on_type(io, annotation)

return Parameter(name=python_name)
return io


def get_unsubscripted_type(t: Any) -> Any:
Expand All @@ -120,6 +137,8 @@ def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]])
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return all(origin_type_issubtype(arg, type_) for arg in get_args(unwrapped_type))
if origin_type is Literal:
return all(isinstance(value, type_) for value in get_args(unwrapped_type))
return isinstance(origin_type, type) and issubclass(origin_type, type_)


Expand Down
46 changes: 16 additions & 30 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
get_workflow_annotation,
is_subscripted,
origin_type_issupertype,
set_enum_based_on_type,
)
from hera.shared.serialization import serialize
from hera.workflows._context import _context
Expand Down Expand Up @@ -379,6 +380,7 @@ def _get_parameters_from_callable(source: Callable) -> List[Parameter]:
default = MISSING

param = Parameter(name=p.name, default=default)
set_enum_based_on_type(param, p.annotation)
parameters.append(param)

return parameters
Expand Down Expand Up @@ -495,22 +497,18 @@ class will be used as inputs, rather than the class itself.

artifacts.extend(input_class._get_artifacts(add_missing_path=True))

elif param_or_artifact := get_workflow_annotation(func_param.annotation):
if param_or_artifact.output:
else:
io = construct_io_from_annotation(func_param.name, func_param.annotation)
if io.output:
continue

# Create a new object so we don't modify the Workflow itself
new_object = param_or_artifact.copy()
if not new_object.name:
new_object.name = func_param.name

if isinstance(new_object, Artifact):
if new_object.path is None:
new_object.path = new_object._get_default_inputs_path()
if isinstance(io, Artifact):
if io.path is None:
io.path = io._get_default_inputs_path()

artifacts.append(new_object)
elif isinstance(new_object, Parameter):
if new_object.default is not None:
artifacts.append(io)
elif isinstance(io, Parameter):
if io.default is not None:
# TODO: in 5.18 remove the flag check and `warn`, and raise the ValueError directly (minus "flag" text)
warnings.warn(
"Using the default field for Parameters in Annotations is deprecated since v5.16"
Expand All @@ -524,27 +522,15 @@ class will be used as inputs, rather than the class itself.
)
if func_param.default != inspect.Parameter.empty:
# TODO: remove this check in 5.18:
if new_object.default is not None:
if io.default is not None:
raise ValueError(
"default cannot be set via both the function parameter default and the Parameter's default"
)
new_object.default = serialize(func_param.default)
parameters.append(new_object)
else:
if (
func_param.default != inspect.Parameter.empty
and func_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
):
default = func_param.default
else:
default = MISSING

if origin_type_issupertype(func_param.annotation, NoneType) and (
default is MISSING or default is not None
):
raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.")
io.default = serialize(func_param.default)

parameters.append(Parameter(name=func_param.name, default=default))
if origin_type_issupertype(func_param.annotation, NoneType) and io.default != "null":
raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.")
parameters.append(io)

return parameters, artifacts

Expand Down
18 changes: 18 additions & 0 deletions tests/script_annotations/annotated_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Annotated, Literal

from hera.shared import global_config
from hera.workflows import Parameter, Steps, Workflow, script

global_config.experimental_features["script_annotations"] = True


@script(constructor="runner")
def literal_str(
my_str: Annotated[Literal["foo", "bar"], Parameter(name="my-str")],
) -> Annotated[Literal[1, 2], Parameter(name="index")]:
return {"foo": 1, "bar": 2}[my_str]


with Workflow(name="my-workflow", entrypoint="steps") as w:
with Steps(name="steps"):
literal_str()
13 changes: 13 additions & 0 deletions tests/script_annotations/literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Literal

from hera.workflows import Steps, Workflow, script


@script(constructor="runner")
def literal_str(my_str: Literal["foo", "bar"]) -> Literal[1, 2]:
return {"foo": 1, "bar": 2}[my_str]


with Workflow(name="my-workflow", entrypoint="steps") as w:
with Steps(name="steps"):
literal_str()
24 changes: 24 additions & 0 deletions tests/script_annotations/pydantic_io_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Literal

from hera.shared import global_config
from hera.workflows import Input, Output, Steps, Workflow, script

global_config.experimental_features["script_pydantic_io"] = True


class ExampleInput(Input):
my_str: Literal["foo", "bar"]


class ExampleOutput(Output):
index: Literal[1, 2]


@script(constructor="runner")
def literal_str(input: ExampleInput) -> ExampleOutput:
return ExampleOutput(index={"foo": 1, "bar": 2}[input.my_str])


with Workflow(name="my-workflow", entrypoint="steps") as w:
with Steps(name="steps"):
literal_str()
12 changes: 11 additions & 1 deletion tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List, Union
from typing import Any, List, Literal, Union

try:
from typing import Annotated
Expand Down Expand Up @@ -53,6 +53,11 @@ def annotated_basic_types_with_other_metadata(
return Output(output=[Input(a=a_but_kebab, b=b_but_kebab)])


@script()
def annotated_str_literal(my_literal: Annotated[Literal["1", "2"], Parameter(name="str-literal")]) -> str:
return f"type given: {type(my_literal).__name__}"


@script()
def annotated_object(annotated_input_value: Annotated[Input, Parameter(name="input-value")]) -> Output:
return Output(output=[annotated_input_value])
Expand Down Expand Up @@ -81,6 +86,11 @@ def str_or_int_parameter(my_str_or_int: Union[int, str]) -> str:
return f"type given: {type(my_str_or_int).__name__}"


@script()
def str_literal(my_literal: Literal["1", "2"]) -> str:
return f"type given: {type(my_literal).__name__}"


@script()
def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict:
return json.loads(my_json_str)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
"type given: int",
id="str-or-int-given-int",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_literal",
[{"name": "my_literal", "value": "1"}],
"type given: str",
id="str-literal",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand All @@ -89,6 +95,12 @@
[{"my": "dict"}],
id="str-json-param-as-list",
),
pytest.param(
"tests.script_runner.parameter_inputs:annotated_str_literal",
[{"name": "my_literal", "value": "1"}],
"type given: str",
id="annotated-str-literal",
),
pytest.param(
"tests.script_runner.parameter_inputs:annotated_str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand Down
32 changes: 32 additions & 0 deletions tests/test_script_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,35 @@ def test_script_with_param(global_config_fixture, module_name):
}
]
assert consume_task["withParam"] == "{{tasks.generate.outputs.parameters.some-values}}"


@pytest.mark.parametrize(
("module_name", "input_name"),
[
pytest.param("tests.script_annotations.literals", "my_str", id="bare-type-annotation"),
pytest.param("tests.script_annotations.annotated_literals", "my-str", id="annotated"),
pytest.param("tests.script_annotations.pydantic_io_literals", "my_str", id="pydantic-io"),
],
)
@pytest.mark.parametrize("experimental_feature", ["", "script_annotations", "script_pydantic_io"])
def test_script_literals(global_config_fixture, module_name, input_name, experimental_feature):
"""Test that Literals work correctly as direct type annotations."""
# GIVEN
if experimental_feature:
global_config_fixture.experimental_features[experimental_feature] = True

# Force a reload of the test module, as the runner performs "importlib.import_module", which
# may fetch a cached version
module = importlib.import_module(module_name)
importlib.reload(module)
workflow: Workflow = importlib.import_module(module.__name__).w

# WHEN
workflow_dict = workflow.to_dict()
assert workflow == Workflow.from_dict(workflow_dict)
assert workflow == Workflow.from_yaml(workflow.to_yaml())

# THEN
(literal_str,) = (t for t in workflow_dict["spec"]["templates"] if t["name"] == "literal-str")

assert literal_str["inputs"]["parameters"] == [{"name": input_name, "enum": ["foo", "bar"]}]
9 changes: 9 additions & 0 deletions tests/test_unit/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ def unknown_annotations_ignored(my_optional_string: Optional[str] = "123") -> st
_get_inputs_from_callable(unknown_annotations_ignored)


def test_invalid_script_when_optional_parameter_does_not_have_default_value_6():
@script()
def unknown_annotations_ignored(my_optional_string: Annotated[Optional[str], Parameter(name="my-string")]) -> str:
return "Got: {}".format(my_optional_string)

with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."):
_get_inputs_from_callable(unknown_annotations_ignored)


def test_invalid_script_when_multiple_input_workflow_annotations_are_given():
@script()
def invalid_script(a_str: Annotated[str, Artifact(name="a_str"), Parameter(name="a_str")] = "123") -> str:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_unit/test_shared_type_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import List, NoReturn, Optional, Union
from typing import List, Literal, NoReturn, Optional, Union

if sys.version_info >= (3, 9):
from typing import Annotated
Expand Down Expand Up @@ -151,6 +151,10 @@ def test_get_unsubscripted_type(annotation, expected):
pytest.param(Annotated[Optional[str], "foo"], (str, NoneType), True, id="annotated-optional"),
pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"),
pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"),
pytest.param(Literal["foo", "bar"], (str, NoneType), True, id="literal-str-is-subtype-of-optional-str"),
pytest.param(Literal["foo", None], (str, NoneType), True, id="literal-none-is-subtype-of-optional-str"),
pytest.param(Literal[1, 2], (str, NoneType), False, id="literal-int-not-subtype-of-optional-str"),
pytest.param(Literal[1, "foo"], (str, NoneType), False, id="mixed-literal-not-subtype-of-optional-str"),
],
)
def test_origin_type_issubtype(annotation, target, expected):
Expand Down

0 comments on commit 4328576

Please sign in to comment.