Skip to content

Commit

Permalink
Merge pull request ReaLLMASIC#114 from mmoffatt2/master
Browse files Browse the repository at this point in the history
Argparse Argument for "Patience" -- and Early Exit without Updates for "patience" number of steps
  • Loading branch information
gkielian authored Mar 11, 2024
2 parents 7a88463 + f47c035 commit c3259c2
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def parse_args():
# Checkpoint args
training_group.add_argument('--only_save_checkpoint_at_end', action='store_true')
training_group.add_argument('--always_save_checkpoint', action='store_true')
training_group.add_argument('--patience', default=None, type=int)
training_group.add_argument('--init_from', default='scratch', choices=['scratch', 'prev_run', 'resume', 'gpt2*'], type=str)
training_group.add_argument('--prev_run_ckpt', default='', type=str)
training_group.add_argument('--csv_ckpt_dir', default='', type=str)
Expand Down Expand Up @@ -415,6 +416,7 @@ def train(self):
t0 = time.time()
local_iter_num = 0
running_mfu = -1.0
num_steps_with_worse_loss = 0

while True:
lr = self.get_lr(self.iter_num) if self.args.decay_lr else self.args.learning_rate
Expand All @@ -427,7 +429,9 @@ def train(self):
self.log_metrics(losses, lr, running_mfu, self.iter_num)

if losses['val'] < self.best_val_loss or self.args.always_save_checkpoint:
self.best_val_loss = losses['val']
if losses['val'] < self.best_val_loss:
self.best_val_loss = losses['val']
num_steps_with_worse_loss = 0
if self.iter_num > 0:
checkpoint = {
'model': self.raw_model.state_dict(),
Expand All @@ -439,6 +443,11 @@ def train(self):
}
print(f"saving checkpoint to {self.args.out_dir}")
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
if self.args.patience is not None and num_steps_with_worse_loss >= self.args.patience:
print(f"Early Stopping: loss has not decreased in {self.args.patience + 1} steps")
break
if losses['val'] > self.best_val_loss:
num_steps_with_worse_loss += 1

if self.iter_num == 0 and self.args.eval_only:
break
Expand Down

0 comments on commit c3259c2

Please sign in to comment.