@@ -1371,31 +1371,35 @@ def capture_model(self) -> None:
13711371 expected_decode_len = 1 ,
13721372 )
13731373 logger .info (f"Warm up the Target model with the num_tokens:{ batch_size } , expected_decode_len:{ 1 } " )
1374- # Capture Draft Model without bsz 1
1375- # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
1376- for batch_size in sorted (capture_sizes , reverse = True ):
1377- if batch_size == 1 :
1378- logger .info ("Skip token_num = 1, when capture Draft model for mtp" )
1379- else :
1380- assert batch_size % 2 == 0
1374+
1375+ if self .graph_opt_config .draft_model_use_cudagraph :
1376+ # Capture Draft Model without bsz 1
1377+ # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
1378+ for batch_size in sorted (capture_sizes , reverse = True ):
1379+ if batch_size == 1 :
1380+ logger .info ("Skip token_num = 1, when capture Draft model for mtp" )
1381+ else :
1382+ assert batch_size % 2 == 0
1383+ self ._dummy_run (
1384+ num_tokens = self .parallel_config .max_num_batched_tokens ,
1385+ batch_size = int (batch_size / 2 ),
1386+ in_capturing = True ,
1387+ expected_decode_len = 3 ,
1388+ accept_all_drafts = True ,
1389+ )
1390+ logger .info (
1391+ f"Warm up the Draft model with the num_tokens:{ batch_size } , expected_decode_len:{ 3 } "
1392+ )
1393+ # Capture Draft Model with bsz 1
1394+ if 1 in capture_sizes :
13811395 self ._dummy_run (
13821396 num_tokens = self .parallel_config .max_num_batched_tokens ,
1383- batch_size = int (batch_size / 2 ),
1397+ batch_size = int (1 ),
13841398 in_capturing = True ,
13851399 expected_decode_len = 3 ,
1386- accept_all_drafts = True ,
1400+ accept_all_drafts = False ,
13871401 )
13881402 logger .info (f"Warm up the Draft model with the num_tokens:{ batch_size } , expected_decode_len:{ 3 } " )
1389- # Capture Draft Model with bsz 1
1390- if 1 in capture_sizes :
1391- self ._dummy_run (
1392- num_tokens = self .parallel_config .max_num_batched_tokens ,
1393- batch_size = int (1 ),
1394- in_capturing = True ,
1395- expected_decode_len = 3 ,
1396- accept_all_drafts = False ,
1397- )
1398- logger .info (f"Warm up the Draft model with the num_tokens:{ batch_size } , expected_decode_len:{ 3 } " )
13991403 else :
14001404 for batch_size in sorted (capture_sizes , reverse = True ):
14011405 self ._dummy_run (
0 commit comments