Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for merge] Add full-chunk mode CTC decoding for models from WeNet #872

Open
wants to merge 1 commit into
base: v2.0-pre
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion k2/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
add_subdirectory(csrc)

add_subdirectory(bin)
add_subdirectory(sp)
1 change: 1 addition & 0 deletions k2/torch/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ set(k2_torch_test_srcs
dense_fsa_vec_test.cu
deserialization_test.cu
fsa_class_test.cu
utils_test.cu
wave_reader_test.cu
)

Expand Down
28 changes: 28 additions & 0 deletions k2/torch/csrc/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,32 @@ torch::Tensor TensorToTorch(Tensor &tensor) {
[saved_region = tensor.GetRegion()](void *) {}, options);
}

std::vector<std::string> SplitStringToVector(const std::string &s,
const char *delim) {
std::vector<std::string> fields;
size_t start = 0;
size_t pos = 0;
while ((pos = s.find_first_of(delim, start)) != std::string::npos) {
if (pos != start) {
fields.push_back(s.substr(start, pos - start));
}
start = pos + 1;
}
if (start < s.size()) {
fields.push_back(s.substr(start));
}
return fields;
}

std::vector<std::string> ReadLines(const std::string &filename) {
std::vector<std::string> ans;

std::ifstream is(filename);
std::string line;
while (std::getline(is, line)) {
ans.push_back(std::move(line));
}
return ans;
}

} // namespace k2
7 changes: 7 additions & 0 deletions k2/torch/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ torch::Tensor IndexSelect(torch::Tensor src, torch::Tensor index,
return TensorToTorch(ans);
}

/// Read a file line by line.
std::vector<std::string> ReadLines(const std::string &filename);

/// Split a string by a delimiter. The split parts are returned in a vector.
std::vector<std::string> SplitStringToVector(const std::string &s,
const char *delim);

} // namespace k2

#endif // K2_TORCH_CSRC_UTILS_H_
46 changes: 46 additions & 0 deletions k2/torch/csrc/utils_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/**
* Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "gtest/gtest.h"
#include "k2/torch/csrc/utils.h"

namespace k2 {
TEST(SplitStringToVector, DelimIsSpace) {
std::string s = "ab c d e";
auto ans = SplitStringToVector(s, " ");
EXPECT_EQ(ans.size(), 4u);
EXPECT_EQ(ans[0], "ab");
EXPECT_EQ(ans[1], "c");
EXPECT_EQ(ans[2], "d");
EXPECT_EQ(ans[3], "e");
}

TEST(SplitStringToVector, EmptyInput) {
std::string s = "";
auto ans = SplitStringToVector(s, " ");
EXPECT_EQ(ans.size(), 0u);
}

TEST(SplitStringToVector, OnlyOneField) {
std::string s = "abc";
auto ans = SplitStringToVector(s, " ");
EXPECT_EQ(ans.size(), 1u);
EXPECT_EQ(ans[0], "abc");
}

} // namespace k2
38 changes: 23 additions & 15 deletions k2/torch/csrc/wave_reader.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ struct WaveHeader {
static_assert(sizeof(WaveHeader) == 44, "");

// Read a wave file of mono-channel.
// Return its samples in a 1-D torch.float32 tensor, normalized
// by dividing 32768.
// Return its samples in a 1-D torch.float32 tensor, divided by the given
// normalizer.
// Also, it returns the sample rate.
std::pair<torch::Tensor, float> ReadWaveImpl(std::istream &is) {
std::pair<torch::Tensor, float> ReadWaveImpl(std::istream &is,
float normalizer) {
WaveHeader header;
is.read(reinterpret_cast<char *>(&header), sizeof(header));
K2_CHECK((bool)is) << "Failed to read wave header";
Expand All @@ -91,34 +92,41 @@ std::pair<torch::Tensor, float> ReadWaveImpl(std::istream &is) {
header.subchunk2_size);

K2_CHECK((bool)is) << "Failed to read wave samples";
data = (data / 32768.).to(torch::kFloat32);
data = (data / normalizer).to(torch::kFloat32);
return {data, sample_rate};
}

} // namespace

WaveReader::WaveReader(const std::string &filename) {
WaveReader::WaveReader(const std::string &filename,
float normalizer /*=32768*/) {
std::ifstream is(filename, std::ifstream::binary);
std::tie(data_, sample_rate_) = ReadWaveImpl(is);
std::tie(data_, sample_rate_) = ReadWaveImpl(is, normalizer);
}

