diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 5105cc482a8..387d8355451 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -591,6 +591,9 @@ def __init__( """ Whether to use shared memory pool for multi capture_size """ self.use_unique_memory_pool: bool = False + """ Whether to use cudagraph for draft model.""" + self.draft_model_use_cudagraph: bool = False + self.max_capture_size: int = None self.real_shape_to_captured_size: dict[int, int] = None # CINN Config ... diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 2f292e6acf1..0cb501f3823 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -82,6 +82,7 @@ def __init__( self._init_model_inputs() # CUDA Graph + self.draft_model_use_cudagraph = self.graph_opt_config.draft_model_use_cudagraph self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes @@ -591,7 +592,7 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False): attn_backend.init_attention_metadata(self.forward_meta) # TODO(gongshaotian): Use CUDAGraph with Draft Model - self.forward_meta.step_use_cudagraph = step_use_cudagraph + self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph def exist_prefill(self): """ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 9e0f053634c..e0a936de50a 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1371,31 +1371,35 @@ def capture_model(self) -> None: expected_decode_len=1, ) logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}") - # Capture Draft Model without bsz 1 - # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph - for batch_size in sorted(capture_sizes, reverse=True): - if batch_size == 1: - logger.info("Skip token_num = 1, when capture Draft model for mtp") - else: - assert batch_size % 2 == 0 + + if self.graph_opt_config.draft_model_use_cudagraph: + # Capture Draft Model without bsz 1 + # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture Draft model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=3, + accept_all_drafts=True, + ) + logger.info( + f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" + ) + # Capture Draft Model with bsz 1 + if 1 in capture_sizes: self._dummy_run( num_tokens=self.parallel_config.max_num_batched_tokens, - batch_size=int(batch_size / 2), + batch_size=int(1), in_capturing=True, expected_decode_len=3, - accept_all_drafts=True, + accept_all_drafts=False, ) logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") - # Capture Draft Model with bsz 1 - if 1 in capture_sizes: - self._dummy_run( - num_tokens=self.parallel_config.max_num_batched_tokens, - batch_size=int(1), - in_capturing=True, - expected_decode_len=3, - accept_all_drafts=False, - ) - logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") else: for batch_size in sorted(capture_sizes, reverse=True): self._dummy_run(