Skip to content

Commit 9d8fdde

Browse files
authored
Support resampling (#77)
1 parent 5f31b22 commit 9d8fdde

10 files changed

+96
-26
lines changed

python-api-examples/decode-file.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ def get_args():
7878

7979

8080
def main():
81-
sample_rate = 16000
82-
8381
args = get_args()
8482
assert_file_exists(args.encoder)
8583
assert_file_exists(args.decoder)
@@ -95,12 +93,16 @@ def main():
9593
decoder=args.decoder,
9694
joiner=args.joiner,
9795
num_threads=args.num_threads,
98-
sample_rate=sample_rate,
96+
sample_rate=16000,
9997
feature_dim=80,
10098
decoding_method=args.decoding_method,
10199
)
102100
with wave.open(args.wave_filename) as f:
103-
assert f.getframerate() == sample_rate, f.getframerate()
101+
# If the wave file has a different sampling rate from the one
102+
# expected by the model (16 kHz in our case), we will do
103+
# resampling inside sherpa-onnx
104+
wave_file_sample_rate = f.getframerate()
105+
104106
assert f.getnchannels() == 1, f.getnchannels()
105107
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
106108
num_samples = f.getnframes()
@@ -110,17 +112,17 @@ def main():
110112

111113
samples_float32 = samples_float32 / 32768
112114

113-
duration = len(samples_float32) / sample_rate
115+
duration = len(samples_float32) / wave_file_sample_rate
114116

115117
start_time = time.time()
116118
print("Started!")
117119

118120
stream = recognizer.create_stream()
119121

120-
stream.accept_waveform(sample_rate, samples_float32)
122+
stream.accept_waveform(wave_file_sample_rate, samples_float32)
121123

122-
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
123-
stream.accept_waveform(sample_rate, tail_paddings)
124+
tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32)
125+
stream.accept_waveform(wave_file_sample_rate, tail_paddings)
124126

125127
stream.input_finished()
126128

python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def main():
100100
recognizer = create_recognizer()
101101
print("Started! Please speak")
102102

103-
sample_rate = 16000
103+
# The model is using 16 kHz, we use 48 kHz here to demonstrate that
104+
# sherpa-onnx will do resampling inside.
105+
sample_rate = 48000
104106
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
105107
last_result = ""
106108
stream = recognizer.create_stream()

python-api-examples/speech-recognition-from-microphone.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,12 @@ def create_recognizer():
9292

9393

9494
def main():
95-
print("Started! Please speak")
9695
recognizer = create_recognizer()
97-
sample_rate = 16000
96+
print("Started! Please speak")
97+
98+
# The model is using 16 kHz, we use 48 kHz here to demonstrate that
99+
# sherpa-onnx will do resampling inside.
100+
sample_rate = 48000
98101
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
99102
last_result = ""
100103
stream = recognizer.create_stream()

sherpa-onnx/c-api/c-api.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream);
115115
/// decoding.
116116
///
117117
/// @param stream A pointer returned by CreateOnlineStream().
118-
/// @param sample_rate Sampler rate of the input samples. It has to be 16 kHz
119-
/// for models from icefall.
118+
/// @param sample_rate Sample rate of the input samples. If it is different
119+
/// from config.feat_config.sample_rate, we will do
120+
/// resampling inside sherpa-onnx.
120121
/// @param samples A pointer to a 1-D array containing audio samples.
121122
/// The range of samples has to be normalized to [-1, 1].
122123
/// @param n Number of elements in the samples array.

sherpa-onnx/csrc/features.cc

+43
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <vector>
1212

1313
#include "kaldi-native-fbank/csrc/online-feature.h"
14+
#include "sherpa-onnx/csrc/macros.h"
15+
#include "sherpa-onnx/csrc/resample.h"
1416

