Skip to content

Commit 3c7724c

Browse files
authored
Add cumstomized score for hotwords & add Finalize to stream (#281)
1 parent 884ce6d commit 3c7724c

12 files changed

+316
-109
lines changed

.github/scripts/run-test.sh

+16-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,21 @@ for wave in ${waves[@]}; do
5252
done
5353
done
5454

55+
log "Start testing ${repo_url} with hotwords"
56+
57+
time $EXE \
58+
$repo/tokens.txt \
59+
$repo/encoder_jit_trace-pnnx.ncnn.param \
60+
$repo/encoder_jit_trace-pnnx.ncnn.bin \
61+
$repo/decoder_jit_trace-pnnx.ncnn.param \
62+
$repo/decoder_jit_trace-pnnx.ncnn.bin \
63+
$repo/joiner_jit_trace-pnnx.ncnn.param \
64+
$repo/joiner_jit_trace-pnnx.ncnn.bin \
65+
$repo/test_wavs/1.wav \
66+
2 \
67+
modified_beam_search \
68+
$repo/test_wavs/hotwords.txt
69+
5570
rm -rf $repo
5671

5772
log "------------------------------------------------------------"
@@ -588,4 +603,4 @@ time $EXE \
588603
modified_beam_search \
589604
$repo/hotwords.txt 1.6
590605

591-
rm -rf $repo
606+
rm -rf $repo

sherpa-ncnn/csrc/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,6 @@ endif()
7777
if(SHERPA_NCNN_ENABLE_TEST)
7878
add_executable(test-resample test-resample.cc)
7979
target_link_libraries(test-resample sherpa-ncnn-core)
80+
add_executable(test-context-graph test-context-graph.cc)
81+
target_link_libraries(test-context-graph sherpa-ncnn-core)
8082
endif()

sherpa-ncnn/csrc/context-graph.cc

+74-7
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,67 @@
44

55
#include "sherpa-ncnn/csrc/context-graph.h"
66

7+
#include <algorithm>
78
#include <cassert>
89
#include <queue>
10+
#include <string>
11+
#include <tuple>
912
#include <utility>
1013

1114
namespace sherpa_ncnn {
12-
void ContextGraph::Build(
13-
const std::vector<std::vector<int32_t>> &token_ids) const {
15+
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
16+
const std::vector<float> &scores,
17+
const std::vector<std::string> &phrases,
18+
const std::vector<float> &ac_thresholds) const {
19+
if (!scores.empty()) {
20+
assert(token_ids.size() == scores.size());
21+
}
22+
if (!phrases.empty()) {
23+
assert(token_ids.size() == phrases.size());
24+
}
25+
if (!ac_thresholds.empty()) {
26+
assert(token_ids.size() == ac_thresholds.size());
27+
}
1428
for (int32_t i = 0; i < token_ids.size(); ++i) {
1529
auto node = root_.get();
30+
float score = scores.empty() ? 0.0f : scores[i];
31+
score = score == 0.0f ? context_score_ : score;
32+
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
33+
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
34+
std::string phrase = phrases.empty() ? std::string() : phrases[i];
35+
1636
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
1737
int32_t token = token_ids[i][j];
1838
if (0 == node->next.count(token)) {
1939
bool is_end = j == token_ids[i].size() - 1;
2040
node->next[token] = std::make_unique<ContextState>(
21-
token, context_score_, node->node_score + context_score_,
22-
is_end ? node->node_score + context_score_ : 0, is_end);
41+
token, score, node->node_score + score,
42+
is_end ? node->node_score + score : 0, j + 1,
43+
is_end ? ac_threshold : 0.0f, is_end,
44+
is_end ? phrase : std::string());
45+
} else {
46+
float token_score = std::max(score, node->next[token]->token_score);
47+
node->next[token]->token_score = token_score;
48+
float node_score = node->node_score + token_score;
49+
node->next[token]->node_score = node_score;
50+
bool is_end =
51+
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
52+
node->next[token]->output_score = is_end ? node_score : 0.0f;
53+
node->next[token]->is_end = is_end;
54+
if (j == token_ids[i].size() - 1) {
55+
node->next[token]->phrase = phrase;
56+
node->next[token]->ac_threshold = ac_threshold;
57+
}
2358
}
2459
node = node->next[token].get();
2560
}
2661
}
2762
FillFailOutput();
2863
}
2964

