Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ train:
num_workers: 32
batch_size: 16
optimizer: 'adam'
epochs: 800
adam:
lr: 0.0001
init_lr: 0.0001
final_lr: 0.00001
beta1: 0.5
beta2: 0.9
---
Expand Down
34 changes: 28 additions & 6 deletions utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import traceback

from model.generator import Generator
Expand All @@ -13,18 +12,35 @@
from .validation import validate


def cosine_decay(init_val, final_val, step, decay_steps):
alpha = final_val / init_val
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be init_val / final_val?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The learning rate decays. You might write a demo for testing.

init_val = 1e-4
final_val = 1e-5

Copy link

@casper-hansen casper-hansen Jan 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the following source it's "Minimum learning rate value as a fraction of learning_rate."
https://docs.w3cub.com/tensorflow~python/tf/train/cosine_decay/

Given the values, it looks like it's correct. The naming is just off - it should be the smallest value in the numerator and largest value in the denominator.

cosine_decay = 0.5 * (1 + math.cos(math.pi * step / decay_steps))
decayed = (1 - alpha) * cosine_decay + alpha
return init_val * decayed


def adjust_learning_rate(optimizer, epoch, hp):
init_lr = hp.train.adam.init_lr
final_lr = hp.train.adam.final_lr
decay_steps = hp.train.epochs
lr = cosine_decay(init_lr, final_lr, epoch, decay_steps)

for param_group in optimizer.param_groups:
param_group['lr'] = lr


def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str):
model_g = Generator(hp.audio.n_mel_channels).cuda()
model_d = MultiScaleDiscriminator().cuda()

optim_g = torch.optim.Adam(model_g.parameters(),
lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
lr=hp.train.adam.init_lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
optim_d = torch.optim.Adam(model_d.parameters(),
lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
lr=hp.train.adam.init_lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))

githash = get_commit_hash()

init_epoch = -1
elapsed_epochs = 0
step = 0

if chkpt_path is not None:
Expand All @@ -35,7 +51,7 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp,
optim_g.load_state_dict(checkpoint['optim_g'])
optim_d.load_state_dict(checkpoint['optim_d'])
step = checkpoint['step']
init_epoch = checkpoint['epoch']
elapsed_epochs = checkpoint['epoch']

if hp_str != checkpoint['hp_str']:
logger.warning("New hparams is different from checkpoint. Will use new.")
Expand All @@ -54,11 +70,14 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp,
try:
model_g.train()
model_d.train()
for epoch in itertools.count(init_epoch+1):

epochs = hp.train.epochs - elapsed_epochs
for epoch in range(epochs):
if epoch % hp.log.validation_interval == 0:
with torch.no_grad():
validate(hp, args, model_g, model_d, valloader, writer, step)

epoch += elapsed_epochs
trainloader.dataset.shuffle_mapping()
loader = tqdm.tqdm(trainloader, desc='Loading train data')
for (melG, audioG), (melD, audioD) in loader:
Expand All @@ -67,6 +86,9 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp,
melD = melD.cuda()
audioD = audioD.cuda()

adjust_learning_rate(optim_g, epoch, hp)
adjust_learning_rate(optim_d, epoch, hp)

# generator
optim_g.zero_grad()
fake_audio = model_g(melG)[:, :, :hp.audio.segment_length]
Expand Down