-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathevaluate.py
113 lines (93 loc) · 3.87 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import json
import time
import argparse
from tqdm import tqdm
from collections import defaultdict
import math
import argparse
import importlib
import data
from utils import fix_seeds, remove_from_dict, prepare_bpe
from data.collate import collate_fn, gpu_collate, no_pad_collate
from data.transforms import (
Compose, AddLengths, AudioSqueeze, TextPreprocess,
MaskSpectrogram, ToNumpy, BPEtexts, MelSpectrogram,
ToGpu, Pad, NormalizedMelSpectrogram
)
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset, ConcatDataset
# from tensorboardX import SummaryWriter
import numpy as np
from functools import partial
# model:
from model import configs as quartznet_configs
from model.quartznet import QuartzNet
# utils:
import yaml
from easydict import EasyDict as edict
from utils import fix_seeds, remove_from_dict, prepare_bpe
import wandb
from decoder import GreedyDecoder, BeamCTCDecoder
def evaluate(config):
fix_seeds(seed=config.train.get('seed', 42))
dataset_module = importlib.import_module(f'.{config.dataset.name}', data.__name__)
bpe = prepare_bpe(config)
transforms_val = Compose([
TextPreprocess(),
ToNumpy(),
BPEtexts(bpe=bpe),
AudioSqueeze()
])
batch_transforms_val = Compose([
ToGpu('cuda' if torch.cuda.is_available() else 'cpu'),
NormalizedMelSpectrogram(
sample_rate=config.dataset.get('sample_rate', 16000), # for LJspeech
n_mels=config.model.feat_in,
normalize=config.dataset.get('normalize', None)
).to('cuda' if torch.cuda.is_available() else 'cpu'),
AddLengths(),
Pad()
])
val_dataset = dataset_module.get_dataset(config, transforms=transforms_val, part='val')
val_dataloader = DataLoader(val_dataset, num_workers=config.train.get('num_workers', 4),
batch_size=1, collate_fn=no_pad_collate)
model = QuartzNet(
model_config=getattr(quartznet_configs, config.model.name, '_quartznet5x5_config'),
**remove_from_dict(config.model, ['name'])
)
print(model)
if config.train.get('from_checkpoint', None) is not None:
model.load_weights(config.train.from_checkpoint)
if torch.cuda.is_available():
model = model.cuda()
criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
decoder = BeamCTCDecoder(bpe=bpe)
model.eval()
val_stats = defaultdict(list)
for batch_idx, batch in enumerate(val_dataloader):
batch = batch_transforms_val(batch)
with torch.no_grad():
logits = model(batch['audio'])
output_length = torch.ceil(batch['input_lengths'].float() / model.stride).int()
loss = criterion(logits.permute(2, 0, 1).log_softmax(dim=2), batch['text'], output_length, batch['target_lengths'])
target_strings = decoder.convert_to_strings(batch['text'])
decoded_output = decoder.decode(logits.permute(0, 2, 1).softmax(dim=2))
wer = np.mean([decoder.wer(true, pred) for true, pred in zip(target_strings, decoded_output)])
cer = np.mean([decoder.cer(true, pred) for true, pred in zip(target_strings, decoded_output)])
val_stats['val_loss'].append(loss.item())
val_stats['wer'].append(wer)
val_stats['cer'].append(cer)
for k, v in val_stats.items():
val_stats[k] = np.mean(v)
val_stats['val_samples'] = wandb.Table(columns=['gt_text', 'pred_text'], data=zip(target_strings, decoded_output))
print(val_stats)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluation model.')
parser.add_argument('--config', default='configs/train_LJSpeech.yml',
help='path to config file')
args = parser.parse_args()
with open(args.config, 'r') as f:
config = edict(yaml.safe_load(f))
evaluate(config)