-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathtrain.py
executable file
·330 lines (277 loc) · 12.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#!/usr/bin/env python3
# Copyright (c) Yiwen Shao
# Apache 2.0
import argparse
import os
import shutil
import time
import random
import torch
import torch.nn.parallel
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.optim.lr_scheduler as lr_scheduler
from dataset import ChainDataset, AudioDataLoader, BucketSampler
from models import get_model
from pychain.loss import ChainLoss
from pychain.graph import ChainGraph
import simplefst
parser = argparse.ArgumentParser(description='PyChain training')
# Datasets
parser.add_argument('--train', type=str, required=True,
help='training set json file')
parser.add_argument('--valid', type=str, required=True,
help='valid set json file')
parser.add_argument('--den-fst', type=str, required=True,
help='denominator fst path')
# Optimization options
parser.add_argument('--epochs', default=15, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--train-bsz', default=128, type=int, metavar='N',
help='train batchsize')
parser.add_argument('--valid-bsz', default=128, type=int, metavar='N',
help='valid batchsize')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--dropout', default=0, type=float,
metavar='Dropout', help='Dropout ratio')
parser.add_argument('--optimizer', type=str, default='adam',
help='optimizer type')
parser.add_argument('--scheduler', type=str, default='plateau',
help='Learning rate scheduler')
parser.add_argument('--milestones', type=int, nargs='+', default=[5, 10],
help='Decrease learning rate at these epochs.(only for step decay)')
parser.add_argument('--gamma', type=float, default=0.1,
help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '--pf', default=10, type=int,
help='print frequency')
parser.add_argument('--beta1', default=0.9, type=float,
help='adam beta1')
parser.add_argument('--beta2', default=0.999, type=float, help='adam beta2')
parser.add_argument('--curriculum', default=-1, type=int,
help='curriculum learning epochs that will start from short sequences')
# Checkpoints
parser.add_argument('--exp', default='exp/tdnn', type=str, metavar='PATH',
help='path to save checkpoint and log (default: checkpoint)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
# Architecture
parser.add_argument('--arch', '-a', metavar='ARCH', default='TDNN',
choices=['TDNN', 'RNN', 'LSTM',
'GRU', 'TDNN-LSTM', 'TDNN-MFCC'],
help='model architecture: ')
parser.add_argument('--layers', default=5, type=int, help='number of layers')
parser.add_argument('--feat-dim', default=40, type=int,
help='number of features for each frame')
parser.add_argument('--hidden-dims', default=[256, 256, 256, 256, 256], type=int, nargs='+',
help='output dimensions for each hidden layer')
parser.add_argument('--num-targets', default=100, type=int,
help='number of nnet output dimensions (i.e. number of pdf-ids)')
parser.add_argument('--kernel-sizes', default=[3, 3, 3, 3, 3], type=int, nargs='+',
help='kernel sizes of TDNN/CNN layers (only required for TDNN)')
parser.add_argument('--dilations', default=[1, 1, 3, 3, 3], type=int, nargs='+',
help='dilations for TDNN/CNN kernels (only required for TDNN)')
parser.add_argument('--strides', default=[1, 1, 1, 1, 3], type=int, nargs='+',
help='strides for TDNN/CNN kernels (only required for TDNN)')
parser.add_argument('--residual', default=False, type=bool,
help='residual connection in TDNN')
parser.add_argument('--bidirectional', default=False, type=bool,
help='bidirectional rnn')
# LF-MMI Loss
parser.add_argument('--leaky', default=1e-5, type=float,
help='leaky hmm coefficient for the denominator')
# Feature extraction
parser.add_argument('--no-feat', action='store_true',
help='not using pre-extracted features but train from raw wav')
# Miscs
parser.add_argument('--seed', type=int, default=0, help='manual seed')
args = parser.parse_args()
print(args)
# Use CUDA
use_cuda = torch.cuda.is_available()
# Random seed
random.seed(args.seed)
torch.manual_seed(args.seed)
if use_cuda:
torch.cuda.manual_seed_all(args.seed)
best_loss = 1000 # best valid loss
def main():
global best_loss
writer = SummaryWriter(args.exp)
print('Saving model and logs to {}'.format(args.exp))
start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch
# Data
trainset = ChainDataset(args.train, no_feat=args.no_feat)
trainsampler = BucketSampler(trainset, args.train_bsz)
trainloader = AudioDataLoader(
trainset, batch_sampler=trainsampler, num_workers=4)
validset = ChainDataset(args.valid, no_feat=args.no_feat)
validloader = AudioDataLoader(
validset, batch_size=args.valid_bsz, num_workers=4)
# Model
print("==> creating model '{}'".format(args.arch))
model = get_model(args.feat_dim, args.num_targets, args.layers, args.hidden_dims, args.arch,
kernel_sizes=args.kernel_sizes, dilations=args.dilations,
strides=args.strides,
bidirectional=args.bidirectional,
dropout=args.dropout,
residual=args.residual)
print(model)
if use_cuda:
model = model.cuda()
print(' Total params: %.2fM' % (sum(p.numel()
for p in model.parameters()) / 1000000.0))
# loss
den_fst = simplefst.StdVectorFst.read(args.den_fst)
den_graph = ChainGraph(den_fst)
criterion = ChainLoss(den_graph, args.leaky)
# optimizer
if args.optimizer == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay)
elif args.optimizer == 'adam':
optimizer = optim.Adam(
model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
# Resume
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isfile(
args.resume), 'Error: no checkpoint directory found!'
args.checkpoint = os.path.dirname(args.resume)
checkpoint = torch.load(args.resume)
best_loss = checkpoint['best_loss']
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
# learning rate scheduler
if args.scheduler == 'step':
scheduler = lr_scheduler.MultiStepLR(
optimizer, milestones=args.milestones, gamma=args.gamma, last_epoch=start_epoch - 1)
elif args.scheduler == 'exp':
gamma = args.gamma ** (1.0 / args.epochs) # final_lr = init_lr * gamma
scheduler = lr_scheduler.ExponentialLR(
optimizer, gamma=gamma, last_epoch=start_epoch - 1)
elif args.scheduler == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer, factor=args.gamma, patience=1)
# Train and val
for epoch in range(start_epoch, args.epochs):
if epoch >= args.curriculum:
trainsampler.shuffle(epoch)
train_loss = train(
trainloader, model, criterion, optimizer, writer, epoch, use_cuda)
valid_loss = test(
validloader, model, criterion, writer, epoch, use_cuda)
# save model
is_best = valid_loss < best_loss
best_loss = min(valid_loss, best_loss)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'loss': valid_loss,
'best_loss': best_loss,
'optimizer': optimizer.state_dict(),
'args': args,
}, is_best, exp=args.exp)
if args.scheduler == 'plateau':
scheduler.step(valid_loss)
else:
scheduler.step()
print('Best loss:')
print(best_loss)
def train(trainloader, model, criterion, optimizer, writer, epoch, use_cuda):
# switch to train mode
model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
end = time.time()
lr = optimizer.param_groups[0]['lr']
writer.add_scalar('lr', lr, epoch)
for batch_idx, (inputs, input_lengths, utt_ids, graphs) in enumerate(trainloader):
# measure data loading time
data_time.update(time.time() - end)
if use_cuda:
inputs = inputs.cuda()
# compute output
outputs, output_lengths = model(inputs, input_lengths)
loss = criterion(outputs, output_lengths, graphs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure accuracy and record loss
losses.update(loss.detach().item(), output_lengths.sum())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# print progress
if batch_idx % args.print_freq == 0:
print('Train: [{0}][{1}/{2}]\t'
'Lr: {lr}\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
epoch, batch_idx, len(trainloader), lr=lr, batch_time=batch_time,
loss=losses))
# log to TensorBoard
writer.add_scalar('train_loss', losses.avg, epoch)
return losses.avg
def test(testloader, model, criterion, writer, epoch, use_cuda):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for batch_idx, (inputs, input_lengths, utt_ids, graphs) in enumerate(testloader):
# measure data loading time
data_time.update(time.time() - end)
if use_cuda:
inputs = inputs.cuda()
# compute output
outputs, output_lengths = model(inputs, input_lengths)
loss = criterion(outputs, output_lengths, graphs)
# measure accuracy and record loss
losses.update(loss.detach().item(), output_lengths.sum())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# print progress
if batch_idx % 1 == 0: # print each batch stats since validset is small
print('Validation: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
epoch, batch_idx, len(testloader), batch_time=batch_time,
loss=losses))
# log to TensorBoard
writer.add_scalar('valid_loss', losses.avg, epoch)
return losses.avg
class AverageMeter(object):
"""Computes and stores the average and current value
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def save_checkpoint(state, is_best, exp='exp', filename='checkpoint.pth.tar'):
filepath = os.path.join(exp, filename)
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(
exp, 'model_best.pth.tar'))
if __name__ == '__main__':
main()