Skip to content

Commit

Permalink
Support Literals in origin_type_issubclass
Browse files Browse the repository at this point in the history
Signed-off-by: Alice Purcell <[email protected]>
  • Loading branch information
alicederyn committed Oct 28, 2024
1 parent 326e542 commit e9e774f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Any,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -120,6 +121,8 @@ def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]])
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return all(origin_type_issubtype(arg, type_) for arg in get_args(unwrapped_type))
if origin_type is Literal:
return all(isinstance(value, type_) for value in get_args(unwrapped_type))
return isinstance(origin_type, type) and issubclass(origin_type, type_)


Expand Down
12 changes: 11 additions & 1 deletion tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List, Union
from typing import Any, List, Literal, Union

try:
from typing import Annotated
Expand Down Expand Up @@ -53,6 +53,11 @@ def annotated_basic_types_with_other_metadata(
return Output(output=[Input(a=a_but_kebab, b=b_but_kebab)])


@script()
def annotated_str_literal(my_literal: Annotated[Literal["1", "2"], Parameter(name="str-literal")]) -> str:
return f"type given: {type(my_literal).__name__}"


@script()
def annotated_object(annotated_input_value: Annotated[Input, Parameter(name="input-value")]) -> Output:
return Output(output=[annotated_input_value])
Expand Down Expand Up @@ -81,6 +86,11 @@ def str_or_int_parameter(my_str_or_int: Union[int, str]) -> str:
return f"type given: {type(my_str_or_int).__name__}"


@script()
def str_literal(my_literal: Literal["1", "2"]) -> str:
return f"type given: {type(my_literal).__name__}"


@script()
def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict:
return json.loads(my_json_str)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
"type given: int",
id="str-or-int-given-int",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_literal",
[{"name": "my_literal", "value": "1"}],
"type given: str",
id="str-literal",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand All @@ -89,6 +95,12 @@
[{"my": "dict"}],
id="str-json-param-as-list",
),
pytest.param(
"tests.script_runner.parameter_inputs:annotated_str_literal",
[{"name": "my_literal", "value": "1"}],
"type given: str",
id="annotated-str-literal",
),
pytest.param(
"tests.script_runner.parameter_inputs:annotated_str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand Down
6 changes: 5 additions & 1 deletion tests/test_unit/test_shared_type_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import List, NoReturn, Optional, Union
from typing import List, Literal, NoReturn, Optional, Union

if sys.version_info >= (3, 9):
from typing import Annotated
Expand Down Expand Up @@ -151,6 +151,10 @@ def test_get_unsubscripted_type(annotation, expected):
pytest.param(Annotated[Optional[str], "foo"], (str, NoneType), True, id="annotated-optional"),
pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"),
pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"),
pytest.param(Literal["foo", "bar"], (str, NoneType), True, id="literal-str-is-subtype-of-optional-str"),
pytest.param(Literal["foo", None], (str, NoneType), True, id="literal-none-is-subtype-of-optional-str"),
pytest.param(Literal[1, 2], (str, NoneType), False, id="literal-int-not-subtype-of-optional-str"),
pytest.param(Literal[1, "foo"], (str, NoneType), False, id="mixed-literal-not-subtype-of-optional-str"),
],
)
def test_origin_type_issubtype(annotation, target, expected):
Expand Down

0 comments on commit e9e774f

Please sign in to comment.