Skip to content

Commit

Permalink
Add cuda_device_count_stateless (vllm-project#5473)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored and robertgshaw2-neuralmagic committed Jun 16, 2024
1 parent ac8c1a5 commit 64e03e9
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 23 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

- label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
Expand Down
17 changes: 2 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import gc
import logging
import os
import subprocess
import sys
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import pytest
Expand All @@ -24,7 +22,7 @@
from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from vllm.utils import cuda_device_count_stateless, is_cpu

logger = init_logger(__name__)

Expand Down Expand Up @@ -769,18 +767,7 @@ def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context
in current process."""

try:
out = subprocess.run([
sys.executable, "-c",
"import torch; print(torch.cuda.device_count())"
],
capture_output=True,
check=True,
text=True)
except subprocess.CalledProcessError as e:
logger.warning("Failed to get number of GPUs.", exc_info=e)
return 0
return int(out.stdout.strip())
return cuda_device_count_stateless()


@pytest.fixture(scope="session")
Expand Down
31 changes: 31 additions & 0 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os

import ray

from vllm.utils import cuda_device_count_stateless


@ray.remote
class _CUDADeviceCountStatelessTestActor():

def get_count(self):
return cuda_device_count_stateless()

def set_cuda_visible_devices(self, cuda_visible_devices: str):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices

def get_cuda_visible_devices(self):
return os.environ["CUDA_VISIBLE_DEVICES"]


def test_cuda_device_count_stateless():
"""Test that cuda_device_count_stateless changes return value if
CUDA_VISIBLE_DEVICES is changed."""

actor = _CUDADeviceCountStatelessTestActor.options(num_gpus=2).remote()
assert ray.get(actor.get_cuda_visible_devices.remote()) == "0,1"
assert ray.get(actor.get_count.remote()) == 2
ray.get(actor.set_cuda_visible_devices.remote("0"))
assert ray.get(actor.get_count.remote()) == 1
ray.get(actor.set_cuda_visible_devices.remote(""))
assert ray.get(actor.get_count.remote()) == 0
6 changes: 3 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_tpu)

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -637,12 +638,11 @@ def __init__(
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
from torch.cuda import device_count

from vllm.executor import ray_utils
backend = "mp"
ray_found = ray_utils.ray is not None
if device_count() < self.world_size:
if cuda_device_count_stateless() < self.world_size:
if not ray_found:
raise ValueError("Unable to load Ray which is "
"required for multi-node inference")
Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
gpu_p2p_access_check)
from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless

try:
import pynvml
Expand Down Expand Up @@ -144,7 +145,7 @@ def __init__(self,
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
device_ids = list(range(cuda_device_count_stateless()))

physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless

logger = init_logger(__name__)

Expand Down Expand Up @@ -152,7 +153,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:

is_distributed = dist.is_initialized()

num_dev = torch.cuda.device_count()
num_dev = cuda_device_count_stateless()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
Expand Down
6 changes: 3 additions & 3 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
from vllm.utils import (cuda_device_count_stateless,
get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)

logger = init_logger(__name__)
Expand All @@ -33,8 +34,7 @@ def _init_executor(self) -> None:
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

from torch.cuda import device_count
assert world_size <= device_count(), (
assert world_size <= cuda_device_count_stateless(), (
"please set tensor_parallel_size to less than max local gpu count")

distributed_init_method = get_distributed_init_method(
Expand Down
35 changes: 35 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,3 +693,38 @@ def inner(*args, **kwargs):
return inner # type: ignore

return wrapper


@lru_cache(maxsize=8)
def _cuda_device_count_stateless(
cuda_visible_devices: Optional[str] = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.

# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
import torch.version

if not torch.cuda._is_compiled():
return 0
# bypass _device_count_nvml() if rocm (not supported)
nvml_count = -1 if torch.version.hip else torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
return r


def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""

# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.

return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)

0 comments on commit 64e03e9

Please sign in to comment.