diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py new file mode 100644 index 000000000000..ae818ee360f1 --- /dev/null +++ b/tests/worker/test_model_input.py @@ -0,0 +1,152 @@ +import dataclasses +from typing import List, Tuple, Type + +import torch + +from vllm.attention import AttentionMetadata +from vllm.attention.backends.abstract import AttentionBackend +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.worker.embedding_model_runner import ( + ModelInputForGPUWithPoolingMetadata) +from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class MockAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + def get_impl_cls(): + raise NotImplementedError + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return AttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + pass + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + pass + + +def test_model_runner_input(): + sampling_metadata = SamplingMetadata( + ["seq_group"], + "selected_token_indices", + "categorized_sample_indices", + "num_prompts", + ) + attn_metadata = AttentionMetadata( + num_prefills=1, + num_prefill_tokens=2, + num_decode_tokens=3, + slot_mapping=torch.zeros(1), + ) + model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.ones(10), + input_positions=torch.ones(10), + sampling_metadata=sampling_metadata, + attn_metadata=attn_metadata) + + assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata) + + # Test round trip serialization. + tensor_dict = model_input.as_broadcastable_tensor_dict() + attn_backend = MockAttentionBackend() + received_model_input = ( + ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend)) + # Check that received copy has correct values. + assert isinstance(received_model_input, + ModelInputForGPUWithSamplingMetadata) + assert received_model_input.input_tokens is not None + assert ( + received_model_input.input_tokens == model_input.input_tokens).all() + assert received_model_input.input_positions is not None + assert (received_model_input.input_positions == model_input.input_positions + ).all() + assert received_model_input.multi_modal_kwargs is None + assert (received_model_input.multi_modal_kwargs == + model_input.multi_modal_kwargs) + assert received_model_input.lora_requests is None + assert received_model_input.lora_requests == model_input.lora_requests + assert received_model_input.lora_mapping is None + assert received_model_input.lora_mapping == model_input.lora_mapping + for field in dataclasses.fields(AttentionMetadata): + assert getattr(received_model_input.attn_metadata, field.name, + None) == getattr(attn_metadata, field.name, None) + # For sampling metadata, only selected_token_indices is copied. + assert (received_model_input.sampling_metadata.selected_token_indices == + sampling_metadata.selected_token_indices) + assert received_model_input.sampling_metadata.seq_groups is None + + +def test_embedding_model_runner_input(): + pooling_metadata = PoolingMetadata( + seq_groups=[[0]], + seq_data={}, + prompt_lens=[1], + ) + attn_metadata = AttentionMetadata( + num_prefills=1, + num_prefill_tokens=2, + num_decode_tokens=3, + slot_mapping=torch.zeros(1), + ) + model_input = ModelInputForGPUWithPoolingMetadata( + input_tokens=torch.ones(10), + input_positions=torch.ones(10), + pooling_metadata=pooling_metadata, + attn_metadata=attn_metadata) + + assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata) + + # Test round trip serialization. + tensor_dict = model_input.as_broadcastable_tensor_dict() + attn_backend = MockAttentionBackend() + received_model_input = ( + ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend)) + # Check that received copy has correct values. + assert isinstance(received_model_input, + ModelInputForGPUWithPoolingMetadata) + assert received_model_input.input_tokens is not None + assert ( + received_model_input.input_tokens == model_input.input_tokens).all() + assert received_model_input.input_positions is not None + assert (received_model_input.input_positions == model_input.input_positions + ).all() + assert received_model_input.multi_modal_kwargs is None + assert (received_model_input.multi_modal_kwargs == + model_input.multi_modal_kwargs) + assert received_model_input.lora_requests is None + assert received_model_input.lora_requests == model_input.lora_requests + assert received_model_input.lora_mapping is None + assert received_model_input.lora_mapping == model_input.lora_mapping + for field in dataclasses.fields(AttentionMetadata): + assert getattr(received_model_input.attn_metadata, field.name, + None) == getattr(attn_metadata, field.name, None) + # Pooling metadata is not broadcast. + assert received_model_input.pooling_metadata is None diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index dd0d3bf5082d..e1775790c0a0 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -61,12 +61,13 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices.append(selected_token_start_idx + seq_len - 1) selected_token_start_idx += seq_len - model_input = model_runner._prepare_model_input(seq_group_metadata_list) + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens - slot_mapping = model_input.slot_mapping + slot_mapping = attn_metadata.slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) @@ -174,10 +175,11 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - model_input = model_runner._prepare_model_input(seq_group_metadata_list) + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) input_tokens, input_positions, attn_metadata, slot_mapping = ( model_input.input_tokens, model_input.input_positions, - model_input.attn_metadata, model_input.slot_mapping) + model_input.attn_metadata, model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) @@ -259,32 +261,29 @@ def test_empty_seq_group(): enforce_eager=False, ) seq_group_metadata_list: List[SequenceGroupMetadata] = [] - model_input = model_runner._prepare_model_input(seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) + input_tokens, input_positions, attn_metadata = ( model_input.input_tokens, model_input.input_positions, model_input.attn_metadata, - model_input.slot_mapping, ) - assert len(input_tokens) == 0 - assert len(input_positions) == 0 + assert input_tokens is None + assert input_positions is None assert attn_metadata is None - assert len(slot_mapping) == 0 - - model_input = model_runner._prepare_model_input(seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata, slot_mapping, - return_seq_lens) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - model_input.slot_mapping, - model_input.seq_lens, - ) - assert len(input_tokens) == 0 - assert len(input_positions) == 0 + + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.seq_lens, + ) + assert input_tokens is None + assert input_positions is None assert attn_metadata is None - assert len(slot_mapping) == 0 - assert len(return_seq_lens) == 0 + assert return_seq_lens is None @pytest.fixture @@ -353,8 +352,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): seq_group_metadata_list.append(seq_group_metadata) decode_metadata_list.append(seq_group_metadata) - (input_tokens, input_positions, attn_metadata, _, _, _, - _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + model_input = model_runner.prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + ) prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata @@ -367,7 +370,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. - attn_metadata = model_runner._prepare_model_input( + attn_metadata = model_runner._prepare_model_input_tensors( seq_group_metadata_list).attn_metadata for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 6396103bf5ef..40768532f59c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -21,9 +21,13 @@ def get_impl_cls() -> Type["AttentionImpl"]: @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadata": + def get_metadata_cls() -> Type["AttentionMetadata"]: raise NotImplementedError + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + @staticmethod @abstractmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index dce2b83615b7..7b4578fcd8b9 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -90,8 +90,8 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: return BlocksparseFlashAttentionImpl @staticmethod - def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata": - return BlocksparseFlashAttentionMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return BlocksparseFlashAttentionMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 1c48e2a0bb33..8cb5c3101a80 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl @staticmethod - def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata": - return FlashAttentionMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7b7959d257fa..535d30b55bc9 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -22,8 +22,8 @@ def get_impl_cls() -> Type["FlashInferImpl"]: return FlashInferImpl @staticmethod - def make_metadata(*args, **kwargs) -> "FlashInferMetadata": - return FlashInferMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index f09b24f2a030..5114bfa6e158 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["IpexAttnBackendImpl"]: return IpexAttnBackendImpl @staticmethod - def make_metadata(*args, **kwargs) -> "IpexAttnMetadata": - return IpexAttnMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["IpexAttnMetadata"]: + return IpexAttnMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b203c5ec54c9..62b4a144fc44 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -16,8 +16,8 @@ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: return PallasAttentionBackendImpl @staticmethod - def make_metadata(*args, **kwargs) -> "PallasMetadata": - return PallasMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 9294068c64d1..81fabdbdfc83 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl @staticmethod - def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata": - return ROCmFlashAttentionMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return ROCmFlashAttentionMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c01e0a0a3a19..63f8466da931 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -31,8 +31,8 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl @staticmethod - def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata": - return TorchSDPAMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TorchSDPAMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0fecd9f6e610..ff449c3ff74f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -28,8 +28,8 @@ def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl @staticmethod - def make_metadata(*args, **kwargs) -> "XFormersMetadata": - return XFormersMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return XFormersMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 235b5bc47021..d8693e636ac8 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -64,8 +64,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks=num_cpu_blocks) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[SamplerOutput]]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", @@ -79,7 +79,7 @@ def stop_remote_worker_execution_loop(self) -> None: if self.parallel_worker_tasks is None: return - self._driver_execute_model() + self._driver_execute_model(execute_model_req=None) parallel_worker_tasks = self.parallel_worker_tasks self.parallel_worker_tasks = None # Ensure that workers exit model loop cleanly @@ -123,13 +123,13 @@ def save_sharded_state( @abstractmethod def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: """Run execute_model in the driver worker. - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. + Passing None will cause the driver to stop the model execution loop + running in each of the remote workers. In this case, this method + returns None. Otherwise, this method returns the model output. """ raise NotImplementedError diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 7c2520b5a64f..d7c19622e270 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -69,8 +69,8 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 0a654200ed79..5522b5322e66 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -87,7 +87,7 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: def execute_model( self, execute_model_req: ExecuteModelRequest - ) -> List[Union[SamplerOutput, PoolerOutput]]: + ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: output = self.driver_worker.execute_model(execute_model_req) return output diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index e63e5a3a027f..101443f23f20 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -76,16 +76,14 @@ def shutdown(self): worker_monitor.close() def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - return self.driver_worker.execute_model( - execute_model_req=execute_model_req) + return self.driver_worker.execute_model(execute_model_req) def _run_workers( self, diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index c5e2fb0f6773..720af0bc7929 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -55,8 +55,7 @@ def execute_model( assert execute_model_req.num_lookahead_slots == 0, ( "lookahead not supported for Neuron backend.") - output = self.driver_worker.execute_model( - execute_model_req.seq_group_metadata_list) + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index fc83c552888a..faa500c2d79c 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -190,9 +190,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_parallel_loading_workers) def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution diff --git a/vllm/sequence.py b/vllm/sequence.py index 287e1b9df616..0925d15461fd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -887,7 +887,8 @@ def prune(self, @dataclass class ExecuteModelRequest: - """The model execution request.""" + """The model execution request, containing CPU metadata only. The LLM + engine should create an instance of this class for each request batch.""" # The sequence group metadata list. seq_group_metadata_list: List[SequenceGroupMetadata] # Blocks to swap in. List of CPU -> GPU block number. diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py index 0926e13bedab..6c1c8da57d18 100644 --- a/vllm/spec_decode/mlp_speculator_worker.py +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -7,7 +7,6 @@ SequenceGroupMetadata) from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase -from vllm.worker.model_runner import ModelInput class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): @@ -56,7 +55,7 @@ def _prepare_input_tensors( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, List[int], List[int]]: if not seq_group_metadata_list: - return ModelInput.empty(self.device) + return torch.empty(0, device=self.device), [], [] input_tokens: List[int] = [] seq_lens: List[int] = [] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d539f56937be..e3464c0d3900 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,5 +1,6 @@ from collections import defaultdict -from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch from torch import nn @@ -8,20 +9,64 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) _PAD_SLOT_ID = -1 -class CPUModelRunner: +@dataclass(frozen=True) +class CPUModelInput(ModelRunnerInputBase): + """ + Used by the CPUModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["CPUModelInput"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None + ) -> "CPUModelInput": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class CPUModelRunner(ModelRunnerBase[CPUModelInput]): def __init__( self, @@ -270,86 +315,70 @@ def _prepare_decode( attn_metadata, ) - def prepare_input_tensors( + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> CPUModelInput: + return CPUModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Optional[Dict[str, torch.Tensor]]]: + ) -> CPUModelInput: multi_modal_kwargs = None - if self.is_driver_worker: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - pin_memory=False) - # Broadcast the metadata. - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "selected_token_indices": - sampling_metadata.selected_token_indices, - } - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs + ) = self._prepare_prompt(seq_group_metadata_list) else: - metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict.pop("input_tokens") - input_positions = metadata_dict.pop("input_positions") - selected_token_indices = metadata_dict.pop( - "selected_token_indices") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) - sampling_metadata = SamplingMetadata( - seq_groups=None, - seq_data=None, - seq_lens=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - generators=None, - ) - - return (input_tokens, input_positions, attn_metadata, - sampling_metadata, multi_modal_kwargs) + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode(seq_group_metadata_list) + seq_lens = [] + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + pin_memory=False) + return CPUModelInput( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + sampling_metadata=sampling_metadata, + ) @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: CPUModelInput, kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, - multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) - model_executable = self.model execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, + "input_ids": model_input.input_tokens, + "positions": model_input.input_positions, "kv_caches": kv_caches, - "attn_metadata": attn_metadata, + "attn_metadata": model_input.attn_metadata, } - if self.vision_language_config and multi_modal_input is not None: - execute_model_kwargs.update(multi_modal_input) + if (self.vision_language_config + and model_input.multi_modal_kwargs is not None): + execute_model_kwargs.update(model_input.multi_modal_kwargs) hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. - logits = self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -358,6 +387,6 @@ def execute_model( # Sample the next token. output = self.model.sample( logits=logits, - sampling_metadata=sampling_metadata, + sampling_metadata=model_input.sampling_metadata, ) return output diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 914df0c7df0e..30ee262c7a8b 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.distributed @@ -8,15 +8,15 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import (broadcast_tensor_dict, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoraNotSupportedWorkerBase, WorkerInput) logger = init_logger(__name__) @@ -110,7 +110,7 @@ def get_cache_block_size( return dtype_size * total -class CPUWorker(LoraNotSupportedWorkerBase): +class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a CPU socket. Each worker is associated with a single CPU socket. The worker is @@ -154,7 +154,7 @@ def __init__( # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = CPUModelRunner( + self.model_runner: CPUModelRunner = CPUModelRunner( model_config, parallel_config, scheduler_config, @@ -255,54 +255,37 @@ def _init_cache_engine(self) -> None: for layer_cache in self.cpu_cache: layer_cache.fill_(0) - def cache_copy( + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[torch.Tensor]]: + return self.cpu_cache + + def execute_worker( self, - blocks_to_copy: torch.Tensor, + worker_input: WorkerInput, ) -> None: - if blocks_to_copy.numel() > 0: - self.cache_engine.copy(blocks_to_copy) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine.copy(worker_input.blocks_to_copy) @torch.inference_mode() - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> List[SamplerOutput]: - - if execute_model_req is None: - seq_group_metadata_list = None - else: - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - - if self.is_driver_worker: - assert seq_group_metadata_list is not None - num_seq_groups: int = len(seq_group_metadata_list) - assert execute_model_req is not None - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device="cpu", - dtype=torch.int64).view(-1, 2) - assert len(execute_model_req.blocks_to_swap_in) == 0 - assert len(execute_model_req.blocks_to_swap_out) == 0 - data: Dict[str, Any] = { - "num_seq_groups": num_seq_groups, - "blocks_to_copy": execute_model_req.blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) - else: - data = broadcast_tensor_dict(src=0) - num_seq_groups = data["num_seq_groups"] - blocks_to_copy = data["blocks_to_copy"] - - self.cache_copy(blocks_to_copy) - - # If there is no input, we don't need to execute the model. - if num_seq_groups == 0: - return [] - - output = self.model_runner.execute_model(seq_group_metadata_list, - self.cpu_cache) - - # CPU worker only supports single-step execution. - return [output] + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + assert execute_model_req is not None + num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) + blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_copy=blocks_to_copy, + ) def init_distributed_environment(self) -> None: """Initialize the distributed environment.""" diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 465130d10e2f..3c8dfa2c6d8d 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,24 +1,32 @@ -from typing import Dict, List, Optional, Set, Tuple +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Type import torch -from vllm.attention import AttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU logger = init_logger(__name__) -class EmbeddingModelRunner(ModelRunner): +@dataclasses.dataclass(frozen=True) +class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): + """ + Used by the EmbeddingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + +class EmbeddingModelRunner( + GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): + _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( + ModelInputForGPUWithPoolingMetadata) def __init__( self, @@ -47,21 +55,22 @@ def __init__( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + model_input: ModelInputForGPUWithPoolingMetadata, kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: - (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) - if self.lora_config: - self.set_active_loras(lora_requests, lora_mapping) + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) # Currently cuda graph is only supported by the decode phase. - prefill_meta = attn_metadata.prefill_metadata - decode_meta = attn_metadata.decode_metadata + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata if prefill_meta is None and decode_meta.use_cuda_graph: - graph_batch_size = input_tokens.shape[0] + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model @@ -70,13 +79,14 @@ def execute_model( kv_caches = [None] * num_layers execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, + "input_ids": model_input.input_tokens, + "positions": model_input.input_positions, "kv_caches": kv_caches, - "attn_metadata": attn_metadata, + "attn_metadata": model_input.attn_metadata, } if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + execute_model_kwargs.update({"image_input": multi_modal_kwargs}) hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. @@ -84,66 +94,31 @@ def execute_model( return None return self.model.pooler(hidden_states=hidden_states, - pooling_metadata=pooling_metadata) + pooling_metadata=model_input.pooling_metadata) + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForGPUWithPoolingMetadata: + return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) - def prepare_input_tensors( + def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: - if self.is_driver_worker: - assert seq_group_metadata_list is not None - # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - _, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) - # Prepare PoolingMetadata - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - seq_lens) - - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "lora_requests": lora_requests, - "lora_mapping": lora_mapping, - "multi_modal_kwargs": multi_modal_kwargs, - "num_prefill_tokens": num_prefill_tokens, - "num_decode_tokens": num_decode_tokens, - "slot_mapping": slot_mapping, - "num_prefills": num_prefills, - } - if attn_metadata: - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) - else: - metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict.pop("input_tokens") - input_positions = metadata_dict.pop("input_positions") - lora_mapping = metadata_dict.pop("lora_mapping") - lora_requests = metadata_dict.pop("lora_requests") - multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - if metadata_dict: - attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - else: - attn_metadata = None - pooling_metadata = PoolingMetadata(seq_groups=None, - seq_data=None, - prompt_lens=None) - - return (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_kwargs) + ) -> ModelInputForGPUWithPoolingMetadata: + assert seq_group_metadata_list is not None + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list) + # Prepare PoolingMetadata. + assert model_input.seq_lens is not None + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + model_input.seq_lens) + + return dataclasses.replace(model_input, + pooling_metadata=pooling_metadata) def _prepare_pooling( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a321eafce1a2..9fdb2ea5dd4e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,8 +1,10 @@ +import dataclasses import gc import time import warnings from collections import defaultdict -from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -12,7 +14,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -26,6 +27,15 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -39,40 +49,90 @@ ] _NUM_WARMUP_ITERS = 2 +TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") -class ModelInput(NamedTuple): - input_tokens: torch.Tensor - input_positions: torch.Tensor - attn_metadata: Optional[AttentionMetadata] - seq_lens: List[int] - query_lens: List[int] - lora_mapping: Optional[LoRAMapping] - lora_requests: Set[LoRARequest] - multi_modal_kwargs: Dict[str, torch.Tensor] - slot_mapping: torch.Tensor - num_prefill_tokens: int - num_decode_tokens: int - num_prefills: int - @classmethod - def empty(cls, device): - return ModelInput( - input_tokens=torch.empty(0, device=device), - input_positions=torch.empty(0, device=device), - attn_metadata=None, - seq_lens=[], - query_lens=[], - lora_mapping=None, - lora_requests=set(), - multi_modal_kwargs={}, - slot_mapping=torch.empty(0, device=device), - num_prefill_tokens=0, - num_decode_tokens=0, - num_prefills=0, - ) +@dataclasses.dataclass(frozen=True) +class ModelInputForGPU(ModelRunnerInputBase): + """ + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForGPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForGPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclasses.dataclass(frozen=True) +class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + # Used for speculative decoding. We do not broadcast it because it is only + # used by the driver worker. + is_prompt: Optional[bool] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict -class ModelRunner: + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForGPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): + """ + Helper class for shared methods between GPU model runners. + """ + _model_input_cls: Type[TModelInputForGPU] def __init__( self, @@ -241,11 +301,13 @@ def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size - def _prepare_model_input( + def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> ModelInput: - """Prepare the model input based on a given sequence group. + ) -> TModelInputForGPU: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. The API assumes seq_group_metadata_list is sorted by prefill -> decode. @@ -296,7 +358,7 @@ def _prepare_model_input( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - return ModelInput.empty(self.device) + return self._model_input_cls() if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window + self.block_size - @@ -646,7 +708,7 @@ def _prepare_model_input( for k, v in multi_modal_kwargs_list.items() } - return ModelInput( + return self._model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, @@ -655,132 +717,8 @@ def _prepare_model_input( lora_mapping=lora_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - ) - - def prepare_input_tensors( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: - if self.is_driver_worker: - assert seq_group_metadata_list is not None - # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - query_lens, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seq_lens, query_lens, self.device, - self.pin_memory) - - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "selected_token_indices": - sampling_metadata.selected_token_indices, - "lora_requests": lora_requests, - "lora_mapping": lora_mapping, - "multi_modal_kwargs": multi_modal_kwargs, - "num_prefill_tokens": num_prefill_tokens, - "num_decode_tokens": num_decode_tokens, - "slot_mapping": slot_mapping, - "num_prefills": num_prefills, - } - if attn_metadata: - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) - else: - metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict.pop("input_tokens") - input_positions = metadata_dict.pop("input_positions") - selected_token_indices = metadata_dict.pop( - "selected_token_indices") - lora_mapping = metadata_dict.pop("lora_mapping") - lora_requests = metadata_dict.pop("lora_requests") - multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - if metadata_dict: - attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - else: - attn_metadata = None - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - - return (input_tokens, input_positions, attn_metadata, - sampling_metadata, lora_requests, lora_mapping, - multi_modal_kwargs) - - @torch.inference_mode() - def execute_model( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - kv_caches: List[torch.Tensor], - ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, lora_mapping, multi_modal_kwargs - ) = self.prepare_input_tensors(seq_group_metadata_list) - - if self.lora_config: - self.set_active_loras(lora_requests, lora_mapping) - - # Currently cuda graph is only supported by the decode phase. - prefill_meta = attn_metadata.prefill_metadata - decode_meta = attn_metadata.decode_metadata - if prefill_meta is None and decode_meta.use_cuda_graph: - graph_batch_size = input_tokens.shape[0] - model_executable = self.graph_runners[graph_batch_size] - else: - model_executable = self.model - - hidden_states = model_executable( - input_ids=input_tokens, - positions=input_positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - **multi_modal_kwargs, - ) - - # Compute the logits. - logits = self.model.compute_logits(hidden_states, sampling_metadata) - - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return None - - # Sample the next token. - output: SamplerOutput = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, ) - if self.return_hidden_states: - # we only need to pass hidden states of most recent token - assert seq_group_metadata_list is not None - if seq_group_metadata_list[0].is_prompt: - hidden_states = hidden_states.index_select( - 0, sampling_metadata.selected_token_indices) - output.hidden_states = hidden_states - - return output - @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. @@ -853,7 +791,8 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - self.execute_model(seqs, kv_caches) + model_input = self.prepare_model_input(seqs) + self.execute_model(model_input, kv_caches) torch.cuda.synchronize() return @@ -986,6 +925,110 @@ def vocab_size(self) -> int: return self.model_config.get_vocab_size() +class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): + """ + GPU model runner with sampling step. + """ + _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( + ModelInputForGPUWithSamplingMetadata) + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForGPUWithSamplingMetadata: + return ( + ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> ModelInputForGPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory) + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + ) -> SamplerOutput: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + **multi_modal_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return None + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + if model_input.is_prompt: + assert model_input.sampling_metadata is not None + hidden_states = hidden_states.index_select( + 0, model_input.sampling_metadata.selected_token_indices) + output.hidden_states = hidden_states + + return output + + class CUDAGraphRunner: def __init__(self, model: nn.Module): diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py new file mode 100644 index 000000000000..9b1706035a33 --- /dev/null +++ b/vllm/worker/model_runner_base.py @@ -0,0 +1,157 @@ +import dataclasses +from abc import ABC, abstractmethod +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, + TypeVar) + +import torch + +from vllm.sequence import SamplerOutput, SequenceGroupMetadata + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.attention.backends.abstract import AttentionBackend + from vllm.model_executor import SamplingMetadata + +T = TypeVar('T', bound="ModelRunnerInputBase") + + +def _add_attn_metadata_broadcastable_dict( + tensor_dict: Dict[str, Any], + attn_metadata: Optional["AttentionMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + AttentionMetadata fields. + """ + if attn_metadata is not None: + tensor_dict.update(attn_metadata.asdict_zerocopy()) + + +def _init_attn_metadata_from_tensor_dict( + attn_backend: "AttentionBackend", + tensor_dict: Dict[str, Any], +) -> Dict[str, Any]: + """ + Helper method to initialize AttentionMetadata based on an + AttentionBackend and broadcastable AttentionMetadata fields. + """ + # Extract the fields used to create AttentionMetadata. + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + tensor_dict["attn_metadata"] = attn_metadata + return tensor_dict + + +def _init_sampling_metadata_from_tensor_dict( # type: ignore + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper method to initialize SamplingMetadata based on broadcastable + SamplingMetadata fields. + """ + from vllm.model_executor import SamplingMetadata + + selected_token_indices = tensor_dict.pop("selected_token_indices", None) + # An empty SamplingMetadata to signal that the worker should skip + # sampling. + if selected_token_indices is not None: + tensor_dict["sampling_metadata"] = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + return tensor_dict + + +def _add_sampling_metadata_broadcastable_dict( + tensor_dict: Dict[str, Any], + sampling_metadata: Optional["SamplingMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + SamplingMetadata fields. + """ + if sampling_metadata is not None: + tensor_dict["selected_token_indices"] = ( + sampling_metadata.selected_token_indices) + + +@dataclasses.dataclass(frozen=True) +class ModelRunnerInputBase(ABC): + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelRunnerInputBase objects. + + Model runners that support multi-GPU execution should define a + ModelRunnerInputBase subclass, add their required fields, and specify how to + serialize/deserialize a ModelInput for broadcast between workers. + """ + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + """ + Extract broadcastable fields. Override for fields that require some + custom deserialization. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def from_broadcasted_tensor_dict( + cls: Type[T], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> T: + """ + Pop fields from the given tensor_dict and populate a new instance of + ModelRunnerInputBase. + """ + raise NotImplementedError + + +class ModelRunnerBase(ABC, Generic[T]): + """ + Model runner interface that abstracts a particular hardware and/or type of + model. Model execution may communicate data with model runners in other + processes, but it should not include control plane metadata communication. + + Each ModelRunnerBase subclass should define a corresponding + ModelRunnerInputBase subclass. + """ + + @abstractmethod + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> T: + """ + Make an instance of a ModelRunnerInputBase from the broadcasted tensor + dict. + """ + raise NotImplementedError + + @abstractmethod + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> T: + """ + Prepare the inputs to ModelRunnerBase.execute_model from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError + + @torch.inference_mode() + def execute_model( + self, + model_input: T, + kv_caches: Optional[List[torch.Tensor]], + ) -> Optional[SamplerOutput]: + """ + Execute the model on the given input. + """ + raise NotImplementedError diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index a336be04e124..fec2c97e7388 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Tuple +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -10,11 +11,39 @@ from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) -class NeuronModelRunner: +@dataclass(frozen=True) +class ModelInputForNeuron(ModelRunnerInputBase): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + input_block_ids: Optional[torch.Tensor] = None + sampling_metadata: Optional["SamplingMetadata"] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForNeuron": + assert attn_backend is None + return cls.from_broadcasted_tensor_dict(tensor_dict) + + +class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): def __init__( self, @@ -139,10 +168,14 @@ def _prepare_decode( return input_tokens, input_positions, input_block_ids - def prepare_input_tensors( + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron: + return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) + + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: + ) -> ModelInputForNeuron: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt @@ -164,30 +197,31 @@ def prepare_input_tensors( self.device, self.pin_memory) - return (input_tokens, input_positions, input_block_ids, - sampling_metadata) + return ModelInputForNeuron(input_tokens=input_tokens, + input_positions=input_positions, + input_block_ids=input_block_ids, + sampling_metadata=sampling_metadata) @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, input_block_ids, sampling_metadata - ) = self.prepare_input_tensors(seq_group_metadata_list) - hidden_states = self.model( - input_ids=input_tokens, - positions=input_positions, - input_block_ids=input_block_ids, + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, ) # Compute the logits. - logits = self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) # Sample the next token. output = self.model.sample( logits=logits, - sampling_metadata=sampling_metadata, + sampling_metadata=model_input.sampling_metadata, ) return output diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index d0e6aaed180e..307c107ddef7 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,5 +1,5 @@ """A Neuron worker class.""" -from typing import List, Tuple +from typing import List, Optional, Tuple import torch import torch.distributed @@ -7,12 +7,13 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.model_executor import set_random_seed -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoraNotSupportedWorkerBase, WorkerInput) -class NeuronWorker(LoraNotSupportedWorkerBase): +class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ @@ -34,8 +35,9 @@ def __init__( from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = NeuronModelRunner(model_config, parallel_config, - scheduler_config, device_config) + self.model_runner: NeuronModelRunner = NeuronModelRunner( + model_config, parallel_config, scheduler_config, device_config) + self.is_driver_worker = True def init_device(self) -> None: # Set random seed. @@ -73,22 +75,19 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - @torch.inference_mode() - def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> List[SamplerOutput]: - num_seq_groups = len(seq_group_metadata_list) + @property + def do_metadata_broadcast(self) -> bool: + return False - # If there is no input, we don't need to execute the model. - if num_seq_groups == 0: - return [] + @property + def kv_cache(self) -> Optional[List[torch.Tensor]]: + return None - output = self.model_runner.execute_model(seq_group_metadata_list) - - # Neuron worker only supports single-step output. Wrap the output in a - # list to conform to interface. - return [output] + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + return WorkerInput(num_seq_groups=len( + execute_model_req.seq_group_metadata_list), ) def get_cache_block_size_bytes(self) -> int: """Determine the size in bytes of a cache block. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c60764ef1bed..e1944a4f1d63 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import List, Optional, Set, Tuple, Type import torch import torch.distributed @@ -9,21 +9,20 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) -from vllm.distributed import (broadcast_tensor_dict, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner -from vllm.worker.model_runner import ModelRunner -from vllm.worker.worker_base import WorkerBase +from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput -class Worker(WorkerBase): +class Worker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a GPU. Each worker is associated with a single GPU. The worker is responsible for @@ -78,9 +77,10 @@ def __init__( or (speculative_config.draft_model_config.hf_config.model_type != "mlp_speculator") else {"return_hidden_states": True} - ModelRunnerClass = (EmbeddingModelRunner if - self.model_config.embedding_mode else ModelRunner) - self.model_runner = ModelRunnerClass( + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if self.model_config.embedding_mode: + ModelRunnerClass = EmbeddingModelRunner + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( model_config, parallel_config, scheduler_config, @@ -225,40 +225,18 @@ def _warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) - def cache_swap( - self, - blocks_to_swap_in: torch.Tensor, - blocks_to_swap_out: torch.Tensor, - blocks_to_copy: torch.Tensor, - ) -> None: - # Issue cache operations. - if blocks_to_swap_in.numel() > 0: - self.cache_engine.swap_in(blocks_to_swap_in) - if blocks_to_swap_out.numel() > 0: - self.cache_engine.swap_out(blocks_to_swap_out) - if blocks_to_copy.numel() > 0: - self.cache_engine.copy(blocks_to_copy) + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[torch.Tensor]]: + return self.gpu_cache @torch.inference_mode() - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[Union[SamplerOutput, PoolerOutput]]: - if not self.is_driver_worker: - self._execute_model_non_driver() - return [] - - if execute_model_req is None: - # This signals that there's no more requests to process for now. - # All workers are running infinite loop with broadcast_tensor_dict, - # and it stops the loop when the driver broadcasts an empty input. - # Send an empty input to notify all other workers to stop their - # execution loop. - broadcast_tensor_dict({}, src=0) - return [] - - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - num_seq_groups = len(seq_group_metadata_list) + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + num_seq_groups = len(execute_model_req.seq_group_metadata_list) # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, @@ -273,59 +251,26 @@ def execute_model( blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2) - data: Dict[str, Any] = { - "num_seq_groups": num_seq_groups, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) - - self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) - - # If there is no input, we don't need to execute the model. - if num_seq_groups == 0: - return [] - output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache) - - # Worker only supports single-step execution. Wrap the output in a list - # to conform to interface. - return [output] + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) @torch.inference_mode() - def start_worker_execution_loop(self) -> None: - """Execute model loop in parallel worker. - - You can stop the loop by executing a driver worker with an empty output. - See `stop_remote_worker_execution_loop` for more details. - """ - while self._execute_model_non_driver(): - pass - - def _execute_model_non_driver(self) -> bool: - """Execute model in parallel worker. - - Returns True iff there are remaining sequences to process. - """ - assert not self.is_driver_worker - data = broadcast_tensor_dict(src=0) - if not data: - return False - - num_seq_groups = data.get("num_seq_groups", 0) - blocks_to_swap_in = data.get("blocks_to_swap_in") - blocks_to_swap_out = data.get("blocks_to_swap_out") - blocks_to_copy = data.get("blocks_to_copy") - self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) - - # If there is no input, we don't need to execute the model. - if num_seq_groups == 0: - return False - - self.model_runner.execute_model(None, self.gpu_cache) - return True + def execute_worker(self, worker_input: WorkerInput) -> None: + # Issue cache operations. + if (worker_input.blocks_to_swap_in is not None + and worker_input.blocks_to_swap_in.numel() > 0): + self.cache_engine.swap_in(worker_input.blocks_to_swap_in) + if (worker_input.blocks_to_swap_out is not None + and worker_input.blocks_to_swap_out.numel() > 0): + self.cache_engine.swap_out(worker_input.blocks_to_swap_out) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine.copy(worker_input.blocks_to_copy) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index dc09718de4a3..1f1ce4e7b114 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,20 +1,26 @@ +import dataclasses import importlib import os from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +import torch + +from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) +from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase logger = init_logger(__name__) class WorkerBase(ABC): """Worker interface that allows vLLM to cleanly separate implementations for - different hardware. + different hardware. Also abstracts control plane communication, e.g., to + communicate request metadata to other workers. """ @abstractmethod @@ -46,13 +52,23 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise NotImplementedError + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + while True: + output = self.execute_model(execute_model_req=None) + if output is None: + return None + @abstractmethod def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - """Executes at least one model step on the given sequences, unless no - sequences are provided.""" + ) -> Optional[List[SamplerOutput]]: raise NotImplementedError @abstractmethod @@ -98,6 +114,150 @@ def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA") +@dataclasses.dataclass(frozen=True) +class WorkerInput: + """Local inputs to each worker. May contain device-specific data. These + fields should be broadcastable to other workers. + """ + + num_seq_groups: Optional[int] = None + blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_copy: Optional[torch.Tensor] = None + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["WorkerInput"], + tensor_dict: Dict[str, Any], + ) -> "WorkerInput": + """ + Pop fields from the given tensor_dict and populate a new instance of + WorkerInput. + """ + return cls( + num_seq_groups=tensor_dict.pop("num_seq_groups"), + blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), + blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), + blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + ) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + """ + Extract broadcastable fields. + """ + tensor_dict = { + "num_seq_groups": self.num_seq_groups, + "blocks_to_swap_in": self.blocks_to_swap_in, + "blocks_to_swap_out": self.blocks_to_swap_out, + "blocks_to_copy": self.blocks_to_copy, + } + + return tensor_dict + + +class LocalOrDistributedWorkerBase(WorkerBase): + """ + Partial implementation of WorkerBase that has a default `execute_model` + definition to perform metadata transfer between workers when in distributed + mode. Subclasses of this interface should use model runners that inherit + from ModelRunnerBase, and should only need to implement worker-local logic. + If custom control plane logic is needed to transfer metadata, or if the + model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. + """ + is_driver_worker: bool + model_runner: ModelRunnerBase + + @property + @abstractmethod + def do_metadata_broadcast(self) -> bool: + """ + Used by the default `execute_model` to check whether broadcast is + needed to transfer request inputs from the driver worker to other + workers in the TP group. If WorkerBase subclass only supports + single-worker execution, then this method should return False. + """ + raise NotImplementedError + + @property + @abstractmethod + def kv_cache(self) -> Optional[List[torch.Tensor]]: + """ + Get the kv cache to pass to the worker's model runner. Used by the + default `execute_model`. If the worker's model runner does not follow + the ModelRunnerBase interface, then inherit from WorkerBase instead. + """ + raise NotImplementedError + + @abstractmethod + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + """ + Prepare the inputs to WorkerBase.execute_worker from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError + + @abstractmethod + def execute_worker(self, worker_input: WorkerInput) -> None: + """ + Process an execution request. + """ + raise NotImplementedError + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list)) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + else: + assert self.do_metadata_broadcast + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + output = self.model_runner.execute_model(model_input, self.kv_cache) + # Worker only supports single-step execution. Wrap the output in a + # list to conform to interface. + return [output] + + class WorkerWrapperBase: """ The whole point of this class is to lazily initialize the worker. diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index f30de703e805..d9124a788a69 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Tuple +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -14,6 +15,15 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -24,7 +34,42 @@ ] -class XPUModelRunner: +@dataclass(frozen=True) +class ModelInputForXPU(ModelRunnerInputBase): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_input: Optional[Dict[str, torch.Tensor]] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["ModelInputForXPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForXPU": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): def __init__( self, @@ -130,15 +175,22 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - self.execute_model(seqs, kv_caches) + model_input = self.prepare_model_input(seqs) + self.execute_model(model_input, kv_caches) torch.xpu.synchronize() return - def prepare_input_tensors( + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForXPU: + return (ModelInputForXPU.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Optional[torch.Tensor]]: + ) -> ModelInputForXPU: multi_modal_input = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or @@ -185,8 +237,11 @@ def prepare_input_tensors( num_prompts=0, ) - return (input_tokens, input_positions, attn_metadata, - sampling_metadata, multi_modal_input) + return ModelInputForXPU(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + sampling_metadata=sampling_metadata, + multi_modal_input=multi_modal_input) def _prepare_decode( self, @@ -277,27 +332,25 @@ def _prepare_decode( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: ModelInputForXPU, kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, - multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) - model_executable = self.model execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, + "input_ids": model_input.input_tokens, + "positions": model_input.input_positions, "kv_caches": kv_caches, - "attn_metadata": attn_metadata, + "attn_metadata": model_input.attn_metadata, } if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) + execute_model_kwargs.update( + {"image_input": model_input.multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. - logits = self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -306,7 +359,7 @@ def execute_model( # Sample the next token. output = self.model.sample( logits=logits, - sampling_metadata=sampling_metadata, + sampling_metadata=model_input.sampling_metadata, ) return output