-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
473 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// sherpa-ncnn/csrc/context-graph.cc | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#include "sherpa-ncnn/csrc/context-graph.h" | ||
|
||
#include <cassert> | ||
#include <queue> | ||
#include <utility> | ||
|
||
namespace sherpa_ncnn { | ||
void ContextGraph::Build( | ||
const std::vector<std::vector<int32_t>> &token_ids) const { | ||
for (int32_t i = 0; i < token_ids.size(); ++i) { | ||
auto node = root_.get(); | ||
for (int32_t j = 0; j < token_ids[i].size(); ++j) { | ||
int32_t token = token_ids[i][j]; | ||
if (0 == node->next.count(token)) { | ||
bool is_end = j == token_ids[i].size() - 1; | ||
node->next[token] = std::make_unique<ContextState>( | ||
token, context_score_, node->node_score + context_score_, | ||
is_end ? node->node_score + context_score_ : 0, is_end); | ||
} | ||
node = node->next[token].get(); | ||
} | ||
} | ||
FillFailOutput(); | ||
} | ||
|
||
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( | ||
const ContextState *state, int32_t token) const { | ||
const ContextState *node; | ||
float score; | ||
if (1 == state->next.count(token)) { | ||
node = state->next.at(token).get(); | ||
score = node->token_score; | ||
} else { | ||
node = state->fail; | ||
while (0 == node->next.count(token)) { | ||
node = node->fail; | ||
if (-1 == node->token) break; // root | ||
} | ||
if (1 == node->next.count(token)) { | ||
node = node->next.at(token).get(); | ||
} | ||
score = node->node_score - state->node_score; | ||
} | ||
return std::make_pair(score + node->output_score, node); | ||
} | ||
|
||
std::pair<float, const ContextState *> ContextGraph::Finalize( | ||
const ContextState *state) const { | ||
float score = -state->node_score; | ||
return std::make_pair(score, root_.get()); | ||
} | ||
|
||
void ContextGraph::FillFailOutput() const { | ||
std::queue<const ContextState *> node_queue; | ||
for (auto &kv : root_->next) { | ||
kv.second->fail = root_.get(); | ||
node_queue.push(kv.second.get()); | ||
} | ||
while (!node_queue.empty()) { | ||
auto current_node = node_queue.front(); | ||
node_queue.pop(); | ||
for (auto &kv : current_node->next) { | ||
auto fail = current_node->fail; | ||
if (1 == fail->next.count(kv.first)) { | ||
fail = fail->next.at(kv.first).get(); | ||
} else { | ||
fail = fail->fail; | ||
while (0 == fail->next.count(kv.first)) { | ||
fail = fail->fail; | ||
if (-1 == fail->token) break; | ||
} | ||
if (1 == fail->next.count(kv.first)) | ||
fail = fail->next.at(kv.first).get(); | ||
} | ||
kv.second->fail = fail; | ||
// fill the output arc | ||
auto output = fail; | ||
while (!output->is_end) { | ||
output = output->fail; | ||
if (-1 == output->token) { | ||
output = nullptr; | ||
break; | ||
} | ||
} | ||
kv.second->output = output; | ||
kv.second->output_score += output == nullptr ? 0 : output->output_score; | ||
node_queue.push(kv.second.get()); | ||
} | ||
} | ||
} | ||
} // namespace sherpa_ncnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// sherpa-ncnn/csrc/context-graph.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_ | ||
#define SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_ | ||
|
||
#include <memory> | ||
#include <unordered_map> | ||
#include <utility> | ||
#include <vector> | ||
|
||
|
||
namespace sherpa_ncnn { | ||
|
||
class ContextGraph; | ||
using ContextGraphPtr = std::shared_ptr<ContextGraph>; | ||
|
||
struct ContextState { | ||
int32_t token; | ||
float token_score; | ||
float node_score; | ||
float output_score; | ||
bool is_end; | ||
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next; | ||
const ContextState *fail = nullptr; | ||
const ContextState *output = nullptr; | ||
|
||
ContextState() = default; | ||
ContextState(int32_t token, float token_score, float node_score, | ||
float output_score, bool is_end) | ||
: token(token), | ||
token_score(token_score), | ||
node_score(node_score), | ||
output_score(output_score), | ||
is_end(is_end) {} | ||
}; | ||
|
||
class ContextGraph { | ||
public: | ||
ContextGraph() = default; | ||
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids, | ||
float hotwords_score) | ||
: context_score_(hotwords_score) { | ||
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false); | ||
root_->fail = root_.get(); | ||
Build(token_ids); | ||
} | ||
|
||
std::pair<float, const ContextState *> ForwardOneStep( | ||
const ContextState *state, int32_t token_id) const; | ||
std::pair<float, const ContextState *> Finalize( | ||
const ContextState *state) const; | ||
|
||
const ContextState *Root() const { return root_.get(); } | ||
|
||
private: | ||
float context_score_; | ||
std::unique_ptr<ContextState> root_; | ||
void Build(const std::vector<std::vector<int32_t>> &token_ids) const; | ||
void FillFailOutput() const; | ||
}; | ||
|
||
} // namespace sherpa_ncnn | ||
#endif // SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.