Skip to content

Commit

Permalink
Add gpu-id flag in the cli
Browse files Browse the repository at this point in the history
  • Loading branch information
namannimmo10 committed Jun 8, 2023
1 parent 0c01293 commit fb7069c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
21 changes: 21 additions & 0 deletions httomo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def check(yaml_config: Path, in_data: Path = None):
default=1,
help=" The number of the CPU cores per process.",
)
@click.option(
"--gpu-id",
type=click.INT,
default=0,
help="The GPU ID of the device to use.",
)
@click.option(
"--save-all",
is_flag=True,
Expand Down Expand Up @@ -99,6 +105,7 @@ def run(
dimension: int,
pad: int,
ncore: int,
gpu_id: int,
save_all: bool,
file_based_reslice: bool,
reslice_dir: Path,
Expand All @@ -120,6 +127,20 @@ def run(
# Copy YAML pipeline file to output directory
copy(yaml_config, httomo.globals.run_out_dir)

# try to access the GPU with the ID given
try:
import cupy as cp

gpu_count = cp.cuda.runtime.getDeviceCount()
if gpu_id not in range(0, gpu_count):
raise ValueError(
f"GPU Device not available for access. Use a GPU ID in the range: 0 to {gpu_count} (exclusive)"
)
cp.cuda.Device(gpu_id).use()

except ImportError:
pass # silently pass and run the CPU pipeline

return run_tasks(
in_file,
yaml_config,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,14 @@ def test_cli_pass_output_folder(
]
subprocess.check_output(cmd)
assert Path(custom_output_dir, "user.log").exists()


@pytest.mark.cupy
def test_cli_pass_gpu_id(cmd, standard_data, standard_loader, output_folder):
cmd.insert(7, standard_data)
cmd.insert(8, standard_loader)
cmd.insert(4, "--gpu-id")
cmd.insert(5, "10")

result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert "GPU Device not available for access." in result.stderr

0 comments on commit fb7069c

Please sign in to comment.