Skip to content

Commit 81ce250

Browse files
committed
set default value as false
1 parent f8bc992 commit 81ce250

File tree

2 files changed

+24
-20
lines changed

2 files changed

+24
-20
lines changed

fastdeploy/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def __init__(
592592
self.use_unique_memory_pool: bool = False
593593

594594
""" Whether to use cudagraph for draft model."""
595-
self.draft_model_use_cudagraph: bool = True
595+
self.draft_model_use_cudagraph: bool = False
596596

597597
self.max_capture_size: int = None
598598
self.real_shape_to_captured_size: dict[int, int] = None

fastdeploy/worker/gpu_model_runner.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,31 +1371,35 @@ def capture_model(self) -> None:
13711371
expected_decode_len=1,
13721372
)
13731373
logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}")
1374-
# Capture Draft Model without bsz 1
1375-
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
1376-
for batch_size in sorted(capture_sizes, reverse=True):
1377-
if batch_size == 1:
1378-
logger.info("Skip token_num = 1, when capture Draft model for mtp")
1379-
else:
1380-
assert batch_size % 2 == 0
1374+
1375+
if self.graph_opt_config.draft_model_use_cudagraph:
1376+
# Capture Draft Model without bsz 1
1377+
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
1378+
for batch_size in sorted(capture_sizes, reverse=True):
1379+
if batch_size == 1:
1380+
logger.info("Skip token_num = 1, when capture Draft model for mtp")
1381+
else:
1382+
assert batch_size % 2 == 0
1383+
self._dummy_run(
1384+
num_tokens=self.parallel_config.max_num_batched_tokens,
1385+
batch_size=int(batch_size / 2),
1386+
in_capturing=True,
1387+
expected_decode_len=3,
1388+
accept_all_drafts=True,
1389+
)
1390+
logger.info(
1391+
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
1392+
)
1393+
# Capture Draft Model with bsz 1
1394+
if 1 in capture_sizes:
13811395
self._dummy_run(
13821396
num_tokens=self.parallel_config.max_num_batched_tokens,
1383-
batch_size=int(batch_size / 2),
1397+
batch_size=int(1),
13841398
in_capturing=True,
13851399
expected_decode_len=3,
1386-
accept_all_drafts=True,
1400+
accept_all_drafts=False,
13871401
)
13881402
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
1389-
# Capture Draft Model with bsz 1
1390-
if 1 in capture_sizes:
1391-
self._dummy_run(
1392-
num_tokens=self.parallel_config.max_num_batched_tokens,
1393-
batch_size=int(1),
1394-
in_capturing=True,
1395-
expected_decode_len=3,
1396-
accept_all_drafts=False,
1397-
)
1398-
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
13991403
else:
14001404
for batch_size in sorted(capture_sizes, reverse=True):
14011405
self._dummy_run(

0 commit comments

Comments
 (0)