WaveReader::WaveReader(std::istream &is) {
std::tie(data_, sample_rate_) = ReadWaveImpl(is);
WaveReader::WaveReader(std::istream &is, float normalizer /*=32768*/) {
std::tie(data_, sample_rate_) = ReadWaveImpl(is, normalizer);
}

torch::Tensor ReadWave(const std::string &filename,
float expected_sample_rate) {
WaveReader reader(filename);
K2_CHECK_EQ(reader.SampleRate(), expected_sample_rate);
return reader.Data();
torch::Tensor ReadWave(const std::string &filename, float expected_sample_rate,
float normalizer /*=32768*/) {
try {
WaveReader reader(filename, normalizer);
K2_CHECK_EQ(reader.SampleRate(), expected_sample_rate);
return reader.Data();
} catch (const std::runtime_error &) {
K2_LOG(INFO) << "Failed to read " << filename;
throw;
}
}

std::vector<torch::Tensor> ReadWave(const std::vector<std::string> &filenames,
float expected_sample_rate) {
float expected_sample_rate,
float normalizer /*=32768*/) {
std::vector<torch::Tensor> ans;
ans.reserve(filenames.size());
for (const auto &path : filenames) {
ans.emplace_back(ReadWave(path, expected_sample_rate));
ans.emplace_back(ReadWave(path, expected_sample_rate, normalizer));
}
return ans;
}
Expand Down
19 changes: 10 additions & 9 deletions k2/torch/csrc/wave_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ class WaveReader {
/** Construct a wave reader from a wave filename, encoded in PCM format.

@param filename Path to a wave file. Must be mono and PCM encoded.
Note: Samples are divided by 32768 so that they are
in the range [-1, 1)
@param normalizer Divide audio samples by this number.
*/
explicit WaveReader(const std::string &filename);
explicit WaveReader(const std::string &filename, float normalizer = 32768);

/** Construct a wave reader from a input stream.
*
See the help in the above function. You can open a file
with a std::ifstream and pass it to this function.
*/
explicit WaveReader(std::istream &is);
explicit WaveReader(std::istream &is, float normalizer = 32768);

/// Return a 1-D tensor with dtype torch.float32
const torch::Tensor &Data() const { return data_; }
Expand All @@ -57,7 +57,6 @@ class WaveReader {
private:
/// A 1-D tensor with dtype torch.float32
torch::Tensor data_;

float sample_rate_;
};

Expand All @@ -66,15 +65,17 @@ class WaveReader {
@param filename Path to a wave file. It MUST be single channel, PCM encoded.
@param expected_sample_rate Expected sample rate of the wave file. If the
sample rate don't match, it throws an exception.
@param normalizer Divide audio samples by this number.

@return Return a 1-D torch tensor with dtype torch.float32. Samples are
normalized to the range [-1, 1).
@return Return a 1-D torch tensor with dtype torch.float32.
*/
torch::Tensor ReadWave(const std::string &filename, float expected_sample_rate);
torch::Tensor ReadWave(const std::string &filename, float expected_sample_rate,
float normalizer = 32768);

/// Same `ReadWave` above. It supports reading a list of wave files.
std::vector<torch::Tensor> ReadWave(const std::vector<std::string> &filenames,
float expected_sample_rate);
float expected_sample_rate,
float normalizer = 32768);

} // namespace k2

Expand Down
19 changes: 19 additions & 0 deletions k2/torch/sp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# it is located in k2/csrc/cmake/transform.cmake
include(transform)

set(bin_dep_libs
${TORCH_LIBRARIES}
k2_torch
sentencepiece-static # see cmake/sentencepiece.cmake
)

#----------------------------------------
# CTC decoding
#----------------------------------------
set(sp_ctc_decode_srcs sp_ctc_decode.cu)
if(NOT K2_WITH_CUDA)
transform(OUTPUT_VARIABLE sp_ctc_decode_srcs SRCS ${sp_ctc_decode_srcs})
endif()
add_executable(sp_ctc_decode ${sp_ctc_decode_srcs})
set_property(TARGET sp_ctc_decode PROPERTY CXX_STANDARD 14)
target_link_libraries(sp_ctc_decode ${bin_dep_libs})
Loading