From 768b246ef5a7ffd0a774c0ba57aefabb9ec5dcd0 Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Tue, 23 Jul 2024 08:40:27 +0800 Subject: [PATCH 01/12] Update kenlm.pyx add function: "predict next --- python/kenlm.pyx | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/python/kenlm.pyx b/python/kenlm.pyx index a6984a889..c7bc10edc 100644 --- a/python/kenlm.pyx +++ b/python/kenlm.pyx @@ -150,6 +150,38 @@ cdef class Model: def __get__(self): return self.model.Order() + def predict_next(self, context): + """ + 预测给定上下文后的下一个最可能的词。 + + :param context: 上下文句子 + :return: 预测的下一个词和其对应的概率 + """ + cdef State state = State() + cdef State out_state = State() + cdef float max_prob = float('-inf') + cdef str best_word = "" + + # 初始化状态 + words = as_str(context).split() + self.BeginSentenceWrite(state) + + # 处理上下文 + for word in words: + self.BaseScore(state, word, out_state) + state = out_state + + # 对词表中的每个词计算概率 + for i in range(self.vocab.Bound()): + word = self.vocab.Word(i) + prob = self.BaseScore(state, word, out_state) + + if prob > max_prob: + max_prob = prob + best_word = word + + return best_word, max_prob + def score(self, sentence, bos = True, eos = True): """ Return the log10 probability of a string. By default, the string is From b85f582632c2260d45d2d5d662233859f19ec9b4 Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:45:11 +0800 Subject: [PATCH 02/12] Update model.hh --- lm/model.hh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lm/model.hh b/lm/model.hh index db43d8b5e..113e42718 100644 --- a/lm/model.hh +++ b/lm/model.hh @@ -35,6 +35,9 @@ template class GenericModel : public base::Mod static const ModelType kModelType; static const unsigned int kVersion = Search::kVersion; + + // 新增的 predict_next 方法声明 + std::pair predict_next(const std::string &context) const; /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class From 27c1334d71eb68c81480a9f15a83fed1e2d787b2 Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:49:15 +0800 Subject: [PATCH 03/12] Update model.cc --- lm/model.cc | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/lm/model.cc b/lm/model.cc index b968edd9c..5370fc7ff 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -323,6 +323,37 @@ template class GenericModel, template class GenericModel, SortedVocabulary>; template class GenericModel, SortedVocabulary>; +//predict_next方法 +#include +template +std::pair GenericModel::predict_next(const std::string &context) const { + // 将 context 转换为 WordIndex 序列 + std::vector context_words; + std::istringstream iss(context); + std::string word; + while (iss >> word) { + context_words.push_back(vocab_.Index(word)); + } + + // 初始化状态 + State state; + GetState(context_words.rbegin(), context_words.rend(), state); + + float max_prob = -std::numeric_limits::infinity(); + WordIndex best_word = vocab_.Index("UNKNOWN"); + for (WordIndex i = 0; i < vocab_.Size(); ++i) { + State out_state; + FullScoreReturn ret = FullScore(state, i, out_state); + if (ret.prob > max_prob) { + max_prob = ret.prob; + best_word = i; + } + } + + return {vocab_.Word(best_word), max_prob}; +} + + } // namespace detail base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) { From 0440739b5132454121c72134bf5cf146b79e35b7 Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:58:49 +0800 Subject: [PATCH 04/12] Update kenlm.pyx --- python/kenlm.pyx | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/kenlm.pyx b/python/kenlm.pyx index c7bc10edc..f6f446cc2 100644 --- a/python/kenlm.pyx +++ b/python/kenlm.pyx @@ -1,6 +1,12 @@ import os cimport _kenlm +# 在 cdef extern from 块中声明 C++ 方法(predict_next) +cdef extern from "lm/model.hh": + cdef cppclass Model: + pair[string, float] predict_next(const string &context) const + + cdef bytes as_str(data): if isinstance(data, bytes): return data @@ -306,6 +312,12 @@ cdef class Model: cdef _kenlm.FullScoreReturn ret = self.model.BaseFullScore(&in_state._c_state, wid, &out_state._c_state) return FullScoreReturn(ret.prob, ret.ngram_length, wid == 0) + # 增加predict_next方法 + def predict_next(self, context): + cdef string cpp_context = context.encode('utf-8') + result = self.thisptr.predict_next(cpp_context) + return result.first.decode('utf-8'), result.second + def __contains__(self, word): cdef bytes w = as_str(word) return (self.vocab.Index(w) != 0) From 924ee25c8716d474df0f3082659144732e4dd584 Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Tue, 23 Jul 2024 14:14:51 +0800 Subject: [PATCH 05/12] Update kenlm.cpp --- python/kenlm.cpp | 90 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/python/kenlm.cpp b/python/kenlm.cpp index a45fcb519..e8dc49d0a 100644 --- a/python/kenlm.cpp +++ b/python/kenlm.cpp @@ -11481,5 +11481,95 @@ static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t ival) { return PyInt_FromSize_t(ival); } +//predict_next方法对应的变更(包括:-a 到 -f 共6个改变) +// 定义 PyModel 结构体 -a +typedef struct { + PyObject_HEAD + lm::ngram::Model* model; +} PyModel; + +// 定义新的 Python 方法 -b +static PyObject* PyModel_predict_next(PyObject* self, PyObject* args) { + const char* context; + if (!PyArg_ParseTuple(args, "s", &context)) { + return NULL; + } + + std::pair result = ((PyModel*)self)->model->predict_next(context); + return Py_BuildValue("sf", result.first.c_str(), result.second); +} + +// 方法定义数组 -c +static PyMethodDef PyModel_methods[] = { + {"predict_next", (PyCFunction)PyModel_predict_next, METH_VARARGS, "Predict the next word given a context"}, + {NULL} /* Sentinel */ +}; + +// 类型定义 -d +static PyTypeObject PyModelType = { + PyVarObject_HEAD_INIT(NULL, 0) + "kenlm.Model", /* tp_name */ + sizeof(PyModel), /* tp_basicsize */ + 0, /* tp_itemsize */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Model objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + PyModel_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + 0, /* tp_new */ +}; + +// 模块定义 -e +static PyModuleDef kenlmmodule = { + PyModuleDef_HEAD_INIT, + "kenlm", + "KenLM Python bindings", + -1, + NULL, NULL, NULL, NULL, NULL +}; + +// 模块初始化函数 -f +PyMODINIT_FUNC PyInit_kenlm(void) { + PyObject* m; + + if (PyType_Ready(&PyModelType) < 0) + return NULL; + + m = PyModule_Create(&kenlmmodule); + if (m == NULL) + return NULL; + + Py_INCREF(&PyModelType); + PyModule_AddObject(m, "Model", (PyObject *)&PyModelType); + return m; +} #endif /* Py_PYTHON_H */ From 4418e74f601a737c236af5550adf5a6f4affc210 Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Tue, 23 Jul 2024 14:58:44 +0800 Subject: [PATCH 06/12] Update model.cc --- lm/model.cc | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/lm/model.cc b/lm/model.cc index 5370fc7ff..e8f8eee5e 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -12,7 +12,10 @@ #include #include #include + +//predict_next函数所需包含的头文件 #include +#include namespace lm { namespace ngram { @@ -324,7 +327,6 @@ template class GenericModel, SortedVocabulary>; //predict_next方法 -#include template std::pair GenericModel::predict_next(const std::string &context) const { // 将 context 转换为 WordIndex 序列 @@ -332,25 +334,25 @@ std::pair GenericModel::predict_next(co std::istringstream iss(context); std::string word; while (iss >> word) { - context_words.push_back(vocab_.Index(word)); + context_words.push_back(this->vocab_.Index(word)); } // 初始化状态 State state; - GetState(context_words.rbegin(), context_words.rend(), state); + this->GetState(context_words.rbegin().base(), context_words.rend().base(), state); float max_prob = -std::numeric_limits::infinity(); - WordIndex best_word = vocab_.Index("UNKNOWN"); - for (WordIndex i = 0; i < vocab_.Size(); ++i) { + WordIndex best_word = this->vocab_.Index(""); + for (WordIndex i = 0; i < this->vocab_.size(); ++i) { State out_state; - FullScoreReturn ret = FullScore(state, i, out_state); + FullScoreReturn ret = this->FullScore(state, i, out_state); if (ret.prob > max_prob) { max_prob = ret.prob; best_word = i; } } - return {vocab_.Word(best_word), max_prob}; + return {this->vocab_.word(best_word), max_prob}; } From cc9d722762f4b258bf7cad293c0600b7be1d29db Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Tue, 23 Jul 2024 15:34:13 +0800 Subject: [PATCH 07/12] Update model.cc --- lm/model.cc | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/lm/model.cc b/lm/model.cc index e8f8eee5e..b3c13cdcb 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -339,11 +339,11 @@ std::pair GenericModel::predict_next(co // 初始化状态 State state; - this->GetState(context_words.rbegin().base(), context_words.rend().base(), state); + this->GetState(&context_words[0], &context_words[0] + context_words.size(), state); float max_prob = -std::numeric_limits::infinity(); WordIndex best_word = this->vocab_.Index(""); - for (WordIndex i = 0; i < this->vocab_.size(); ++i) { + for (WordIndex i = 0; i < this->vocab_.Bound(); ++i) { State out_state; FullScoreReturn ret = this->FullScore(state, i, out_state); if (ret.prob > max_prob) { @@ -352,7 +352,16 @@ std::pair GenericModel::predict_next(co } } - return {this->vocab_.word(best_word), max_prob}; + // 查找 best_word 的字符串表示 + std::string best_word_str; + for (const auto& entry : this->vocab_) { + if (entry.second == best_word) { + best_word_str = entry.first; + break; + } + } + + return {best_word_str, max_prob}; } From a3f20305bef134d453601244657effc89f93287c Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Tue, 23 Jul 2024 15:46:07 +0800 Subject: [PATCH 08/12] Update model.cc --- lm/model.cc | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/lm/model.cc b/lm/model.cc index b3c13cdcb..5c7e9ff97 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -16,6 +16,8 @@ //predict_next函数所需包含的头文件 #include #include +#include +#include "enumerate_vocab.hh" // Include EnumerateVocab header namespace lm { namespace ngram { @@ -327,9 +329,28 @@ template class GenericModel, SortedVocabulary>; //predict_next方法 + +// Define a class for enumerating the vocabulary +class VocabEnumerator : public lm::EnumerateVocab { +public: + VocabEnumerator(const lm::ngram::Vocabulary &vocab) : vocab_(vocab) {} + + void Add(lm::WordIndex index, const lm::StringPiece &str) override { + vocab_map_[index] = str.as_string(); + } + + const std::unordered_map& GetVocabMap() const { + return vocab_map_; + } + +private: + const lm::ngram::Vocabulary &vocab_; + std::unordered_map vocab_map_; +}; + template std::pair GenericModel::predict_next(const std::string &context) const { - // 将 context 转换为 WordIndex 序列 + // Convert context to WordIndex sequence std::vector context_words; std::istringstream iss(context); std::string word; @@ -337,13 +358,14 @@ std::pair GenericModel::predict_next(co context_words.push_back(this->vocab_.Index(word)); } - // 初始化状态 + // Initialize state State state; this->GetState(&context_words[0], &context_words[0] + context_words.size(), state); float max_prob = -std::numeric_limits::infinity(); WordIndex best_word = this->vocab_.Index(""); - for (WordIndex i = 0; i < this->vocab_.Bound(); ++i) { + + for (WordIndex i = 0; i < this->vocab_.Size(); ++i) { State out_state; FullScoreReturn ret = this->FullScore(state, i, out_state); if (ret.prob > max_prob) { @@ -352,14 +374,12 @@ std::pair GenericModel::predict_next(co } } - // 查找 best_word 的字符串表示 - std::string best_word_str; - for (const auto& entry : this->vocab_) { - if (entry.second == best_word) { - best_word_str = entry.first; - break; - } - } + // Enumerate vocabulary to find the best_word's string representation + VocabEnumerator enumerator(this->vocab_); + this->Enumerate(enumerator); + + auto vocab_map = enumerator.GetVocabMap(); + std::string best_word_str = vocab_map[best_word]; return {best_word_str, max_prob}; } From a41032b938d0d2556397673a8500f37957c1fb54 Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Tue, 23 Jul 2024 15:53:52 +0800 Subject: [PATCH 09/12] Update model.cc --- lm/model.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm/model.cc b/lm/model.cc index 5c7e9ff97..2b4620e87 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -376,10 +376,10 @@ std::pair GenericModel::predict_next(co // Enumerate vocabulary to find the best_word's string representation VocabEnumerator enumerator(this->vocab_); - this->Enumerate(enumerator); + this->vocab_.Enumerate(enumerator); auto vocab_map = enumerator.GetVocabMap(); - std::string best_word_str = vocab_map[best_word]; + std::string best_word_str = vocab_map.at(best_word); return {best_word_str, max_prob}; } From f5c6687d900c63c5227b22c40e11479f6ac4437e Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:56:46 +0800 Subject: [PATCH 10/12] Update model.cc --- lm/model.cc | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/lm/model.cc b/lm/model.cc index 2b4620e87..62a9ed763 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -18,6 +18,8 @@ #include #include #include "enumerate_vocab.hh" // Include EnumerateVocab header +#include "vocab.hh" // Include Vocabulary header +#include "string_piece.hh" // Include StringPiece header namespace lm { namespace ngram { @@ -349,7 +351,7 @@ class VocabEnumerator : public lm::EnumerateVocab { }; template -std::pair GenericModel::predict_next(const std::string &context) const { +std::unordered_map GenericModel::predict_next(const std::string &context) const { // Convert context to WordIndex sequence std::vector context_words; std::istringstream iss(context); @@ -360,31 +362,21 @@ std::pair GenericModel::predict_next(co // Initialize state State state; - this->GetState(&context_words[0], &context_words[0] + context_words.size(), state); + this->GetState(&context_words[0], &context_words[0] + context_words.size(), &state); - float max_prob = -std::numeric_limits::infinity(); - WordIndex best_word = this->vocab_.Index(""); + std::unordered_map word_probs; + // Calculate the score for each word in the vocabulary for (WordIndex i = 0; i < this->vocab_.Size(); ++i) { State out_state; FullScoreReturn ret = this->FullScore(state, i, out_state); - if (ret.prob > max_prob) { - max_prob = ret.prob; - best_word = i; - } + std::string word = this->vocab_.Word(i); + word_probs[word] = ret.prob; } - // Enumerate vocabulary to find the best_word's string representation - VocabEnumerator enumerator(this->vocab_); - this->vocab_.Enumerate(enumerator); - - auto vocab_map = enumerator.GetVocabMap(); - std::string best_word_str = vocab_map.at(best_word); - - return {best_word_str, max_prob}; + return word_probs; } - } // namespace detail base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) { From a6038e54709d060b4ae71dbd66b6f2621bb54f94 Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Wed, 24 Jul 2024 10:13:22 +0800 Subject: [PATCH 11/12] Update kenlm.pyx --- python/kenlm.pyx | 72 +++++++++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/python/kenlm.pyx b/python/kenlm.pyx index f6f446cc2..65c4e8d07 100644 --- a/python/kenlm.pyx +++ b/python/kenlm.pyx @@ -133,6 +133,10 @@ cdef class Model: cdef public bytes path cdef _kenlm.const_Vocabulary* vocab + // predict_next + cdef cppclass GenericModel: + unordered_map[string, float] predict_next(vector[string] context) # Adjusted for predict_next + def __init__(self, path, Config config = Config()): """ Load the language model. @@ -156,37 +160,37 @@ cdef class Model: def __get__(self): return self.model.Order() - def predict_next(self, context): - """ - 预测给定上下文后的下一个最可能的词。 - - :param context: 上下文句子 - :return: 预测的下一个词和其对应的概率 - """ - cdef State state = State() - cdef State out_state = State() - cdef float max_prob = float('-inf') - cdef str best_word = "" + # def predict_next(self, context): + # """ + # 预测给定上下文后的下一个最可能的词。 + + # :param context: 上下文句子 + # :return: 预测的下一个词和其对应的概率 + # """ + # cdef State state = State() + # cdef State out_state = State() + # cdef float max_prob = float('-inf') + # cdef str best_word = "" - # 初始化状态 - words = as_str(context).split() - self.BeginSentenceWrite(state) + # # 初始化状态 + # words = as_str(context).split() + # self.BeginSentenceWrite(state) - # 处理上下文 - for word in words: - self.BaseScore(state, word, out_state) - state = out_state + # # 处理上下文 + # for word in words: + # self.BaseScore(state, word, out_state) + # state = out_state - # 对词表中的每个词计算概率 - for i in range(self.vocab.Bound()): - word = self.vocab.Word(i) - prob = self.BaseScore(state, word, out_state) + # # 对词表中的每个词计算概率 + # for i in range(self.vocab.Bound()): + # word = self.vocab.Word(i) + # prob = self.BaseScore(state, word, out_state) - if prob > max_prob: - max_prob = prob - best_word = word + # if prob > max_prob: + # max_prob = prob + # best_word = word - return best_word, max_prob + # return best_word, max_prob def score(self, sentence, bos = True, eos = True): """ @@ -312,11 +316,17 @@ cdef class Model: cdef _kenlm.FullScoreReturn ret = self.model.BaseFullScore(&in_state._c_state, wid, &out_state._c_state) return FullScoreReturn(ret.prob, ret.ngram_length, wid == 0) - # 增加predict_next方法 - def predict_next(self, context): - cdef string cpp_context = context.encode('utf-8') - result = self.thisptr.predict_next(cpp_context) - return result.first.decode('utf-8'), result.second + # # 增加predict_next方法 + # def predict_next(self, context): + # cdef string cpp_context = context.encode('utf-8') + # result = self.thisptr.predict_next(cpp_context) + # return result.first.decode('utf-8'), result.second + + // predict_next + def predict_next_python(self, context): + cdef list context_words = context.split() + cdef unordered_map[string, float] probs = self.c_model.predict_next([word.encode() for word in context_words]) + return {word.decode(): prob for word, prob in probs.items()} def __contains__(self, word): cdef bytes w = as_str(word) From 74a66b2a3d42fd4987cb71a97b6e4331c3ccf8b7 Mon Sep 17 00:00:00 2001 From: Jiang Nan <91359667+NeuroSymbol@users.noreply.github.com> Date: Wed, 24 Jul 2024 10:27:08 +0800 Subject: [PATCH 12/12] Update model.cc --- lm/model.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm/model.cc b/lm/model.cc index 62a9ed763..e75986f9a 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -19,7 +19,7 @@ #include #include "enumerate_vocab.hh" // Include EnumerateVocab header #include "vocab.hh" // Include Vocabulary header -#include "string_piece.hh" // Include StringPiece header +#include "../util/string_piece.hh" // Include StringPiece header namespace lm { namespace ngram {