From b399e55b6ac0addd8f8cd767d2ad4e7aed03269a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=81=E8=A1=8C?= Date: Wed, 10 Jan 2024 20:18:10 +0800 Subject: [PATCH] support bge-large-zh embedding model. --- CMakeLists.txt | 2 + demo/embedding_demo.cpp | 45 ++++++++++++++++ include/llm.hpp | 61 ++++++++++++++++++++- src/llm.cpp | 116 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 222 insertions(+), 2 deletions(-) create mode 100644 demo/embedding_demo.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 20504ce9..7f95d8ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,6 +64,7 @@ else() # cli demo add_executable(cli_demo ${CMAKE_CURRENT_LIST_DIR}/demo/cli_demo.cpp) add_executable(web_demo ${CMAKE_CURRENT_LIST_DIR}/demo/web_demo.cpp) + add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp) if (MSVC) target_link_libraries(cli_demo llm) target_link_libraries(web_demo llm pthreadVC2) @@ -71,6 +72,7 @@ else() file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/libs/ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/Debug/) else() target_link_libraries(cli_demo llm) + target_link_libraries(embedding_demo llm) target_link_libraries(web_demo llm pthread) endif() endif() diff --git a/demo/embedding_demo.cpp b/demo/embedding_demo.cpp new file mode 100644 index 00000000..c97f72a4 --- /dev/null +++ b/demo/embedding_demo.cpp @@ -0,0 +1,45 @@ +// +// embedding_demo.cpp +// +// Created by MNN on 2024/01/10. +// ZhaodeWang +// + +#include "llm.hpp" +#include +#include + +static void dumpVARP(VARP var) { + auto size = var->getInfo()->size; + auto ptr = var->readMap(); + printf("[ "); + for (int i = 0; i < 5; i++) { + printf("%f, ", ptr[i]); + } + printf("... "); + for (int i = size - 5; i < size; i++) { + printf("%f, ", ptr[i]); + } + printf(" ]\n"); +} + +int main(int argc, const char* argv[]) { + if (argc < 2) { + std::cout << "Usage: " << argv[0] << " model.mnn" << std::endl; + return 0; + } + std::string model_dir = argv[1]; + std::cout << "model path is " << model_dir << std::endl; + std::unique_ptr embedding(Embedding::createEmbedding(model_dir)); + embedding->load(model_dir); + auto vec_0 = embedding->embedding("在春暖花开的季节,走在樱花缤纷的道路上,人们纷纷拿出手机拍照留念。樱花树下,情侣手牵手享受着这绝美的春光。孩子们在树下追逐嬉戏,脸上洋溢着纯真的笑容。春天的气息在空气中弥漫,一切都显得那么生机勃勃,充满希望。"); + auto vec_1 = embedding->embedding("春天到了,樱花树悄然绽放,吸引了众多游客前来观赏。小朋友们在花瓣飘落的树下玩耍,而恋人们则在这浪漫的景色中尽情享受二人世界。每个人的脸上都挂着幸福的笑容,仿佛整个世界都被春天温暖的阳光和满树的樱花渲染得更加美好。"); + auto vec_2 = embedding->embedding("在炎热的夏日里,沙滩上的游客们穿着泳装享受着海水的清凉。孩子们在海边堆沙堡,大人们则在太阳伞下品尝冷饮,享受悠闲的时光。远处,冲浪者们挑战着波涛,体验着与海浪争斗的刺激。夏天的海滩,总是充满了活力和热情。"); + dumpVARP(vec_0); + dumpVARP(vec_1); + dumpVARP(vec_2); + printf("dist_0_1: %f\n", Embedding::dist(vec_0, vec_1)); + printf("dist_0_2: %f\n", Embedding::dist(vec_0, vec_2)); + printf("dist_1_2: %f\n", Embedding::dist(vec_1, vec_2)); + return 0; +} diff --git a/include/llm.hpp b/include/llm.hpp index eb9c11d5..ff8e73e9 100644 --- a/include/llm.hpp +++ b/include/llm.hpp @@ -27,8 +27,8 @@ using namespace MNN; using namespace Express; class Tokenizer; +// Llm start // llm stream buffer with callback - class LlmStreamBuffer : public std::streambuf { public: using CallBack = std::function;; @@ -217,4 +217,63 @@ class Llama2_7b : public Llm { virtual bool is_stop(int token_id) override; }; +// Llm end + +// Embedding start +class Embedding { +public: + Embedding() { + // default tokenier is Tiktoken + tokenizer_.reset(new Tiktoken); + } + virtual ~Embedding() { + module_.reset(); + runtime_manager_.reset(); + } + static Embedding* createEmbedding(const std::string& path, std::string model_type = "auto"); + static float dist(VARP var0, VARP var1); + void load(const std::string& model_dir); + VARP embedding(const std::string& txt); + void print_speed(); +public: + // time + int64_t embedding_us_ = 0; + int prompt_len_ = 0; +protected: + std::vector tokenizer_encode(const std::string& input_str); +protected: + // model configs + int layer_nums_ = 0; + int hidden_size_ = 1024; + std::string model_name_ = ""; + // tokenizer + std::unique_ptr tokenizer_; +private: + virtual std::vector tokenizer(const std::string& query) = 0; + virtual VARP gen_attention_mask(int seq_len) = 0; + virtual VARP gen_position_ids(int seq_len) = 0; +private: + // MNN Modules + std::shared_ptr runtime_manager_; + std::shared_ptr module_; + // model dir + std::string model_dir_; +}; + +// some embedding models +class Bge : public Embedding { +public: + Bge() { + model_name_ = "Bge"; + layer_nums_ = 24; + hidden_size_ = 1024; + } +private: + virtual std::vector tokenizer(const std::string& query) override; + virtual VARP gen_attention_mask(int seq_len) override; + virtual VARP gen_position_ids(int seq_len) override; +}; + +// Embedding end + #endif // LLM_hpp diff --git a/src/llm.cpp b/src/llm.cpp index a8e77f96..a89ddfb9 100644 --- a/src/llm.cpp +++ b/src/llm.cpp @@ -20,6 +20,7 @@ #include #endif +// Llm start Llm* Llm::createLLM(const std::string& path, std::string model_type) { auto size = path.size(); @@ -192,6 +193,11 @@ void Llm::load(const std::string& model_dir) { printf("load tokenizer\n"); // 1. load vocab std::string tokenizer_path = model_dir + "/tokenizer.txt"; + if (is_single_) { + size_t pos = model_dir.find_last_of("/\\"); + std::string dir_path = (pos != std::string::npos) ? model_dir.substr(0, pos + 1) : ""; + tokenizer_path = dir_path + "/tokenizer.txt"; + } load_progress_ += 5.f; tokenizer_->load(tokenizer_path); load_progress_ += 5.f; @@ -674,4 +680,112 @@ bool Llama2_7b::is_stop(int token_id) { return token_id == 2 || token_id == 103028; } return token_id == 2; -} \ No newline at end of file +} +// Llm end + +// Embedding start +float Embedding::dist(VARP var0, VARP var1) { + auto distVar = _ReduceSum(_Square(var0 - var1)); + auto dist = distVar->readMap()[0]; + return dist; +} + +Embedding* Embedding::createEmbedding(const std::string& path, std::string model_type) { + auto size = path.size(); + + Embedding* embedding = nullptr; + if (model_type == "auto") { + model_type = path; + } + if (model_type.find("bge") != std::string::npos) { + embedding = new Bge; + } + if (!embedding) { + std::cerr << "model type can't judge!" << std::endl; + return embedding; + } + std::cout << "### model name : "<< embedding->model_name_ << std::endl; + return embedding; +} + +void Embedding::load(const std::string& model_dir) { + model_dir_ = model_dir; + // init + ScheduleConfig config; + BackendConfig cpuBackendConfig; + config.type = MNN_FORWARD_CPU; + // config.type = MNN_FORWARD_OPENCL; + config.numThread = 4; + cpuBackendConfig.precision = BackendConfig::Precision_Low; + cpuBackendConfig.memory = BackendConfig::Memory_Low; + config.backendConfig = &cpuBackendConfig; + runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config)); + printf("load tokenizer\n"); + // 1. load vocab + size_t pos = model_dir.find_last_of("/\\"); + std::string dir_path = (pos != std::string::npos) ? model_dir.substr(0, pos + 1) : ""; + std::string tokenizer_path = dir_path + "/tokenizer.txt"; + tokenizer_->load(tokenizer_path); + printf("load tokenizer Done\n"); + // 2. load model + Module::Config module_config; + module_config.shapeMutable = true; + module_config.rearrange = true; + std::string model_path = model_dir; + MNN_PRINT("load %s ... ", model_path.c_str()); + module_.reset(Module::load( + {"input_ids", "attention_mask", "position_ids"}, + {"sentence_embeddings"}, model_path.c_str(), runtime_manager_, &module_config)); + MNN_PRINT("Done!\n"); +} + +VARP Embedding::embedding(const std::string& txt) { + auto ids = tokenizer(txt); + prompt_len_ = ids.size(); + auto inputs_ids = _Const(ids.data(), {prompt_len_}, NCHW, halide_type_of()); + auto attention_mask = gen_attention_mask(prompt_len_); + auto position_ids = gen_position_ids(prompt_len_); + auto outputs = module_->onForward({inputs_ids, attention_mask, position_ids}); + auto sentence_embeddings = outputs[0]; + return sentence_embeddings; +} + +void Embedding::print_speed() { + auto total_s = embedding_us_ * 1e-6; + printf("\n#################################\n"); + printf(" total token = %d\n", prompt_len_); + printf(" total time = %.2f s\n", total_s); + printf(" total speed = %.2f tok/s\n", prompt_len_ / total_s); + printf("##################################\n"); +} + +std::vector Embedding::tokenizer_encode(const std::string& input_str) { + auto ids = tokenizer_->encode(input_str); + return ids; +} + +std::vector Bge::tokenizer(const std::string& query) { + auto ids = tokenizer_encode(query); + ids.insert(ids.begin(), 101); + ids.push_back(102); + return ids; +} + +VARP Bge::gen_attention_mask(int seq_len) { + auto attention_mask = _Input({1, 1, 1, seq_len}, NCHW, halide_type_of()); + auto ptr = attention_mask->writeMap(); + for (int i = 0; i < seq_len; i++) { + ptr[i] = 1; + } + return attention_mask; +} + +VARP Bge::gen_position_ids(int seq_len) { + auto position_ids = _Input({1, seq_len}, NCHW, halide_type_of()); + auto ptr = position_ids->writeMap(); + for (int i = 0; i < seq_len; i++) { + ptr[i] = i; + } + return position_ids; +} +// Embedding end \ No newline at end of file