diff --git a/include/llm.hpp b/include/llm.hpp index 6d17c50d..015b64ee 100644 --- a/include/llm.hpp +++ b/include/llm.hpp @@ -248,6 +248,17 @@ class TinyLlama : public Llama2_7b { private: virtual std::vector tokenizer(const std::string& query) override; }; + +class Yi_6b : public Llama2_7b { +public: + Yi_6b() { + model_name_ = "Yi_6b"; + key_value_shape_ = {2, 1, 4, 0, 128}; + } +private: + virtual std::vector tokenizer(const std::string& query) override; + virtual bool is_stop(int token_id) override; +}; // Llm end // Embedding start diff --git a/src/llm.cpp b/src/llm.cpp index 378c7c48..b385bed9 100644 --- a/src/llm.cpp +++ b/src/llm.cpp @@ -68,6 +68,9 @@ Llm* Llm::createLLM(const std::string& path, std::string model_type) { } else if (model_type.find("tinyllama") != std::string::npos) { llm = new TinyLlama; llm->model_name_ = "TinyLlama"; + } else if (model_type.find("yi") != std::string::npos) { + llm = new Yi_6b; + llm->model_name_ = "Yi_6b"; } if (!llm) { std::cerr << "model type can't judge!" << std::endl; @@ -716,6 +719,16 @@ std::vector TinyLlama::tokenizer(const std::string& query) { ids.insert(ids.end(), {2, 29871, 13, 29966, 29989, 465, 22137, 29989, 29958, 13}); return ids; } + +std::vector Yi_6b::tokenizer(const std::string& query) { + auto prompt = "<|im_start|> user\n" + query + "<|im_end|>\n<|im_start|> assistant\n"; + auto ids = tokenizer_encode(prompt); + return ids; +} + +bool Yi_6b::is_stop(int token_id) { + return token_id == 7 || token_id == 64001; +} // Llm end // Embedding start