-
Notifications
You must be signed in to change notification settings - Fork 0
/
libpreparedata.lua
123 lines (110 loc) · 4.02 KB
/
libpreparedata.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
require "libhtktoth"
require "nn"
-- th preparedata scpfile globalnorm mlffile outputdir
--require 'torch'
function readmlf(filename)
fin = io.open(filename,'r')
label = {}
for line in fin:lines() do
local l = line:split(' ')
label[l[1]] = l[2]
end
end
function parseline(str)
local t = {}
for k,v in string.gmatch(str, "%S*") do
if 0~=string.len(k) then
t[#t+1] = k
end
end
return t
end
-- Reading in the global transform which is already preprocessed in bash
-- The global transform already has already calcualted the parameters -\mu and 1/\sigma
-- as the mean and the inverse covariance respectively
function readglobalnorm(filename)
local fin = io.open(filename,'r')
local line = fin:read()
local line = fin:read()
-- First read in the biases of the transforms, which are useless
-- dim_bias = tonumber(line:split(' ')[2])
-- SOme of the transfiles do not have a linebreak
local meansstr = {}
-- transform file was written so that there is no linebreak after the first line
if((#(line:split(' ')) > 2)) then
local linesplit = line:split(' ')
for i=4,#linesplit do
meansstr[#meansstr+1] = linesplit[i]
end
-- Linebreak after the "v XXXXX" line
else
line = fin:read()
meansstr = line:split(' ')
end
local means = torch.DoubleTensor(#meansstr)
-- Convert all the means to a number
for i = 1,#meansstr do
means[i] = tonumber(meansstr[i])
end
local line = fin:read()
line = fin:read()
-- dim_window is the dimension of the input vector to the dnn
local dim_window = tonumber(line:split(' ')[2])
line = fin:read()
local windowstr = {}
-- We assume having no new line here once again
if((#(line:split(' ')) > 2))then
local linesplit = line:split(' ')
for i=4,#linesplit do
windowstr[#windowstr+1] = linesplit[i]
end
else
-- One blank line
local line = fin:read()
-- Now we got the features
windowstr = line:split(' ')
end
local window = torch.DoubleTensor(#windowstr)
-- Reading in the 1/sigma aka variances
for i = 1,#windowstr do
window[i] = tonumber(windowstr[i])
end
assert(means:size(1) == window:size(1),"Error when loading the global.trans file, meansize "..means:size(1).." does not match variance size ".. window:size(1))
fin:close()
return means, window
end
function readfile(inputfile,extframe)
fin = assert(io.open(inputfile,'r'))
local htkfeat = loadhtk(inputfile,extframe)
local featheader = loadheader(inputfile)
dim_feat = htkfeat:size(2)
if (dim_feat ~= means:size(1)) then
print("Feature dimension "..dim_feat.." does not match globalnorm dimension "..means:size(1))
end
return htkfeat,featheader
end
-- Writes out the outputfile and normalizes its frames with T-norm
function writefile(outputfile, frame, extframe,htkheader)
-- Default is 5 frames left and right
local out_feat = {}
-- setn(out_feat,n_frames * #means)
local nsamples = frame:size(1)
--We extend the frame window left and right by nextframe frames
local nextframes= frame:size(2)
local m = nn.Replicate(nsamples)
-- We use replicate to extend the window size to use the map function (size needs to be equal)
-- Replicate does not have any extra cost, only resets the stride parameter once we did iterate
-- Already once over a certain tensor
local means = m:forward(means)
local window = m:forward(window)
-- Apply T-Norm here, T-Norm is defined as:
-- ss = (ss_old - mu)/cov
frame:map2(means,window,function(old_frame,mean,cov) return (old_frame+mean)*cov end)
-- print("Spend " .. te-ts.. "s in normalization")
-- frame:map2(windows,means)
out_feat = torch.totable(frame:resize(nsamples*nextframes))
local newsamplesize = htkheader.nsamples*(2*extframe) + htkheader.nsamples
-- Remove the crc checksum ...
local newparmkind = string.gsub(htkheader.parmkind,"_K","")
writehtk(outputfile,newsamplesize,htkheader.sampleperiod,htkheader.samplesize/4,newparmkind,out_feat)
end