Skip to content

Commit

Permalink
Add get_num_gpus_available_isolated
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Jun 12, 2024
1 parent 7d19de2 commit b89b8e3
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 21 deletions.
15 changes: 2 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from vllm.utils import is_cpu, lazy_num_gpus_available

logger = init_logger(__name__)

Expand Down Expand Up @@ -537,15 +537,4 @@ def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context
in current process."""

try:
out = subprocess.run([
sys.executable, "-c",
"import torch; print(torch.cuda.device_count())"
],
capture_output=True,
check=True,
text=True)
except subprocess.CalledProcessError as e:
logger.warning("Failed to get number of GPUs.", exc_info=e)
return 0
return int(out.stdout.strip())
return lazy_num_gpus_available()
5 changes: 2 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu, get_num_gpus_available_isolated

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -605,12 +605,11 @@ def __init__(
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
from torch.cuda import device_count

from vllm.executor import ray_utils
backend = "mp"
ray_found = ray_utils.ray is not None
if device_count() < self.world_size:
if get_num_gpus_available_isolated() < self.world_size:
if not ray_found:
raise ValueError("Unable to load Ray which is "
"required for multi-node inference")
Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.distributed.parallel_state import (
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
from vllm.logger import init_logger
from vllm.utils import get_num_gpus_available_isolated

try:
import pynvml
Expand Down Expand Up @@ -149,7 +150,7 @@ def __init__(self,
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
device_ids = list(range(get_num_gpus_available_isolated()))

physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import vllm.envs as envs
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger
from vllm.utils import get_num_gpus_available_isolated

logger = init_logger(__name__)

Expand Down Expand Up @@ -153,7 +154,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:

is_distributed = dist.is_initialized()

num_dev = torch.cuda.device_count()
num_dev = get_num_gpus_available_isolated()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
Expand Down
5 changes: 2 additions & 3 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
get_vllm_instance_id, make_async, get_num_gpus_available_isolated)

logger = init_logger(__name__)

Expand All @@ -33,8 +33,7 @@ def _init_executor(self) -> None:
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

from torch.cuda import device_count
assert world_size <= device_count(), (
assert world_size <= get_num_gpus_available_isolated(), (
"please set tensor_parallel_size to less than max local gpu count")

distributed_init_method = get_distributed_init_method(
Expand Down
23 changes: 23 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,3 +693,26 @@ def inner(*args, **kwargs):
return inner # type: ignore

return wrapper


@lru_cache(maxsize=None)
def get_num_gpus_available_isolated() -> int:
"""Get number of GPUs without initializing the CUDA context
in current process.
This should be used instead of torch.cuda.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""

try:
out = subprocess.run([
sys.executable, "-c",
"import torch; print(torch.cuda.device_count())"
],
capture_output=True,
check=True,
text=True)
except subprocess.CalledProcessError as e:
logger.warning("Failed to get number of GPUs.", exc_info=e)
return 0
return int(out.stdout.strip())

0 comments on commit b89b8e3

Please sign in to comment.