Skip to content

Commit

Permalink
Add gpu-id flag in the cli
Browse files Browse the repository at this point in the history
  • Loading branch information
namannimmo10 committed Jun 12, 2023
1 parent ef30e0e commit a0e6e7c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 1 deletion.
22 changes: 22 additions & 0 deletions httomo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def check(yaml_config: Path, in_data: Path = None):
default=1,
help=" The number of the CPU cores per process.",
)
@click.option(
"--gpu-id",
type=click.INT,
default=0,
help="The GPU ID of the device to use. Default is 0.",
)
@click.option(
"--save-all",
is_flag=True,
Expand Down Expand Up @@ -99,6 +105,7 @@ def run(
dimension: int,
pad: int,
ncore: int,
gpu_id: int,
save_all: bool,
file_based_reslice: bool,
reslice_dir: Path,
Expand All @@ -120,6 +127,21 @@ def run(
# Copy YAML pipeline file to output directory
copy(yaml_config, httomo.globals.run_out_dir)

# try to access the GPU with the ID given
try:
import cupy as cp

gpu_count = cp.cuda.runtime.getDeviceCount()
if gpu_id not in range(0, gpu_count):
raise ValueError(
f"GPU Device not available for access. Use a GPU ID in the range: 0 to {gpu_count} (exclusive)"
)
cp.cuda.Device(gpu_id).use()
httomo.globals.gpu_id = gpu_id

except ImportError:
pass # silently pass and run the CPU pipeline

return run_tasks(
in_file,
yaml_config,
Expand Down
1 change: 1 addition & 0 deletions httomo/globals.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
run_out_dir = None
logger = None
gpu_id = -1
3 changes: 2 additions & 1 deletion httomo/wrappers_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from mpi4py.MPI import Comm

import httomo.globals
from httomo.data import mpiutil
from httomo.utils import Colour, log_once

Expand Down Expand Up @@ -41,7 +42,7 @@ def __init__(
self.comm = comm
if gpu_enabled:
self.num_GPUs = xp.cuda.runtime.getDeviceCount()
self.gpu_id = mpiutil.local_rank % self.num_GPUs
self.gpu_id = httomo.globals.gpu_id

def _transfer_data(self, *args) -> Union[tuple, xp.ndarray, np.ndarray]:
"""Transfer the data between the host and device for the GPU-enabled method
Expand Down
11 changes: 11 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,14 @@ def test_cli_pass_output_folder(
]
subprocess.check_output(cmd)
assert Path(custom_output_dir, "user.log").exists()


@pytest.mark.cupy
def test_cli_pass_gpu_id(cmd, standard_data, standard_loader, output_folder):
cmd.insert(7, standard_data)
cmd.insert(8, standard_loader)
cmd.insert(4, "--gpu-id")
cmd.insert(5, "10")

result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert "GPU Device not available for access." in result.stderr

0 comments on commit a0e6e7c

Please sign in to comment.