From c8a540bbb9e8f5d62c25eddd1b494a5d98ca652f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=81=E8=A1=8C?= Date: Thu, 1 Feb 2024 15:41:13 +0800 Subject: [PATCH] support qwen2 --- include/llm.hpp | 13 +++++++++++++ src/llm.cpp | 16 ++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/include/llm.hpp b/include/llm.hpp index 015b64ee..f98c6397 100644 --- a/include/llm.hpp +++ b/include/llm.hpp @@ -238,6 +238,19 @@ class Llama2_7b : public Llm { virtual bool is_stop(int token_id) override; }; +class Qwen2_4b : public Llama2_7b { +public: + Qwen2_4b() { + model_name_ = "Qwen2_4b"; + layer_nums_ = 40; + key_value_shape_ = {2, 1, 20, 0, 128}; + hidden_size_ = 2560; + } +private: + virtual std::vector tokenizer(const std::string& query) override; + virtual bool is_stop(int token_id) override; +}; + class TinyLlama : public Llama2_7b { public: TinyLlama() { diff --git a/src/llm.cpp b/src/llm.cpp index 129b7f1d..a4a8127f 100644 --- a/src/llm.cpp +++ b/src/llm.cpp @@ -47,6 +47,10 @@ Llm* Llm::createLLM(const std::string& path, std::string model_type) { } else if (model_type.find("codegeex2") != std::string::npos) { llm = new Chatglm2_6b; llm->model_name_ = "Codegeex2_6b"; + } else if (model_type.find("qwen2") != std::string::npos) { + if (model_type.find("4") != std::string::npos) { + llm = new Qwen2_4b; + } } else if (model_type.find("qwen") != std::string::npos) { if (model_type.find("1.8") != std::string::npos) { llm = new Qwen_1_8b; @@ -717,6 +721,18 @@ bool Llama2_7b::is_stop(int token_id) { return token_id == 2; } +std::vector Qwen2_4b::tokenizer(const std::string& query) { + auto ids = tokenizer_encode(query); + // auto prompt = "\n<|im_start|>user\n" + query + "<|im_end|>\n<|im_start|>assistant\n"; + ids.insert(ids.begin(), {198, 151644, 872, 198}); + ids.insert(ids.end(), {151645, 198, 151644, 77091, 198}); + return ids; +} + +bool Qwen2_4b::is_stop(int token_id) { + return token_id == 151645 || token_id == 151643; +} + std::vector TinyLlama::tokenizer(const std::string& query) { auto ids = tokenizer_encode(query); /*