Skip to content

Commit

Permalink
add check on input gpu indices
Browse files Browse the repository at this point in the history
  • Loading branch information
sroet committed Aug 29, 2024
1 parent b6c4a93 commit 4c85fc7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/pytom_tm/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParseDoseFile,
ParseDefocus,
BetweenZeroAndOne,
ParseGPUIndices,
)
from pytom_tm.tmjob import load_json_to_tmjob
from os import urandom
Expand Down Expand Up @@ -902,6 +903,7 @@ def match_template(argv=None):
"--gpu-ids",
nargs="+",
type=int,
action=ParseGPUIndices,
required=True,
help="GPU indices to run the program on.",
)
Expand Down
27 changes: 27 additions & 0 deletions src/pytom_tm/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,33 @@ def __call__(
parser.error("{0} can only take one or two arguments".format(option_string))


class ParseGPUIndices(argparse.Action):
"""argparse.Action subclass to parse gpu indices. The input can either be an int or
a list of ints that specify the gpu indices to be used.
"""

def __call__(
self,
parser,
namespace,
values: Union[list[int, ...], int],
option_string: Optional[str] = None,
):
import cupy

max_value = cupy.cuda.runtime.getDeviceCount()
if isinstance(values, int):
values = [values]
for val in values:
if val < 0 or val >= max_value:
parser.error(
f"{option_string} all gpu indices should be between 0 "
"and {max_value-1}"
)

setattr(namespace, self.dest, values)


class ParseDoseFile(argparse.Action):
"""argparse.Action subclass to parse a txt file contain information on accumulated
dose per tilt."""
Expand Down

0 comments on commit 4c85fc7

Please sign in to comment.