diff --git a/sherpa/cpp_api/autocast.h b/sherpa/cpp_api/autocast.h new file mode 100644 index 000000000..8c33572d7 --- /dev/null +++ b/sherpa/cpp_api/autocast.h @@ -0,0 +1,57 @@ +// sherpa/cpp_api/autocast.h +// +// Copyright (c) 2022 Xiaomi Corporation +#include "ATen/autocast_mode.h" +#include "torch/script.h" + +#ifndef SHERPA_CPP_API_AUTO_CAST_H_ +#define SHERPA_CPP_API_AUTO_CAST_H_ + +namespace sherpa { +// This is an RAII class to simulate the context manager torch.autocast() +// from Python. +// +// This class is not intended to be called in a nested environment. +class AutoCast { + public: + /** + * @param use_amp true to use amp; false to disable amp + * @param use_gpu Ignored if use_amp is false. + * true to set amp for CUDA. + * false to set amp for CPU.. + */ + AutoCast(bool use_amp, bool use_gpu) : use_amp_(use_amp), use_gpu_(use_gpu) { + if (!use_amp_) return; + + if (use_gpu_) { + at::autocast::set_enabled(true); + } else { + at::autocast::set_cpu_enabled(true); + } + } + ~AutoCast() { + if (!use_amp_) return; + + // by default, the cache for autocast is enabled. + at::autocast::clear_cache(); + + if (use_gpu_) { + at::autocast::set_enabled(false); + } else { + at::autocast::set_cpu_enabled(false); + } + } + + private: + // true to enable amp. false to disable it. + bool use_amp_; + + // ignored if use_amp_ is false. + // true to set amp for cuda. + // false to set amp for cpu. + bool use_gpu_; +}; + +} // namespace sherpa + +#endif // SHERPA_CPP_API_AUTO_CAST_H_ diff --git a/sherpa/cpp_api/offline-recognizer-ctc-impl.h b/sherpa/cpp_api/offline-recognizer-ctc-impl.h index 8463be6fe..e0f2c31cb 100644 --- a/sherpa/cpp_api/offline-recognizer-ctc-impl.h +++ b/sherpa/cpp_api/offline-recognizer-ctc-impl.h @@ -10,6 +10,7 @@ #include #include +#include "sherpa/cpp_api/autocast.h" #include "sherpa/cpp_api/feature-config.h" #include "sherpa/cpp_api/offline-recognizer-impl.h" #include "sherpa/csrc/log.h" @@ -136,8 +137,13 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { auto features_length = torch::tensor(features_length_vec); - torch::IValue ivalue = model_->Forward(features, features_length); - torch::Tensor log_prob = model_->GetLogSoftmaxOut(ivalue); + torch::IValue ivalue; + { + AutoCast autocast(config_.use_amp, config_.use_gpu); + ivalue = model_->Forward(features, features_length); + } + + torch::Tensor log_prob = model_->GetLogSoftmaxOut(ivalue).to(torch::kFloat); torch::Tensor log_prob_len = model_->GetLogSoftmaxOutLength(ivalue); auto results = @@ -161,6 +167,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { auto features_length = torch::tensor({features.size(0)}); features = features.unsqueeze(0); + features = features.to(device_); + features_length = features_length.to(device_); + model_->WarmUp(features, features_length); SHERPA_LOG(INFO) << "WarmUp ended"; } diff --git a/sherpa/cpp_api/offline-recognizer-transducer-impl.h b/sherpa/cpp_api/offline-recognizer-transducer-impl.h index 71d15a7b2..3f5d72339 100644 --- a/sherpa/cpp_api/offline-recognizer-transducer-impl.h +++ b/sherpa/cpp_api/offline-recognizer-transducer-impl.h @@ -10,6 +10,7 @@ #include #include +#include "sherpa/cpp_api/autocast.h" #include "sherpa/cpp_api/feature-config.h" #include "sherpa/cpp_api/offline-recognizer-impl.h" #include "sherpa/csrc/offline-conformer-transducer-model.h" @@ -107,8 +108,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { torch::Tensor encoder_out; torch::Tensor encoder_out_length; - std::tie(encoder_out, encoder_out_length) = - model_->RunEncoder(features, features_length); + { + // Note: We only use AMP for running the encoder. + AutoCast autocast(config_.use_amp, config_.use_gpu); + std::tie(encoder_out, encoder_out_length) = + model_->RunEncoder(features, features_length); + } + encoder_out = encoder_out.to(torch::kFloat); encoder_out_length = encoder_out_length.cpu(); auto results = decoder_->Decode(encoder_out, encoder_out_length); @@ -131,7 +137,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { s->AcceptSamples(samples.data(), samples.size()); auto features = s->GetFeatures(); auto features_length = torch::tensor({features.size(0)}); - features = features.unsqueeze(0); + + features = features.unsqueeze(0).to(device_); + features_length = features_length.to(device_); model_->WarmUp(features, features_length); SHERPA_LOG(INFO) << "WarmUp ended"; diff --git a/sherpa/cpp_api/offline-recognizer.cc b/sherpa/cpp_api/offline-recognizer.cc index ca8b73c4a..6984f65a2 100644 --- a/sherpa/cpp_api/offline-recognizer.cc +++ b/sherpa/cpp_api/offline-recognizer.cc @@ -94,6 +94,10 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { "If true, it uses the first device. You can use the environment " "variable CUDA_VISIBLE_DEVICES to select which device to use."); + po->Register("use-amp", &use_amp, + "true to use automatic-mixed-precision (amp) in neural network " + "computation."); + po->Register("decoding-method", &decoding_method, "Decoding method to use. Possible values are: greedy_search, " "modified_beam_search, and fast_beam_search"); diff --git a/sherpa/cpp_api/offline-recognizer.h b/sherpa/cpp_api/offline-recognizer.h index db7109c91..b4698c02e 100644 --- a/sherpa/cpp_api/offline-recognizer.h +++ b/sherpa/cpp_api/offline-recognizer.h @@ -58,6 +58,10 @@ struct OfflineRecognizerConfig { /// GPU for computation bool use_gpu = false; + // true to use automatic-mixed-precision (amp) in neural + // network computation. + bool use_amp = false; + std::string decoding_method = "greedy_search"; /// used only for modified_beam_search diff --git a/sherpa/python/csrc/offline-recognizer.cc b/sherpa/python/csrc/offline-recognizer.cc index d2b2b0fab..43d7bc7bb 100644 --- a/sherpa/python/csrc/offline-recognizer.cc +++ b/sherpa/python/csrc/offline-recognizer.cc @@ -102,6 +102,9 @@ Constructor for the offline recognizer configuration. Used only when the passed ``nn_model`` is a transducer model. Valid values are: ``greedy_search``, ``modified_beam_search``, and ``fast_beam_search``. + use_amp: + ``True`` to use automatic-mixed-precision (amp) during neural network + computation. )doc"; static void PybindOfflineCtcDecoderConfig(py::module &m) { // NOLINT @@ -148,8 +151,9 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT const OfflineCtcDecoderConfig &ctc_decoder_config = {}, const FeatureConfig &feat_config = {}, const FastBeamSearchConfig &fast_beam_search_config = {}, - const std::string &decoding_method = "greedy_search") - -> std::unique_ptr { + const std::string &decoding_method = "greedy_search", + bool use_amp = + false) -> std::unique_ptr { auto config = std::make_unique(); config->ctc_decoder_config = ctc_decoder_config; @@ -158,6 +162,7 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT config->nn_model = nn_model; config->tokens = tokens; config->use_gpu = use_gpu; + config->use_amp = use_amp; config->decoding_method = decoding_method; config->num_active_paths = num_active_paths; @@ -169,7 +174,7 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT py::arg("feat_config") = FeatureConfig(), py::arg("fast_beam_search_config") = FastBeamSearchConfig(), py::arg("decoding_method") = "greedy_search", - kOfflineRecognizerConfigInitDoc) + py::arg("use_amp") = false, kOfflineRecognizerConfigInitDoc) .def("__str__", [](const PyClass &self) -> std::string { return self.ToString(); }) .def_readwrite("ctc_decoder_config", &PyClass::ctc_decoder_config) @@ -179,6 +184,7 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT .def_readwrite("nn_model", &PyClass::nn_model) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("use_gpu", &PyClass::use_gpu) + .def_readwrite("use_amp", &PyClass::use_amp) .def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("num_active_paths", &PyClass::num_active_paths) .def("validate", &PyClass::Validate);