From 82a1b1a82b1fbb454c82a9ef95730b929c9b270c Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Mon, 5 Aug 2024 01:46:44 -0700 Subject: [PATCH] [Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963) --- tests/spec_decode/test_spec_decode_worker.py | 68 ++++++++++++++------ vllm/config.py | 8 ++- vllm/engine/arg_utils.py | 1 + vllm/spec_decode/spec_decode_worker.py | 68 ++++++++++++++++---- vllm/spec_decode/util.py | 15 +++++ 5 files changed, 125 insertions(+), 35 deletions(-) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 671c9bef294f9..9ae1b4bc40f0f 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + draft_worker, + target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) @@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + draft_worker, + target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() vocab_size = 32_000 @@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int, set_random_seed(1) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), False, - metrics_collector) + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector, + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), False, - metrics_collector) + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector, + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str): spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - False, metrics_collector) + worker = SpecDecodeWorker( + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector, + ) worker.init_device() draft_worker.init_device.assert_called_once() @@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method): target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + worker = SpecDecodeWorker(proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + metrics_collector=metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} worker.initialize_cache(**kwargs) @@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens(): seq_group_metadata_list=seq_group_metadata_list, accepted_token_ids=accepted_token_ids, target_logprobs=target_token_logprobs, - k=k) + k=k, + stage_times=(0, 0, 0)) # Verify that _seq_with_bonus_token_in_last_step contains the following: # 1. Sequence IDs that were already present in # _seq_with_bonus_token_in_last_step but were not part of the current diff --git a/vllm/config.py b/vllm/config.py index 35945e34452d2..bec0b63197ef4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -907,6 +907,7 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + disable_log_stats: bool, speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], @@ -1095,7 +1096,8 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=\ typical_acceptance_sampler_posterior_alpha, - disable_logprobs=disable_logprobs + disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, ) @staticmethod @@ -1189,6 +1191,7 @@ def __init__( typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, + disable_log_stats: bool, ): """Create a SpeculativeConfig object. @@ -1221,6 +1224,8 @@ def __init__( sampling, target sampling, and after accepted tokens are determined. If set to False, log probabilities will be returned. + disable_log_stats: Whether to disable periodic printing of stage + times in speculative decoding. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config @@ -1235,6 +1240,7 @@ def __init__( self.typical_acceptance_sampler_posterior_alpha = \ typical_acceptance_sampler_posterior_alpha self.disable_logprobs = disable_logprobs + self.disable_log_stats = disable_log_stats self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2737b50927f6b..acc0551af0154 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -792,6 +792,7 @@ def create_engine_config(self, ) -> EngineConfig: speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, + disable_log_stats=self.disable_log_stats, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, draft_token_acceptance_method=\ diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ad8c0cee0b5b6..690aad505e215 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -27,7 +27,7 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (create_sequence_group_output, +from vllm.spec_decode.util import (Timer, create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) @@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. typical_acceptance_sampler_posterior_alpha, - disable_logprobs=speculative_config.disable_logprobs) + disable_logprobs=speculative_config.disable_logprobs, + disable_log_stats=speculative_config.disable_log_stats, + ) return spec_decode_worker @@ -116,6 +118,7 @@ def create_worker( typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, + disable_log_stats: bool, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True @@ -171,6 +174,7 @@ def create_worker( proposer_worker, scorer_worker, disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, allow_zero_draft_token_step=allow_zero_draft_token_step) @@ -180,7 +184,8 @@ def __init__( proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, - disable_logprobs: bool, + disable_logprobs: bool = False, + disable_log_stats: bool = False, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, @@ -203,6 +208,8 @@ def __init__( disable_logprobs: If set to True, token log probabilities will not be output in both the draft worker and the target worker. If set to False, log probabilities will be output by both. + disable_log_stats: If set to True, disable periodic printing of + speculative stage times. disable_by_batch_size: If the batch size is larger than this, disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set @@ -240,6 +247,7 @@ def __init__( # in the subsequent step. self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs + self._disable_log_stats = disable_log_stats def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -525,28 +533,37 @@ def _run_speculative_decoding_step( execute_model_req.previous_hidden_states = self.previous_hidden_states self.previous_hidden_states = None - # Generate proposals using draft worker. - proposals = self.proposer_worker.get_spec_proposals( - execute_model_req, self._seq_with_bonus_token_in_last_step) + with Timer() as proposal_timer: + # Generate proposals using draft worker. + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") - proposal_scores = self.scorer.score_proposals( - execute_model_req, - proposals, - ) - accepted_token_ids, target_logprobs = self._verify_tokens( - execute_model_req.seq_group_metadata_list, proposal_scores, - proposals, execute_model_req.num_lookahead_slots) + with Timer() as scoring_timer: + proposal_scores = self.scorer.score_proposals( + execute_model_req, + proposals, + ) + + with Timer() as verification_timer: + accepted_token_ids, target_logprobs = self._verify_tokens( + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) + + stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, + scoring_timer.elapsed_time_ms, + verification_timer.elapsed_time_ms) return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, target_logprobs=target_logprobs, - k=execute_model_req.num_lookahead_slots) + k=execute_model_req.num_lookahead_slots, + stage_times=stage_times) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( @@ -645,6 +662,7 @@ def _create_output_sampler_list( accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] k: int, + stage_times: Tuple[float, float, float], ) -> List[SamplerOutput]: """Given the accepted token ids, create a list of SamplerOutput. @@ -722,8 +740,30 @@ def _create_output_sampler_list( if maybe_rejsample_metrics is not None: sampler_output_list[ 0].spec_decode_worker_metrics = maybe_rejsample_metrics + + # Log time spent in each stage periodically. + # This is periodic because the rejection sampler emits metrics + # periodically. + self._maybe_log_stage_times(*stage_times) + return sampler_output_list + def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, + scoring_time_ms: float, + verification_time_ms: float) -> None: + """Log the speculative stage times. If stat logging is disabled, do + nothing. + """ + if self._disable_log_stats: + return + + logger.info( + "SpecDecodeWorker stage times: " + "average_time_per_proposal_tok_ms=%.02f " + "scoring_time_ms=%.02f verification_time_ms=%.02f", + average_time_per_proposal_tok_ms, scoring_time_ms, + verification_time_ms) + def _create_dummy_logprob_lists( self, batch_size: int, diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index ade546eef264e..c6223a97dba10 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,3 +1,4 @@ +import time from contextlib import contextmanager from typing import Dict, List, Optional, Tuple @@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs): yield finally: torch.cuda.nvtx.range_pop() + + +class Timer: + """Basic timer context manager for measuring CPU time. + """ + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.end_time = time.time() + self.elapsed_time_s = self.end_time - self.start_time + self.elapsed_time_ms = self.elapsed_time_s * 1000