4
4
5
5
#include " sherpa-ncnn/csrc/context-graph.h"
6
6
7
+ #include < algorithm>
7
8
#include < cassert>
8
9
#include < queue>
10
+ #include < string>
11
+ #include < tuple>
9
12
#include < utility>
10
13
11
14
namespace sherpa_ncnn {
12
- void ContextGraph::Build (
13
- const std::vector<std::vector<int32_t >> &token_ids) const {
15
+ void ContextGraph::Build (const std::vector<std::vector<int32_t >> &token_ids,
16
+ const std::vector<float > &scores,
17
+ const std::vector<std::string> &phrases,
18
+ const std::vector<float > &ac_thresholds) const {
19
+ if (!scores.empty ()) {
20
+ assert (token_ids.size () == scores.size ());
21
+ }
22
+ if (!phrases.empty ()) {
23
+ assert (token_ids.size () == phrases.size ());
24
+ }
25
+ if (!ac_thresholds.empty ()) {
26
+ assert (token_ids.size () == ac_thresholds.size ());
27
+ }
14
28
for (int32_t i = 0 ; i < token_ids.size (); ++i) {
15
29
auto node = root_.get ();
30
+ float score = scores.empty () ? 0 .0f : scores[i];
31
+ score = score == 0 .0f ? context_score_ : score;
32
+ float ac_threshold = ac_thresholds.empty () ? 0 .0f : ac_thresholds[i];
33
+ ac_threshold = ac_threshold == 0 .0f ? ac_threshold_ : ac_threshold;
34
+ std::string phrase = phrases.empty () ? std::string () : phrases[i];
35
+
16
36
for (int32_t j = 0 ; j < token_ids[i].size (); ++j) {
17
37
int32_t token = token_ids[i][j];
18
38
if (0 == node->next .count (token)) {
19
39
bool is_end = j == token_ids[i].size () - 1 ;
20
40
node->next [token] = std::make_unique<ContextState>(
21
- token, context_score_, node->node_score + context_score_,
22
- is_end ? node->node_score + context_score_ : 0 , is_end);
41
+ token, score, node->node_score + score,
42
+ is_end ? node->node_score + score : 0 , j + 1 ,
43
+ is_end ? ac_threshold : 0 .0f , is_end,
44
+ is_end ? phrase : std::string ());
45
+ } else {
46
+ float token_score = std::max (score, node->next [token]->token_score );
47
+ node->next [token]->token_score = token_score;
48
+ float node_score = node->node_score + token_score;
49
+ node->next [token]->node_score = node_score;
50
+ bool is_end =
51
+ (j == token_ids[i].size () - 1 ) || node->next [token]->is_end ;
52
+ node->next [token]->output_score = is_end ? node_score : 0 .0f ;
53
+ node->next [token]->is_end = is_end;
54
+ if (j == token_ids[i].size () - 1 ) {
55
+ node->next [token]->phrase = phrase;
56
+ node->next [token]->ac_threshold = ac_threshold;
57
+ }
23
58
}
24
59
node = node->next [token].get ();
25
60
}
26
61
}
27
62
FillFailOutput ();
28
63
}
29
64
30
- std::pair<float , const ContextState *> ContextGraph::ForwardOneStep (
31
- const ContextState *state, int32_t token) const {
65
+ std::tuple<float , const ContextState *, const ContextState *>
66
+ ContextGraph::ForwardOneStep (const ContextState *state, int32_t token,
67
+ bool strict_mode /* = true*/ ) const {
32
68
const ContextState *node;
33
69
float score;
34
70
if (1 == state->next .count (token)) {
@@ -45,7 +81,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
45
81
}
46
82
score = node->node_score - state->node_score ;
47
83
}
48
- return std::make_pair (score + node->output_score , node);
84
+
85
+ assert (nullptr != node);
86
+
87
+ const ContextState *matched_node =
88
+ node->is_end ? node : (node->output != nullptr ? node->output : nullptr );
89
+
90
+ if (!strict_mode && node->output_score != 0 ) {
91
+ assert (nullptr != matched_node);
92
+ float output_score =
93
+ node->is_end ? node->node_score
94
+ : (node->output != nullptr ? node->output ->node_score
95
+ : node->node_score );
96
+ return std::make_tuple (score + output_score - node->node_score , root_.get (),
97
+ matched_node);
98
+ }
99
+ return std::make_tuple (score + node->output_score , node, matched_node);
49
100
}
50
101
51
102
std::pair<float , const ContextState *> ContextGraph::Finalize (
@@ -54,6 +105,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
54
105
return std::make_pair (score, root_.get ());
55
106
}
56
107
108
+ std::pair<bool , const ContextState *> ContextGraph::IsMatched (
109
+ const ContextState *state) const {
110
+ bool status = false ;
111
+ const ContextState *node = nullptr ;
112
+ if (state->is_end ) {
113
+ status = true ;
114
+ node = state;
115
+ } else {
116
+ if (state->output != nullptr ) {
117
+ status = true ;
118
+ node = state->output ;
119
+ }
120
+ }
121
+ return std::make_pair (status, node);
122
+ }
123
+
57
124
void ContextGraph::FillFailOutput () const {
58
125
std::queue<const ContextState *> node_queue;
59
126
for (auto &kv : root_->next ) {
0 commit comments