@@ -418,6 +418,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418418 (applicable to 2D sharding only)
419419 if set and DMP collection is enabled for 2D sharding,
420420 sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
421+ gradient_accumulation_steps (int): number of steps to accumulate gradients before
422+ performing backward pass and optimizer update. Default is 1 (no accumulation).
423+ should_scale_losses (bool): whether to scale accumulated losses by
424+ gradient_accumulation_steps. Default is False.
421425 """
422426
423427 # The PipelinedForward class that is used in _rewrite_model
@@ -438,6 +442,8 @@ def __init__(
438442 ] = None ,
439443 dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
440444 enqueue_batch_after_forward : bool = False ,
445+ gradient_accumulation_steps : int = 1 ,
446+ should_scale_losses : bool = False ,
441447 ) -> None :
442448 self ._model = model
443449 self ._optimizer = optimizer
@@ -503,6 +509,11 @@ def __init__(
503509 dmp_collection_sync_interval_batches
504510 )
505511
512+ self ._accumulation_steps : int = gradient_accumulation_steps
513+ self ._accumulation_step_count : int = gradient_accumulation_steps - 1
514+ self ._should_scale_losses : bool = should_scale_losses
515+ self ._is_first_step : bool = True
516+
506517 if self ._dmp_collection_sync_interval_batches is not None :
507518 logger .info (
508519 f"{ self .__class__ .__name__ } : [Sparse 2D] DMP collection will sync every "
@@ -680,7 +691,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680691 # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
681692 self ._set_module_context (self .contexts [0 ])
682693
683- if self ._model .training :
694+ # only zero grad at the start of each accumulation
695+ if self ._model .training and (
696+ self ._is_first_step or self ._accumulation_step_count == 0
697+ ):
684698 with record_function ("## zero_grad ##" ):
685699 self ._optimizer .zero_grad ()
686700
@@ -696,35 +710,57 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696710 # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
697711 self .enqueue_batch (dataloader_iter )
698712
699- # forward
700- with record_function (f"## forward { self .contexts [0 ].index } ##" ):
701- self ._state = PipelineState .CALL_FWD
702- losses , output = self ._model_fwd (self .batches [0 ])
713+ # NOTE: the first step cannot be no_sync when DDP.static_graph = True,
714+ # due to an unfortunate restriction in torch.distributed
715+ no_sync = not self ._is_first_step and (
716+ self ._model .training
717+ and self ._accumulation_step_count + 1 < self ._accumulation_steps
718+ )
719+ with (
720+ self ._model ._dmp_wrapped_module .no_sync () # pyre-ignore[16]
721+ if no_sync
722+ else contextlib .nullcontext ()
723+ ):
724+ # forward
725+ with record_function (f"## forward { self .contexts [0 ].index } ##" ):
726+ self ._state = PipelineState .CALL_FWD
727+ losses , output = self ._model_fwd (self .batches [0 ])
703728
704- if self ._enqueue_batch_after_forward :
705- # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
706- # Start this step after the forward of batch i, so that the H2D copy doesn't compete
707- # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
708- self .enqueue_batch (dataloader_iter )
729+ if self ._enqueue_batch_after_forward :
730+ # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
731+ # Start this step after the forward of batch i, so that the H2D copy doesn't compete
732+ # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
733+ self .enqueue_batch (dataloader_iter )
709734
710- if len (self .batches ) >= 2 :
711- # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
712- self .wait_sparse_data_dist (self .contexts [1 ])
735+ if len (self .batches ) >= 2 :
736+ # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
737+ self .wait_sparse_data_dist (self .contexts [1 ])
713738
714- if self ._model .training :
715739 # backward
716- self ._state = PipelineState .CALL_BWD
717- self ._backward (losses )
718-
719- self .sync_embeddings (
720- self ._model ,
721- self ._dmp_collection_sync_interval_batches ,
722- self .contexts [0 ],
723- )
724-
725- # update
726- with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
727- self ._optimizer .step ()
740+ if self ._model .training :
741+ self ._state = PipelineState .CALL_BWD
742+ if (
743+ self ._should_scale_losses
744+ and self ._accumulation_steps > 1
745+ and not self ._is_first_step
746+ ):
747+ losses = losses / self ._accumulation_steps
748+ self ._backward (losses )
749+
750+ if no_sync :
751+ self ._accumulation_step_count += 1
752+ else :
753+ self .sync_embeddings (
754+ self ._model ,
755+ self ._dmp_collection_sync_interval_batches ,
756+ self .contexts [0 ],
757+ )
758+ # update
759+ with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
760+ self ._optimizer .step ()
761+ self ._accumulation_step_count = 0
762+ if self ._is_first_step :
763+ self ._is_first_step = False
728764
729765 self .dequeue_batch ()
730766 return output
0 commit comments