Skip to content

Commit 6b731c5

Browse files
authored
scripts: Check .is_cuda only in non-C++ files (#7561)
The check-torchcuda.py today will search for all occurrences of .is_cuda in the repository when a commit only modifies C++ headers and sources, which I believe is not intended. Check usage of .is_cuda only when a commit modifies any non-C++ file. Signed-off-by: Junjie Mao <[email protected]>
1 parent 2585881 commit 6b731c5

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

scripts/check-torchcuda.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,24 @@ def err(s: str) -> None:
5757

5858
files = []
5959
for file in sys.argv[1:]:
60-
if not file.endswith(".cpp"):
60+
if file.endswith(".py"):
6161
files.append(file)
6262

63-
res = subprocess.run(
64-
["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files],
65-
capture_output=True,
66-
)
67-
if res.returncode == 0:
68-
err('''
63+
if len(files) > 0:
64+
res = subprocess.run(
65+
["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files],
66+
capture_output=True,
67+
)
68+
if res.returncode == 0:
69+
err('''
6970
Error: The string ".is_cuda" was found. This implies checking if a tensor is a cuda tensor.
7071
Please replace all calls to "tensor.is_cuda" with "get_accelerator().on_accelerator(tensor)",
7172
and add the following import line:
7273
'from deepspeed.accelerator import get_accelerator'
7374
''')
74-
err(res.stdout.decode("utf-8"))
75-
sys.exit(1)
76-
elif res.returncode == 2:
77-
err(f"Error invoking grep on {', '.join(files)}:")
78-
err(res.stderr.decode("utf-8"))
79-
sys.exit(2)
75+
err(res.stdout.decode("utf-8"))
76+
sys.exit(1)
77+
elif res.returncode == 2:
78+
err(f"Error invoking grep on {', '.join(files)}:")
79+
err(res.stderr.decode("utf-8"))
80+
sys.exit(2)

0 commit comments

Comments
 (0)