Skip to content

Commit

Permalink
add check on input gpu indices (#228)
Browse files Browse the repository at this point in the history
* add check on input gpu indices

* add test

* actually call the entry point...

* allow entry_point tests to deal with multi input values

* negative numbers are reated as flags...

* actuall catch the correct error

* remove false comments

* nargs=+ always feeds a list
  • Loading branch information
sroet authored Aug 29, 2024
1 parent b6c4a93 commit 9f1e4e7
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/pytom_tm/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParseDoseFile,
ParseDefocus,
BetweenZeroAndOne,
ParseGPUIndices,
)
from pytom_tm.tmjob import load_json_to_tmjob
from os import urandom
Expand Down Expand Up @@ -902,6 +903,7 @@ def match_template(argv=None):
"--gpu-ids",
nargs="+",
type=int,
action=ParseGPUIndices,
required=True,
help="GPU indices to run the program on.",
)
Expand Down
25 changes: 25 additions & 0 deletions src/pytom_tm/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,31 @@ def __call__(
parser.error("{0} can only take one or two arguments".format(option_string))


class ParseGPUIndices(argparse.Action):
"""argparse.Action subclass to parse gpu indices. The input can either be an int or
a list of ints that specify the gpu indices to be used.
"""

def __call__(
self,
parser,
namespace,
values: list[int, ...],
option_string: Optional[str] = None,
):
import cupy

max_value = cupy.cuda.runtime.getDeviceCount()
for val in values:
if val < 0 or val >= max_value:
parser.error(
f"{option_string} all gpu indices should be between 0 "
"and {max_value-1}"
)

setattr(namespace, self.dest, values)


class ParseDoseFile(argparse.Action):
"""argparse.Action subclass to parse a txt file contain information on accumulated
dose per tilt."""
Expand Down
23 changes: 21 additions & 2 deletions tests/test_entry_points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
import pathlib
import numpy as np
import cupy as cp
import logging
from shutil import which
from contextlib import redirect_stdout, redirect_stderr
Expand Down Expand Up @@ -43,7 +44,10 @@

def prep_argv(arg_dict):
argv = []
[argv.extend([k, v]) if v != "" else argv.append(k) for k, v in arg_dict.items()]
[
argv.extend([k] + v.split()) if v != "" else argv.append(k)
for k, v in arg_dict.items()
]
return argv


Expand Down Expand Up @@ -176,5 +180,20 @@ def start(arg_dict): # simplify run
msg="File should exist in debug mode",
)

# rest the log level after the entry point modified it
# reset the log level after the entry point modified it
logging.basicConfig(level=LOG_LEVEL, force=True)

# test providing invalid gpu indices
n_devices = cp.cuda.runtime.getDeviceCount()
for indices in ["-1", f"0 {n_devices}"]:
dump = StringIO()
with (
self.assertRaises(SystemExit) as ex,
redirect_stdout(dump),
redirect_stderr(dump),
):
arguments = defaults.copy()
arguments["-g"] = indices
start(arguments)
self.assertIn("gpu indices", dump.getvalue())
dump.close()

0 comments on commit 9f1e4e7

Please sign in to comment.