File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -395,12 +395,15 @@ def warmup(self):
395395 with self .all_context ():
396396 max_batches = self .cache_config .max_batches
397397 num_tokens = max_batches
398-
398+ dist_ctx = get_dist_manager ().current_context ()
399+ dp = dist_ctx .dp
399400 # warmup prefill
400401 inputs = self .inputs_strategy .make_dummy (max_batches ,
401402 is_decoding = False ,
402403 device = 'cuda' ,
403404 vocab_size = self .model_config .vocab_size )
405+ if dp > 1 :
406+ inputs .build_dp_meta ()
404407 self ._forward_impl (inputs )
405408
406409 # warmup decoding(with cuda graph)
@@ -411,6 +414,8 @@ def warmup(self):
411414 is_decoding = True ,
412415 device = 'cuda' ,
413416 vocab_size = self .model_config .vocab_size )
417+ if dp > 1 :
418+ inputs .build_dp_meta ()
414419 self ._forward_impl (inputs )
415420
416421 def _slice_outs (self , inputs : torch .Tensor , seq_length : torch .LongTensor ):
You can’t perform that action at this time.
0 commit comments