Skip to content

Commit f9698c7

Browse files
nelyahuloadams
andauthored
pipe engine eval_batch: add option to disable loss broadcast (#4326)
it is sometimes not required to have the loss bcasted to all rank after evaluation cycle and it is only required by some ranks. It adds overhead of communication between rank. by setting bcast_loss=False (default is True and retains the previous behavior), loss will not be bcasted. if monitor is enabled loss will be bcasted. Co-authored-by: Logan Adams <[email protected]>
1 parent 8e64c3b commit f9698c7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

deepspeed/runtime/pipe/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def train_batch(self, data_iter=None):
386386
# TODO: should return precisely what loss returned and allow others to be queried?
387387
return self.agg_train_loss
388388

389-
def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg'):
389+
def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg', bcast_loss=True):
390390
"""Evaluate the pipeline on a batch of data from ``data_iter``. The
391391
engine will evaluate ``self.train_batch_size()`` total samples
392392
collectively across all workers.
@@ -449,7 +449,7 @@ def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_o
449449
if self.is_last_stage():
450450
eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output)
451451

452-
if compute_loss:
452+
if compute_loss and (bcast_loss or self.monitor.enabled):
453453
eval_output = self._bcast_pipe_scalar(eval_output)
454454

455455
if self.global_rank == 0 and self.monitor.enabled:

0 commit comments

Comments
 (0)