Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Spec Decode] Introduce DraftModelRunner (vllm-project#5799)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored and robertgshaw2-neuralmagic committed Jul 1, 2024
1 parent 4b9894c commit 6664f2a
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 37 deletions.
3 changes: 3 additions & 0 deletions tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tests.nm_utils.utils_skip import should_skip_test_group
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
Expand Down Expand Up @@ -90,6 +91,7 @@ def test_same_output_for_single_step():
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
Worker,
Expand Down Expand Up @@ -173,6 +175,7 @@ def test_same_output_for_multi_step():
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)

worker = create_worker(
Expand Down
5 changes: 4 additions & 1 deletion tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker import Worker

T = TypeVar("T", bound=Worker)
Expand Down Expand Up @@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True) -> T:
enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None) -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
Expand All @@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
)

worker.init_device()
Expand Down
3 changes: 3 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,8 @@ class ExecuteModelRequest:
running_queue_size: int = 0
# Optional hidden states from prior step.
previous_hidden_states: Optional[HiddenStates] = None
# The number of forward steps to run.
num_steps: int = 1

def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
Expand All @@ -893,4 +895,5 @@ def clone(
num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states,
num_steps=self.num_steps,
)
170 changes: 170 additions & 0 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import List, Optional

import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)

logger = init_logger(__name__)


class TP1DraftModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
This runner is still under development so there's no performance gain
at this moment. Currently we adopt a temporary solution that caches the
seq_group_metadata_list for multi-step execution, so that we can
leverage existing prepare_model_input to be compatible with the current
execution flow, but we plan to remove this cache and avoid calling
prepare_model_input in execute_model at all.
The detail development plan includes:
1. Use "update_model_input" to update existing model_input without
creating a new one.
2. Improve the performance of "update_model_input" with a GPU kernel.
3. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
return_hidden_states: bool = False,
):
if return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)

super().__init__(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config,
return_hidden_states=return_hidden_states,
)

# TODO: Remove this cache when we are able to update model_input
# directly in advance_step.
self.cached_seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None

def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
"""
self.cached_seq_group_metadata_list = seq_group_metadata_list
return super().prepare_model_input(seq_group_metadata_list)

def update_model_input(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model inputs for the next step.
TODO: In-place update model_input instead of calling
prepare_model_input.
"""

# Append the output token to the sequence data.
assert self.cached_seq_group_metadata_list is not None
for seq_group_metadata, sequence_group_outputs in zip(
self.cached_seq_group_metadata_list, last_output.outputs):
seq_group_metadata.is_prompt = False

for seq_output in sequence_group_outputs.samples:
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]

token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]

seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)

return self.prepare_model_input(self.cached_seq_group_metadata_list)

@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")

if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)

outputs: List[SamplerOutput] = []
for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model

multi_modal_kwargs = model_input.multi_modal_kwargs or {}
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
**multi_modal_kwargs,
)

# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)

# Sample the next token.
outputs.append(
self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
))

# Prepare the inputs for the next step.
if step != num_steps - 1:
model_input = self.update_model_input(model_input, outputs[-1])

return outputs
29 changes: 16 additions & 13 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
Expand Down Expand Up @@ -67,22 +68,24 @@ def sampler_output(
copied_execute_model_req = execute_model_req.clone(
copied_seq_group_metadata_list)

# Assert enough KV space for sample_len tokens per sequence.
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
sample_len)

# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
if isinstance(self.model_runner, TP1DraftModelRunner):
copied_execute_model_req.num_steps = sample_len
model_outputs = self.execute_model(
execute_model_req=copied_execute_model_req)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]

self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
model_outputs.append(model_output)
else:
# TODO: Remove this branch once DraftModelRunner supports TP>1.
for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
execute_model_req=copied_execute_model_req)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]

self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
model_outputs.append(model_output)

return model_outputs, True

Expand Down
3 changes: 3 additions & 0 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
Expand Down Expand Up @@ -117,6 +118,8 @@ def create_worker(
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size

if draft_tp == 1:
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)
Expand Down
11 changes: 8 additions & 3 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,12 @@ def execute_model(
self,
model_input: CPUModelInput,
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")

model_executable = self.model
execute_model_kwargs = {
"input_ids": model_input.input_tokens,
Expand All @@ -371,11 +376,11 @@ def execute_model(

# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return None
return []

# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return output
return [output]
15 changes: 11 additions & 4 deletions vllm/worker/embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ def execute_model(
self,
model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
) -> Optional[PoolerOutput]:
num_steps: int = 1,
) -> Optional[List[PoolerOutput]]:
if num_steps > 1:
raise ValueError(
"EmbeddingModelRunner does not support multi-step execution.")

if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
Expand Down Expand Up @@ -91,10 +96,12 @@ def execute_model(

# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return None
return []

return self.model.pooler(hidden_states=hidden_states,
pooling_metadata=model_input.pooling_metadata)
return [
self.model.pooler(hidden_states=hidden_states,
pooling_metadata=model_input.pooling_metadata)
]

def make_model_input_from_broadcasted_tensor_dict(
self,
Expand Down
10 changes: 7 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,11 @@ def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
) -> SamplerOutput:
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")

if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
Expand Down Expand Up @@ -992,7 +996,7 @@ def execute_model(

# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return None
return []

# Sample the next token.
output: SamplerOutput = self.model.sample(
Expand All @@ -1011,7 +1015,7 @@ def execute_model(

output.hidden_states = hidden_states

return output
return [output]


class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def execute_model(
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
) -> Optional[SamplerOutput]:
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
"""
Execute the model on the given input.
"""
Expand Down
Loading

0 comments on commit 6664f2a

Please sign in to comment.