Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping #453

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
# adamw optimizer
learning_rate = 6e-4 # max learning rate
max_iters = 600000 # total number of training iterations
early_stopping_iters = None # number of iterations to stop after with no validation loss improvement
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
Expand All @@ -78,6 +79,8 @@
config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------

assert not (early_stopping_iters and always_save_checkpoint), "early_stopping and always_save_checkpoint are mutually exclusive"

# various inits, derived attributes, I/O setup
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
Expand Down Expand Up @@ -252,7 +255,8 @@ def get_lr(it):
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0
while True:
patience = 0
while not early_stopping_iters or patience < early_stopping_iters:

# determine and set the learning rate for this iteration
lr = get_lr(iter_num) if decay_lr else learning_rate
Expand All @@ -272,6 +276,7 @@ def get_lr(it):
"mfu": running_mfu*100, # convert to percentage
})
if losses['val'] < best_val_loss or always_save_checkpoint:
patience = 0
best_val_loss = losses['val']
if iter_num > 0:
checkpoint = {
Expand All @@ -284,6 +289,8 @@ def get_lr(it):
}
print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
else:
patience += 1
if iter_num == 0 and eval_only:
break

Expand Down Expand Up @@ -331,6 +338,8 @@ def get_lr(it):
# termination conditions
if iter_num > max_iters:
break
if iter_num <= max_iters:
print(f"early stopping after {patience} iterations with no validation loss improvement")

if ddp:
destroy_process_group()