Skip to content

Commit

Permalink
Don't add path for Steps/DAG artifact inputs
Browse files Browse the repository at this point in the history
* Refactors _get_artifacts and _get_inputs to with add_missing_path input var,
with default = False to match the _get_outputs equivalent

Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton committed Sep 30, 2024
1 parent 7a79a66 commit 19be598
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@
inputs:
artifacts:
- name: artifact_a
path: /tmp/hera-inputs/artifacts/artifact_a
- name: artifact_b
path: /tmp/hera-inputs/artifacts/artifact_b
name: worker
outputs:
artifacts:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ spec:
inputs:
artifacts:
- name: artifact_a
path: /tmp/hera-inputs/artifacts/artifact_a
- name: artifact_b
path: /tmp/hera-inputs/artifacts/artifact_b
name: worker
outputs:
artifacts:
Expand Down
2 changes: 1 addition & 1 deletion src/hera/workflows/_meta_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def container_decorator(func: Callable[FuncIns, FuncR]) -> Callable:
if len(func_inputs) >= 1:
input_arg = list(func_inputs.values())[0].annotation
if issubclass(input_arg, (InputV1, InputV2)):
inputs = input_arg._get_inputs()
inputs = input_arg._get_inputs(add_missing_path=True)

func_return = signature.return_annotation
outputs = []
Expand Down
8 changes: 4 additions & 4 deletions src/hera/workflows/io/_io_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,19 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet
return parameters

@classmethod
def _get_artifacts(cls) -> List[Artifact]:
def _get_artifacts(cls, add_missing_path: bool = False) -> List[Artifact]:
artifacts = []

for _, _, artifact in _construct_io_from_fields(cls):
if isinstance(artifact, Artifact):
if artifact.path is None:
if add_missing_path and artifact.path is None:
artifact.path = artifact._get_default_inputs_path()
artifacts.append(artifact)
return artifacts

@classmethod
def _get_inputs(cls) -> List[Union[Artifact, Parameter]]:
return cls._get_artifacts() + cls._get_parameters()
def _get_inputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]:
return cls._get_artifacts(add_missing_path) + cls._get_parameters()

@classmethod
def _get_as_templated_arguments(cls) -> Self:
Expand Down
2 changes: 1 addition & 1 deletion src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ class will be used as inputs, rather than the class itself.
else:
parameters.extend(input_class._get_parameters())

artifacts.extend(input_class._get_artifacts())
artifacts.extend(input_class._get_artifacts(add_missing_path=True))

elif param_or_artifact := get_workflow_annotation(func_param.annotation):
if param_or_artifact.output:
Expand Down
22 changes: 10 additions & 12 deletions tests/test_unit/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,24 @@ def test_dag_io_declaration():

assert len(model_workflow.spec.templates) == 1

template = model_workflow.spec.templates[0]
dag_template = model_workflow.spec.templates[0]

assert template.inputs
assert len(template.inputs.parameters) == 2
assert template.inputs.parameters == [
assert dag_template.inputs
assert len(dag_template.inputs.parameters) == 2
assert dag_template.inputs.parameters == [
ModelParameter(name="basic_input_parameter"),
ModelParameter(name="my-input-param"),
]
assert len(template.inputs.artifacts) == 1
assert template.inputs.artifacts == [
ModelArtifact(name="my-input-artifact", path="/tmp/hera-inputs/artifacts/my-input-artifact"),
]
assert len(dag_template.inputs.artifacts) == 1
assert dag_template.inputs.artifacts == [ModelArtifact(name="my-input-artifact")]

assert template.outputs
assert len(template.outputs.parameters) == 2
assert template.outputs.parameters == [
assert dag_template.outputs
assert len(dag_template.outputs.parameters) == 2
assert dag_template.outputs.parameters == [
ModelParameter(name="basic_output_parameter"),
ModelParameter(name="my-output-param"),
]
assert template.outputs.artifacts == [
assert dag_template.outputs.artifacts == [
ModelArtifact(name="my-output-artifact"),
]

Expand Down
14 changes: 8 additions & 6 deletions tests/test_unit/test_io_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ class Foo(Input):
foo: int
bar: str = "a default"

assert Foo._get_artifacts() == []
assert Foo._get_artifacts(add_missing_path=True) == []


def test_get_artifacts_with_pydantic_annotations():
class Foo(Input):
foo: Annotated[int, Field(gt=0)]
bar: Annotated[str, Field(max_length=10)] = "a default"

assert Foo._get_artifacts() == []
assert Foo._get_artifacts(add_missing_path=True) == []


def test_get_artifacts_annotated_with_name():
Expand All @@ -96,7 +96,7 @@ class Foo(Input):
bar: Annotated[str, Parameter(name="b_ar")] = "a default"
baz: Annotated[str, Artifact(name="b_az")]

assert Foo._get_artifacts() == [Artifact(name="b_az", path="/tmp/hera-inputs/artifacts/b_az")]
assert Foo._get_artifacts(add_missing_path=True) == [Artifact(name="b_az", path="/tmp/hera-inputs/artifacts/b_az")]


def test_get_artifacts_annotated_with_description():
Expand All @@ -105,7 +105,7 @@ class Foo(Input):
bar: Annotated[str, Parameter(description="param bar")] = "a default"
baz: Annotated[str, Artifact(description="artifact baz")]

assert Foo._get_artifacts() == [
assert Foo._get_artifacts(add_missing_path=True) == [
Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz", description="artifact baz")
]

Expand All @@ -114,7 +114,9 @@ def test_get_artifacts_annotated_with_path():
class Foo(Input):
baz: Annotated[str, Artifact(path="/tmp/hera-inputs/artifacts/bishbosh")]

assert Foo._get_artifacts() == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/bishbosh")]
assert Foo._get_artifacts(add_missing_path=True) == [
Artifact(name="baz", path="/tmp/hera-inputs/artifacts/bishbosh")
]


def test_get_artifacts_with_multiple_annotations():
Expand All @@ -123,7 +125,7 @@ class Foo(Input):
bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default"
baz: Annotated[str, Field(max_length=15), Artifact()]

assert Foo._get_artifacts() == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz")]
assert Foo._get_artifacts(add_missing_path=True) == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz")]


def test_get_as_arguments_unannotated():
Expand Down

0 comments on commit 19be598

Please sign in to comment.