From f4f921b7f12c67d3c4b7575caf5ddd9bd4b0b787 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 29 Apr 2024 13:52:22 -0700 Subject: [PATCH] [Core][Distributed] use cpu group to broadcast metadata in cpu (#4444) --- .../tensorize_vllm_model_for_testing.py | 6 +- tests/worker/test_model_runner.py | 23 ++++--- vllm/distributed/communication_op.py | 69 +++++++++++++------ 3 files changed, 63 insertions(+), 35 deletions(-) diff --git a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py index e4b15fd57ad..0e113ab647e 100644 --- a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py +++ b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py @@ -6,14 +6,14 @@ from functools import partial from typing import Type -import torch import torch.nn as nn from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, TensorSerializer, stream_io) from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor from transformers import AutoConfig, PretrainedConfig -from vllm.distributed import initialize_model_parallel +from vllm.distributed import (init_distributed_environment, + initialize_model_parallel) from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.model_executor.model_loader.tensorizer import TensorizerArgs @@ -226,7 +226,7 @@ def deserialize(): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "8080" -torch.distributed.init_process_group(world_size=1, rank=0) +init_distributed_environment(world_size=1, rank=0, local_rank=0) initialize_model_parallel() keyfile = args.keyfile if args.keyfile else None diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index abb401f25c1..56fe6db589f 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -2,8 +2,10 @@ import torch from vllm.config import ModelConfig, SchedulerConfig +from vllm.distributed.parallel_state import init_distributed_environment from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -249,19 +251,18 @@ def test_empty_seq_group(): assert len(return_prompt_lens) == 0 -@pytest.mark.parametrize("batch_size", list(range(2, 128))) -@pytest.mark.parametrize("enforce_eager", [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): - - def get_world_size(group=None): - return 1 +@pytest.fixture +def distributed_init(): + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}", + local_rank=0) - def mock_get_process_group_ranks(group=None): - return [0] - monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size) - monkeypatch.setattr(torch.distributed, "get_process_group_ranks", - mock_get_process_group_ranks) +@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, distributed_init): model_config = ModelConfig( "facebook/opt-125m", diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index a3e93691a1e..8b2c26c3a8a 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -4,7 +4,8 @@ import torch from torch.distributed import ProcessGroup -from .parallel_state import (get_tensor_model_parallel_group, +from .parallel_state import (get_cpu_world_group, + get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce) @@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any], TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) +def _split_tensor_dict( + tensor_dict: Dict[Any, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note(youkaichao): currently this only supports broadcasting + # tensors on cuda. In the future, we can add device as a field in + # TensorMetadata to support broadcasting tensors on different + # devices. + assert value.is_cuda, ( + f"Tensor {key}: {value} is not on cuda. Currently we only " + f"support broadcasting tensors on cuda.") + metadata_list.append((key, TensorMetadata(value.dtype, + value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + def broadcast_tensor_dict( tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: - """Broadcast the input tensor dictionary.""" + """Broadcast the input tensor dictionary. + `group` is used to broadcast the tensors, while `metadata_group` is used + to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, + dtypes). + """ group = group or torch.distributed.group.WORLD + metadata_group = metadata_group or get_cpu_world_group() ranks = torch.distributed.get_process_group_ranks(group) assert src in ranks, f"Invalid src rank ({src})" @@ -161,27 +195,20 @@ def broadcast_tensor_dict( assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - for key, value in tensor_dict.items(): - if isinstance(value, torch.Tensor): - assert value.is_cuda, ( - f"Tensor {key}: {value} is not on cuda. Currently we only " - f"support broadcasting tensors on cuda.") - metadata_list.append( - (key, TensorMetadata(value.dtype, value.size()))) - else: - metadata_list.append((key, value)) + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` involves serialization and deserialization, + # all happening on CPU. Therefore, we can use the CPU group. torch.distributed.broadcast_object_list([metadata_list], src=src, - group=group) + group=metadata_group) async_handles = [] - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - tensor = tensor_dict[key] - async_handles.append( - torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True)) + for tensor in tensor_list: + async_handles.append( + torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True)) for async_handle in async_handles: async_handle.wait() @@ -189,7 +216,7 @@ def broadcast_tensor_dict( recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, src=src, - group=group) + group=metadata_group) assert recv_metadata_list[0] is not None tensor_dict = {} async_handles = []