diff --git a/lm/model.cc b/lm/model.cc index b968edd9..e75986f9 100644 --- a/lm/model.cc +++ b/lm/model.cc @@ -12,7 +12,14 @@ #include #include #include + +//predict_next函数所需包含的头文件 #include +#include +#include +#include "enumerate_vocab.hh" // Include EnumerateVocab header +#include "vocab.hh" // Include Vocabulary header +#include "../util/string_piece.hh" // Include StringPiece header namespace lm { namespace ngram { @@ -323,6 +330,53 @@ template class GenericModel, template class GenericModel, SortedVocabulary>; template class GenericModel, SortedVocabulary>; +//predict_next方法 + +// Define a class for enumerating the vocabulary +class VocabEnumerator : public lm::EnumerateVocab { +public: + VocabEnumerator(const lm::ngram::Vocabulary &vocab) : vocab_(vocab) {} + + void Add(lm::WordIndex index, const lm::StringPiece &str) override { + vocab_map_[index] = str.as_string(); + } + + const std::unordered_map& GetVocabMap() const { + return vocab_map_; + } + +private: + const lm::ngram::Vocabulary &vocab_; + std::unordered_map vocab_map_; +}; + +template +std::unordered_map GenericModel::predict_next(const std::string &context) const { + // Convert context to WordIndex sequence + std::vector context_words; + std::istringstream iss(context); + std::string word; + while (iss >> word) { + context_words.push_back(this->vocab_.Index(word)); + } + + // Initialize state + State state; + this->GetState(&context_words[0], &context_words[0] + context_words.size(), &state); + + std::unordered_map word_probs; + + // Calculate the score for each word in the vocabulary + for (WordIndex i = 0; i < this->vocab_.Size(); ++i) { + State out_state; + FullScoreReturn ret = this->FullScore(state, i, out_state); + std::string word = this->vocab_.Word(i); + word_probs[word] = ret.prob; + } + + return word_probs; +} + } // namespace detail base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) { diff --git a/lm/model.hh b/lm/model.hh index db43d8b5..113e4271 100644 --- a/lm/model.hh +++ b/lm/model.hh @@ -35,6 +35,9 @@ template class GenericModel : public base::Mod static const ModelType kModelType; static const unsigned int kVersion = Search::kVersion; + + // 新增的 predict_next 方法声明 + std::pair predict_next(const std::string &context) const; /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class diff --git a/python/kenlm.cpp b/python/kenlm.cpp index a45fcb51..e8dc49d0 100644 --- a/python/kenlm.cpp +++ b/python/kenlm.cpp @@ -11481,5 +11481,95 @@ static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t ival) { return PyInt_FromSize_t(ival); } +//predict_next方法对应的变更(包括:-a 到 -f 共6个改变) +// 定义 PyModel 结构体 -a +typedef struct { + PyObject_HEAD + lm::ngram::Model* model; +} PyModel; + +// 定义新的 Python 方法 -b +static PyObject* PyModel_predict_next(PyObject* self, PyObject* args) { + const char* context; + if (!PyArg_ParseTuple(args, "s", &context)) { + return NULL; + } + + std::pair result = ((PyModel*)self)->model->predict_next(context); + return Py_BuildValue("sf", result.first.c_str(), result.second); +} + +// 方法定义数组 -c +static PyMethodDef PyModel_methods[] = { + {"predict_next", (PyCFunction)PyModel_predict_next, METH_VARARGS, "Predict the next word given a context"}, + {NULL} /* Sentinel */ +}; + +// 类型定义 -d +static PyTypeObject PyModelType = { + PyVarObject_HEAD_INIT(NULL, 0) + "kenlm.Model", /* tp_name */ + sizeof(PyModel), /* tp_basicsize */ + 0, /* tp_itemsize */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Model objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + PyModel_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + 0, /* tp_new */ +}; + +// 模块定义 -e +static PyModuleDef kenlmmodule = { + PyModuleDef_HEAD_INIT, + "kenlm", + "KenLM Python bindings", + -1, + NULL, NULL, NULL, NULL, NULL +}; + +// 模块初始化函数 -f +PyMODINIT_FUNC PyInit_kenlm(void) { + PyObject* m; + + if (PyType_Ready(&PyModelType) < 0) + return NULL; + + m = PyModule_Create(&kenlmmodule); + if (m == NULL) + return NULL; + + Py_INCREF(&PyModelType); + PyModule_AddObject(m, "Model", (PyObject *)&PyModelType); + return m; +} #endif /* Py_PYTHON_H */ diff --git a/python/kenlm.pyx b/python/kenlm.pyx index a6984a88..65c4e8d0 100644 --- a/python/kenlm.pyx +++ b/python/kenlm.pyx @@ -1,6 +1,12 @@ import os cimport _kenlm +# 在 cdef extern from 块中声明 C++ 方法(predict_next) +cdef extern from "lm/model.hh": + cdef cppclass Model: + pair[string, float] predict_next(const string &context) const + + cdef bytes as_str(data): if isinstance(data, bytes): return data @@ -127,6 +133,10 @@ cdef class Model: cdef public bytes path cdef _kenlm.const_Vocabulary* vocab + // predict_next + cdef cppclass GenericModel: + unordered_map[string, float] predict_next(vector[string] context) # Adjusted for predict_next + def __init__(self, path, Config config = Config()): """ Load the language model. @@ -150,6 +160,38 @@ cdef class Model: def __get__(self): return self.model.Order() + # def predict_next(self, context): + # """ + # 预测给定上下文后的下一个最可能的词。 + + # :param context: 上下文句子 + # :return: 预测的下一个词和其对应的概率 + # """ + # cdef State state = State() + # cdef State out_state = State() + # cdef float max_prob = float('-inf') + # cdef str best_word = "" + + # # 初始化状态 + # words = as_str(context).split() + # self.BeginSentenceWrite(state) + + # # 处理上下文 + # for word in words: + # self.BaseScore(state, word, out_state) + # state = out_state + + # # 对词表中的每个词计算概率 + # for i in range(self.vocab.Bound()): + # word = self.vocab.Word(i) + # prob = self.BaseScore(state, word, out_state) + + # if prob > max_prob: + # max_prob = prob + # best_word = word + + # return best_word, max_prob + def score(self, sentence, bos = True, eos = True): """ Return the log10 probability of a string. By default, the string is @@ -274,6 +316,18 @@ cdef class Model: cdef _kenlm.FullScoreReturn ret = self.model.BaseFullScore(&in_state._c_state, wid, &out_state._c_state) return FullScoreReturn(ret.prob, ret.ngram_length, wid == 0) + # # 增加predict_next方法 + # def predict_next(self, context): + # cdef string cpp_context = context.encode('utf-8') + # result = self.thisptr.predict_next(cpp_context) + # return result.first.decode('utf-8'), result.second + + // predict_next + def predict_next_python(self, context): + cdef list context_words = context.split() + cdef unordered_map[string, float] probs = self.c_model.predict_next([word.encode() for word in context_words]) + return {word.decode(): prob for word, prob in probs.items()} + def __contains__(self, word): cdef bytes w = as_str(word) return (self.vocab.Index(w) != 0)