Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with with_param #1236

Merged
merged 14 commits into from
Oct 25, 2024
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: dynamic-fanout-
spec:
entrypoint: d
templates:
- dag:
tasks:
- name: generate
template: generate
- arguments:
parameters:
- description: this is some value
name: some-value
value: '{{item}}'
depends: generate
name: consume
template: consume
withParam: '{{tasks.generate.outputs.parameters.some-values}}'
name: d
- name: generate
outputs:
parameters:
- name: some-values
valueFrom:
path: /tmp/hera-outputs/parameters/some-values
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.experimental.script_annotations_dynamic_fanout:generate
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__outputs_directory
value: /tmp/hera-outputs
image: python:3.9
source: '{{inputs.parameters}}'
- inputs:
parameters:
- description: this is some value
name: some-value
name: consume
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.experimental.script_annotations_dynamic_fanout:consume
command:
- python
env:
- name: hera__script_annotations
value: ''
image: python:3.9
source: '{{inputs.parameters}}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: dynamic-fanout-
spec:
entrypoint: d
templates:
- dag:
tasks:
- name: generate
template: generate
- arguments:
parameters:
- description: this is some value
name: some-value
value: '{{item}}'
depends: generate
name: consume
template: consume
withParam: '{{tasks.generate.outputs.parameters.some-values}}'
name: d
- name: generate
outputs:
parameters:
- name: some-values
valueFrom:
path: /tmp/hera-outputs/parameters/some-values
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.experimental.script_runner_io_dynamic_fanout:generate
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__outputs_directory
value: /tmp/hera-outputs
- name: hera__script_pydantic_io
value: ''
image: python:3.9
source: '{{inputs.parameters}}'
- inputs:
parameters:
- description: this is some value
name: some-value
name: consume
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.experimental.script_runner_io_dynamic_fanout:consume
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__script_pydantic_io
value: ''
image: python:3.9
source: '{{inputs.parameters}}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
This example showcases how clients can use Hera to dynamically generate tasks that process outputs from one task in
parallel. This is useful for batch jobs and instances where clients do not know ahead of time how many tasks/entities
they may need to process.
"""

from typing import Annotated, List

from hera.shared import global_config
from hera.workflows import DAG, Parameter, Workflow, script

global_config.experimental_features["script_annotations"] = True


@script(constructor="runner")
def generate() -> Annotated[List[int], Parameter(name="some-values")]:
return [i for i in range(10)]


@script(constructor="runner")
def consume(some_value: Annotated[int, Parameter(name="some-value", description="this is some value")]):
print("Received value: {value}!".format(value=some_value))


# assumes you used `hera.set_global_token` and `hera.set_global_host` so that the workflow can be submitted
with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w:
with DAG(name="d"):
g = generate(arguments={})
c = consume(with_param=g.get_parameter("some-values"))
g >> c
38 changes: 38 additions & 0 deletions examples/workflows/experimental/script_runner_io_dynamic_fanout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to add these two examples - they are not demonstrating something unique - script annotations/pydantic IO have their own examples, and the syntax for with_param is more simply shown in the loops examples, where the syntax in the DAG construction is the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, I added them to exercise the code rather than as examples per se. Is there a better place to put this kind of end-to-end YAML output test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do a roundtrip to_dict/from_dict like in the test_script_annotations tests

workflow_dict = workflow.to_dict()
assert workflow == Workflow.from_dict(workflow_dict)

This example showcases how clients can use Hera to dynamically generate tasks that process outputs from one task in
parallel. This is useful for batch jobs and instances where clients do not know ahead of time how many tasks/entities
they may need to process.
"""

from typing import Annotated, List

from hera.shared import global_config
from hera.workflows import DAG, Input, Output, Parameter, Workflow, script

global_config.experimental_features["script_pydantic_io"] = True


class GenerateOutput(Output):
some_values: Annotated[List[int], Parameter(name="some-values")]


class ConsumeInput(Input):
some_value: Annotated[int, Parameter(name="some-value", description="this is some value")]


@script(constructor="runner")
def generate() -> GenerateOutput:
return GenerateOutput(some_values=[i for i in range(10)])


@script(constructor="runner")
def consume(input: ConsumeInput) -> None:
print("Received value: {value}!".format(value=input.some_value))


# assumes you used `hera.set_global_token` and `hera.set_global_host` so that the workflow can be submitted
with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w:
with DAG(name="d"):
g = generate(arguments={})
c = consume(with_param=g.get_parameter("some-values"))
g >> c
18 changes: 18 additions & 0 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@
return metadata[0]


