forked from yilong2001/berts.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtokenization.h
161 lines (132 loc) · 4.39 KB
/
tokenization.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#ifndef CUBERT_TOKENIZATION_H
#define CUBERT_TOKENIZATION_H
#include <string>
#include <vector>
#include <unordered_map>
#include <iostream>
namespace bert {
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab);
/**
* Checks whether `chars` is a whitespace character.
* @param c
* @return
*/
bool _is_whitespace(int c);
/**
* Checks whether `chars` is a control character.
* @param c
* @return
*/
bool _is_control(int c);
/**
* Checks whether `chars` is a punctuation character.
* @param cp
* @return
*/
bool _is_punctuation(int cp);
/**
* Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/
class BasicTokenizer {
public:
/**
* Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input.
*/
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete;
virtual ~BasicTokenizer() = default;
/**
* Tokenizes a piece of text.
*
* to_lower
* _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text.
* _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text.
*
* @param text
* @param output_tokens
*/
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length);
private:
const bool do_lower_case;
/**
* Checks whether CP is the codepoint of a CJK character.
* @param cp
* @return
*/
inline static bool _is_chinese_char(int cp);
};
/**
* Runs WordPiece tokenziation.
*/
class WordpieceTokenizer {
public:
explicit WordpieceTokenizer(
std::unordered_map<std::string, uint64_t> *vocab,
std::string unk_token = "[UNK]",
int max_input_chars_per_word = 200
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete;
virtual ~WordpieceTokenizer() = default;
/**
* Tokenizes a piece of text into its word pieces.
*
* This uses a greedy longest-match-first algorithm to perform tokenization
* using the given vocabulary.
*
* For example:
* input = "unaffable"
* output = ["un", "##aff", "##able"]
*
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens.
*/
void tokenize(const std::string &text, std::vector<std::string> *output_tokens);
private:
const std::unordered_map<std::string, uint64_t> *vocab;
const std::string unk_token;
const int max_input_chars_per_word;
};
/**
* Runs end-to-end tokenziation.
*/
class FullTokenizer {
public:
/*explicit FullTokenizer(const char *vocab_file, bool do_lower_case = false) {
vocab = new std::unordered_map<std::string, uint64_t>();
load_vocab(vocab_file, vocab);
basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab);
}*/
explicit FullTokenizer(std::unordered_map<std::string, uint64_t>* _vocab, bool do_lower_case = false) {
vocab = _vocab;
basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab);
}
FullTokenizer(const FullTokenizer &other) = delete;
virtual ~FullTokenizer() {
delete wordpiece_tokenizer;
delete basic_tokenizer;
delete vocab;
}
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length);
inline uint64_t convert_token_to_id(const std::string &token) {
auto item = vocab->find(token);
if (item == vocab->end()) {
std::cerr << "vocab missing key: " << token << std::endl;
return 0;
} else {
return item->second;
}
}
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids);
private:
std::unordered_map<std::string, uint64_t> *vocab;
BasicTokenizer *basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer;
};
}
#endif //CUBERT_TOKENIZATION_H