Skip to content

Commit

Permalink
resume training, functional lr scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
SiLiKhon committed Oct 14, 2020
1 parent ea4c2d2 commit 43701d0
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 24 deletions.
20 changes: 15 additions & 5 deletions models/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,24 @@ def __call__(self, step):


class ScheduleLRCallback:
def __init__(self, model, decay_rate, writer):
def __init__(self, model, func_gen, func_disc, writer):
self.model = model
self.decay_rate = decay_rate
self.func_gen = func_gen
self.func_disc = func_disc
self.writer = writer

def __call__(self, step):
self.model.disc_opt.lr.assign(self.model.disc_opt.lr * self.decay_rate)
self.model.gen_opt.lr.assign(self.model.gen_opt.lr * self.decay_rate)
self.model.disc_opt.lr.assign(self.func_disc(step))
self.model.gen_opt.lr.assign(self.func_gen(step))
with self.writer.as_default():
tf.summary.scalar("discriminator learning rate", self.model.disc_opt.lr, step)
tf.summary.scalar("generator learning rate", self.model.gen_opt.lr, step)
tf.summary.scalar("generator learning rate", self.model.gen_opt.lr, step)


def get_scheduler(lr, lr_decay):
if isinstance(lr_decay, str):
return eval(lr_decay)

def schedule_lr(step):
return lr * lr_decay**step
return schedule_lr
5 changes: 3 additions & 2 deletions models/model_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def gen_loss_js(d_real, d_fake):

class Model_v4:
def __init__(self, config):
self.disc_opt = tf.keras.optimizers.RMSprop(config['lr'])
self.gen_opt = tf.keras.optimizers.RMSprop(config['lr'])
self.disc_opt = tf.keras.optimizers.RMSprop(config['lr_disc'])
self.gen_opt = tf.keras.optimizers.RMSprop(config['lr_gen'])
self.gp_lambda = config['gp_lambda']
self.gpdata_lambda = config['gpdata_lambda']
self.num_disc_updates = config['num_disc_updates']
Expand Down Expand Up @@ -113,6 +113,7 @@ def _load_weights(self, checkpoint, gen_or_disc):
network.load_weights(str(checkpoint))

if 'optimizer_weights' in model_file:
print('Also recovering the optimizer state')
opt_weight_values = hdf5_format.load_optimizer_weights_from_hdf5_group(
model_file
)
Expand Down
4 changes: 2 additions & 2 deletions models/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

def train(data_train, data_val, train_step_fn, loss_eval_fn, num_epochs, batch_size,
train_writer=None, val_writer=None, callbacks=[], features_train=None, features_val=None,
features_noise=None):
features_noise=None, first_epoch=0):
if not ((features_train is None) or (features_val is None)):
assert features_train is not None, 'train: features should be provided for both train and val'
assert features_val is not None, 'train: features should be provided for both train and val'

for i_epoch in range(num_epochs):
for i_epoch in range(first_epoch, num_epochs):
print("Working on epoch #{}".format(i_epoch), flush=True)

tf.keras.backend.set_learning_phase(1) # training
Expand Down
4 changes: 3 additions & 1 deletion models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ def load_weights(model, model_path, epoch=None):
disc_checkpoint = model_path / f"discriminator_{epoch:05d}.h5"

model.load_generator(gen_checkpoint)
model.load_discriminator(disc_checkpoint)
model.load_discriminator(disc_checkpoint)

return epoch
42 changes: 28 additions & 14 deletions run_model_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from data import preprocessing
from models.utils import latest_epoch, load_weights
from models.training import train
from models.callbacks import SaveModelCallback, WriteHistSummaryCallback, ScheduleLRCallback
from models.callbacks import SaveModelCallback, WriteHistSummaryCallback, ScheduleLRCallback, get_scheduler
from models.model_v4 import Model_v4
from metrics import evaluate_model
import cuda_gpu_config
Expand Down Expand Up @@ -52,6 +52,11 @@ def load_config(file):
(config['feature_noise_decay'] is None)
), 'Noise power and decay must be both provided'

if 'lr_disc' not in config: config['lr_disc'] = config['lr']
if 'lr_gen' not in config: config['lr_gen' ] = config['lr']
if 'lr_schedule_rate_disc' not in config: config['lr_schedule_rate_disc'] = config['lr_schedule_rate']
if 'lr_schedule_rate_gen' not in config: config['lr_schedule_rate_gen' ] = config['lr_schedule_rate']

return config


Expand All @@ -62,26 +67,29 @@ def main():

model_path = Path('saved_models') / args.checkpoint_name

config_path = str(model_path / 'config.yaml')
continue_training = False
if args.prediction_only:
assert model_path.exists(), "Couldn't find model directory"
assert not args.config, "Config should be read from model path when doing prediction"
args.config = str(model_path / 'config.yaml')
else:
assert not model_path.exists(), "Model directory already exists"
assert args.config, "No config provided"

model_path.mkdir(parents=True)
config_destination = str(model_path / 'config.yaml')
shutil.copy(args.config, config_destination)
if not args.config:
assert model_path.exists(), "Couldn't find model directory"
continue_training = True
else:
assert not model_path.exists(), "Model directory already exists"

args.config = config_destination
model_path.mkdir(parents=True)
shutil.copy(args.config, config_path)

args.config = config_path
config = load_config(args.config)

model = Model_v4(config)

if args.prediction_only:
load_weights(model, model_path)
next_epoch = 0
if args.prediction_only or continue_training:
next_epoch = load_weights(model, model_path) + 1

preprocessing._VERSION = model.data_version
data, features = preprocessing.read_csv_2d(pad_range=model.pad_range, time_range=model.time_range)
Expand Down Expand Up @@ -131,12 +139,18 @@ def features_noise(epoch):
save_period=config['save_every'], writer=writer_val
)
schedule_lr = ScheduleLRCallback(
model, decay_rate=config['lr_schedule_rate'], writer=writer_val
model, writer=writer_val,
func_gen=get_scheduler(config['lr_gen'], config['lr_schedule_rate_gen']),
func_disc=get_scheduler(config['lr_disc'], config['lr_schedule_rate_disc'])
)
if continue_training:
schedule_lr(next_epoch - 1)

train(Y_train, Y_test, model.training_step, model.calculate_losses, config['num_epochs'],
config['batch_size'], train_writer=writer_train, val_writer=writer_val,
callbacks=[write_hist_summary, save_model, schedule_lr],
features_train=X_train, features_val=X_test, features_noise=features_noise)
callbacks=[schedule_lr, save_model, write_hist_summary],
features_train=X_train, features_val=X_test, features_noise=features_noise,
first_epoch=next_epoch)


if __name__ == '__main__':
Expand Down

0 comments on commit 43701d0

Please sign in to comment.