Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpu-id flag in the cli #139

Merged
merged 1 commit into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions httomo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def check(yaml_config: Path, in_data: Path = None):
default=1,
help=" The number of the CPU cores per process.",
)
@click.option(
"--gpu-id",
type=click.INT,
default=-1,
help="The GPU ID of the device to use.",
)
@click.option(
"--save-all",
is_flag=True,
Expand Down Expand Up @@ -99,6 +105,7 @@ def run(
dimension: int,
pad: int,
ncore: int,
gpu_id: int,
save_all: bool,
file_based_reslice: bool,
reslice_dir: Path,
Expand All @@ -120,6 +127,25 @@ def run(
# Copy YAML pipeline file to output directory
copy(yaml_config, httomo.globals.run_out_dir)

# try to access the GPU with the ID given
try:
import cupy as cp

gpu_count = cp.cuda.runtime.getDeviceCount()

if gpu_id != -1:
if gpu_id not in range(0, gpu_count):
raise ValueError(
f"GPU Device not available for access. Use a GPU ID in the range: 0 to {gpu_count} (exclusive)"
)

cp.cuda.Device(gpu_id).use()

httomo.globals.gpu_id = gpu_id

except ImportError:
pass # silently pass and run the CPU pipeline

return run_tasks(
in_file,
yaml_config,
Expand Down
1 change: 1 addition & 0 deletions httomo/globals.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
run_out_dir = None
logger = None
gpu_id = -1
5 changes: 2 additions & 3 deletions httomo/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
MAX_SWEEPS = 1



@dataclass
class MethodFunc:
"""
Expand Down Expand Up @@ -450,9 +449,9 @@ def _get_method_funcs(yaml_config: Path, comm: MPI.Comm) -> List[MethodFunc]:
method_func=wrapper_func,
wrapper_func=wrapper_method,
parameters=method_conf,
cpu=True, # get cpu/gpu meta data info from httomolib methods
cpu=True, # get cpu/gpu meta data info from httomolib methods
gpu=False,
calc_max_slices=None, # call calc_max_slices function in wrappers
calc_max_slices=None, # call calc_max_slices function in wrappers
reslice_ahead=False,
pattern=Pattern.all,
is_loader=False,
Expand Down
5 changes: 4 additions & 1 deletion httomo/wrappers_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from mpi4py.MPI import Comm

import httomo.globals
from httomo.data import mpiutil
from httomo.utils import Colour, log_once

Expand Down Expand Up @@ -41,7 +42,9 @@ def __init__(
self.comm = comm
if gpu_enabled:
self.num_GPUs = xp.cuda.runtime.getDeviceCount()
self.gpu_id = mpiutil.local_rank % self.num_GPUs
_id = httomo.globals.gpu_id
# if gpu-id was specified in the CLI, use that
self.gpu_id = mpiutil.local_rank % self.num_GPUs if _id == -1 else _id

def _transfer_data(self, *args) -> Union[tuple, xp.ndarray, np.ndarray]:
"""Transfer the data between the host and device for the GPU-enabled method
Expand Down
13 changes: 13 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,16 @@ def test_cli_pass_output_folder(
]
subprocess.check_output(cmd)
assert Path(custom_output_dir, "user.log").exists()


@pytest.mark.cupy
def test_cli_pass_gpu_id(cmd, standard_data, standard_loader, output_folder):
cmd.insert(7, standard_data)
cmd.insert(8, standard_loader)
cmd.insert(4, "--gpu-id")
cmd.insert(5, "10")

result = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
assert "GPU Device not available for access." in result.stderr
Loading