Skip to content

Commit

Permalink
support hotwords in C++ (#257)
Browse files Browse the repository at this point in the history
  • Loading branch information
HalFTeen authored Aug 31, 2023
1 parent fc4d3bc commit 0f8e46d
Show file tree
Hide file tree
Showing 16 changed files with 473 additions and 31 deletions.
45 changes: 45 additions & 0 deletions .github/scripts/run-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,48 @@ for wave in ${waves[@]}; do
done

rm -rf $repo

log "------------------------------------------------------------"
log "Run hotwords test (Chinese)"
log "------------------------------------------------------------"
repo_url=https://huggingface.co/HalFTeen/sherpa-ncnn-hotwords-test/
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "encoder_jit_trace-pnnx.ncnn.bin"
git lfs pull --include "decoder_jit_trace-pnnx.ncnn.bin"
git lfs pull --include "joiner_jit_trace-pnnx.ncnn.bin"
popd


log "----test $m without hotwords---"
time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.param \
$repo/encoder_jit_trace-pnnx.ncnn.bin \
$repo/decoder_jit_trace-pnnx.ncnn.param \
$repo/decoder_jit_trace-pnnx.ncnn.bin \
$repo/joiner_jit_trace-pnnx.ncnn.param \
$repo/joiner_jit_trace-pnnx.ncnn.bin \
$repo/hotwords.wav \
4 \
modified_beam_search


log "----test $m with hotwords---"
time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.param \
$repo/encoder_jit_trace-pnnx.ncnn.bin \
$repo/decoder_jit_trace-pnnx.ncnn.param \
$repo/decoder_jit_trace-pnnx.ncnn.bin \
$repo/joiner_jit_trace-pnnx.ncnn.param \
$repo/joiner_jit_trace-pnnx.ncnn.bin \
$repo/hotwords.wav \
4 \
modified_beam_search \
$repo/hotwords.txt 1.6

rm -rf $repo
16 changes: 12 additions & 4 deletions c-api-examples/decode-file-c-api.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ const char *kUsage =
"for a list of pre-trained models to download.\n";

int32_t main(int32_t argc, char *argv[]) {
if (argc < 9 || argc > 11) {
if (argc < 9 || argc > 13) {
fprintf(stderr, "%s\n", kUsage);
return -1;
}
Expand All @@ -62,7 +62,7 @@ int32_t main(int32_t argc, char *argv[]) {

config.decoder_config.decoding_method = "greedy_search";

if (argc == 11) {
if (argc >= 11) {
config.decoder_config.decoding_method = argv[10];
}
config.decoder_config.num_active_paths = 4;
Expand All @@ -73,7 +73,16 @@ int32_t main(int32_t argc, char *argv[]) {

config.feat_config.sampling_rate = 16000;
config.feat_config.feature_dim = 80;

if(argc >= 12) {
config.hotwords_file = argv[11];
} else {
config.hotwords_file = "";
}
if(argc == 13) {
config.hotwords_score = atof(argv[12]);
} else {
config.hotwords_score = 1.5;
}
SherpaNcnnRecognizer *recognizer = CreateRecognizer(&config);

const char *wav_filename = argv[8];
Expand All @@ -92,7 +101,6 @@ int32_t main(int32_t argc, char *argv[]) {

int16_t buffer[N];
float samples[N];

SherpaNcnnStream *s = CreateStream(recognizer);

SherpaNcnnDisplay *display = CreateDisplay(50);
Expand Down
2 changes: 2 additions & 0 deletions sherpa-ncnn/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ SherpaNcnnRecognizer *CreateRecognizer(
config.decoder_config.method = in_config->decoder_config.decoding_method;
config.decoder_config.num_active_paths =
in_config->decoder_config.num_active_paths;
config.hotwords_file = in_config->hotwords_file;
config.hotwords_score = in_config->hotwords_score;

config.enable_endpoint = in_config->enable_endpoint;

Expand Down
8 changes: 8 additions & 0 deletions sherpa-ncnn/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ SHERPA_NCNN_API typedef struct SherpaNcnnRecognizerConfig {
/// this value.
/// Used only when enable_endpoint is not 0.
float rule3_min_utterance_length;

/// hotwords file, each line is a hotword which is segmented into char by space
/// if language is something like CJK, segment manually,
/// if language is something like English, segment by bpe model.
const char *hotwords_file;

/// scale of hotwords, used only when hotwords_file is not empty
float hotwords_score;
} SherpaNcnnRecognizerConfig;

SHERPA_NCNN_API typedef struct SherpaNcnnResult {
Expand Down
1 change: 1 addition & 0 deletions sherpa-ncnn/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
include_directories(${CMAKE_SOURCE_DIR})

set(sherpa_ncnn_core_srcs
context-graph.cc
conv-emformer-model.cc
decoder.cc
endpoint.cc
Expand Down
95 changes: 95 additions & 0 deletions sherpa-ncnn/csrc/context-graph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// sherpa-ncnn/csrc/context-graph.cc
//
// Copyright (c) 2023 Xiaomi Corporation

#include "sherpa-ncnn/csrc/context-graph.h"

#include <cassert>
#include <queue>
#include <utility>

namespace sherpa_ncnn {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? node->node_score + context_score_ : 0, is_end);
}
node = node->next[token].get();
}
}
FillFailOutput();
}

std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
node = state->next.at(token).get();
score = node->token_score;
} else {
node = state->fail;
while (0 == node->next.count(token)) {
node = node->fail;
if (-1 == node->token) break; // root
}
if (1 == node->next.count(token)) {
node = node->next.at(token).get();
}
score = node->node_score - state->node_score;
}
return std::make_pair(score + node->output_score, node);
}

std::pair<float, const ContextState *> ContextGraph::Finalize(
const ContextState *state) const {
float score = -state->node_score;
return std::make_pair(score, root_.get());
}

void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {
kv.second->fail = root_.get();
node_queue.push(kv.second.get());
}
while (!node_queue.empty()) {
auto current_node = node_queue.front();
node_queue.pop();
for (auto &kv : current_node->next) {
auto fail = current_node->fail;
if (1 == fail->next.count(kv.first)) {
fail = fail->next.at(kv.first).get();
} else {
fail = fail->fail;
while (0 == fail->next.count(kv.first)) {
fail = fail->fail;
if (-1 == fail->token) break;
}
if (1 == fail->next.count(kv.first))
fail = fail->next.at(kv.first).get();
}
kv.second->fail = fail;
// fill the output arc
auto output = fail;
while (!output->is_end) {
output = output->fail;
if (-1 == output->token) {
output = nullptr;
break;
}
}
kv.second->output = output;
kv.second->output_score += output == nullptr ? 0 : output->output_score;
node_queue.push(kv.second.get());
}
}
}
} // namespace sherpa_ncnn
65 changes: 65 additions & 0 deletions sherpa-ncnn/csrc/context-graph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// sherpa-ncnn/csrc/context-graph.h
//
// Copyright (c) 2023 Xiaomi Corporation

#ifndef SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_
#define SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_

#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>


namespace sherpa_ncnn {

class ContextGraph;
using ContextGraphPtr = std::shared_ptr<ContextGraph>;

struct ContextState {
int32_t token;
float token_score;
float node_score;
float output_score;
bool is_end;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
const ContextState *output = nullptr;

ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float output_score, bool is_end)
: token(token),
token_score(token_score),
node_score(node_score),
output_score(output_score),
is_end(is_end) {}
};

class ContextGraph {
public:
ContextGraph() = default;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float hotwords_score)
: context_score_(hotwords_score) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
root_->fail = root_.get();
Build(token_ids);
}

std::pair<float, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id) const;
std::pair<float, const ContextState *> Finalize(
const ContextState *state) const;

