Skip to content

Commit

Permalink
[Speculative decoding] Add periodic log with time spent in proposal/s…
Browse files Browse the repository at this point in the history
…coring/verification (vllm-project#6963)
  • Loading branch information
cadedaniel authored Aug 5, 2024
1 parent c0d8f16 commit 82a1b1a
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 35 deletions.
68 changes: 48 additions & 20 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)

Expand Down Expand Up @@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
set_random_seed(1)

worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()

vocab_size = 32_000
Expand Down Expand Up @@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,

set_random_seed(1)

worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()

proposal_token_ids = torch.randint(low=0,
Expand Down Expand Up @@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,

set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()

proposal_token_ids = torch.randint(low=0,
Expand Down Expand Up @@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
set_random_seed(1)

worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)

seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
Expand Down Expand Up @@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
set_random_seed(1)

worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)

seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
Expand Down Expand Up @@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)

worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
False, metrics_collector)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector,
)
worker.init_device()

draft_worker.init_device.assert_called_once()
Expand All @@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)

worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector)

kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs)
Expand Down Expand Up @@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs,
k=k)
k=k,
stage_times=(0, 0, 0))
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
Expand Down
8 changes: 7 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ def maybe_create_spec_config(
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
disable_log_stats: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
Expand Down Expand Up @@ -1095,7 +1096,8 @@ def maybe_create_spec_config(
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
)

@staticmethod
Expand Down Expand Up @@ -1189,6 +1191,7 @@ def __init__(
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
):
"""Create a SpeculativeConfig object.
Expand Down Expand Up @@ -1221,6 +1224,8 @@ def __init__(
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
Expand All @@ -1235,6 +1240,7 @@ def __init__(
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats

self._verify_args()

Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def create_engine_config(self, ) -> EngineConfig:
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
Expand Down
68 changes: 54 additions & 14 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (create_sequence_group_output,
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
Expand Down Expand Up @@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs)
disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
)

return spec_decode_worker

Expand Down Expand Up @@ -116,6 +118,7 @@ def create_worker(
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
) -> "SpecDecodeWorker":

allow_zero_draft_token_step = True
Expand Down Expand Up @@ -171,6 +174,7 @@ def create_worker(
proposer_worker,
scorer_worker,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
Expand All @@ -180,7 +184,8 @@ def __init__(
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler,
disable_logprobs: bool,
disable_logprobs: bool = False,
disable_log_stats: bool = False,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
Expand All @@ -203,6 +208,8 @@ def __init__(
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_log_stats: If set to True, disable periodic printing of
speculative stage times.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
Expand Down Expand Up @@ -240,6 +247,7 @@ def __init__(
# in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats

def init_device(self) -> None:
"""Initialize both scorer and proposer models.
Expand Down Expand Up @@ -525,28 +533,37 @@ def _run_speculative_decoding_step(
execute_model_req.previous_hidden_states = self.previous_hidden_states
self.previous_hidden_states = None

# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
with Timer() as proposal_timer:
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)

if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")

proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
with Timer() as scoring_timer:
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)

with Timer() as verification_timer:
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)

stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
scoring_timer.elapsed_time_ms,
verification_timer.elapsed_time_ms)

return self._create_output_sampler_list(
execute_model_req.seq_group_metadata_list,
accepted_token_ids,
target_logprobs=target_logprobs,
k=execute_model_req.num_lookahead_slots)
k=execute_model_req.num_lookahead_slots,
stage_times=stage_times)

@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
Expand Down Expand Up @@ -645,6 +662,7 @@ def _create_output_sampler_list(
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
k: int,
stage_times: Tuple[float, float, float],
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
Expand Down Expand Up @@ -722,8 +740,30 @@ def _create_output_sampler_list(
if maybe_rejsample_metrics is not None:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics

# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self._maybe_log_stage_times(*stage_times)

return sampler_output_list

def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
scoring_time_ms: float,
verification_time_ms: float) -> None:
"""Log the speculative stage times. If stat logging is disabled, do
nothing.
"""
if self._disable_log_stats:
return

logger.info(
"SpecDecodeWorker stage times: "
"average_time_per_proposal_tok_ms=%.02f "
"scoring_time_ms=%.02f verification_time_ms=%.02f",
average_time_per_proposal_tok_ms, scoring_time_ms,
verification_time_ms)

def _create_dummy_logprob_lists(
self,
batch_size: int,
Expand Down
15 changes: 15 additions & 0 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs):
yield
finally:
torch.cuda.nvtx.range_pop()


class Timer:
"""Basic timer context manager for measuring CPU time.
"""

def __enter__(self):
self.start_time = time.time()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time_s = self.end_time - self.start_time
self.elapsed_time_ms = self.elapsed_time_s * 1000

0 comments on commit 82a1b1a

Please sign in to comment.