Skip to content

Commit

Permalink
refactor tokenizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed May 11, 2024
1 parent 72b7209 commit db99291
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 211 deletions.
8 changes: 4 additions & 4 deletions demo/tokenizer_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ int main(int argc, const char* argv[]) {
return 0;
}
std::string tokenizer_path = argv[1];
std::unique_ptr<Tokenizer> tokenizer_(new Tiktoken);
tokenizer_->load(tokenizer_path);
std::unique_ptr<Tokenizer> tokenizer(Tokenizer::createTokenizer(tokenizer_path));
const std::string system_str = "Youare a helpful assistant.";
const std::string user_str = "Hello";
// const std::string query = "\n<|im_start|>system\n" + system_str + "<|im_end|>\n<|im_start|>\n" + user_str + "<|im_end|>\n<|im_start|>assistant\n";
const std::string query = "\n<|im_start|>user\n" + user_str + "<|im_end|>\n<|im_start|>assistant\n";
// const std::string query = "<|im_start|>user\n" + user_str + "<|im_end|>\n<|im_start|>assistant\n";
// const std::string query = system_str + "\n" + user_str;
auto tokens = tokenizer_->encode(query);
auto tokens = tokenizer->encode(query);

std::string decode_str;
printf("encode tokens = [ ");
for (auto token : tokens) {
printf("%d, ", token);
decode_str += tokenizer_->decode(token);
decode_str += tokenizer->decode(token);
}
printf("]\n");
printf("decode str = %s\n", decode_str.c_str());
Expand Down
88 changes: 7 additions & 81 deletions include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,16 @@ class LlmConfig {
return config_.value("hidden_size", 4096);
}

int layer_num() const {
return config_.value("layer_num", 32);
}

std::vector<int> key_value_shape() const {
return config_.value("key_value_shape", std::vector<int>{});
}

