Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to subclass from both Input and Output allowing for passthrough IO #1109

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# New Decorators Passthrough Io






=== "Hera"

```python linenums="1"
from typing_extensions import Annotated

from hera.shared import global_config
from hera.workflows import Artifact, ArtifactLoader, Input, Output, WorkflowTemplate

global_config.experimental_features["decorator_syntax"] = True

w = WorkflowTemplate(name="my-template")


class PassthroughIO(Input, Output):
my_str: str
my_int: int
my_artifact: Annotated[str, Artifact(name="my-artifact", loader=ArtifactLoader.json)]


@w.script()
def give_output() -> PassthroughIO:
return PassthroughIO(my_str="test", my_int=42)


@w.script()
def take_input(inputs: PassthroughIO) -> Output:
return Output(result=f"Got a string: {inputs.my_str}, got an int: {inputs.my_int}")


@w.dag()
def my_dag():
output_task = give_output()
take_input(output_task)
```

=== "YAML"

```yaml linenums="1"
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: my-template
spec:
templates:
- name: give-output
outputs:
artifacts:
- name: my-artifact
path: /tmp/hera-outputs/artifacts/my-artifact
parameters:
- name: my_str
valueFrom:
path: /tmp/hera-outputs/parameters/my_str
- name: my_int
valueFrom:
path: /tmp/hera-outputs/parameters/my_int
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.experimental.new_decorators_passthrough_io:give_output
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__outputs_directory
value: /tmp/hera-outputs
- name: hera__script_pydantic_io
value: ''
image: python:3.8
source: '{{inputs.parameters}}'
- inputs:
artifacts:
- name: my-artifact
path: /tmp/hera-inputs/artifacts/my-artifact
parameters:
- name: my_str
- name: my_int
name: take-input
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.experimental.new_decorators_passthrough_io:take_input
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__outputs_directory
value: /tmp/hera-outputs
- name: hera__script_pydantic_io
value: ''
image: python:3.8
source: '{{inputs.parameters}}'
- dag:
tasks:
- name: output_task
template: give-output
- arguments:
artifacts:
- from: '{{tasks.output_task.outputs.artifacts.my-artifact}}'
name: my-artifact
parameters:
- name: my_str
value: '{{tasks.output_task.outputs.parameters.my_str}}'
- name: my_int
value: '{{tasks.output_task.outputs.parameters.my_int}}'
depends: output_task
name: take-input
template: take-input
name: my-dag
```

Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: my-template
spec:
templates:
- name: give-output
outputs:
artifacts:
- name: my-artifact
path: /tmp/hera-outputs/artifacts/my-artifact
parameters:
- name: my_str
valueFrom:
path: /tmp/hera-outputs/parameters/my_str
- name: my_int
valueFrom:
path: /tmp/hera-outputs/parameters/my_int
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.experimental.new_decorators_passthrough_io:give_output
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__outputs_directory
value: /tmp/hera-outputs
- name: hera__script_pydantic_io
value: ''
image: python:3.8
source: '{{inputs.parameters}}'
- inputs:
artifacts:
- name: my-artifact
path: /tmp/hera-inputs/artifacts/my-artifact
parameters:
- name: my_str
- name: my_int
name: take-input
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.experimental.new_decorators_passthrough_io:take_input
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__outputs_directory
value: /tmp/hera-outputs
- name: hera__script_pydantic_io
value: ''
image: python:3.8
source: '{{inputs.parameters}}'
- dag:
tasks:
- name: output_task
template: give-output
- arguments:
artifacts:
- from: '{{tasks.output_task.outputs.artifacts.my-artifact}}'
name: my-artifact
parameters:
- name: my_str
value: '{{tasks.output_task.outputs.parameters.my_str}}'
- name: my_int
value: '{{tasks.output_task.outputs.parameters.my_int}}'
depends: output_task
name: take-input
template: take-input
name: my-dag
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing_extensions import Annotated

from hera.shared import global_config
from hera.workflows import Artifact, ArtifactLoader, Input, Output, WorkflowTemplate

global_config.experimental_features["decorator_syntax"] = True

w = WorkflowTemplate(name="my-template")


class PassthroughIO(Input, Output):
my_str: str
my_int: int
my_artifact: Annotated[str, Artifact(name="my-artifact", loader=ArtifactLoader.json)]


@w.script()
def give_output() -> PassthroughIO:
return PassthroughIO(my_str="test", my_int=42)


@w.script()
def take_input(inputs: PassthroughIO) -> Output:
return Output(result=f"Got a string: {inputs.my_str}, got an int: {inputs.my_int}")


