Skip to content

Commit

Permalink
refactor embedding model.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Jun 4, 2024
1 parent c5f3ede commit 2e9d685
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 135 deletions.
72 changes: 12 additions & 60 deletions include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ class Llm {
modules_.clear();
runtime_manager_.reset();
}
void chat();
static Llm* createLLM(const std::string& config_path);
virtual void load();
void chat();
VARP forward(const std::vector<int>& input_ids);
int sample(VARP logits, const std::vector<int>& pre_ids);
std::string apply_chat_template(const std::string& input_str) const;
Expand All @@ -190,35 +190,26 @@ class Llm {
void print_speed();
friend class Pipeline;
public:
// TODO
std::string model_name_ = "";
bool is_single_ = true;
bool is_disk_embedding_ = true;
bool is_visual_ = false;
int layer_nums_ = 0;
int hidden_size_ = 4096;
// config
int max_new_tokens_ = 1024;
int backend_type_ = 0;
// forward info
int prompt_len_ = 0;
int gen_seq_len_ = 0;
int all_seq_len_ = 0;
// time
int64_t prefill_us_ = 0;
int64_t decode_us_ = 0;
bool is_single_ = true;
bool is_disk_embedding_ = true;
std::shared_ptr<LlmConfig> config_;
std::unique_ptr<Tokenizer> tokenizer_;
protected:
std::string decode(int id);
bool is_stop(int token_id);
protected:
std::vector<int> key_value_shape_ = {};
std::vector<VARP> past_key_values_;
VARP inputs_embeds_, attention_mask_, position_ids_;
std::shared_ptr<Executor::RuntimeManager> runtime_manager_;
std::vector<std::shared_ptr<Module>> modules_;
std::vector<VARP> past_key_values_;
protected:
void init_runtime();
std::string decode(int id);
bool is_stop(int token_id);
virtual std::vector<int> tokenizer(const std::string& query);
virtual VARP embedding(const std::vector<int>& input_ids);
virtual VARP gen_attention_mask(int seq_len);
Expand Down Expand Up @@ -247,53 +238,14 @@ class Lvlm : public Llm {
// Llm end

// Embedding start
class Embedding {
class Embedding : public Llm {
public:
Embedding() {
// default tokenier is Bert
tokenizer_.reset(new BertTokenizer);
}
virtual ~Embedding() {
module_.reset();
runtime_manager_.reset();
}
static Embedding* createEmbedding(const std::string& path, std::string model_type = "auto");
Embedding(std::shared_ptr<LlmConfig> config) : Llm(config) {}
static Embedding* createEmbedding(const std::string& config_path);
static float dist(VARP var0, VARP var1);
void load(const std::string& model_dir);
virtual void load() override;
VARP embedding(const std::string& txt);
void print_speed();
int dim() { return hidden_size_; }
public:
// time
int64_t embedding_us_ = 0;
int prompt_len_ = 0;
protected:
// model configs
int layer_nums_ = 0;
int hidden_size_ = 1024;
std::string model_name_ = "";
// tokenizer
std::unique_ptr<Tokenizer> tokenizer_;
private:
virtual std::vector<int> tokenizer(const std::string& query) = 0;
virtual VARP gen_attention_mask(int seq_len) = 0;
virtual VARP gen_position_ids(int seq_len) = 0;
private:
// MNN Modules
std::shared_ptr<Executor::RuntimeManager> runtime_manager_;
std::shared_ptr<Module> module_;
// model dir
std::string model_dir_;
};

// some embedding models
class Bge : public Embedding {
public:
Bge() {
model_name_ = "Bge";
layer_nums_ = 24;
hidden_size_ = 1024;
}
int dim() { return config_->hidden_size(); }
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual VARP gen_attention_mask(int seq_len) override;
Expand Down
134 changes: 59 additions & 75 deletions src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,44 @@ Llm* Llm::createLLM(const std::string& config_path) {
return llm;
}

static MNNForwardType backend_type_convert(const std::string& type_str) {
if (type_str == "cpu") return MNN_FORWARD_CPU;
if (type_str == "metal") return MNN_FORWARD_METAL;
if (type_str == "cuda") return MNN_FORWARD_CUDA;
if (type_str == "opencl") return MNN_FORWARD_OPENCL;
if (type_str == "opengl") return MNN_FORWARD_OPENGL;
if (type_str == "vulkan") return MNN_FORWARD_VULKAN;
if (type_str == "npu") return MNN_FORWARD_NN;
return MNN_FORWARD_AUTO;
}

void Llm::init_runtime() {
ScheduleConfig config;
BackendConfig cpuBackendConfig;
config.type = backend_type_convert(config_->backend_type());
config.numThread = config_->thread_num();
if (config_->memory() == "low") {
cpuBackendConfig.memory = BackendConfig::Memory_Low;
}
if (config_->precision() == "low") {
cpuBackendConfig.precision = BackendConfig::Precision_Low;
}
config.backendConfig = &cpuBackendConfig;
runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config));
runtime_manager_->setHint(MNN::Interpreter::MEM_ALLOCATOR_TYPE, 0);
}

void Llm::load() {
init_runtime();
// init module status
key_value_shape_ = config_->key_value_shape();
layer_nums_ = config_->layer_nums();
is_single_ = config_->is_single();
{
std::ifstream embedding_bin(config_->embedding_file());
is_disk_embedding_ = embedding_bin.good();
embedding_bin.close();
}
MNN_PRINT("### is_single_ = %d, is_disk_embedding_ = %d\n", is_single_, is_disk_embedding_);
// init runtime
ScheduleConfig config;
BackendConfig cpuBackendConfig;
config.type = static_cast<MNNForwardType>(backend_type_);;
config.numThread = config_->thread_num();
cpuBackendConfig.precision = BackendConfig::Precision_Low;
cpuBackendConfig.memory = BackendConfig::Memory_Low;
config.backendConfig = &cpuBackendConfig;
runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config));
runtime_manager_->setHint(MNN::Interpreter::MEM_ALLOCATOR_TYPE, 0);
// 1. load vocab
MNN_PRINT("load tokenizer\n");
tokenizer_.reset(Tokenizer::createTokenizer(config_->tokenizer_file()));
Expand All @@ -62,9 +79,10 @@ void Llm::load() {
Module::Config module_config;
module_config.shapeMutable = true;
module_config.rearrange = true;
int layer_nums = config_->layer_nums();
if (is_single_) {
// load single model
key_value_shape_.insert(key_value_shape_.begin(), layer_nums_);
key_value_shape_.insert(key_value_shape_.begin(), layer_nums);
modules_.resize(1);
std::string model_path = config_->llm_model();
MNN_PRINT("load %s ... ", model_path.c_str());
Expand All @@ -75,14 +93,14 @@ void Llm::load() {
MNN_PRINT("Done!\n");
} else {
// load split models
modules_.resize(layer_nums_ + 2);
modules_.resize(layer_nums + 2);
// load lm model
modules_[layer_nums_].reset(Module::load({}, {}, config_->lm_model().c_str(), runtime_manager_, &module_config));
modules_[layer_nums].reset(Module::load({}, {}, config_->lm_model().c_str(), runtime_manager_, &module_config));
if (!is_disk_embedding_) {
modules_[layer_nums_ + 1].reset(Module::load({}, {}, config_->embedding_model().c_str(), runtime_manager_, &module_config));
modules_[layer_nums + 1].reset(Module::load({}, {}, config_->embedding_model().c_str(), runtime_manager_, &module_config));
}
// load block models
for (int i = 0; i < layer_nums_; i++) {
for (int i = 0; i < layer_nums; i++) {
std::string model_path = config_->block_model(i);
MNN_PRINT("load %s ... ", model_path.c_str());
modules_[i].reset(Module::load(
Expand Down Expand Up @@ -129,9 +147,10 @@ VARP Llm::forward(const std::vector<int>& input_ids) {
past_key_values_[0] = outputs[1];
} else {
// split block models
int layer_nums = config_->layer_nums();
auto hidden_states = embedding(input_ids);
ExecutorScope::Current()->gc(Executor::FULL);
for (int i = 0; i < layer_nums_; i++) {
for (int i = 0; i < layer_nums; i++) {
AUTOTIME;
auto outputs = modules_[i]->onForward({hidden_states, attention_mask, position_ids, past_key_values_[i]});
hidden_states = outputs[0];
Expand All @@ -140,7 +159,7 @@ VARP Llm::forward(const std::vector<int>& input_ids) {
ExecutorScope::Current()->gc(Executor::FULL);
{
AUTOTIME;
auto outputs = modules_[layer_nums_]->onForward({hidden_states});
auto outputs = modules_[layer_nums]->onForward({hidden_states});
logits = outputs[0];
}
}
Expand Down Expand Up @@ -210,7 +229,7 @@ void Llm::generate_init() {
if (is_single_) {
past_key_values_.push_back(_Input(key_value_shape_, NCHW));
} else {
for (int i = 0; i < layer_nums_; i++) {
for (int i = 0; i < config_->layer_nums(); i++) {
past_key_values_.push_back(_Input(key_value_shape_, NCHW));
}
}
Expand Down Expand Up @@ -314,7 +333,7 @@ VARP Llm::embedding(const std::vector<int>& input_ids) {
if (!is_disk_embedding_) {
// using model forward
auto inputs_ids_ = _Const(input_ids.data(), {static_cast<int>(input_ids.size())}, NCHW, halide_type_of<int>());
auto hidden_states = modules_[layer_nums_ + 1]->onForward({inputs_ids_})[0];
auto hidden_states = modules_[config_->layer_nums() + 1]->onForward({inputs_ids_})[0];
return hidden_states;
}
AUTOTIME;
Expand Down Expand Up @@ -554,90 +573,55 @@ float Embedding::dist(VARP var0, VARP var1) {
return dist;
}

Embedding* Embedding::createEmbedding(const std::string& path, std::string model_type) {
auto size = path.size();

Embedding* embedding = nullptr;
if (model_type == "auto") {
model_type = path;
}
if (model_type.find("bge") != std::string::npos) {
embedding = new Bge;
}
if (!embedding) {
std::cerr << "model type can't judge!" << std::endl;
return embedding;
}
std::cout << "### model name : "<< embedding->model_name_ << std::endl;
embedding->load(path);
Embedding* Embedding::createEmbedding(const std::string& config_path) {
std::shared_ptr<LlmConfig> config(new LlmConfig(config_path));
Embedding* embedding = new Embedding(config);
embedding->load();
return embedding;
}

void Embedding::load(const std::string& model_dir) {
// init
ScheduleConfig config;
BackendConfig cpuBackendConfig;
config.type = MNN_FORWARD_CPU;
// config.type = MNN_FORWARD_OPENCL;
config.numThread = 4;
cpuBackendConfig.precision = BackendConfig::Precision_Low;
cpuBackendConfig.memory = BackendConfig::Memory_Low;
config.backendConfig = &cpuBackendConfig;
runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config));
void Embedding::load() {
init_runtime();
printf("load tokenizer\n");
std::cout << config_->tokenizer_file() << std::endl;
// 1. load vocab
size_t pos = model_dir.find_last_of("/\\");
std::string dir_path = (pos != std::string::npos) ? model_dir.substr(0, pos + 1) : "";
std::string tokenizer_path = dir_path + "/tokenizer.txt";
tokenizer_.reset(Tokenizer::createTokenizer(config_->tokenizer_file()));
printf("load tokenizer Done\n");
// 2. load model
Module::Config module_config;
module_config.shapeMutable = true;
module_config.rearrange = true;
std::string model_path = model_dir;
auto model_path = config_->llm_model();
MNN_PRINT("load %s ... ", model_path.c_str());
module_.reset(Module::load(
modules_.resize(1);
modules_[0].reset(Module::load(
{"input_ids", "attention_mask", "position_ids"},
{"sentence_embeddings"}, model_path.c_str(), runtime_manager_, &module_config));
MNN_PRINT("Done!\n");
}

VARP Embedding::embedding(const std::string& txt) {
auto ids = tokenizer(txt);
prompt_len_ = ids.size();
auto inputs_ids = _Const(ids.data(), {prompt_len_}, NCHW, halide_type_of<int>());
auto attention_mask = gen_attention_mask(prompt_len_);
auto position_ids = gen_position_ids(prompt_len_);
auto st = std::chrono::system_clock::now();
auto outputs = module_->onForward({inputs_ids, attention_mask, position_ids});
auto et = std::chrono::system_clock::now();
embedding_us_ = std::chrono::duration_cast<std::chrono::microseconds>(et - st).count();
int prompt_len = ids.size();
auto inputs_ids = _Const(ids.data(), {prompt_len}, NCHW, halide_type_of<int>());
auto attention_mask = gen_attention_mask(prompt_len);
auto position_ids = gen_position_ids(prompt_len);
auto outputs = modules_[0]->onForward({inputs_ids, attention_mask, position_ids});
auto sentence_embeddings = outputs[0];
// print_speed();
return sentence_embeddings;
}

void Embedding::print_speed() {
auto total_s = embedding_us_ * 1e-6;
printf("\n#################################\n");
printf(" total token = %d\n", prompt_len_);
printf(" total time = %.2f s\n", total_s);
printf(" total speed = %.2f tok/s\n", prompt_len_ / total_s);
printf("##################################\n");
}

std::vector<int> Bge::tokenizer(const std::string& query) {
std::vector<int> Embedding::tokenizer(const std::string& query) {
auto prompt = query;
if (query.size() <= 256) {
prompt = "为这个句子生成表示以用于检索相关文章:" + query;
}
prompt = apply_chat_template(prompt);
auto ids = tokenizer_->encode(prompt);
ids.insert(ids.begin(), 101);
ids.push_back(102);
return ids;
}

VARP Bge::gen_attention_mask(int seq_len) {
VARP Embedding::gen_attention_mask(int seq_len) {
auto attention_mask = _Input({1, 1, 1, seq_len}, NCHW, halide_type_of<int>());
auto ptr = attention_mask->writeMap<int>();
for (int i = 0; i < seq_len; i++) {
Expand All @@ -646,7 +630,7 @@ VARP Bge::gen_attention_mask(int seq_len) {
return attention_mask;
}

VARP Bge::gen_position_ids(int seq_len) {
VARP Embedding::gen_position_ids(int seq_len) {
auto position_ids = _Input({1, seq_len}, NCHW, halide_type_of<int>());
auto ptr = position_ids->writeMap<int>();
for (int i = 0; i < seq_len; i++) {
Expand Down

0 comments on commit 2e9d685

Please sign in to comment.