def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]:
"""Constructs a Parameter or Artifact object based on annotations.

If a field has a Parameter or Artifact annotation, a copy will be returned, with missing
fields filled out based on other metadata. Otherwise, a Parameter object will be constructed.

For a function parameter, python_name should be the parameter name.
For a Pydantic Input or Output class, python_name should be the field name.
"""
if annotation := get_workflow_annotation(annotation):
# Copy so as to not modify the fields themselves
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or python_name
return annotation_copy

Check warning on line 98 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L96-L98

Added lines #L96 - L98 were not covered by tests

return Parameter(name=python_name)

Check warning on line 100 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L100

Added line #L100 was not covered by tests


def get_unsubscripted_type(t: Any) -> Any:
"""Return the origin of t, if subscripted, or t itself.

Expand Down
43 changes: 31 additions & 12 deletions src/hera/workflows/_meta_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from hera.shared import BaseMixin, global_config
from hera.shared._global_config import _DECORATOR_SYNTAX_FLAG, _flag_enabled
from hera.shared._pydantic import BaseModel, get_fields, root_validator
from hera.shared._type_util import get_annotated_metadata
from hera.shared._type_util import construct_io_from_annotation, get_annotated_metadata, unwrap_annotation
from hera.workflows._context import _context
from hera.workflows.exceptions import InvalidTemplateCall
from hera.workflows.io.v1 import (
Expand Down Expand Up @@ -263,6 +263,18 @@
return output


def _get_pydantic_input_type(source: Callable) -> Union[None, Type[InputV1], Type[InputV2]]:
"""Returns a Pydantic Input type for the source, if it is using Pydantic IO."""
function_parameters = inspect.signature(source).parameters

Check warning on line 268 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L268

Added line #L268 was not covered by tests
if len(function_parameters) != 1:
return None
parameter = next(iter(function_parameters.values()))
parameter_type = unwrap_annotation(parameter.annotation)

Check warning on line 272 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L270-L272

Added lines #L270 - L272 were not covered by tests
if not isinstance(parameter_type, type) or not issubclass(parameter_type, (InputV1, InputV2)):
return None
return parameter_type

Check warning on line 275 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L274-L275

Added lines #L274 - L275 were not covered by tests


def _get_param_items_from_source(source: Callable) -> List[Parameter]:
"""Returns a list (possibly empty) of `Parameter` from the specified `source`.

