-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathfeature.py
83 lines (75 loc) · 2.22 KB
/
feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from speechbrain.processing.features import (
STFT,
spectral_magnitude,
Filterbank,
DCT,
Deltas,
ContextWindow,
)
class AccSpec(torch.nn.Module):
"""Generate features for input to the speech pipeline.
Arguments
---------
context : bool (default: False)
Whether or not to append forward and backward contexts to
the features.
sample_rate : int (default: 16000)
Sampling rate for the input waveforms.
win_length : float (default: 25)
Length (in ms) of the sliding window used to compute the STFT.
hop_length : float (default: 10)
Length (in ms) of the hop of the sliding window used to compute
the STFT.
n_fft : int (default: 400)
Number of samples to use in each stft.
left_frames : int (default: 5)
Number of frames of left context to add.
right_frames : int (default: 5)
Number of frames of right context to add.
Example
-------
>>> import torch
>>> inputs = torch.randn([10, 500])
>>> feature_maker = Fbank()
>>> feats = feature_maker(inputs)
>>> feats.shape
torch.Size([10, 101, 40])
"""
def __init__(
self,
context=False,
sample_rate=500,
n_fft=80,
left_frames=5,
right_frames=5,
win_length=80,
hop_length=20,
):
super().__init__()
self.context = context
self.compute_STFT = STFT(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
)
self.context_window = ContextWindow(
left_frames=left_frames, right_frames=right_frames,
)
def forward(self, wav):
"""Returns a set of features generated from the input Acc waveforms.
Arguments
---------
wav : tensor
A batch of Acc signals to transform to features.
"""
with torch.no_grad():
STFT = self.compute_STFT(wav)
mag = spectral_magnitude(STFT)
if self.context:
mag = self.context_window(mag)
# mag[:,:,16:41]
# mag[:,:,10:41]
# mag[:,:,:]
return mag[:,:,10:]