Skip to content

Commit

Permalink
support qwen2
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Feb 1, 2024
1 parent 01ec5f2 commit c8a540b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
13 changes: 13 additions & 0 deletions include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,19 @@ class Llama2_7b : public Llm {
virtual bool is_stop(int token_id) override;
};

class Qwen2_4b : public Llama2_7b {
public:
Qwen2_4b() {
model_name_ = "Qwen2_4b";
layer_nums_ = 40;
key_value_shape_ = {2, 1, 20, 0, 128};
hidden_size_ = 2560;
}
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual bool is_stop(int token_id) override;
};

class TinyLlama : public Llama2_7b {
public:
TinyLlama() {
Expand Down
16 changes: 16 additions & 0 deletions src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ Llm* Llm::createLLM(const std::string& path, std::string model_type) {
} else if (model_type.find("codegeex2") != std::string::npos) {
llm = new Chatglm2_6b;
llm->model_name_ = "Codegeex2_6b";
} else if (model_type.find("qwen2") != std::string::npos) {
if (model_type.find("4") != std::string::npos) {
llm = new Qwen2_4b;
}
} else if (model_type.find("qwen") != std::string::npos) {
if (model_type.find("1.8") != std::string::npos) {
llm = new Qwen_1_8b;
Expand Down Expand Up @@ -717,6 +721,18 @@ bool Llama2_7b::is_stop(int token_id) {
return token_id == 2;
}

std::vector<int> Qwen2_4b::tokenizer(const std::string& query) {
auto ids = tokenizer_encode(query);
// auto prompt = "\n<|im_start|>user\n" + query + "<|im_end|>\n<|im_start|>assistant\n";
ids.insert(ids.begin(), {198, 151644, 872, 198});
ids.insert(ids.end(), {151645, 198, 151644, 77091, 198});
return ids;
}

bool Qwen2_4b::is_stop(int token_id) {
return token_id == 151645 || token_id == 151643;
}

std::vector<int> TinyLlama::tokenizer(const std::string& query) {
auto ids = tokenizer_encode(query);
/*
Expand Down

0 comments on commit c8a540b

Please sign in to comment.