Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update kenlm.pyx #452

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions lm/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
#include <functional>
#include <numeric>
#include <cmath>

//predict_next函数所需包含的头文件
#include <limits>
#include <sstream>
#include <string>
#include "enumerate_vocab.hh" // Include EnumerateVocab header
#include "vocab.hh" // Include Vocabulary header
#include "../util/string_piece.hh" // Include StringPiece header

namespace lm {
namespace ngram {
Expand Down Expand Up @@ -323,6 +330,53 @@ template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>,
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>;
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;

//predict_next方法

// Define a class for enumerating the vocabulary
class VocabEnumerator : public lm::EnumerateVocab {
public:
VocabEnumerator(const lm::ngram::Vocabulary &vocab) : vocab_(vocab) {}

void Add(lm::WordIndex index, const lm::StringPiece &str) override {
vocab_map_[index] = str.as_string();
}

const std::unordered_map<lm::WordIndex, std::string>& GetVocabMap() const {
return vocab_map_;
}

private:
const lm::ngram::Vocabulary &vocab_;
std::unordered_map<lm::WordIndex, std::string> vocab_map_;
};

template <class Search, class VocabularyT>
std::unordered_map<std::string, float> GenericModel<Search, VocabularyT>::predict_next(const std::string &context) const {
// Convert context to WordIndex sequence
std::vector<WordIndex> context_words;
std::istringstream iss(context);
std::string word;
while (iss >> word) {
context_words.push_back(this->vocab_.Index(word));
}

// Initialize state
State state;
this->GetState(&context_words[0], &context_words[0] + context_words.size(), &state);

std::unordered_map<std::string, float> word_probs;

// Calculate the score for each word in the vocabulary
for (WordIndex i = 0; i < this->vocab_.Size(); ++i) {
State out_state;
FullScoreReturn ret = this->FullScore(state, i, out_state);
std::string word = this->vocab_.Word(i);
word_probs[word] = ret.prob;
}

return word_probs;
}

} // namespace detail

base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) {
Expand Down
3 changes: 3 additions & 0 deletions lm/model.hh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
static const ModelType kModelType;

static const unsigned int kVersion = Search::kVersion;

// 新增的 predict_next 方法声明
std::pair<std::string, float> predict_next(const std::string &context) const;

/* Get the size of memory that will be mapped given ngram counts. This
* does not include small non-mapped control structures, such as this class
Expand Down
90 changes: 90 additions & 0 deletions python/kenlm.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 54 additions & 0 deletions python/kenlm.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import os
cimport _kenlm

# 在 cdef extern from 块中声明 C++ 方法(predict_next)
cdef extern from "lm/model.hh":
cdef cppclass Model:
pair[string, float] predict_next(const string &context) const


cdef bytes as_str(data):
if isinstance(data, bytes):
return data
Expand Down Expand Up @@ -127,6 +133,10 @@ cdef class Model:
cdef public bytes path
cdef _kenlm.const_Vocabulary* vocab

// predict_next
cdef cppclass GenericModel:
unordered_map[string, float] predict_next(vector[string] context) # Adjusted for predict_next

def __init__(self, path, Config config = Config()):
"""
Load the language model.
Expand All @@ -150,6 +160,38 @@ cdef class Model:
def __get__(self):
return self.model.Order()

# def predict_next(self, context):
# """
# 预测给定上下文后的下一个最可能的词。

# :param context: 上下文句子
# :return: 预测的下一个词和其对应的概率
# """
# cdef State state = State()
# cdef State out_state = State()
# cdef float max_prob = float('-inf')
# cdef str best_word = ""

# # 初始化状态
# words = as_str(context).split()
# self.BeginSentenceWrite(state)

# # 处理上下文
# for word in words:
# self.BaseScore(state, word, out_state)
# state = out_state

# # 对词表中的每个词计算概率
# for i in range(self.vocab.Bound()):
# word = self.vocab.Word(i)
# prob = self.BaseScore(state, word, out_state)

# if prob > max_prob:
# max_prob = prob
# best_word = word

# return best_word, max_prob

def score(self, sentence, bos = True, eos = True):
"""
Return the log10 probability of a string. By default, the string is
Expand Down Expand Up @@ -274,6 +316,18 @@ cdef class Model:
cdef _kenlm.FullScoreReturn ret = self.model.BaseFullScore(&in_state._c_state, wid, &out_state._c_state)
return FullScoreReturn(ret.prob, ret.ngram_length, wid == 0)

# # 增加predict_next方法
# def predict_next(self, context):
# cdef string cpp_context = context.encode('utf-8')
# result = self.thisptr.predict_next(cpp_context)
# return result.first.decode('utf-8'), result.second

// predict_next
def predict_next_python(self, context):
cdef list context_words = context.split()
cdef unordered_map[string, float] probs = self.c_model.predict_next([word.encode() for word in context_words])
return {word.decode(): prob for word, prob in probs.items()}

def __contains__(self, word):
cdef bytes w = as_str(word)
return (self.vocab.Index(w) != 0)
Expand Down