Skip to content

Commit

Permalink
support llama2-7b-chat
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Sep 20, 2023
1 parent bae1544 commit e38a0c6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,32 @@ llm模型导出onnx模型请使用[llm-export](https://github.com/wangzhaode/llm
| codegeex2-6b | [![Download][download-codegeex2-6b-onnx]][release-codegeex2-6b-onnx] | [![Download][download-codegeex2-6b-mnn]][release-codegeex2-6b-mnn] |
| Qwen-7B-Chat | [![Download][download-qwen-7b-chat-onnx]][release-qwen-7b-chat-onnx] | [![Download][download-qwen-7b-chat-mnn]][release-qwen-7b-chat-mnn] |
| Baichuan2-7B-Chat | [![Download][download-baichuan2-7b-chat-onnx]][release-baichuan2-7b-chat-onnx] | [![Download][download-baichuan2-7b-chat-mnn]][release-baichuan2-7b-chat-mnn] |
| Llama-2-7b-chat | [![Download][download-llama2-7b-chat-onnx]][release-llama2-7b-chat-onnx] | [![Download][download-llama2-7b-chat-mnn]][release-llama2-7b-chat-mnn] |

[download-chatglm-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm-6b-onnx/total
[download-chatglm2-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm2-6b-onnx/total
[download-codegeex2-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/codegeex2-6b-onnx/total
[download-qwen-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen-7b-chat-onnx/total
[download-baichuan2-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/baichuan2-7b-chat-onnx/total
[download-llama2-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/llama2-7b-chat-onnx/total
[release-chatglm-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm-6b-onnx
[release-chatglm2-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm2-6b-onnx
[release-codegeex2-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/codegeex2-6b-onnx
[release-qwen-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen-7b-chat-onnx
[release-baichuan2-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/baichuan2-7b-chat-onnx
[release-llama2-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/llama2-7b-chat-onnx
[download-chatglm-6b-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/chatglm-6b-mnn/total
[download-chatglm2-6b-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/chatglm2-6b-mnn/total
[download-codegeex2-6b-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/codegeex2-6b-mnn/total
[download-qwen-7b-chat-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/qwen-7b-chat-mnn/total
[download-baichuan2-7b-chat-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/baichuan2-7b-chat-mnn/total
[download-llama2-7b-chat-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-llm/llama2-7b-chat-mnn/total
[release-chatglm-6b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/chatglm-6b-mnn
[release-chatglm2-6b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/chatglm2-6b-mnn
[release-codegeex2-6b-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/codegeex2-6b-mnn
[release-qwen-7b-chat-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/qwen-7b-chat-mnn
[release-baichuan2-7b-chat-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/baichuan2-7b-chat-mnn
[release-llama2-7b-chat-mnn]: https://github.com/wangzhaode/mnn-llm/releases/tag/llama2-7b-chat-mnn


### 下载int4模型
Expand Down
6 changes: 3 additions & 3 deletions include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ class Qwen_7b : public Llm {
virtual bool is_stop(int token_id) override;
};

class Baichuan2_7b : public Llm {
class Llama2_7b : public Llm {
public:
Baichuan2_7b() {
model_name_ = "Baichuan2_7b";
Llama2_7b() {
model_name_ = "Llama2_7b";
layer_nums_ = 32;
key_value_shape_ = {2, 1, 32, 0, 128};
}
Expand Down
26 changes: 18 additions & 8 deletions src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,11 @@ Llm* Llm::createLLM(const std::string& path) {
llm->model_name_ = "Codegeex2_6b";
} else if (path.find("qwen") != std::string::npos) {
llm = new Qwen_7b;
} else if (path.find("llama2") != std::string::npos) {
llm = new Llama2_7b;
} else if (path.find("baichuan") != std::string::npos) {
llm = new Baichuan2_7b;
llm = new Llama2_7b;
llm->model_name_ = "Baichuan2_7b";
}
llm->is_single_ = is_single;
return llm;
Expand Down Expand Up @@ -400,15 +403,22 @@ bool Qwen_7b::is_stop(int token_id) {
return token_id >= 151645;
}

// Baichuan2_7b
std::vector<int> Baichuan2_7b::tokenizer(const std::string& query) {
// Llama2_7b
std::vector<int> Llama2_7b::tokenizer(const std::string& query) {
auto ids = tokenizer_encode(query);
ids.insert(ids.begin(), 195);
ids.push_back(196);
if (model_name_ == "Baichuan2_7b") {
// baichuan2: <reserved_106>{query}<reserved_107>: 195, query, 196
ids.insert(ids.begin(), 195);
ids.push_back(196);
return ids;
}
// llama2: <bos>[INST]{query}[/INST]: 1, 5539, 25580, 29962, query, 12452, 25580, 29962
ids.insert(ids.begin(), {1, 5539, 25580, 29962});
ids.insert(ids.end(), {12452, 25580, 29962});
return ids;
}

VARP Baichuan2_7b::gen_attention_mask(int seq_len) {
VARP Llama2_7b::gen_attention_mask(int seq_len) {
if (seq_len == 1) {
auto attention_mask = _Input({1, 1, 1, all_seq_len_ + 1}, NCHW, halide_type_of<float>());
auto ptr = attention_mask->writeMap<float>();
Expand All @@ -428,7 +438,7 @@ VARP Baichuan2_7b::gen_attention_mask(int seq_len) {
}
}

VARP Baichuan2_7b::gen_position_ids(int seq_len) {
VARP Llama2_7b::gen_position_ids(int seq_len) {
auto position_ids = _Input({1, seq_len}, NCHW, halide_type_of<int>());
auto ptr = position_ids->writeMap<int>();
if (seq_len == 1) {
Expand All @@ -441,6 +451,6 @@ VARP Baichuan2_7b::gen_position_ids(int seq_len) {
return position_ids;
}

bool Baichuan2_7b::is_stop(int token_id) {
bool Llama2_7b::is_stop(int token_id) {
return token_id == 2;
}

0 comments on commit e38a0c6

Please sign in to comment.