Skip to content

Commit 55b060d

Browse files
authored
Fix module for script runners generated by CLI and from __main__ (#1445)
**Pull Request Checklist** - [x] Fixes #573, Fixes #1161 - [x] Tests added - [x] Documentation/examples added - [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR title **Description of PR** Currently, Script Runners cannot be exported to YAML correctly if they are contained in the same module as `__main__`, i.e. when trying to export a workflow using the file it's written in, as the `transform_values` sees `__main__` as the module. The CLI also ignores the full path spec of the function, only using the stem. This PR fixes both cases by using a utility function to construct a valid module path. --------- Signed-off-by: Elliot Gunton <[email protected]>
1 parent 24dfddc commit 55b060d

File tree

8 files changed

+270
-45
lines changed

8 files changed

+270
-45
lines changed

src/hera/_cli/generate/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Code generation CLI functions."""
2+
13
from hera._cli.generate import yaml
24

35
__all__ = [

src/hera/_cli/generate/yaml.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""The main entrypoint for hera CLI."""
2-
31
from __future__ import annotations
42

53
import importlib.util
@@ -8,6 +6,7 @@
86

97
from hera._cli.base import GenerateYaml
108
from hera._cli.generate.util import YAML_EXTENSIONS, convert_code, expand_paths, write_output
9+
from hera.workflows._runner.util import create_module_string
1110
from hera.workflows.workflow import Workflow
1211

1312
DEFAULT_EXTENSION = ".yaml"
@@ -47,8 +46,8 @@ def load_workflows_from_module(path: Path) -> list[Workflow]:
4746
Returns:
4847
A list containing all `Workflow` objects defined within that module.
4948
"""
50-
module_name = path.stem
51-
spec = importlib.util.spec_from_file_location(module_name, path, submodule_search_locations=[str(path.parent)])
49+
module_name = create_module_string(path)
50+
spec = importlib.util.spec_from_file_location(module_name, path)
5251
assert spec
5352

5453
module = importlib.util.module_from_spec(spec)

src/hera/workflows/_runner/util.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,30 @@ def _run() -> None:
297297
exit(result.exit_code)
298298

299299
print(serialize(result))
300+
301+
302+
def create_module_string(path: Path) -> str:
303+
"""Create a Python module path from the given path.
304+
305+
We find the most specific sys.path to create a valid, importable module path to the given path.
306+
307+
e.g. if sys.path contains "/project" and the file is "/project/workflows/wf_a.py", then the returned string will be
308+
"workflows.wf_a"
309+
310+
If we cannot find a valid sys.path, we simply use the file stem, e.g. for the
311+
file "/project/workflows/wf_a.py", return `wf_a`.
312+
"""
313+
path = path.resolve()
314+
315+
# find the most specific sys.path that contains the given path
316+
candidates = []
317+
for base in map(lambda p: Path(p).resolve(), sys.path + [os.getcwd()]):
318+
if path.is_relative_to(base):
319+
candidates.append(base)
320+
321+
if not candidates:
322+
return path.stem
323+
324+
# use the most specific sys.path to construct a valid module path to import
325+
base_path = max(candidates, key=lambda p: len(str(p)))
326+
return ".".join(str(path.resolve().relative_to(base_path)).replace(".py", "").split("/"))

src/hera/workflows/script.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import textwrap
1212
from abc import abstractmethod
1313
from functools import wraps
14+
from pathlib import Path
1415
from typing import (
1516
Any,
1617
Callable,
@@ -852,11 +853,19 @@ def transform_values(self, cls: Type[Script], values: Any) -> Any:
852853

853854
if values.get("args") is not None:
854855
raise ValueError("Cannot specify args when callable is True")
856+
857+
module = values["source"].__module__
858+
859+
if module == "__main__":
860+
from hera.workflows._runner.util import create_module_string
861+
862+
module = create_module_string(Path(values["source"].__globals__["__file__"]))
863+
855864
values["args"] = [
856865
"-m",
857866
"hera.workflows.runner",
858867
"-e",
859-
f"{values['source'].__module__}:{values['source'].__name__}",
868+
f"{module}:{values['source'].__name__}",
860869
]
861870

862871
return values
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from hera.workflows import Workflow, script
2+
3+
4+
@script(constructor="runner")
5+
def hello():
6+
pass
7+
8+
9+
with Workflow(
10+
generate_name="runner-workflow-",
11+
entrypoint="hello",
12+
) as w:
13+
hello()

tests/cli/test_generate_yaml.py

Lines changed: 85 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import shutil
12
import sys
23
from pathlib import Path
34
from textwrap import dedent
@@ -23,47 +24,70 @@ def patch_open():
2324
return patch("io.open", new=mock_open())
2425

2526

26-
single_workflow_output = dedent("""\
27-
apiVersion: argoproj.io/v1alpha1
28-
kind: Workflow
29-
metadata:
30-
name: single
31-
spec: {}
32-
""")
33-
34-
workflow_template_output = dedent("""\
35-
apiVersion: argoproj.io/v1alpha1
36-
kind: WorkflowTemplate
37-
metadata:
38-
name: workflow-template
39-
spec: {}
40-
""")
41-
42-
cluster_workflow_template_output = dedent("""\
43-
apiVersion: argoproj.io/v1alpha1
44-
kind: ClusterWorkflowTemplate
45-
metadata:
46-
name: cluster-workflow-template
47-
spec: {}
48-
""")
49-
50-
multiple_workflow_output = dedent("""\
51-
apiVersion: argoproj.io/v1alpha1
52-
kind: Workflow
53-
metadata:
54-
name: one
55-
spec: {}
56-
---
57-
apiVersion: argoproj.io/v1alpha1
58-
kind: Workflow
59-
metadata:
60-
name: two
61-
spec: {}
62-
""")
27+
single_workflow_output = """\
28+
apiVersion: argoproj.io/v1alpha1
29+
kind: Workflow
30+
metadata:
31+
name: single
32+
spec: {}
33+
"""
34+
35+
runner_workflow_output = """\
36+
apiVersion: argoproj.io/v1alpha1
37+
kind: Workflow
38+
metadata:
39+
generateName: runner-workflow-
40+
spec:
41+
entrypoint: hello
42+
templates:
43+
- name: hello
44+
script:
45+
image: python:3.9
46+
source: '{{inputs.parameters}}'
47+
args:
48+
- -m
49+
- hera.workflows.runner
50+
- -e
51+
- tests.cli.examples.runner_workflow:hello
52+
command:
53+
- python
54+
"""
55+
56+
57+
workflow_template_output = """\
58+
apiVersion: argoproj.io/v1alpha1
59+
kind: WorkflowTemplate
60+
metadata:
61+
name: workflow-template
62+
spec: {}
63+
"""
64+
65+
cluster_workflow_template_output = """\
66+
apiVersion: argoproj.io/v1alpha1
67+
kind: ClusterWorkflowTemplate
68+
metadata:
69+
name: cluster-workflow-template
70+
spec: {}
71+
"""
72+
73+
multiple_workflow_output = """\
74+
apiVersion: argoproj.io/v1alpha1
75+
kind: Workflow
76+
metadata:
77+
name: one
78+
spec: {}
79+
---
80+
apiVersion: argoproj.io/v1alpha1
81+
kind: Workflow
82+
metadata:
83+
name: two
84+
spec: {}
85+
"""
6386

6487
whole_folder_output = join_output(
6588
cluster_workflow_template_output,
6689
multiple_workflow_output,
90+
runner_workflow_output,
6791
single_workflow_output,
6892
workflow_template_output,
6993
)
@@ -89,6 +113,24 @@ def test_single_workflow(capsys):
89113
assert output == single_workflow_output
90114

91115

116+
@pytest.mark.cli
117+
def test_runner_workflow(capsys):
118+
runner.invoke("tests/cli/examples/runner_workflow.py")
119+
120+
output = get_stdout(capsys)
121+
assert output == runner_workflow_output
122+
123+
124+
@pytest.mark.cli
125+
def test_runner_workflow_not_in_cwd(capsys, tmp_path):
126+
shutil.copy("tests/cli/examples/runner_workflow.py", tmp_path)
127+
runner.invoke(str(tmp_path / "runner_workflow.py"))
128+
129+
output = get_stdout(capsys)
130+
# The module is not in sys.path so we just use the stem of the workflow (i.e. best guess)
131+
assert output == runner_workflow_output.replace("tests.cli.examples.runner_workflow", "runner_workflow")
132+
133+
92134
@pytest.mark.cli
93135
def test_multiple_workflow(capsys):
94136
runner.invoke("tests/cli/examples/multiple_workflow.py")
@@ -308,7 +350,11 @@ def test_exclude_one(capsys):
308350
runner.invoke("tests/cli/examples", "--exclude=*/examples/*template*")
309351

310352
output = get_stdout(capsys)
311-
assert output == join_output(multiple_workflow_output, single_workflow_output)
353+
assert output == join_output(
354+
multiple_workflow_output,
355+
runner_workflow_output,
356+
single_workflow_output,
357+
)
312358

313359

314360
@pytest.mark.cli
@@ -320,7 +366,7 @@ def test_exclude_two(capsys):
320366
)
321367

322368
output = get_stdout(capsys)
323-
assert output == multiple_workflow_output
369+
assert output == join_output(multiple_workflow_output, runner_workflow_output)
324370

325371

326372
@pytest.mark.cli

tests/test_runner.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import tests.helper as test_module
2323
from hera.shared._pydantic import _PYDANTIC_VERSION
2424
from hera.shared.serialization import serialize
25-
from hera.workflows._runner.util import _run, _runner
25+
from hera.workflows._runner.util import _run, _runner, create_module_string
2626
from hera.workflows.io.v1 import Output as OutputV1
2727

2828
try:
@@ -1079,3 +1079,63 @@ def test_script_partially_annotated_tuple_should_raise_an_error():
10791079
),
10801080
):
10811081
_runner(entrypoint, kwargs_list)
1082+
1083+
1084+
@pytest.mark.parametrize(
1085+
"sys_path_relatives,file_rel_path,expected",
1086+
[
1087+
pytest.param(["project"], "project/wf_a.py", "wf_a", id="Exact direct match in sys.path"),
1088+
pytest.param(["project"], "project/workflows/wf_a.py", "workflows.wf_a", id="Submodule match in sys.path"),
1089+
pytest.param(
1090+
["project"],
1091+
"project/workflows/subpackage/another/wf_a.py",
1092+
"workflows.subpackage.another.wf_a",
1093+
id="Deep submodule match in sys.path",
1094+
),
1095+
pytest.param(
1096+
["project", "project/src"],
1097+
"project/src/workflows/wf_b.py",
1098+
"workflows.wf_b",
1099+
id="More specific match (src dir) in sys.path",
1100+
),
1101+
pytest.param([], "project/workflows/wf_c.py", "wf_c", id="No match, fallback to stem"),
1102+
pytest.param(
1103+
[""],
1104+
"project/workflows/wf_d.py",
1105+
"project.workflows.wf_d",
1106+
id="sys.path contains root, nested module path is full path",
1107+
),
1108+
],
1109+
)
1110+
def test_create_module_string(
1111+
tmp_path,
1112+
monkeypatch,
1113+
sys_path_relatives: list[str],
1114+
file_rel_path: str,
1115+
expected: str,
1116+
):
1117+
# GIVEN
1118+
# Create file structure
1119+
file_path = tmp_path / file_rel_path
1120+
file_path.parent.mkdir(parents=True, exist_ok=True)
1121+
1122+
# Set up sys.path using tmp_path as root
1123+
mock_sys_path = [str(tmp_path / rel) for rel in sys_path_relatives]
1124+
monkeypatch.setattr(sys, "path", mock_sys_path)
1125+
1126+
# THEN
1127+
assert create_module_string(file_path) == expected
1128+
1129+
1130+
def test_symlinked_sys_path(tmp_path, monkeypatch):
1131+
real_dir = tmp_path / "real_project"
1132+
real_dir.mkdir()
1133+
file_path = real_dir / "wf.py"
1134+
1135+
# Create a symlink pointing to real_project
1136+
symlink_path = tmp_path / "link_project"
1137+
symlink_path.symlink_to(real_dir)
1138+
1139+
monkeypatch.setattr(sys, "path", [str(symlink_path)])
1140+
1141+
assert create_module_string(file_path) == "wf"

0 commit comments

Comments
 (0)