From bd938f7c7c34eda5e04425090b7955694f1b8603 Mon Sep 17 00:00:00 2001 From: Derek Hyatt Date: Sat, 9 Mar 2024 12:35:12 -0500 Subject: [PATCH] - early stopping (default: off) --- train.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 951bda9914..a651d173c0 100644 --- a/train.py +++ b/train.py @@ -57,6 +57,7 @@ # adamw optimizer learning_rate = 6e-4 # max learning rate max_iters = 600000 # total number of training iterations +early_stopping_iters = None # number of iterations to stop after with no validation loss improvement weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 @@ -78,6 +79,8 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- +assert not (early_stopping_iters and always_save_checkpoint), "early_stopping and always_save_checkpoint are mutually exclusive" + # various inits, derived attributes, I/O setup ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? if ddp: @@ -252,7 +255,8 @@ def get_lr(it): local_iter_num = 0 # number of iterations in the lifetime of this process raw_model = model.module if ddp else model # unwrap DDP container if needed running_mfu = -1.0 -while True: +patience = 0 +while not early_stopping_iters or patience < early_stopping_iters: # determine and set the learning rate for this iteration lr = get_lr(iter_num) if decay_lr else learning_rate @@ -272,6 +276,7 @@ def get_lr(it): "mfu": running_mfu*100, # convert to percentage }) if losses['val'] < best_val_loss or always_save_checkpoint: + patience = 0 best_val_loss = losses['val'] if iter_num > 0: checkpoint = { @@ -284,6 +289,8 @@ def get_lr(it): } print(f"saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) + else: + patience += 1 if iter_num == 0 and eval_only: break @@ -331,6 +338,8 @@ def get_lr(it): # termination conditions if iter_num > max_iters: break +if iter_num <= max_iters: + print(f"early stopping after {patience} iterations with no validation loss improvement") if ddp: destroy_process_group()