diff --git a/docs/examples/workflows/experimental/new_dag_decorator_artifacts.md b/docs/examples/workflows/experimental/new_dag_decorator_artifacts.md index bb5fcaedd..de676f7b2 100644 --- a/docs/examples/workflows/experimental/new_dag_decorator_artifacts.md +++ b/docs/examples/workflows/experimental/new_dag_decorator_artifacts.md @@ -119,9 +119,7 @@ inputs: artifacts: - name: artifact_a - path: /tmp/hera-inputs/artifacts/artifact_a - name: artifact_b - path: /tmp/hera-inputs/artifacts/artifact_b name: worker outputs: artifacts: diff --git a/examples/workflows/experimental/new-dag-decorator-artifacts.yaml b/examples/workflows/experimental/new-dag-decorator-artifacts.yaml index 03c8a59f9..527aa2428 100644 --- a/examples/workflows/experimental/new-dag-decorator-artifacts.yaml +++ b/examples/workflows/experimental/new-dag-decorator-artifacts.yaml @@ -55,9 +55,7 @@ spec: inputs: artifacts: - name: artifact_a - path: /tmp/hera-inputs/artifacts/artifact_a - name: artifact_b - path: /tmp/hera-inputs/artifacts/artifact_b name: worker outputs: artifacts: diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index 2ad3778fe..1bca8002e 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -698,7 +698,7 @@ def container_decorator(func: Callable[FuncIns, FuncR]) -> Callable: if len(func_inputs) >= 1: input_arg = list(func_inputs.values())[0].annotation if issubclass(input_arg, (InputV1, InputV2)): - inputs = input_arg._get_inputs() + inputs = input_arg._get_inputs(add_missing_path=True) func_return = signature.return_annotation outputs = [] diff --git a/src/hera/workflows/io/_io_mixins.py b/src/hera/workflows/io/_io_mixins.py index 6d7f2d767..938f0054b 100644 --- a/src/hera/workflows/io/_io_mixins.py +++ b/src/hera/workflows/io/_io_mixins.py @@ -106,19 +106,19 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet return parameters @classmethod - def _get_artifacts(cls) -> List[Artifact]: + def _get_artifacts(cls, add_missing_path: bool = False) -> List[Artifact]: artifacts = [] for _, _, artifact in _construct_io_from_fields(cls): if isinstance(artifact, Artifact): - if artifact.path is None: + if add_missing_path and artifact.path is None: artifact.path = artifact._get_default_inputs_path() artifacts.append(artifact) return artifacts @classmethod - def _get_inputs(cls) -> List[Union[Artifact, Parameter]]: - return cls._get_artifacts() + cls._get_parameters() + def _get_inputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]: + return cls._get_artifacts(add_missing_path) + cls._get_parameters() @classmethod def _get_as_templated_arguments(cls) -> Self: diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index 058f5a1b9..15cf5a8e9 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -494,7 +494,7 @@ class will be used as inputs, rather than the class itself. else: parameters.extend(input_class._get_parameters()) - artifacts.extend(input_class._get_artifacts()) + artifacts.extend(input_class._get_artifacts(add_missing_path=True)) elif param_or_artifact := get_workflow_annotation(func_param.annotation): if param_or_artifact.output: diff --git a/tests/test_unit/test_decorators.py b/tests/test_unit/test_decorators.py index 6ba17e21d..42c76bfd5 100644 --- a/tests/test_unit/test_decorators.py +++ b/tests/test_unit/test_decorators.py @@ -36,26 +36,24 @@ def test_dag_io_declaration(): assert len(model_workflow.spec.templates) == 1 - template = model_workflow.spec.templates[0] + dag_template = model_workflow.spec.templates[0] - assert template.inputs - assert len(template.inputs.parameters) == 2 - assert template.inputs.parameters == [ + assert dag_template.inputs + assert len(dag_template.inputs.parameters) == 2 + assert dag_template.inputs.parameters == [ ModelParameter(name="basic_input_parameter"), ModelParameter(name="my-input-param"), ] - assert len(template.inputs.artifacts) == 1 - assert template.inputs.artifacts == [ - ModelArtifact(name="my-input-artifact", path="/tmp/hera-inputs/artifacts/my-input-artifact"), - ] + assert len(dag_template.inputs.artifacts) == 1 + assert dag_template.inputs.artifacts == [ModelArtifact(name="my-input-artifact")] - assert template.outputs - assert len(template.outputs.parameters) == 2 - assert template.outputs.parameters == [ + assert dag_template.outputs + assert len(dag_template.outputs.parameters) == 2 + assert dag_template.outputs.parameters == [ ModelParameter(name="basic_output_parameter"), ModelParameter(name="my-output-param"), ] - assert template.outputs.artifacts == [ + assert dag_template.outputs.artifacts == [ ModelArtifact(name="my-output-artifact"), ] diff --git a/tests/test_unit/test_io_mixins.py b/tests/test_unit/test_io_mixins.py index f4f3e1161..16e2d59b9 100644 --- a/tests/test_unit/test_io_mixins.py +++ b/tests/test_unit/test_io_mixins.py @@ -79,7 +79,7 @@ class Foo(Input): foo: int bar: str = "a default" - assert Foo._get_artifacts() == [] + assert Foo._get_artifacts(add_missing_path=True) == [] def test_get_artifacts_with_pydantic_annotations(): @@ -87,7 +87,7 @@ class Foo(Input): foo: Annotated[int, Field(gt=0)] bar: Annotated[str, Field(max_length=10)] = "a default" - assert Foo._get_artifacts() == [] + assert Foo._get_artifacts(add_missing_path=True) == [] def test_get_artifacts_annotated_with_name(): @@ -96,7 +96,7 @@ class Foo(Input): bar: Annotated[str, Parameter(name="b_ar")] = "a default" baz: Annotated[str, Artifact(name="b_az")] - assert Foo._get_artifacts() == [Artifact(name="b_az", path="/tmp/hera-inputs/artifacts/b_az")] + assert Foo._get_artifacts(add_missing_path=True) == [Artifact(name="b_az", path="/tmp/hera-inputs/artifacts/b_az")] def test_get_artifacts_annotated_with_description(): @@ -105,7 +105,7 @@ class Foo(Input): bar: Annotated[str, Parameter(description="param bar")] = "a default" baz: Annotated[str, Artifact(description="artifact baz")] - assert Foo._get_artifacts() == [ + assert Foo._get_artifacts(add_missing_path=True) == [ Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz", description="artifact baz") ] @@ -114,7 +114,9 @@ def test_get_artifacts_annotated_with_path(): class Foo(Input): baz: Annotated[str, Artifact(path="/tmp/hera-inputs/artifacts/bishbosh")] - assert Foo._get_artifacts() == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/bishbosh")] + assert Foo._get_artifacts(add_missing_path=True) == [ + Artifact(name="baz", path="/tmp/hera-inputs/artifacts/bishbosh") + ] def test_get_artifacts_with_multiple_annotations(): @@ -123,7 +125,7 @@ class Foo(Input): bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default" baz: Annotated[str, Field(max_length=15), Artifact()] - assert Foo._get_artifacts() == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz")] + assert Foo._get_artifacts(add_missing_path=True) == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz")] def test_get_as_arguments_unannotated():