-
Notifications
You must be signed in to change notification settings - Fork 1
/
wavTranscriber.py
114 lines (84 loc) · 3.29 KB
/
wavTranscriber.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import glob
import webrtcvad
import logging
import wavSplit
from deepspeech import Model
from timeit import default_timer as timer
'''
Load the pre-trained model into the memory
@param models: Output Grapgh Protocol Buffer file
@param alphabet: Alphabet.txt file
@param lm: Language model file
@param trie: Trie file
@Retval
Returns a list [DeepSpeech Object, Model Load Time, LM Load Time]
'''
def load_model(models, alphabet, lm, trie):
N_FEATURES = 26
N_CONTEXT = 9
BEAM_WIDTH = 500
LM_ALPHA = 0.75
LM_BETA = 1.85
model_load_start = timer()
ds = Model(models, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
model_load_end = timer() - model_load_start
logging.debug("Loaded model in %0.3fs." % (model_load_end))
lm_load_start = timer()
ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)
lm_load_end = timer() - lm_load_start
logging.debug('Loaded language model in %0.3fs.' % (lm_load_end))
return [ds, model_load_end, lm_load_end]
'''
Run Inference on input audio file
@param ds: Deepspeech object
@param audio: Input audio for running inference on
@param fs: Sample rate of the input audio file
@Retval:
Returns a list [Inference, Inference Time, Audio Length]
'''
def stt(ds, audio, fs):
inference_time = 0.0
audio_length = len(audio) * (1 / 16000)
# Run Deepspeech
logging.debug('Running inference...')
inference_start = timer()
output = ds.stt(audio, fs)
inference_end = timer() - inference_start
inference_time += inference_end
logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length))
return [output, inference_time]
'''
Resolve directory path for the models and fetch each of them.
@param dirName: Path to the directory containing pre-trained models
@Retval:
Retunns a tuple containing each of the model files (pb, alphabet, lm and trie)
'''
def resolve_models(dirName):
pb = glob.glob(dirName + "/*.pb")[0]
logging.debug("Found Model: %s" % pb)
alphabet = glob.glob(dirName + "/alphabet.txt")[0]
logging.debug("Found Alphabet: %s" % alphabet)
lm = glob.glob(dirName + "/lm.binary")[0]
trie = glob.glob(dirName + "/trie")[0]
logging.debug("Found Language Model: %s" % lm)
logging.debug("Found Trie: %s" % trie)
return pb, alphabet, lm, trie
'''
Generate VAD segments. Filters out non-voiced audio frames.
@param waveFile: Input wav file to run VAD on.0
@Retval:
Returns tuple of
segments: a bytearray of multiple smaller audio frames
(The longer audio split into mutiple smaller one's)
sample_rate: Sample rate of the input audio file
audio_length: Duraton of the input audio file
'''
def vad_segment_generator(wavFile, aggressiveness, frame_duration_ms=30, padding_duration_ms=300):
logging.debug("Caught the wav file @: %s" % (wavFile))
audio, sample_rate, audio_length = wavSplit.read_wave(wavFile)
assert sample_rate == 16000, "Only 16000Hz input WAV files are supported for now!"
vad = webrtcvad.Vad(int(aggressiveness))
frames = wavSplit.frame_generator(frame_duration_ms, audio, sample_rate)
frames = list(frames)
segments = wavSplit.vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames)
return [segment for segment in segments], sample_rate, audio_length