Skip to content

Commit

Permalink
Fix UnboundLocalError: local variable 'checkpoint_manager' referenced…
Browse files Browse the repository at this point in the history
… before assignment.

PiperOrigin-RevId: 634067308
  • Loading branch information
liangyaning33 authored and t5-copybara committed May 15, 2024
1 parent e84aab9 commit 717cb3c
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions t5x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig):

# Skip initialization if neither save nor restore is requested.
train_state = None
checkpoint_manager = None
if valid_restore_cfg or checkpoint_period or checkpoint_steps:
train_state, checkpoint_manager = (
utils.create_checkpoint_manager_and_restore(
Expand Down Expand Up @@ -554,7 +555,7 @@ def _run_inference_eval():
_run_inference_eval()

# Save checkpoints before the training loop starts.
if checkpoint_period:
if checkpoint_period and checkpoint_manager:
# If not using Orbax, always save checkpoint, otherwise, only save a
# checkpoint if a checkpoint does not already exist for that step. This is
# because Orbax will error out if we try to save a checkpoint that already
Expand Down Expand Up @@ -680,7 +681,7 @@ def _as_gda(spec):
logging.info('Training for %d steps.', epoch_end_step - host_step)
while host_step < epoch_end_step:
if trainer.stop_training:
if checkpoint_period:
if checkpoint_period and checkpoint_manager:
logging.info('Saving a checkpoint before early stopping...')
checkpoint_manager.save(
trainer.train_state,
Expand Down Expand Up @@ -732,7 +733,7 @@ def _as_gda(spec):
{TRAIN_METRIC_KEY: train_summary.result()},
)

if is_checkpoint_step:
if is_checkpoint_step and checkpoint_manager:
logging.info('Saving a checkpoint at specified checkpoint step')
checkpoint_manager.save(
trainer.train_state,
Expand All @@ -747,7 +748,7 @@ def _as_gda(spec):
host_step += inner_num_steps
logging.info('END Train loop.')
except trainer_lib.PreemptionError as e:
if checkpoint_period:
if checkpoint_period and checkpoint_manager:
logging.info('Saving emergency checkpoint.')
checkpoint_manager.save(
trainer.train_state,
Expand All @@ -763,8 +764,10 @@ def _as_gda(spec):
gc.collect()

# Maybe save a checkpoint if step is at period.
if checkpoint_period and (
final_epoch or step_offset % checkpoint_period == 0
if (
checkpoint_period
and (final_epoch or step_offset % checkpoint_period == 0)
and checkpoint_manager
):
train_summary.result()
logging.info('Saving checkpoint.')
Expand Down Expand Up @@ -797,7 +800,8 @@ def _as_gda(spec):
# Inference Evaluation (i.e., with decoding or scoring).
if is_eval_epoch and evaluator is not None:
_run_inference_eval()
checkpoint_manager.close()
if checkpoint_manager:
checkpoint_manager.close()

# Wait until computations are done before exiting
_cleanup()
Expand Down

0 comments on commit 717cb3c

Please sign in to comment.