const ContextState *Root() const { return root_.get(); }

private:
float context_score_;
std::unique_ptr<ContextState> root_;
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
void FillFailOutput() const;
};

} // namespace sherpa_ncnn
#endif // SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_
3 changes: 2 additions & 1 deletion sherpa-ncnn/csrc/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct DecoderResult {
// used only for modified_beam_search
Hypotheses hyps;
};

class Stream;
class Decoder {
public:
virtual ~Decoder() = default;
Expand Down Expand Up @@ -88,6 +88,7 @@ class Decoder {
* and there are no paddings.
*/
virtual void Decode(ncnn::Mat encoder_out, DecoderResult *result) = 0;
virtual void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result){};
};

} // namespace sherpa_ncnn
Expand Down
10 changes: 7 additions & 3 deletions sherpa-ncnn/csrc/hypothesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "sherpa-ncnn/csrc/context-graph.h"

namespace sherpa_ncnn {

Expand All @@ -37,12 +38,13 @@ struct Hypothesis {

// The total score of ys in log space.
double log_prob = 0;

const ContextState *context_state;
int32_t num_trailing_blanks = 0;

Hypothesis() = default;
Hypothesis(const std::vector<int32_t> &ys, double log_prob)
: ys(ys), log_prob(log_prob) {}
Hypothesis(const std::vector<int32_t> &ys, double log_prob,
const ContextState *context_state = nullptr)
: ys(ys), log_prob(log_prob), context_state(context_state) {}

// If two Hypotheses have the same `Key`, then they contain
// the same token sequence.
Expand Down Expand Up @@ -104,6 +106,8 @@ class Hypotheses {

const auto begin() const { return hyps_dict_.begin(); }
const auto end() const { return hyps_dict_.end(); }
auto begin() { return hyps_dict_.begin(); }
auto end() { return hyps_dict_.end(); }

void Clear() { hyps_dict_.clear(); }

Expand Down
Loading

0 comments on commit 0f8e46d

Please sign in to comment.