30-
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
31-
const ContextState *state, int32_t token) const {
65+
std::tuple<float, const ContextState *, const ContextState *>
66+
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
67+
bool strict_mode /*= true*/) const {
3268
const ContextState *node;
3369
float score;
3470
if (1 == state->next.count(token)) {
@@ -45,7 +81,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
4581
}
4682
score = node->node_score - state->node_score;
4783
}
48-
return std::make_pair(score + node->output_score, node);
84+
85+
assert(nullptr != node);
86+
87+
const ContextState *matched_node =
88+
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
89+
90+
if (!strict_mode && node->output_score != 0) {
91+
assert(nullptr != matched_node);
92+
float output_score =
93+
node->is_end ? node->node_score
94+
: (node->output != nullptr ? node->output->node_score
95+
: node->node_score);
96+
return std::make_tuple(score + output_score - node->node_score, root_.get(),
97+
matched_node);
98+
}
99+
return std::make_tuple(score + node->output_score, node, matched_node);
49100
}
50101

51102
std::pair<float, const ContextState *> ContextGraph::Finalize(
@@ -54,6 +105,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
54105
return std::make_pair(score, root_.get());
55106
}
56107

108+
std::pair<bool, const ContextState *> ContextGraph::IsMatched(
109+
const ContextState *state) const {
110+
bool status = false;
111+
const ContextState *node = nullptr;
112+
if (state->is_end) {
113+
status = true;
114+
node = state;
115+
} else {
116+
if (state->output != nullptr) {
117+
status = true;
118+
node = state->output;
119+
}
120+
}
121+
return std::make_pair(status, node);
122+
}
123+
57124
void ContextGraph::FillFailOutput() const {
58125
std::queue<const ContextState *> node_queue;
59126
for (auto &kv : root_->next) {

sherpa-ncnn/csrc/context-graph.h

+36-10
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
#define SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_
77

88
#include <memory>
9+
#include <string>
10+
#include <tuple>
911
#include <unordered_map>
1012
#include <utility>
1113
#include <vector>
1214

13-
1415
namespace sherpa_ncnn {
1516

1617
class ContextGraph;
@@ -21,43 +22,68 @@ struct ContextState {
2122
float token_score;
2223
float node_score;
2324
float output_score;
25+
int32_t level;
26+
float ac_threshold;
2427
bool is_end;
28+
std::string phrase;
2529
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
2630
const ContextState *fail = nullptr;
2731
const ContextState *output = nullptr;
2832

2933
ContextState() = default;
3034
ContextState(int32_t token, float token_score, float node_score,
31-
float output_score, bool is_end)
35+
float output_score, int32_t level = 0, float ac_threshold = 0.0f,
36+
bool is_end = false, const std::string &phrase = {})
3237
: token(token),
3338
token_score(token_score),
3439
node_score(node_score),
3540
output_score(output_score),
36-
is_end(is_end) {}
41+
level(level),
42+
ac_threshold(ac_threshold),
43+
is_end(is_end),
44+
phrase(phrase) {}
3745
};
3846

3947
class ContextGraph {
4048
public:
4149
ContextGraph() = default;
4250
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
43-
float hotwords_score)
44-
: context_score_(hotwords_score) {
45-
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
51+
float context_score, float ac_threshold,
52+
const std::vector<float> &scores = {},
53+
const std::vector<std::string> &phrases = {},
54+
const std::vector<float> &ac_thresholds = {})
55+
: context_score_(context_score), ac_threshold_(ac_threshold) {
56+
root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
4657
root_->fail = root_.get();
47-
Build(token_ids);
58+
Build(token_ids, scores, phrases, ac_thresholds);
4859
}
4960

50-
std::pair<float, const ContextState *> ForwardOneStep(
51-
const ContextState *state, int32_t token_id) const;
61+
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
62+
float context_score, const std::vector<float> &scores = {},
63+
const std::vector<std::string> &phrases = {})
64+
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
65+
std::vector<float>()) {}
66+
67+
std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
68+
const ContextState *state, int32_t token_id,
69+
bool strict_mode = true) const;
70+
71+
std::pair<bool, const ContextState *> IsMatched(
72+
const ContextState *state) const;
73+
5274
std::pair<float, const ContextState *> Finalize(
5375
const ContextState *state) const;
5476

5577
const ContextState *Root() const { return root_.get(); }
5678

5779
private:
5880
float context_score_;
81+
float ac_threshold_;
5982
std::unique_ptr<ContextState> root_;
60-
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
83+
void Build(const std::vector<std::vector<int32_t>> &token_ids,
84+
const std::vector<float> &scores,
85+
const std::vector<std::string> &phrases,
86+
const std::vector<float> &ac_thresholds) const;
6187
void FillFailOutput() const;
6288
};
6389