Expand All @@ -275,17 +287,24 @@
List[Parameter]
A list of identified parameters (possibly empty).
"""
source_signature: List[str] = []
for p in inspect.signature(source).parameters.values():
if p.default == inspect.Parameter.empty and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
# only add positional or keyword arguments that are not set to a default value
# as the default value ones are captured by the automatically generated `Parameter` fields for positional
# kwargs. Otherwise, we assume that the user sets the value of the parameter via the `with_param` field
source_signature.append(p.name)

if len(source_signature) == 1:
return [Parameter(name=n, value="{{item}}") for n in source_signature]
return [Parameter(name=n, value=f"{{{{item.{n}}}}}") for n in source_signature]
non_default_parameters: List[Parameter] = []

Check warning on line 290 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L290

Added line #L290 was not covered by tests
if pydantic_input := _get_pydantic_input_type(source):
for parameter in pydantic_input._get_parameters():
if parameter.default is None:
non_default_parameters.append(parameter)

Check warning on line 294 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L294

Added line #L294 was not covered by tests
else:
for p in inspect.signature(source).parameters.values():
if p.default == inspect.Parameter.empty and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
# only add positional or keyword arguments that are not set to a default value
# as the default value ones are captured by the automatically generated `Parameter` fields for positional
# kwargs. Otherwise, we assume that the user sets the value of the parameter via the `with_param` field
io = construct_io_from_annotation(p.name, p.annotation)

Check warning on line 301 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L301

Added line #L301 was not covered by tests
if isinstance(io, Parameter) and io.default is None and not io.output:
alicederyn marked this conversation as resolved.
Show resolved Hide resolved
non_default_parameters.append(io)

Check warning on line 303 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L303

Added line #L303 was not covered by tests

for param in non_default_parameters:
param.value = "{{" + ("item" if len(non_default_parameters) == 1 else f"item.{param.name}") + "}}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems harder to read/understand IMO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can't be as succinct as before because we're doing a side-effect. A couple of options spring to mind, which do you prefer?

# Just move the if/else out into a separate line
for param in non_default_parameters:
    param_label = "item" if len(non_default_parameters) == 1 else f"item.{param.name}"
    param.value = "{{" + param_label + "}}"
# Move the if/else back out to the top
if len(non_default_parameters) == 1:
    non_default_parameters[0].value = "{{item}}"
else:
    for param in non_default_parameters:
        param.value = "{{{item." + param.name + "))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second please! 🙏 Makes it much more obvious that we're doing a very different thing if len == 1

return non_default_parameters

Check warning on line 307 in src/hera/workflows/_meta_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_meta_mixins.py#L306-L307

Added lines #L306 - L307 were not covered by tests


def _get_params_from_items(with_items: List[Any]) -> Optional[List[Parameter]]:
Expand Down
13 changes: 4 additions & 9 deletions src/hera/workflows/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from hera.shared import BaseMixin, global_config
from hera.shared._pydantic import PrivateAttr, get_field_annotations, get_fields, root_validator, validator
from hera.shared._type_util import get_workflow_annotation
from hera.shared._type_util import construct_io_from_annotation
from hera.shared.serialization import serialize
from hera.workflows._context import SubNodeMixin, _context
from hera.workflows._meta_mixins import CallableTemplateMixin, HeraBuildObj, HookMixin
Expand Down Expand Up @@ -738,14 +738,9 @@
result_templated_str = f"{{{{{subnode_type}.{subnode_name}.outputs.result}}}}"
return result_templated_str

if param_or_artifact := get_workflow_annotation(annotations[name]):
output_name = param_or_artifact.name or name
if isinstance(param_or_artifact, Parameter):
return "{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{output_name}" + "}}"
else:
return "{{" + f"{subnode_type}.{subnode_name}.outputs.artifacts.{output_name}" + "}}"

return "{{" + f"{subnode_type}.{subnode_name}.outputs.parameters.{name}" + "}}"
param_or_artifact = construct_io_from_annotation(name, annotations[name])
output_type = "parameters" if isinstance(param_or_artifact, Parameter) else "artifacts"
return "{{" + f"{subnode_type}.{subnode_name}.outputs.{output_type}.{param_or_artifact.name}" + "}}"

Check warning on line 743 in src/hera/workflows/_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_mixins.py#L741-L743

Added lines #L741 - L743 were not covered by tests

return super().__getattribute__(name)

Expand Down
14 changes: 4 additions & 10 deletions src/hera/workflows/io/_io_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from hera.shared import global_config
from hera.shared._global_config import _SUPPRESS_PARAMETER_DEFAULT_ERROR_FLAG
from hera.shared._pydantic import _PYDANTIC_VERSION, FieldInfo, get_field_annotations, get_fields
from hera.shared._type_util import get_workflow_annotation
from hera.shared._type_util import construct_io_from_annotation, get_workflow_annotation
from hera.shared.serialization import MISSING, serialize
from hera.workflows._context import _context
from hera.workflows.artifact import Artifact
Expand Down Expand Up @@ -45,18 +45,12 @@
def _construct_io_from_fields(cls: Type[BaseModel]) -> Iterator[Tuple[str, FieldInfo, Union[Parameter, Artifact]]]:
"""Constructs a Parameter or Artifact object for all Pydantic fields based on their annotations.

If a field has a Parameter or Artifact annotation, a copy will be returned, with name added if missing.
Otherwise, a Parameter object will be constructed.
If a field has a Parameter or Artifact annotation, a copy will be returned, with missing
fields filled out based on other metadata. Otherwise, a Parameter object will be constructed.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I changed this docstring because the core logic is now in a different function, so it will be easy to miss changes in future and bitrot. I intend to change the behaviour as part of fixing #1173, for instance, to set the enum field for Literals.

"""
annotations = get_field_annotations(cls)
for field, field_info in get_fields(cls).items():
if annotation := get_workflow_annotation(annotations[field]):
# Copy so as to not modify the fields themselves
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or field
yield field, field_info, annotation_copy
else:
yield field, field_info, Parameter(name=field)
yield field, field_info, construct_io_from_annotation(field, annotations[field])

Check warning on line 53 in src/hera/workflows/io/_io_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/io/_io_mixins.py#L53

Added line #L53 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍



class InputMixin(BaseModel):
Expand Down
Loading
Loading