diff --git a/src/pytom_tm/entry_points.py b/src/pytom_tm/entry_points.py index cfada29..a493f60 100644 --- a/src/pytom_tm/entry_points.py +++ b/src/pytom_tm/entry_points.py @@ -18,6 +18,7 @@ ParseDoseFile, ParseDefocus, BetweenZeroAndOne, + ParseGPUIndices, ) from pytom_tm.tmjob import load_json_to_tmjob from os import urandom @@ -902,6 +903,7 @@ def match_template(argv=None): "--gpu-ids", nargs="+", type=int, + action=ParseGPUIndices, required=True, help="GPU indices to run the program on.", ) diff --git a/src/pytom_tm/io.py b/src/pytom_tm/io.py index d57825c..22a2cdc 100644 --- a/src/pytom_tm/io.py +++ b/src/pytom_tm/io.py @@ -150,6 +150,33 @@ def __call__( parser.error("{0} can only take one or two arguments".format(option_string)) +class ParseGPUIndices(argparse.Action): + """argparse.Action subclass to parse gpu indices. The input can either be an int or + a list of ints that specify the gpu indices to be used. + """ + + def __call__( + self, + parser, + namespace, + values: Union[list[int, ...], int], + option_string: Optional[str] = None, + ): + import cupy + + max_value = cupy.cuda.runtime.getDeviceCount() + if isinstance(values, int): + values = [values] + for val in values: + if val < 0 or val >= max_value: + parser.error( + f"{option_string} all gpu indices should be between 0 " + "and {max_value-1}" + ) + + setattr(namespace, self.dest, values) + + class ParseDoseFile(argparse.Action): """argparse.Action subclass to parse a txt file contain information on accumulated dose per tilt."""