Skip to content

Commit

Permalink
support bge-large-zh embedding model.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Jan 10, 2024
1 parent 67d21e1 commit b399e55
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ else()
# cli demo
add_executable(cli_demo ${CMAKE_CURRENT_LIST_DIR}/demo/cli_demo.cpp)
add_executable(web_demo ${CMAKE_CURRENT_LIST_DIR}/demo/web_demo.cpp)
add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp)
if (MSVC)
target_link_libraries(cli_demo llm)
target_link_libraries(web_demo llm pthreadVC2)
# copy all lib to target dir
file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/libs/ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/Debug/)
else()
target_link_libraries(cli_demo llm)
target_link_libraries(embedding_demo llm)
target_link_libraries(web_demo llm pthread)
endif()
endif()
45 changes: 45 additions & 0 deletions demo/embedding_demo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//
// embedding_demo.cpp
//
// Created by MNN on 2024/01/10.
// ZhaodeWang
//

#include "llm.hpp"
#include <fstream>
#include <stdlib.h>

static void dumpVARP(VARP var) {
auto size = var->getInfo()->size;
auto ptr = var->readMap<float>();
printf("[ ");
for (int i = 0; i < 5; i++) {
printf("%f, ", ptr[i]);
}
printf("... ");
for (int i = size - 5; i < size; i++) {
printf("%f, ", ptr[i]);
}
printf(" ]\n");
}

int main(int argc, const char* argv[]) {
if (argc < 2) {
std::cout << "Usage: " << argv[0] << " model.mnn" << std::endl;
return 0;
}
std::string model_dir = argv[1];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<Embedding> embedding(Embedding::createEmbedding(model_dir));
embedding->load(model_dir);
auto vec_0 = embedding->embedding("在春暖花开的季节,走在樱花缤纷的道路上,人们纷纷拿出手机拍照留念。樱花树下,情侣手牵手享受着这绝美的春光。孩子们在树下追逐嬉戏,脸上洋溢着纯真的笑容。春天的气息在空气中弥漫,一切都显得那么生机勃勃,充满希望。");
auto vec_1 = embedding->embedding("春天到了,樱花树悄然绽放,吸引了众多游客前来观赏。小朋友们在花瓣飘落的树下玩耍,而恋人们则在这浪漫的景色中尽情享受二人世界。每个人的脸上都挂着幸福的笑容,仿佛整个世界都被春天温暖的阳光和满树的樱花渲染得更加美好。");
auto vec_2 = embedding->embedding("在炎热的夏日里,沙滩上的游客们穿着泳装享受着海水的清凉。孩子们在海边堆沙堡,大人们则在太阳伞下品尝冷饮,享受悠闲的时光。远处,冲浪者们挑战着波涛,体验着与海浪争斗的刺激。夏天的海滩,总是充满了活力和热情。");
dumpVARP(vec_0);
dumpVARP(vec_1);
dumpVARP(vec_2);
printf("dist_0_1: %f\n", Embedding::dist(vec_0, vec_1));
printf("dist_0_2: %f\n", Embedding::dist(vec_0, vec_2));
printf("dist_1_2: %f\n", Embedding::dist(vec_1, vec_2));
return 0;
}
61 changes: 60 additions & 1 deletion include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ using namespace MNN;
using namespace Express;
class Tokenizer;

// Llm start
// llm stream buffer with callback

