diff --git a/tf_keras/callbacks.py b/tf_keras/callbacks.py index 075418b1a..e7041d4c6 100644 --- a/tf_keras/callbacks.py +++ b/tf_keras/callbacks.py @@ -1894,7 +1894,7 @@ def on_train_begin(self, logs=None): # TrainingState is used to manage the training state needed for # failure-recovery of a worker in training. - if self.model._distribution_strategy and not isinstance( + if self.model.distribute_strategy and not isinstance( self.model.distribute_strategy, self._supported_strategies ): raise NotImplementedError(