Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CTC models #177

Merged
merged 2 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sherpa/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Please sort the filenames alphabetically
set(sherpa_srcs
ctc_conformer_model.cc
endpoint.cc
fbank_features.cc
file_utils.cc
Expand Down
55 changes: 55 additions & 0 deletions sherpa/csrc/ctc_conformer_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sherpa/csrc/ctc_conformer_model.h"

#include <string>
#include <vector>

namespace sherpa {

CtcConformerModel::CtcConformerModel(const std::string &filename,
torch::Device device /*= torch::kCPU*/,
bool optimize_for_inference /*= false*/)
: device_(device) {
model_ = torch::jit::load(filename, device);
model_.eval();
#if SHERPA_TORCH_VERSION_MAJOR > 1 || \
(SHERPA_TORCH_VERSION_MAJOR == 1 && SHERPA_TORCH_VERSION_MINOR >= 10)
// torch::jit::optimize_for_inference is available only in torch>=1.10
if (optimize_for_inference) {
model_ = torch::jit::optimize_for_inference(model_);
}
#endif
}

torch::IValue CtcConformerModel::Forward(
const std::vector<torch::IValue> &input) {
return model_(input);
}

torch::Tensor CtcConformerModel::GetLogSoftmaxOut(
torch::IValue forward_out) const {
return forward_out.toTuple()->elements()[0].toTensor();
}

torch::Tensor CtcConformerModel::GetLogSoftmaxOutLength(
torch::IValue forward_out) const {
return forward_out.toTuple()->elements()[1].toTensor();
}

} // namespace sherpa
78 changes: 78 additions & 0 deletions sherpa/csrc/ctc_conformer_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SHERPA_CSRC_CTC_CONFORMER_MODEL_H_
#define SHERPA_CSRC_CTC_CONFORMER_MODEL_H_

#include <string>
#include <vector>

#include "sherpa/csrc/ctc_model.h"
namespace sherpa {

/** This class models the Conformer model from icefall.
*
* See
* https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/conformer_ctc/train.py#L668
*/
class CtcConformerModel : public CtcModel {
public:
~CtcConformerModel() override = default;

/**
* @param filename Path name of the torch script model.
* @param device The model will be moved to this device
* @param optimize_for_inference true to invoke
* torch::jit::optimize_for_inference().
*/
explicit CtcConformerModel(const std::string &filename,
torch::Device device = torch::kCPU,
bool optimize_for_inference = false);

torch::Device Device() const override { return device_; }

/** Run the forward method of the model.
* See
* https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/conformer_ctc/transformer.py#L162
* for its documentation in Python.
*
* @param input It has two element. The first element contains the 3-D
* features of shape (N, T, C); while the second element
* contains the supervision_segments. See the above link
* for the format of it.
*
* @return Return a tuple containing 3 elements, but we only use the first 2
* for CTC decoding. The first element contains the log_softmax output
* of the model with shape (N, T', C').
* The second element contains number of frames of the first
* element before padding.
*/
torch::IValue Forward(const std::vector<torch::IValue> &input) override;

torch::Tensor GetLogSoftmaxOut(torch::IValue forward_out) const override;

torch::Tensor GetLogSoftmaxOutLength(
torch::IValue forward_out) const override;

private:
torch::Device device_;
torch::jit::Module model_;
};

} // namespace sherpa

#endif // SHERPA_CSRC_CTC_CONFORMER_MODEL_H_
53 changes: 53 additions & 0 deletions sherpa/csrc/ctc_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SHERPA_CSRC_CTC_MODEL_H_
#define SHERPA_CSRC_CTC_MODEL_H_

#include <vector>

#include "torch/script.h"

namespace sherpa {

class CtcModel {
public:
virtual ~CtcModel() = default;

// Subsampling factor of the model
virtual int32_t SubsamplingFactor() const { return 4; }

// Return the underlying device where computation would happen
virtual torch::Device Device() const = 0;

// Run the model with a given input.
virtual torch::IValue Forward(const std::vector<torch::IValue> &input) = 0;

// Get the log softmax output of the network from the output of Forward
// method.
// The returned tensor has shape (N, T, C).
virtual torch::Tensor GetLogSoftmaxOut(torch::IValue forward_out) const = 0;

// Get the output length before padding from the output of Forward method.
// The returned tensor has shape (N,)
virtual torch::Tensor GetLogSoftmaxOutLength(
torch::IValue forward_out) const = 0;
};

} // namespace sherpa

#endif // SHERPA_CSRC_CTC_MODEL_H_
2 changes: 1 addition & 1 deletion sherpa/csrc/rnnt_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class RnntModel {
virtual int32_t ContextSize() const = 0;
virtual int32_t VocabSize() const = 0;

int32_t SubsamplingFactor() const { return 4; }
virtual int32_t SubsamplingFactor() const { return 4; }

/** Run the decoder network.
*
Expand Down