sherpa-ncnn/csrc/modified-beam-search-decoder.cc

+5-80
Original file line numberDiff line numberDiff line change
@@ -117,82 +117,7 @@ ncnn::Mat ModifiedBeamSearchDecoder::BuildDecoderInput(
117117

118118
void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
119119
DecoderResult *result) {
120-
int32_t context_size = model_->ContextSize();
121-
Hypotheses cur = std::move(result->hyps);
122-
/* encoder_out.w == encoder_out_dim, encoder_out.h == num_frames. */
123-
for (int32_t t = 0; t != encoder_out.h; ++t) {
124-
std::vector<Hypothesis> prev = cur.GetTopK(num_active_paths_, true);
125-
cur.Clear();
126-
127-
ncnn::Mat decoder_input = BuildDecoderInput(prev);
128-
ncnn::Mat decoder_out;
129-
if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size &&
130-
!result->decoder_out.empty()) {
131-
// When an endpoint is detected, we keep the decoder_out
132-
decoder_out = result->decoder_out;
133-
} else {
134-
decoder_out = RunDecoder2D(model_, decoder_input);
135-
}
136-
137-
// decoder_out.w == decoder_dim
138-
// decoder_out.h == num_active_paths
139-
ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
140-
// Note: encoder_out_t.h == 1, we rely on the binary op broadcasting
141-
// in ncnn
142-
// See https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting
143-
// broadcast B for outer axis, type 14
144-
ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);
145-
146-
// joiner_out.w == vocab_size
147-
// joiner_out.h == num_active_paths
148-
LogSoftmax(&joiner_out);
149-
150-
float *p_joiner_out = joiner_out;
151-
152-
for (int32_t i = 0; i != joiner_out.h; ++i) {
153-
float prev_log_prob = prev[i].log_prob;
154-
for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) {
155-
*p_joiner_out += prev_log_prob;
156-
}
157-
}
158-
159-
auto topk = TopkIndex(static_cast<float *>(joiner_out),
160-
joiner_out.w * joiner_out.h, num_active_paths_);
161-
162-
int32_t frame_offset = result->frame_offset;
163-
for (auto i : topk) {
164-
int32_t hyp_index = i / joiner_out.w;
165-
int32_t new_token = i % joiner_out.w;
166-
167-
const float *p = joiner_out.row(hyp_index);
168-
169-
Hypothesis new_hyp = prev[hyp_index];
170-
171-
// blank id is fixed to 0
172-
if (new_token != 0 && new_token != 2) {
173-
new_hyp.ys.push_back(new_token);
174-
new_hyp.num_trailing_blanks = 0;
175-
new_hyp.timestamps.push_back(t + frame_offset);
176-
} else {
177-
++new_hyp.num_trailing_blanks;
178-
}
179-
// We have already added prev[hyp_index].log_prob to p[new_token]
180-
new_hyp.log_prob = p[new_token];
181-
182-
cur.Add(std::move(new_hyp));
183-
}
184-
}
185-
186-
result->hyps = std::move(cur);
187-
result->frame_offset += encoder_out.h;
188-
auto hyp = result->hyps.GetMostProbable(true);
189-
190-
// set decoder_out in case of endpointing
191-
ncnn::Mat decoder_input = BuildDecoderInput({hyp});
192-
result->decoder_out = model_->RunDecoder(decoder_input);
193-
194-
result->tokens = std::move(hyp.ys);
195-
result->num_trailing_blanks = hyp.num_trailing_blanks;
120+
Decode(encoder_out, nullptr, result);
196121
}
197122

198123
void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
@@ -252,10 +177,10 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
252177
new_hyp.num_trailing_blanks = 0;
253178
new_hyp.timestamps.push_back(t + frame_offset);
254179
if (s && s->GetContextGraph()) {
255-
auto context_res =
256-
s->GetContextGraph()->ForwardOneStep(context_state, new_token);
257-
context_score = context_res.first;
258-
new_hyp.context_state = context_res.second;
180+
auto context_res = s->GetContextGraph()->ForwardOneStep(
181+
context_state, new_token, false /*strict_mode*/);
182+
context_score = std::get<0>(context_res);
183+
new_hyp.context_state = std::get<1>(context_res);
259184
}
260185
} else {
261186
++new_hyp.num_trailing_blanks;

0 commit comments

Comments
 (0)