std::vector<int> stop_ids() const {
return config_.value("stop_ids", std::vector<int>{});
std::string attention_mask() const {
return config_.value("attention_mask", "int");
}

std::string prompt_template() const {
Expand Down Expand Up @@ -257,23 +261,7 @@ class Phi_2 : public Chatglm2_6b {
virtual bool is_stop(int token_id) override;
};

class Qwen_7b : public Llm {
public:
Qwen_7b() {
model_name_ = "Qwen_7b";
layer_nums_ = 32;
key_value_shape_ = {2, 1, 0, 32, 128};
hidden_size_ = 4096;
tokenizer_.reset(new Tiktoken);
}
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual VARP gen_attention_mask(int seq_len) override;
virtual VARP gen_position_ids(int seq_len) override;
virtual bool is_stop(int token_id) override;
};

class Qwen_vl : public Qwen_7b {
class Qwen_vl : public Llm {
public:
Qwen_vl() {
model_name_ = "Qwen_vl";
Expand All @@ -296,17 +284,6 @@ class Qwen_vl : public Qwen_7b {
virtual VARP gen_attention_mask(int seq_len) override;
};

class Qwen_1_8b : public Qwen_7b {
public:
Qwen_1_8b() {
model_name_ = "Qwen_1.8b";
layer_nums_ = 24;
key_value_shape_ = {2, 1, 0, 16, 128};
hidden_size_ = 2048;
tokenizer_.reset(new Tiktoken);
}
};

class Llama2_7b : public Llm {
public:
Llama2_7b() {
Expand All @@ -321,57 +298,6 @@ class Llama2_7b : public Llm {
virtual bool is_stop(int token_id) override;
};

class Qwen2 : public Llama2_7b {
public:
Qwen2() {
model_name_ = "Qwen2";
tokenizer_.reset(new HuggingfaceTokenizer);
}
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual bool is_stop(int token_id) override;
};

class Qwen2_0_5b : public Qwen2 {
public:
Qwen2_0_5b() {
model_name_ = "Qwen2_0.5b";
layer_nums_ = 24;
key_value_shape_ = {2, 1, 16, 0, 64};
hidden_size_ = 1024;
}
};

class Qwen2_1_8b : public Qwen2 {
public:
Qwen2_1_8b() {
model_name_ = "Qwen2_1.8b";
layer_nums_ = 24;
key_value_shape_ = {2, 1, 16, 0, 128};
hidden_size_ = 2048;
}
};

class Qwen2_4b : public Qwen2 {
public:
Qwen2_4b() {
model_name_ = "Qwen2_4b";
layer_nums_ = 40;
key_value_shape_ = {2, 1, 20, 0, 128};
hidden_size_ = 2560;
}
};

class Qwen2_7b : public Qwen2 {
public:
Qwen2_7b() {
model_name_ = "Qwen2_7b";
layer_nums_ = 32;
key_value_shape_ = {2, 1, 32, 0, 128};
hidden_size_ = 4096;
}
};

class TinyLlama : public Llama2_7b {
public:
TinyLlama() {
Expand Down
36 changes: 26 additions & 10 deletions include/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,34 @@

class Tokenizer {
public:
static constexpr int MAGIC_NUMBER = 430;
enum TokenizerType {
SENTENCEPIECE = 0,
TIKTOIKEN = 1,
BERT = 2,
HUGGINGFACE = 3
};
Tokenizer() = default;
virtual ~Tokenizer() = default;
static Tokenizer* createTokenizer(const std::string& type);
virtual bool load(const std::string& filename) = 0;
virtual std::vector<int> encode(const std::string& str) = 0;
static Tokenizer* createTokenizer(const std::string& filename);
bool is_stop(int token);
std::vector<int> encode(const std::string& str);
virtual std::string decode(int id) = 0;
protected:
virtual void load_special_stop(std::ifstream& file, int special_num, int stop_num);
virtual bool load_vocab(std::ifstream& file) = 0;
virtual void encode(const std::string& str, std::vector<int>& ids) = 0;
std::vector<int> special_tokens_;
std::vector<int> stop_tokens_;
};

class Sentencepiece : public Tokenizer {
public:
Sentencepiece() = default;
virtual bool load(const std::string& filename) override;
virtual std::vector<int> encode(const std::string& str) override;
virtual std::string decode(int id) override;
protected:
virtual bool load_vocab(std::ifstream& file) override;
virtual void encode(const std::string& str, std::vector<int>& ids) override;
private:
enum ModelType {
UNIGRAM = 1,
Expand Down Expand Up @@ -77,18 +91,19 @@ class Sentencepiece : public Tokenizer {
class Tiktoken : public Tokenizer {
public:
Tiktoken() = default;
virtual bool load(const std::string& filename) override;
virtual std::vector<int> encode(const std::string& str) override;
virtual std::string decode(int id) override;
protected:
virtual bool load_vocab(std::ifstream& file) override;
virtual void encode(const std::string& str, std::vector<int>& ids) override;
std::unordered_map<std::string, int> encoder_;
std::vector<std::string> decoder_;
};

class BertTokenizer : public Tiktoken {
public:
BertTokenizer() = default;
virtual std::vector<int> encode(const std::string& str) override;
protected:
virtual void encode(const std::string& str, std::vector<int>& ids) override;
private:
std::vector<int> word_piece(const std::string& token);
};
Expand All @@ -105,9 +120,10 @@ struct hash_pair_wstring {
using BPERanks = std::unordered_map<std::pair<std::wstring, std::wstring>, int, hash_pair_wstring>;
public:
HuggingfaceTokenizer() = default;
virtual bool load(const std::string& filename) override;
virtual std::vector<int> encode(const std::string& str) override;
virtual std::string decode(int id) override;
protected:
virtual bool load_vocab(std::ifstream& file) override;
virtual void encode(const std::string& str, std::vector<int>& ids) override;
private:
void bpe(const std::wstring& token, const BPERanks& bpe_ranks, std::vector<std::wstring>* result);
BPERanks bpe_ranks_;
Expand Down
Loading

0 comments on commit db99291

Please sign in to comment.