1517
namespace sherpa_onnx {
1618

@@ -50,6 +52,46 @@ class FeatureExtractor::Impl {
5052

5153
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
5254
std::lock_guard<std::mutex> lock(mutex_);
55+
56+
if (resampler_) {
57+
if (sampling_rate != resampler_->GetInputSamplingRate()) {
58+
SHERPA_ONNX_LOGE(
59+
"You changed the input sampling rate!! Expected: %d, given: "
60+
"%d",
61+
resampler_->GetInputSamplingRate(), sampling_rate);
62+
exit(-1);
63+
}
64+
65+
std::vector<float> samples;
66+
resampler_->Resample(waveform, n, false, &samples);
67+
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
68+
samples.size());
69+
return;
70+
}
71+
72+
if (sampling_rate != opts_.frame_opts.samp_freq) {
73+
SHERPA_ONNX_LOGE(
74+
"Creating a resampler:\n"
75+
" in_sample_rate: %d\n"
76+
" output_sample_rate: %d\n",
77+
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
78+
79+
float min_freq =
80+
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
81+
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
82+
83+
int32_t lowpass_filter_width = 6;
84+
resampler_ = std::make_unique<LinearResample>(
85+
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
86+
lowpass_filter_width);
87+
88+
std::vector<float> samples;
89+
resampler_->Resample(waveform, n, false, &samples);
90+
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
91+
samples.size());
92+
return;
93+
}
94+
5395
fbank_->AcceptWaveform(sampling_rate, waveform, n);
5496
}
5597

@@ -100,6 +142,7 @@ class FeatureExtractor::Impl {
100142
std::unique_ptr<knf::OnlineFbank> fbank_;
101143
knf::FbankOptions opts_;
102144
mutable std::mutex mutex_;
145+
std::unique_ptr<LinearResample> resampler_;
103146
};
104147

105148
FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)

sherpa-onnx/csrc/features.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ class FeatureExtractor {
2929
~FeatureExtractor();
3030

3131
/**
32-
@param sampling_rate The sampling_rate of the input waveform. Should match
33-
the one expected by the feature extractor.
34-
@param waveform Pointer to a 1-D array of size n
32+
@param sampling_rate The sampling_rate of the input waveform. If it does
33+
not equal to config.sampling_rate, we will do
34+
resampling inside.
35+
@param waveform Pointer to a 1-D array of size n. It must be normalized to
36+
the range [-1, 1].
3537
@param n Number of entries in waveform
3638
*/
3739
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);

sherpa-onnx/csrc/online-stream.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class OnlineStream::Impl {
1616
explicit Impl(const FeatureExtractorConfig &config)
1717
: feat_extractor_(config) {}
1818

19-
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
19+
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
2020
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
2121
}
2222

@@ -67,7 +67,7 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
6767

6868
OnlineStream::~OnlineStream() = default;
6969

70-
void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform,
70+
void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
7171
int32_t n) {
7272
impl_->AcceptWaveform(sampling_rate, waveform, n);
7373
}

sherpa-onnx/csrc/online-stream.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ class OnlineStream {
2020
~OnlineStream();
2121

2222
/**
23-
@param sampling_rate The sampling_rate of the input waveform. Should match
24-
the one expected by the feature extractor.
25-
@param waveform Pointer to a 1-D array of size n
23+
@param sampling_rate The sampling_rate of the input waveform. If it does
24+
not equal to config.sampling_rate, we will do
25+
resampling inside.
26+
@param waveform Pointer to a 1-D array of size n. It must be normalized to
27+
the range [-1, 1].
2628
@param n Number of entries in waveform
2729
*/
28-
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
30+
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
2931

3032
/**
3133
* InputFinished() tells the class you won't be providing any

sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
7676
std::vector<int64_t> blanks(context_size, blank_id);
7777
Hypotheses blank_hyp({{blanks, 0}});
7878
r.hyps = std::move(blank_hyp);
79+
r.tokens = std::move(blanks);
7980
return r;
8081
}
8182

sherpa-onnx/python/csrc/online-stream.cc

+18-4
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,27 @@
88

99
namespace sherpa_onnx {
1010

11+
constexpr const char *kAcceptWaveformUsage = R"(
12+
Process audio samples.
13+
14+
Args:
15+
sample_rate:
16+
Sample rate of the input samples. If it is different from the one
17+
expected by the model, we will do resampling inside.
18+
waveform:
19+
A 1-D float32 tensor containing audio samples. It must be normalized
20+
to the range [-1, 1].
21+
)";
22+
1123
void PybindOnlineStream(py::module *m) {
1224
using PyClass = OnlineStream;
1325
py::class_<PyClass>(*m, "OnlineStream")
14-
.def("accept_waveform",
15-
[](PyClass &self, float sample_rate, py::array_t<float> waveform) {
16-
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
17-
})
26+
.def(
27+
"accept_waveform",
28+
[](PyClass &self, float sample_rate, py::array_t<float> waveform) {
29+
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
30+
},
31+
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
1832
.def("input_finished", &PyClass::InputFinished);
1933
}
2034

0 commit comments

Comments
 (0)