@@ -213,8 +213,6 @@ def __init__(
213213 self ._generator_gradient_accumulator .reset ()
214214 self ._discriminator_gradient_accumulator .reset ()
215215
216-
217-
218216 def init_train_eval_metrics (self , list_metrics_name ):
219217 with self ._strategy .scope ():
220218 super ().init_train_eval_metrics (list_metrics_name )
@@ -706,7 +704,6 @@ def __init__(
706704 self ._gradient_accumulator = GradientAccumulator ()
707705 self ._gradient_accumulator .reset ()
708706
709-
710707 def init_train_eval_metrics (self , list_metrics_name ):
711708 with self ._strategy .scope ():
712709 super ().init_train_eval_metrics (list_metrics_name )
@@ -833,7 +830,7 @@ def _one_step_forward_per_replica(self, batch):
833830 if self .config ["gradient_accumulation_steps" ] == 1 :
834831 gradients , per_replica_losses = self ._calculate_gradient_per_batch (batch )
835832 self ._optimizer .apply_gradients (
836- zip (gradients , self ._trainable_variables )
833+ zip (gradients , self ._trainable_variables ), 1.0
837834 )
838835 else :
839836 # gradient acummulation here.
@@ -856,7 +853,7 @@ def _one_step_forward_per_replica(self, batch):
856853
857854 gradients = self ._gradient_accumulator .gradients
858855 self ._optimizer .apply_gradients (
859- zip (gradients , self ._trainable_variables )
856+ zip (gradients , self ._trainable_variables ), 1.0
860857 )
861858 self ._gradient_accumulator .reset ()
862859
0 commit comments