diff --git a/src/pytom_tm/entry_points.py b/src/pytom_tm/entry_points.py index cfada29..a493f60 100644 --- a/src/pytom_tm/entry_points.py +++ b/src/pytom_tm/entry_points.py @@ -18,6 +18,7 @@ ParseDoseFile, ParseDefocus, BetweenZeroAndOne, + ParseGPUIndices, ) from pytom_tm.tmjob import load_json_to_tmjob from os import urandom @@ -902,6 +903,7 @@ def match_template(argv=None): "--gpu-ids", nargs="+", type=int, + action=ParseGPUIndices, required=True, help="GPU indices to run the program on.", ) diff --git a/src/pytom_tm/io.py b/src/pytom_tm/io.py index d57825c..0a1e8aa 100644 --- a/src/pytom_tm/io.py +++ b/src/pytom_tm/io.py @@ -150,6 +150,31 @@ def __call__( parser.error("{0} can only take one or two arguments".format(option_string)) +class ParseGPUIndices(argparse.Action): + """argparse.Action subclass to parse gpu indices. The input can either be an int or + a list of ints that specify the gpu indices to be used. + """ + + def __call__( + self, + parser, + namespace, + values: list[int, ...], + option_string: Optional[str] = None, + ): + import cupy + + max_value = cupy.cuda.runtime.getDeviceCount() + for val in values: + if val < 0 or val >= max_value: + parser.error( + f"{option_string} all gpu indices should be between 0 " + "and {max_value-1}" + ) + + setattr(namespace, self.dest, values) + + class ParseDoseFile(argparse.Action): """argparse.Action subclass to parse a txt file contain information on accumulated dose per tilt.""" diff --git a/tests/test_entry_points.py b/tests/test_entry_points.py index d52d23d..27557c3 100644 --- a/tests/test_entry_points.py +++ b/tests/test_entry_points.py @@ -1,6 +1,7 @@ import unittest import pathlib import numpy as np +import cupy as cp import logging from shutil import which from contextlib import redirect_stdout, redirect_stderr @@ -43,7 +44,10 @@ def prep_argv(arg_dict): argv = [] - [argv.extend([k, v]) if v != "" else argv.append(k) for k, v in arg_dict.items()] + [ + argv.extend([k] + v.split()) if v != "" else argv.append(k) + for k, v in arg_dict.items() + ] return argv @@ -176,5 +180,20 @@ def start(arg_dict): # simplify run msg="File should exist in debug mode", ) - # rest the log level after the entry point modified it + # reset the log level after the entry point modified it logging.basicConfig(level=LOG_LEVEL, force=True) + + # test providing invalid gpu indices + n_devices = cp.cuda.runtime.getDeviceCount() + for indices in ["-1", f"0 {n_devices}"]: + dump = StringIO() + with ( + self.assertRaises(SystemExit) as ex, + redirect_stdout(dump), + redirect_stderr(dump), + ): + arguments = defaults.copy() + arguments["-g"] = indices + start(arguments) + self.assertIn("gpu indices", dump.getvalue()) + dump.close()