class LlmStreamBuffer : public std::streambuf {
public:
using CallBack = std::function<void(const char* str, size_t len)>;;
Expand Down Expand Up @@ -217,4 +217,63 @@ class Llama2_7b : public Llm {
virtual bool is_stop(int token_id) override;
};

// Llm end

// Embedding start
class Embedding {
public:
Embedding() {
// default tokenier is Tiktoken
tokenizer_.reset(new Tiktoken);
}
virtual ~Embedding() {
module_.reset();
runtime_manager_.reset();
}
static Embedding* createEmbedding(const std::string& path, std::string model_type = "auto");
static float dist(VARP var0, VARP var1);
void load(const std::string& model_dir);
VARP embedding(const std::string& txt);
void print_speed();
public:
// time
int64_t embedding_us_ = 0;
int prompt_len_ = 0;
protected:
std::vector<int> tokenizer_encode(const std::string& input_str);
protected:
// model configs
int layer_nums_ = 0;
int hidden_size_ = 1024;
std::string model_name_ = "";
// tokenizer
std::unique_ptr<Tokenizer> tokenizer_;
private:
virtual std::vector<int> tokenizer(const std::string& query) = 0;
virtual VARP gen_attention_mask(int seq_len) = 0;
virtual VARP gen_position_ids(int seq_len) = 0;
private:
// MNN Modules
std::shared_ptr<Executor::RuntimeManager> runtime_manager_;
std::shared_ptr<Module> module_;
// model dir
std::string model_dir_;
};

// some embedding models
class Bge : public Embedding {
public:
Bge() {
model_name_ = "Bge";
layer_nums_ = 24;
hidden_size_ = 1024;
}
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual VARP gen_attention_mask(int seq_len) override;
virtual VARP gen_position_ids(int seq_len) override;
};

// Embedding end

#endif // LLM_hpp
116 changes: 115 additions & 1 deletion src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cv/cv.hpp>
#endif

// Llm start
Llm* Llm::createLLM(const std::string& path, std::string model_type) {
auto size = path.size();

Expand Down Expand Up @@ -192,6 +193,11 @@ void Llm::load(const std::string& model_dir) {
printf("load tokenizer\n");
// 1. load vocab
std::string tokenizer_path = model_dir + "/tokenizer.txt";
if (is_single_) {
size_t pos = model_dir.find_last_of("/\\");
std::string dir_path = (pos != std::string::npos) ? model_dir.substr(0, pos + 1) : "";
tokenizer_path = dir_path + "/tokenizer.txt";
}
load_progress_ += 5.f;
tokenizer_->load(tokenizer_path);
load_progress_ += 5.f;
Expand Down Expand Up @@ -674,4 +680,112 @@ bool Llama2_7b::is_stop(int token_id) {
return token_id == 2 || token_id == 103028;
}
return token_id == 2;
}
}
// Llm end

// Embedding start
float Embedding::dist(VARP var0, VARP var1) {
auto distVar = _ReduceSum(_Square(var0 - var1));
auto dist = distVar->readMap<float>()[0];
return dist;
}

Embedding* Embedding::createEmbedding(const std::string& path, std::string model_type) {
auto size = path.size();

Embedding* embedding = nullptr;
if (model_type == "auto") {
model_type = path;
}
if (model_type.find("bge") != std::string::npos) {
embedding = new Bge;
}
if (!embedding) {
std::cerr << "model type can't judge!" << std::endl;
return embedding;
}
std::cout << "### model name : "<< embedding->model_name_ << std::endl;
return embedding;
}

void Embedding::load(const std::string& model_dir) {
model_dir_ = model_dir;
// init
ScheduleConfig config;
BackendConfig cpuBackendConfig;
config.type = MNN_FORWARD_CPU;
// config.type = MNN_FORWARD_OPENCL;
config.numThread = 4;
cpuBackendConfig.precision = BackendConfig::Precision_Low;
cpuBackendConfig.memory = BackendConfig::Memory_Low;
config.backendConfig = &cpuBackendConfig;
runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config));
printf("load tokenizer\n");
// 1. load vocab
size_t pos = model_dir.find_last_of("/\\");
std::string dir_path = (pos != std::string::npos) ? model_dir.substr(0, pos + 1) : "";
std::string tokenizer_path = dir_path + "/tokenizer.txt";
tokenizer_->load(tokenizer_path);
printf("load tokenizer Done\n");
// 2. load model
Module::Config module_config;
module_config.shapeMutable = true;
module_config.rearrange = true;
std::string model_path = model_dir;
MNN_PRINT("load %s ... ", model_path.c_str());
module_.reset(Module::load(
{"input_ids", "attention_mask", "position_ids"},
{"sentence_embeddings"}, model_path.c_str(), runtime_manager_, &module_config));
MNN_PRINT("Done!\n");
}

VARP Embedding::embedding(const std::string& txt) {
auto ids = tokenizer(txt);
prompt_len_ = ids.size();
auto inputs_ids = _Const(ids.data(), {prompt_len_}, NCHW, halide_type_of<int>());
auto attention_mask = gen_attention_mask(prompt_len_);
auto position_ids = gen_position_ids(prompt_len_);
auto outputs = module_->onForward({inputs_ids, attention_mask, position_ids});
auto sentence_embeddings = outputs[0];
return sentence_embeddings;
}

void Embedding::print_speed() {
auto total_s = embedding_us_ * 1e-6;
printf("\n#################################\n");
printf(" total token = %d\n", prompt_len_);
printf(" total time = %.2f s\n", total_s);
printf(" total speed = %.2f tok/s\n", prompt_len_ / total_s);
printf("##################################\n");
}

std::vector<int> Embedding::tokenizer_encode(const std::string& input_str) {
auto ids = tokenizer_->encode(input_str);
return ids;
}

std::vector<int> Bge::tokenizer(const std::string& query) {
auto ids = tokenizer_encode(query);
ids.insert(ids.begin(), 101);
ids.push_back(102);
return ids;
}

VARP Bge::gen_attention_mask(int seq_len) {
auto attention_mask = _Input({1, 1, 1, seq_len}, NCHW, halide_type_of<int>());
auto ptr = attention_mask->writeMap<int>();
for (int i = 0; i < seq_len; i++) {
ptr[i] = 1;
}
return attention_mask;
}

VARP Bge::gen_position_ids(int seq_len) {
auto position_ids = _Input({1, seq_len}, NCHW, halide_type_of<int>());
auto ptr = position_ids->writeMap<int>();
for (int i = 0; i < seq_len; i++) {
ptr[i] = i;
}
return position_ids;
}
// Embedding end

0 comments on commit b399e55

Please sign in to comment.