From 6932aa5d2be0d4da7aa5c4196486fa2d0de3f3ad Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Tue, 28 Oct 2025 20:57:28 +0800 Subject: [PATCH 1/4] refactore default capture list & refine code --- fastdeploy/config.py | 26 +++++++++---------- .../cudagraph_piecewise_backend.py | 10 +++++-- .../graph_optimization_backend.py | 12 ++++++++- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 5105cc482a8..829ea3f9007 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -631,7 +631,7 @@ def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None: self.real_shape_to_captured_size[bs] = end self.real_shape_to_captured_size[self.max_capture_size] = self.max_capture_size - def _set_cudagraph_sizes(self, max_num_seqs: int = 0): + def _set_cudagraph_sizes(self, max_capture_size: int = 0): """ Calculate a series of candidate capture sizes, and then extract a portion of them as the capture list for the CUDA graph based on user input. @@ -643,7 +643,7 @@ def _set_cudagraph_sizes(self, max_num_seqs: int = 0): # Shape [256, 288, ... 992, 1024] draft_capture_sizes += [32 * i for i in range(17, 33)] - draft_capture_sizes.append(max_num_seqs) + draft_capture_sizes.append(max_capture_size) self.cudagraph_capture_sizes = sorted(draft_capture_sizes) def to_json_string(self): @@ -1148,20 +1148,20 @@ def __init__( self.cache_config: CacheConfig = cache_config # type: ignore self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config # Initialize cuda graph capture list - if self.graph_opt_config.cudagraph_capture_sizes is None: - self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs) - + max_capture_shape = self.parallel_config.max_num_seqs if self.speculative_config is not None and self.speculative_config.method == "mtp": - max_shape = self.parallel_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1) - if max_shape % 2 == 1: - max_shape = max_shape + 1 - self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=min(512, max_shape)) + max_capture_shape = self.parallel_config.max_num_seqs * ( + self.speculative_config.num_speculative_tokens + 1 + ) + assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios." + if self.graph_opt_config.cudagraph_only_prefill: + max_capture_shape = 512 else: - self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs) + max_capture_shape = min(512, max_capture_shape) - # TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn - if self.graph_opt_config.graph_opt_level == 2: - self.graph_opt_config.graph_opt_level = 1 + if self.graph_opt_config.cudagraph_capture_sizes is None: + self.graph_opt_config._set_cudagraph_sizes(max_capture_size=max_capture_shape) + self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=max_capture_shape) self.tokenizer = tokenizer self.max_num_batched_tokens = max_num_batched_tokens diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index ce3bedd2e66..2ec2fcce8eb 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -163,7 +163,7 @@ def __call__(self, **kwargs): for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 entry.runnable(**kwargs) - logger.debug( + logger.info( f"[CUDA GRAPH] [ID:{id(self)}] Warm up for real shape {padding_real_shape}, " f"finished ({n + 1}/{entry.num_finished_warmup}) times" ) @@ -199,7 +199,7 @@ def __call__(self, **kwargs): # For CUDAGraph debug # self._save_cudagrpah_dot_files(entry) - logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph captured for real shape {padding_real_shape}") + logger.info(f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph captured for real shape {padding_real_shape}") # Replay entry.cuda_graph.replay() @@ -243,3 +243,9 @@ def _save_cudagrpah_dot_files(self, entry): f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}_time{time.perf_counter()}", 1 << 0, ) + + def check_capture_successful(self): + """Check whether the shapes are captured or not""" + for shape, entry in self.concrete_size_entries.items(): + if not entry.captured: + raise ValueError(f"[CUDA GRAPH][ID:{id(self)}] Shape {shape} capture failed.") diff --git a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py index 73fae52e929..2fb7355d015 100644 --- a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py +++ b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py @@ -21,7 +21,6 @@ from paddle.jit import sot from paddle.jit.dy2static.utils import Backend as ToStaticBackend -from paddleformers.utils.log import logger from typing_extensions import ParamSpec from fastdeploy.config import FDConfig @@ -35,6 +34,9 @@ from fastdeploy.model_executor.graph_optimization.utils import ( in_sot_warmup_mode as in_warmup_mode, ) +from fastdeploy.utils import get_logger + +logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log") P = ParamSpec("P") T = TypeVar("T") @@ -116,6 +118,9 @@ def __init__(self, runnable: Callable, fd_config: FDConfig): self.fd_config = fd_config self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0] + self._debug_count_cudagraph_replay = 0 + self._debug_count_total_step = 0 + if self.fd_config.graph_opt_config.graph_opt_level > 0: # 1. Prepare cuda grpah input buffers (contain output of subgraphs) @@ -130,6 +135,7 @@ def __init__(self, runnable: Callable, fd_config: FDConfig): ).__get__(self.runnable.__self__) def __call__(self, **kwargs): + self._debug_count_total_step += 1 if not self.fd_config.graph_opt_config.use_cudagraph: return self.runnable(**kwargs) if self.cudagraph_piecewise_backend is None: @@ -143,6 +149,10 @@ def __call__(self, **kwargs): if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.max_captre_size): return self.runnable(**kwargs) else: + self._debug_count_cudagraph_replay += 1 + logger.debug( + f"[CUDA GRAPH][ID:{id(self.cudagraph_piecewise_backend)}] Total step count: {self._debug_count_total_step}, CUDAGraph replay count: {self._debug_count_cudagraph_replay}" + ) return self.cudagraph_piecewise_backend.__call__(**kwargs) def clear_cudagraph_piecewise_backend(self): From 07b219c279e7b68107548400faf8ed2c1de74746 Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Tue, 28 Oct 2025 21:30:19 +0800 Subject: [PATCH 2/4] fix bug --- fastdeploy/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 829ea3f9007..2b565d24bb5 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -583,6 +583,7 @@ def __init__( only to the layer where CUDA graph functionality is required. """ self.cudagraph_splitting_ops: list[str] = [] + self.cudagraph_only_prefill: bool = False """ Whether to use a full cuda graph for the entire forward pass rather than splitting certain operations such as attention into subgraphs. Thus this flag cannot be used together with splitting_ops.""" From 65aa83ff2fdffd7a3245db6edb65fa1255b69570 Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Wed, 29 Oct 2025 10:26:05 +0800 Subject: [PATCH 3/4] fix ci bug --- fastdeploy/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 2b565d24bb5..8993f4f25af 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -602,13 +602,13 @@ def __init__( self.check_legality_parameters() - def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None: + def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None: """ Initialize cuda graph capture sizes and pre-compute the mapping from batch size to padded graph size """ # Regular capture sizes - self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs] + self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size] dedup_sizes = list(set(self.cudagraph_capture_sizes)) if len(dedup_sizes) < len(self.cudagraph_capture_sizes): logger.info( From 2573e41dde676a1ca3b7518586d5d30abf340a54 Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Wed, 29 Oct 2025 11:04:57 +0800 Subject: [PATCH 4/4] Fix test case --- tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py | 4 ++-- tests/graph_optimization/test_cuda_graph_spec_decode.py | 4 ++-- .../graph_optimization/test_static_graph_cuda_graph_split.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py b/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py index 82b8a27ace5..458e2a9655d 100644 --- a/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py +++ b/tests/graph_optimization/test_cuda_graph_dynamic_subgraph.py @@ -156,8 +156,8 @@ def test_cuda_graph_subgraph(self): parallel_config.max_num_seqs = 8 cache_config = CacheConfig({}) # Initialize cuda graph capture list - graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) - graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config._set_cudagraph_sizes(max_capture_size=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs) fd_config = FDConfig( graph_opt_config=graph_opt_config, parallel_config=parallel_config, diff --git a/tests/graph_optimization/test_cuda_graph_spec_decode.py b/tests/graph_optimization/test_cuda_graph_spec_decode.py index 2fc685bbf8d..665344ac9cc 100644 --- a/tests/graph_optimization/test_cuda_graph_spec_decode.py +++ b/tests/graph_optimization/test_cuda_graph_spec_decode.py @@ -103,8 +103,8 @@ def test_cuda_graph_spec_decode(self): parallel_config.max_num_seqs = 1 cache_config = CacheConfig({}) # Initialize cuda graph capture list - graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) - graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config._set_cudagraph_sizes(max_capture_size=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs) fd_config = FDConfig( graph_opt_config=graph_opt_config, parallel_config=parallel_config, diff --git a/tests/graph_optimization/test_static_graph_cuda_graph_split.py b/tests/graph_optimization/test_static_graph_cuda_graph_split.py index 7421333a5d4..2259a522c43 100644 --- a/tests/graph_optimization/test_static_graph_cuda_graph_split.py +++ b/tests/graph_optimization/test_static_graph_cuda_graph_split.py @@ -89,8 +89,9 @@ def test(self): # Set FastDeploy config graph_opt_config = GraphOptimizationConfig({"use_cudagraph": True, "graph_opt_level": 1}) parallel_config = ParallelConfig({"max_num_seqs": 1}) - graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs) - graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs) + graph_opt_config._set_cudagraph_sizes(max_capture_size=parallel_config.max_num_seqs) + graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs) + cache_config = CacheConfig({}) fd_config = FDConfig(