Skip to content

Commit

Permalink
pick device allows strong preferences now
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Nov 13, 2024
1 parent 7543ca1 commit e1821d7
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,23 @@ def find_least_used_cuda_gpu():
return utilization.index(min(utilization))


def pick_device(gpu_index=None):
def pick_device(gpu_preference=None):
"""Pick a device for PyTorch to use.
If CUDA is available, use the least used GPU, and if all are idle use the gpu_index
modulo the number of GPUs. If gpu_index is None, then use a random GPU.
If gpu_preference is a string, use the device with that name. This is considered a
strong preference from a user who knows what they are doing.
If gpu_preference is an integer, this is a weak preference for a numbered
GPU. If CUDA is available, use the least used GPU, and if all are idle use
the gpu_index modulo the number of GPUs. If gpu_index is None, then use a random GPU.
"""

# Strong preference for a specific device.
if gpu_preference is not None and isinstance(gpu_preference, str):
return torch.device(gpu_preference)

# else weak preference for a numbered GPU.

# check that CUDA is usable
def check_CUDA():
try:
Expand All @@ -216,10 +226,10 @@ def check_CUDA():
if torch.backends.cudnn.is_available() and check_CUDA():
which_gpu = find_least_used_cuda_gpu()
if which_gpu is None:
if gpu_index is None:
if gpu_preference is None:
which_gpu = np.random.randint(torch.cuda.device_count())
else:
which_gpu = gpu_index % torch.cuda.device_count()
which_gpu = gpu_preference % torch.cuda.device_count()
print(f"Using CUDA GPU {which_gpu}")
return torch.device(f"cuda:{which_gpu}")
elif torch.backends.mps.is_available():
Expand Down

0 comments on commit e1821d7

Please sign in to comment.