88from collections .abc import Iterator
99from contextlib import contextmanager
1010from copy import deepcopy
11+ from itertools import product
1112from typing import TYPE_CHECKING , Any , NamedTuple , TypeAlias , cast
1213
1314import numpy as np
@@ -2469,7 +2470,9 @@ def execute_model(
24692470 num_scheduled_tokens == self .input_batch .num_reqs * max_query_len
24702471 )
24712472 batch_descriptor = BatchDescriptor (
2472- num_tokens = num_input_tokens , uniform_decode = uniform_decode
2473+ num_tokens = num_input_tokens ,
2474+ uniform_decode = uniform_decode ,
2475+ has_lora = len (self .input_batch .lora_id_to_lora_request ) > 0 ,
24732476 )
24742477 cudagraph_runtime_mode , batch_descriptor = (
24752478 self .cudagraph_dispatcher .dispatch (batch_descriptor , use_cascade_attn )
@@ -3193,6 +3196,7 @@ def _dummy_run(
31933196 is_profile : bool = False ,
31943197 create_mixed_batch : bool = False ,
31953198 remove_lora : bool = True ,
3199+ activate_lora : bool = False ,
31963200 ) -> tuple [torch .Tensor , torch .Tensor ]:
31973201 """
31983202 Run a dummy forward pass to warm up/profile run or capture the
@@ -3215,6 +3219,7 @@ def _dummy_run(
32153219 create_mixed_batch: If True, create a mixed batch with both decode
32163220 (1 token) and prefill (multiple tokens) requests.
32173221 remove_lora: If False, dummy LoRAs are not destroyed after the run
3222+ activate_lora: If False, dummy_run is performed without LoRAs.
32183223 """
32193224 assert (
32203225 cudagraph_runtime_mode is None
@@ -3364,7 +3369,7 @@ def _dummy_run(
33643369 attn_metadata [layer_name ] = attn_metadata_i
33653370
33663371 with self .maybe_dummy_run_with_lora (
3367- self .lora_config , num_scheduled_tokens , remove_lora
3372+ self .lora_config , num_scheduled_tokens , activate_lora , remove_lora
33683373 ):
33693374 # Make sure padding doesn't exceed max_num_tokens
33703375 assert num_tokens_after_padding <= self .max_num_tokens
@@ -3411,6 +3416,7 @@ def _dummy_run(
34113416 BatchDescriptor (
34123417 num_tokens = num_tokens_after_padding ,
34133418 uniform_decode = uniform_decode ,
3419+ has_lora = activate_lora and self .lora_config is not None ,
34143420 )
34153421 )
34163422 if not is_profile
@@ -3769,10 +3775,21 @@ def freeze_gc():
37693775 start_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
37703776 cudagraph_mode = self .compilation_config .cudagraph_mode
37713777 assert cudagraph_mode is not None
3778+
3779+ if self .lora_config :
3780+ if self .compilation_config .cudagraph_specialize_lora :
3781+ lora_cases = [True , False ]
3782+ else :
3783+ lora_cases = [True ]
3784+ else :
3785+ lora_cases = [False ]
3786+
37723787 if cudagraph_mode .mixed_mode () != CUDAGraphMode .NONE :
37733788 cudagraph_runtime_mode = cudagraph_mode .mixed_mode ()
37743789
3775- compilation_cases = list (reversed (self .cudagraph_batch_sizes ))
3790+ compilation_cases = list (
3791+ product (reversed (self .cudagraph_batch_sizes ), lora_cases )
3792+ )
37763793 self ._capture_cudagraphs (
37773794 compilation_cases ,
37783795 cudagraph_runtime_mode = cudagraph_runtime_mode ,
@@ -3793,7 +3810,9 @@ def freeze_gc():
37933810 for x in self .cudagraph_batch_sizes
37943811 if max_num_tokens >= x >= self .uniform_decode_query_len
37953812 ]
3796- compilation_cases_decode = list (reversed (decode_cudagraph_batch_sizes ))
3813+ compilation_cases_decode = list (
3814+ product (reversed (decode_cudagraph_batch_sizes ), lora_cases )
3815+ )
37973816 self ._capture_cudagraphs (
37983817 compilation_cases = compilation_cases_decode ,
37993818 cudagraph_runtime_mode = CUDAGraphMode .FULL ,
@@ -3823,7 +3842,7 @@ def freeze_gc():
38233842
38243843 def _capture_cudagraphs (
38253844 self ,
3826- compilation_cases : list [int ],
3845+ compilation_cases : list [tuple [ int , bool ] ],
38273846 cudagraph_runtime_mode : CUDAGraphMode ,
38283847 uniform_decode : bool ,
38293848 ):
@@ -3844,7 +3863,7 @@ def _capture_cudagraphs(
38443863 )
38453864
38463865 # We skip EPLB here since we don't want to record dummy metrics
3847- for num_tokens in compilation_cases :
3866+ for num_tokens , activate_lora in compilation_cases :
38483867 # We currently only capture ubatched graphs when its a FULL
38493868 # cudagraph, a uniform decode batch, and the number of tokens
38503869 # is above the threshold. Otherwise we just capture a non-ubatched
@@ -3875,6 +3894,7 @@ def _capture_cudagraphs(
38753894 allow_microbatching = allow_microbatching ,
38763895 skip_eplb = True ,
38773896 remove_lora = False ,
3897+ activate_lora = activate_lora ,
38783898 )
38793899 self ._dummy_run (
38803900 num_tokens ,
@@ -3883,6 +3903,7 @@ def _capture_cudagraphs(
38833903 allow_microbatching = allow_microbatching ,
38843904 skip_eplb = True ,
38853905 remove_lora = False ,
3906+ activate_lora = activate_lora ,
38863907 )
38873908 self .maybe_remove_all_loras (self .lora_config )
38883909
0 commit comments