Skip to content

Commit 83976c9

Browse files
authored
fix bug: dp+tp warmup (#3991)
* fix bug: dp+tp warmup * assert dp size
1 parent d86046a commit 83976c9

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,15 @@ def warmup(self):
395395
with self.all_context():
396396
max_batches = self.cache_config.max_batches
397397
num_tokens = max_batches
398-
398+
dist_ctx = get_dist_manager().current_context()
399+
dp = dist_ctx.dp
399400
# warmup prefill
400401
inputs = self.inputs_strategy.make_dummy(max_batches,
401402
is_decoding=False,
402403
device='cuda',
403404
vocab_size=self.model_config.vocab_size)
405+
if dp > 1:
406+
inputs.build_dp_meta()
404407
self._forward_impl(inputs)
405408

406409
# warmup decoding(with cuda graph)
@@ -411,6 +414,8 @@ def warmup(self):
411414
is_decoding=True,
412415
device='cuda',
413416
vocab_size=self.model_config.vocab_size)
417+
if dp > 1:
418+
inputs.build_dp_meta()
414419
self._forward_impl(inputs)
415420

416421
def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor):

0 commit comments

Comments
 (0)