Skip to content

Commit

Permalink
[Core] Refactor Worker and ModelRunner to consolidate control plane c…
Browse files Browse the repository at this point in the history
…ommunication (vllm-project#5408)

Signed-off-by: Stephanie Wang <[email protected]>
Signed-off-by: Stephanie <[email protected]>
Co-authored-by: Stephanie <[email protected]>
  • Loading branch information
2 people authored and prashantgupta24 committed Jun 27, 2024
1 parent b57155e commit 9504961
Show file tree
Hide file tree
Showing 29 changed files with 1,108 additions and 575 deletions.
152 changes: 152 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import dataclasses
from typing import List, Tuple, Type

import torch

from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata


class MockAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
raise NotImplementedError

@staticmethod
def get_impl_cls():
raise NotImplementedError

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
pass

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
pass


def test_model_runner_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)

assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)

# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithSamplingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (received_model_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert received_model_input.sampling_metadata.seq_groups is None


def test_embedding_model_runner_input():
pooling_metadata = PoolingMetadata(
seq_groups=[[0]],
seq_data={},
prompt_lens=[1],
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
pooling_metadata=pooling_metadata,
attn_metadata=attn_metadata)

assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)

# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithPoolingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None
57 changes: 30 additions & 27 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = model_input.slot_mapping
slot_mapping = attn_metadata.slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)

Expand Down Expand Up @@ -174,10 +175,11 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)

model_input = model_runner._prepare_model_input(seq_group_metadata_list)
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.slot_mapping)
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens)

expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
Expand Down Expand Up @@ -259,32 +261,29 @@ def test_empty_seq_group():
enforce_eager=False,
)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
assert len(slot_mapping) == 0

model_input = model_runner._prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, slot_mapping,
return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
model_input.seq_lens,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0

model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.seq_lens,
)
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_seq_lens) == 0
assert return_seq_lens is None


@pytest.fixture
Expand Down Expand Up @@ -353,8 +352,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
seq_group_metadata_list.append(seq_group_metadata)
decode_metadata_list.append(seq_group_metadata)

(input_tokens, input_positions, attn_metadata, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)

prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
Expand All @@ -367,7 +370,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):

# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
attn_metadata = model_runner._prepare_model_input(
attn_metadata = model_runner._prepare_model_input_tensors(
seq_group_metadata_list).attn_metadata

for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
Expand Down
6 changes: 5 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ def get_impl_cls() -> Type["AttentionImpl"]:

@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError

@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
return BlocksparseFlashAttentionImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
return FlashInferMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashInferMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
return IpexAttnBackendImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "IpexAttnMetadata":
return IpexAttnMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
return IpexAttnMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "PallasMetadata":
return PallasMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
return ROCmFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return ROCmFlashAttentionMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
return TorchSDPAMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
return XFormersMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
16 changes: 8 additions & 8 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks=num_cpu_blocks)

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
Expand All @@ -79,7 +79,7 @@ def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return

self._driver_execute_model()
self._driver_execute_model(execute_model_req=None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
Expand Down Expand Up @@ -123,13 +123,13 @@ def save_sharded_state(

@abstractmethod
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
Passing None will cause the driver to stop the model execution loop
running in each of the remote workers. In this case, this method
returns None. Otherwise, this method returns the model output.
"""
raise NotImplementedError

Expand Down
Loading

0 comments on commit 9504961

Please sign in to comment.