@w.dag()
def my_dag():
output_task = give_output()
take_input(output_task)
81 changes: 68 additions & 13 deletions src/hera/workflows/_meta_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Output as OutputV1,
)
from hera.workflows.models import (
Arguments as ModelArguments,
Artifact as ModelArtifact,
Parameter as ModelParameter,
TemplateRef,
Expand Down Expand Up @@ -509,6 +510,46 @@ def _get_underlying_type(annotation: Type):
return real_type


def _derive_passthrough_io(
previous_subnode: Union[Step, Task],
input_class: Union[InputV1, InputV2],
) -> ModelArguments:
"""Returns a ModelArguments object derived from the previous_subnode and current node's input_class.

This function is used where (we assume) a user wants to pass all outputs of previous step/task
as arguments to the current step/task. This is only supported if the previous node uses a script-decorated
function, as it's the only way to go from the Step/Task object to the `source` to get the function signature.
Therefore we explicitly give helpful error messages if any pre-conditions are broken (compared to the
rest of the codebase).

We would need to store metadata in the Step/Task object to support other template types.
"""
from hera.workflows.script import Script

if not isinstance(previous_subnode.template, Script):
raise ValueError("Only Script template passthrough IO is supported")

if not previous_subnode.template.source or not isinstance(previous_subnode.template.source, Callable): # type: ignore
# See https://github.com/python/mypy/issues/3060 for isinstance issue requiring "type: ignore"
raise ValueError("Only Script template passthrough IO is supported")

prev_node_return_class = inspect.signature(previous_subnode.template.source).return_annotation # type: ignore
if not issubclass(prev_node_return_class, (OutputV1, OutputV2)):
raise ValueError("Previous Step or Task must output a hera.workflows.io.Output type")

if prev_node_return_class != input_class:
raise ValueError(
f"Previous Step/Task output type {prev_node_return_class} does not match"
f"current Step/Task input type {input_class} - the same type must be used"
)

object_dict = {}
for field in get_fields(prev_node_return_class):
if field not in {"exit_code", "result"}:
object_dict[field] = getattr(previous_subnode, field)
return input_class.construct(**object_dict)._get_as_arguments()


class TemplateDecoratorFuncsMixin(ContextMixin):
from hera.workflows.container import Container
from hera.workflows.dag import DAG
Expand Down Expand Up @@ -545,13 +586,18 @@ def _create_subnode(
from hera.workflows.task import Task
from hera.workflows.workflow_template import WorkflowTemplate

subnode_args = None
if len(args) == 1 and isinstance(args[0], (InputV1, InputV2)):
subnode_args = args[0]._get_as_arguments()

signature = inspect.signature(func)
function_inputs = list(signature.parameters.values())
input_class = function_inputs[0].annotation if len(function_inputs) == 1 else None
output_class = signature.return_annotation

subnode_args = None
if len(args) == 1:
if isinstance(args[0], (InputV1, InputV2)):
subnode_args = args[0]._get_as_arguments()
elif isinstance(args[0], (Step, Task)) and issubclass(input_class, (InputV1, InputV2)):
subnode_args = _derive_passthrough_io(args[0], input_class)

subnode: Union[Step, Task]

assert _context.pieces
Expand Down Expand Up @@ -819,15 +865,24 @@ def call_wrapper(*args, **kwargs):
with self, template:
if len(func_inputs) == 1:
arg_class = list(func_inputs.values())[0].annotation
if issubclass(arg_class, (InputV1, InputV2)):
input_obj = arg_class._get_as_templated_arguments()
# "run" the dag/steps function to collect the tasks/steps
_context.declaring = True
func_return = func(input_obj)
_context.declaring = False

if func_return and isinstance(func_return, (OutputV1, OutputV2)):
template.outputs = func_return._get_as_invocator_output()
if not issubclass(arg_class, (InputV1, InputV2)):
raise ValueError(
"Function must take no arguments or a single argument that is a subclass of hera.workflows.io.Input"
)
input_obj = arg_class._get_as_templated_arguments()
# "run" the dag/steps function to collect the tasks/steps
_context.declaring = True
func_return = func(input_obj)
_context.declaring = False

elif len(func_inputs) == 0:
# "run" the dag/steps function to collect the tasks/steps
_context.declaring = True
func_return = func()
_context.declaring = False

if func_return and isinstance(func_return, (OutputV1, OutputV2)):
template.outputs = func_return._get_as_invocator_output()

return call_wrapper

Expand Down
16 changes: 13 additions & 3 deletions src/hera/workflows/io/_io_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet
annotations = get_field_annotations(cls)

for field, field_info in get_fields(cls).items():
if issubclass(cls, OutputMixin) and field in {"exit_code", "result"}:
# Skip OutputMixin's fields so users can subclass from both Input and Output
# to create a "passthrough" IO object
continue
if get_origin(annotations[field]) is Annotated:
# Copy so as to not modify the Input fields themselves
param = get_args(annotations[field])[1].copy()
Expand Down Expand Up @@ -147,8 +151,12 @@ def _get_as_arguments(self) -> ModelArguments:
self_dict = self.model_dump()

for field in get_fields(type(self)):
# The value may be a static value (of any time) if it has a default value, so we need to serialize it
# If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")``
if issubclass(type(self), OutputMixin) and field in {"exit_code", "result"}:
# Skip OutputMixin's fields so users can subclass from both Input and Output
# to create a "passthrough" IO object
continue
# The dict value may be of any type if it was a default value, so we need to serialize it.
# If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")`
templated_value = serialize(self_dict[field])

if get_origin(annotations[field]) is Annotated:
Expand Down Expand Up @@ -213,7 +221,9 @@ def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Pa
if field in {"exit_code", "result"}:
continue
if get_origin(annotations[field]) is Annotated:
annotation = get_args(annotations[field])[1]
# Copy annotation to avoid modifying it, as it may be used in a "passthrough" field
# (where it is both an output and an input)
annotation = get_args(annotations[field])[1].copy()
if isinstance(annotation, Parameter):
if add_missing_path and (annotation.value_from is None or annotation.value_from.path is None):
annotation.value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{annotation.name}")
Expand Down
Loading