From b89b8e366ea1b323e47311598d611c90d9a1f3af Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 16:08:17 -0700 Subject: [PATCH 01/12] Add `get_num_gpus_available_isolated` --- tests/conftest.py | 15 ++---------- vllm/config.py | 5 ++-- .../device_communicators/custom_all_reduce.py | 3 ++- .../custom_all_reduce_utils.py | 3 ++- vllm/executor/multiproc_gpu_executor.py | 5 ++-- vllm/utils.py | 23 +++++++++++++++++++ 6 files changed, 33 insertions(+), 21 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e0680467d78b..554a11a74521 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ from vllm.multimodal import MultiModalData from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.sequence import SampleLogprobs -from vllm.utils import is_cpu +from vllm.utils import is_cpu, lazy_num_gpus_available logger = init_logger(__name__) @@ -537,15 +537,4 @@ def num_gpus_available(): """Get number of GPUs without initializing the CUDA context in current process.""" - try: - out = subprocess.run([ - sys.executable, "-c", - "import torch; print(torch.cuda.device_count())" - ], - capture_output=True, - check=True, - text=True) - except subprocess.CalledProcessError as e: - logger.warning("Failed to get number of GPUs.", exc_info=e) - return 0 - return int(out.stdout.strip()) + return lazy_num_gpus_available() diff --git a/vllm/config.py b/vllm/config.py index 2513d43ce8e6..89d256ed9826 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu +from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu, get_num_gpus_available_isolated if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -605,12 +605,11 @@ def __init__( if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. - from torch.cuda import device_count from vllm.executor import ray_utils backend = "mp" ray_found = ray_utils.ray is not None - if device_count() < self.world_size: + if get_num_gpus_available_isolated() < self.world_size: if not ray_found: raise ValueError("Unable to load Ray which is " "required for multi-node inference") diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index bbc2284f8a36..2f3f77787d61 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -12,6 +12,7 @@ from vllm.distributed.parallel_state import ( get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node) from vllm.logger import init_logger +from vllm.utils import get_num_gpus_available_isolated try: import pynvml @@ -149,7 +150,7 @@ def __init__(self, if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) else: - device_ids = list(range(torch.cuda.device_count())) + device_ids = list(range(get_num_gpus_available_isolated())) physical_device_id = device_ids[device.index] tensor = torch.tensor([physical_device_id], diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index 4b89a23dfc46..3f6816445eb8 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger +from vllm.utils import get_num_gpus_available_isolated logger = init_logger(__name__) @@ -153,7 +154,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: is_distributed = dist.is_initialized() - num_dev = torch.cuda.device_count() + num_dev = get_num_gpus_available_isolated() cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices is None: cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 99c9e52034cc..86869b7e621c 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async) + get_vllm_instance_id, make_async, get_num_gpus_available_isolated) logger = init_logger(__name__) @@ -33,8 +33,7 @@ def _init_executor(self) -> None: # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - from torch.cuda import device_count - assert world_size <= device_count(), ( + assert world_size <= get_num_gpus_available_isolated(), ( "please set tensor_parallel_size to less than max local gpu count") distributed_init_method = get_distributed_init_method( diff --git a/vllm/utils.py b/vllm/utils.py index af585929d1a0..4b64465ef9c9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -693,3 +693,26 @@ def inner(*args, **kwargs): return inner # type: ignore return wrapper + + +@lru_cache(maxsize=None) +def get_num_gpus_available_isolated() -> int: + """Get number of GPUs without initializing the CUDA context + in current process. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + try: + out = subprocess.run([ + sys.executable, "-c", + "import torch; print(torch.cuda.device_count())" + ], + capture_output=True, + check=True, + text=True) + except subprocess.CalledProcessError as e: + logger.warning("Failed to get number of GPUs.", exc_info=e) + return 0 + return int(out.stdout.strip()) From babcfe4c3fbf40081b604d0c0101991053613e0e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 16:10:06 -0700 Subject: [PATCH 02/12] Apply suggestions from code review --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 554a11a74521..dd5ba0dcddae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ from vllm.multimodal import MultiModalData from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.sequence import SampleLogprobs -from vllm.utils import is_cpu, lazy_num_gpus_available +from vllm.utils import is_cpu, get_num_gpus_available_isolated logger = init_logger(__name__) @@ -537,4 +537,4 @@ def num_gpus_available(): """Get number of GPUs without initializing the CUDA context in current process.""" - return lazy_num_gpus_available() + return get_num_gpus_available_isolated() From bc7afc32fb42ec6883725b1f77a691167b54e410 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 16:12:16 -0700 Subject: [PATCH 03/12] Lint --- tests/conftest.py | 4 +--- vllm/config.py | 3 ++- vllm/executor/multiproc_gpu_executor.py | 5 +++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index dd5ba0dcddae..0a598ef8bd17 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,6 @@ import contextlib import gc import os -import subprocess -import sys from typing import Any, Dict, List, Optional, Tuple, TypeVar import pytest @@ -21,7 +19,7 @@ from vllm.multimodal import MultiModalData from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.sequence import SampleLogprobs -from vllm.utils import is_cpu, get_num_gpus_available_isolated +from vllm.utils import get_num_gpus_available_isolated, is_cpu logger = init_logger(__name__) diff --git a/vllm/config.py b/vllm/config.py index 89d256ed9826..e052d8e3ccf3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,7 +11,8 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu, get_num_gpus_available_isolated +from vllm.utils import (get_cpu_memory, get_num_gpus_available_isolated, + is_cpu, is_hip, is_neuron, is_tpu) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 86869b7e621c..ba5b9c78871d 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -9,8 +9,9 @@ ResultHandler, WorkerMonitor) from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async, get_num_gpus_available_isolated) +from vllm.utils import (get_distributed_init_method, get_ip, + get_num_gpus_available_isolated, get_open_port, + get_vllm_instance_id, make_async) logger = init_logger(__name__) From dfddcaeb30a04da37e16475d0b12c7256f638e73 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 16:13:41 -0700 Subject: [PATCH 04/12] Lint --- vllm/utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 4b64465ef9c9..ad1a712579e6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -705,13 +705,16 @@ def get_num_gpus_available_isolated() -> int: value.""" try: - out = subprocess.run([ - sys.executable, "-c", - "import torch; print(torch.cuda.device_count())" - ], - capture_output=True, - check=True, - text=True) + out = subprocess.run( + [ + sys.executable, + "-c", + "import torch; print(torch.cuda.device_count())", + ], + capture_output=True, + check=True, + text=True, + ) except subprocess.CalledProcessError as e: logger.warning("Failed to get number of GPUs.", exc_info=e) return 0 From 3c7690e03e0a56b3ff9ba6218ee67f6eeea0c3aa Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 16:35:32 -0700 Subject: [PATCH 05/12] Tweak --- vllm/utils.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index ad1a712579e6..f3adac9f5038 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -696,14 +696,7 @@ def inner(*args, **kwargs): @lru_cache(maxsize=None) -def get_num_gpus_available_isolated() -> int: - """Get number of GPUs without initializing the CUDA context - in current process. - - This should be used instead of torch.cuda.device_count() - unless CUDA_VISIBLE_DEVICES has already been set to the desired - value.""" - +def _get_num_gpus_available_isolated() -> int: try: out = subprocess.run( [ @@ -719,3 +712,24 @@ def get_num_gpus_available_isolated() -> int: logger.warning("Failed to get number of GPUs.", exc_info=e) return 0 return int(out.stdout.strip()) + + +_LAST_CUDA_VISIBLE_DEVICES = None + + +def get_num_gpus_available_isolated() -> int: + """Get number of GPUs without initializing the CUDA context + in current process. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + global _LAST_CUDA_VISIBLE_DEVICES + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices != _LAST_CUDA_VISIBLE_DEVICES: + _get_num_gpus_available_isolated.cache_clear() + _LAST_CUDA_VISIBLE_DEVICES = cuda_visible_devices + + return _get_num_gpus_available_isolated() From 268430534080617f1bd76a8195221b90f5eaf903 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 16:38:22 -0700 Subject: [PATCH 06/12] Tweak --- vllm/utils.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index f3adac9f5038..c05ba7e4994d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -695,8 +695,11 @@ def inner(*args, **kwargs): return wrapper -@lru_cache(maxsize=None) -def _get_num_gpus_available_isolated() -> int: +@lru_cache(maxsize=5) +def _get_num_gpus_available_isolated( + cuda_visible_devices: Optional[str] = None) -> int: + # Note: cuda_visible_devices is not used, but we keep it as an argument for + # LRU Cache purposes. try: out = subprocess.run( [ @@ -714,9 +717,6 @@ def _get_num_gpus_available_isolated() -> int: return int(out.stdout.strip()) -_LAST_CUDA_VISIBLE_DEVICES = None - - def get_num_gpus_available_isolated() -> int: """Get number of GPUs without initializing the CUDA context in current process. @@ -724,12 +724,4 @@ def get_num_gpus_available_isolated() -> int: This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" - - global _LAST_CUDA_VISIBLE_DEVICES - - cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES - if cuda_visible_devices != _LAST_CUDA_VISIBLE_DEVICES: - _get_num_gpus_available_isolated.cache_clear() - _LAST_CUDA_VISIBLE_DEVICES = cuda_visible_devices - - return _get_num_gpus_available_isolated() + return _get_num_gpus_available_isolated(envs.CUDA_VISIBLE_DEVICES) From 85a24ef28323ed59f53ac9fa5b4ecd9d284ce7d8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 17:06:37 -0700 Subject: [PATCH 07/12] Add comment --- vllm/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/utils.py b/vllm/utils.py index c05ba7e4994d..a10a4fcb1fe1 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -724,4 +724,8 @@ def get_num_gpus_available_isolated() -> int: This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _get_num_gpus_available_isolated(envs.CUDA_VISIBLE_DEVICES) From a507a11f0589d0b7add45decfaf1a510b65536f4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 17:19:49 -0700 Subject: [PATCH 08/12] Replace subprocess with direct call --- vllm/utils.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index a10a4fcb1fe1..8bb7a64da1eb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -700,21 +700,20 @@ def _get_num_gpus_available_isolated( cuda_visible_devices: Optional[str] = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for # LRU Cache purposes. - try: - out = subprocess.run( - [ - sys.executable, - "-c", - "import torch; print(torch.cuda.device_count())", - ], - capture_output=True, - check=True, - text=True, - ) - except subprocess.CalledProcessError as e: - logger.warning("Failed to get number of GPUs.", exc_info=e) + + # Code below is based on + # https://github.com/pytorch/pytorch/blob/ + # c1cd946818442aca8c7f812b16d187ce1586c3bc/ + # torch/cuda/__init__.py#L831C1-L831C17 + import torch.cuda + import torch.version + + if not torch.cuda._is_compiled(): return 0 - return int(out.stdout.strip()) + # bypass _device_count_nvml() if rocm (not supported) + nvml_count = -1 if torch.version.hip else torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count + return r def get_num_gpus_available_isolated() -> int: From d13dcc0ac87f92c49318e4ed5fd68f64c53f9db0 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 17:20:20 -0700 Subject: [PATCH 09/12] Tweak docstring --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 8bb7a64da1eb..32ab9a979e9a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -717,7 +717,7 @@ def _get_num_gpus_available_isolated( def get_num_gpus_available_isolated() -> int: - """Get number of GPUs without initializing the CUDA context + """Get number of GPUs without caching the number of devices in current process. This should be used instead of torch.cuda.device_count() From 6dc50f05450e40a89e82dc94c428ac059d834d92 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 17:28:08 -0700 Subject: [PATCH 10/12] Review feedback --- tests/conftest.py | 4 ++-- vllm/config.py | 6 +++--- .../distributed/device_communicators/custom_all_reduce.py | 4 ++-- .../device_communicators/custom_all_reduce_utils.py | 4 ++-- vllm/executor/multiproc_gpu_executor.py | 6 +++--- vllm/utils.py | 8 ++++---- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0a598ef8bd17..5e482466e1c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ from vllm.multimodal import MultiModalData from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.sequence import SampleLogprobs -from vllm.utils import get_num_gpus_available_isolated, is_cpu +from vllm.utils import cuda_device_count_stateless, is_cpu logger = init_logger(__name__) @@ -535,4 +535,4 @@ def num_gpus_available(): """Get number of GPUs without initializing the CUDA context in current process.""" - return get_num_gpus_available_isolated() + return cuda_device_count_stateless() diff --git a/vllm/config.py b/vllm/config.py index e052d8e3ccf3..a0bd6b0975a1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,8 +11,8 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import (get_cpu_memory, get_num_gpus_available_isolated, - is_cpu, is_hip, is_neuron, is_tpu) +from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, + is_hip, is_neuron, is_tpu) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -610,7 +610,7 @@ def __init__( from vllm.executor import ray_utils backend = "mp" ray_found = ray_utils.ray is not None - if get_num_gpus_available_isolated() < self.world_size: + if cuda_device_count_stateless() < self.world_size: if not ray_found: raise ValueError("Unable to load Ray which is " "required for multi-node inference") diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 2f3f77787d61..2f8ffe87d480 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -12,7 +12,7 @@ from vllm.distributed.parallel_state import ( get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node) from vllm.logger import init_logger -from vllm.utils import get_num_gpus_available_isolated +from vllm.utils import cuda_device_count_stateless try: import pynvml @@ -150,7 +150,7 @@ def __init__(self, if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) else: - device_ids = list(range(get_num_gpus_available_isolated())) + device_ids = list(range(cuda_device_count_stateless())) physical_device_id = device_ids[device.index] tensor = torch.tensor([physical_device_id], diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index 3f6816445eb8..b3d397de72cc 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -13,7 +13,7 @@ import vllm.envs as envs from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger -from vllm.utils import get_num_gpus_available_isolated +from vllm.utils import cuda_device_count_stateless logger = init_logger(__name__) @@ -154,7 +154,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: is_distributed = dist.is_initialized() - num_dev = get_num_gpus_available_isolated() + num_dev = cuda_device_count_stateless() cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices is None: cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index ba5b9c78871d..8385e56f88b3 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -9,8 +9,8 @@ ResultHandler, WorkerMonitor) from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (get_distributed_init_method, get_ip, - get_num_gpus_available_isolated, get_open_port, +from vllm.utils import (cuda_device_count_stateless, + get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) logger = init_logger(__name__) @@ -34,7 +34,7 @@ def _init_executor(self) -> None: # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - assert world_size <= get_num_gpus_available_isolated(), ( + assert world_size <= cuda_device_count_stateless(), ( "please set tensor_parallel_size to less than max local gpu count") distributed_init_method = get_distributed_init_method( diff --git a/vllm/utils.py b/vllm/utils.py index 32ab9a979e9a..268d724feaef 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -695,8 +695,8 @@ def inner(*args, **kwargs): return wrapper -@lru_cache(maxsize=5) -def _get_num_gpus_available_isolated( +@lru_cache(maxsize=8) +def _cuda_device_count_stateless( cuda_visible_devices: Optional[str] = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for # LRU Cache purposes. @@ -716,7 +716,7 @@ def _get_num_gpus_available_isolated( return r -def get_num_gpus_available_isolated() -> int: +def cuda_device_count_stateless() -> int: """Get number of GPUs without caching the number of devices in current process. @@ -727,4 +727,4 @@ def get_num_gpus_available_isolated() -> int: # This can be removed and simply replaced with torch.cuda.get_device_count # after https://github.com/pytorch/pytorch/pull/122815 is released. - return _get_num_gpus_available_isolated(envs.CUDA_VISIBLE_DEVICES) + return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) From fdadaf08571e0388a9b67952a94bbaca0bcc0453 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 12 Jun 2024 17:29:54 -0700 Subject: [PATCH 11/12] Clarify --- vllm/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 268d724feaef..b5c42605ba35 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -717,8 +717,8 @@ def _cuda_device_count_stateless( def cuda_device_count_stateless() -> int: - """Get number of GPUs without caching the number of devices - in current process. + """Get number of CUDA devices, caching based on the value of + CUDA_VISIBLE_DEVICES at the time of call. This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired From bf43ffbe1ab6579302774ef81d74867b45597c48 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 13 Jun 2024 11:02:38 -0700 Subject: [PATCH 12/12] Add test --- .buildkite/test-pipeline.yaml | 1 + tests/distributed/test_utils.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 tests/distributed/test_utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6b12d19ba611..6a2932db9f2d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -48,6 +48,7 @@ steps: - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - pytest -v -s spec_decode/e2e/test_integration_dist.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py - label: Distributed Tests (Multiple Groups) #mirror_hardwares: [amd] diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py new file mode 100644 index 000000000000..b7ec59c7a2cc --- /dev/null +++ b/tests/distributed/test_utils.py @@ -0,0 +1,31 @@ +import os + +import ray + +from vllm.utils import cuda_device_count_stateless + + +@ray.remote +class _CUDADeviceCountStatelessTestActor(): + + def get_count(self): + return cuda_device_count_stateless() + + def set_cuda_visible_devices(self, cuda_visible_devices: str): + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + + def get_cuda_visible_devices(self): + return os.environ["CUDA_VISIBLE_DEVICES"] + + +def test_cuda_device_count_stateless(): + """Test that cuda_device_count_stateless changes return value if + CUDA_VISIBLE_DEVICES is changed.""" + + actor = _CUDADeviceCountStatelessTestActor.options(num_gpus=2).remote() + assert ray.get(actor.get_cuda_visible_devices.remote()) == "0,1" + assert ray.get(actor.get_count.remote()) == 2 + ray.get(actor.set_cuda_visible_devices.remote("0")) + assert ray.get(actor.get_count.remote()) == 1 + ray.get(actor.set_cuda_visible_devices.remote("")) + assert ray.get(actor.get_count.remote()) == 0