@@ -44,11 +44,8 @@ class BpeModel {
44
44
}
45
45
}
46
46
47
- OrtxStatus Load (std::istream& vocab_stream,
48
- std::istream& merges_stream,
49
- const char * unk_token,
50
- const char * special_tokens,
51
- bool spm_converted) {
47
+ OrtxStatus Load (std::istream& vocab_stream, std::istream& merges_stream, const char * unk_token,
48
+ const char * special_tokens, bool spm_converted) {
52
49
nlohmann::json tok_json;
53
50
vocab_stream >> tok_json;
54
51
tok_json.get_to (vocab_map_);
@@ -125,9 +122,7 @@ class BpeModel {
125
122
return {};
126
123
}
127
124
128
- OrtxStatus Load (const json& bpe_model,
129
- const char * /* special_tokens */ ,
130
- bool spm_converted) {
125
+ OrtxStatus Load (const json& bpe_model, const char * /* special_tokens */ , bool spm_converted) {
131
126
const json& vocab_json = bpe_model[" vocab" ];
132
127
const json& merges_json = bpe_model[" merges" ];
133
128
vocab_json.get_to (vocab_map_);
@@ -195,8 +190,7 @@ class BpeModel {
195
190
}
196
191
197
192
OrtxStatus Load (std::unordered_map<std::string, uint32_t >& vocab,
198
- std::vector<std::pair<std::string, std::string>>& merges,
199
- const char * /* special_tokens */ ,
193
+ std::vector<std::pair<std::string, std::string>>& merges, const char * /* special_tokens */ ,
200
194
bool spm_converted) {
201
195
vocab_map_ = vocab;
202
196
@@ -207,7 +201,7 @@ class BpeModel {
207
201
}
208
202
209
203
uint32_t index = 0 ;
210
- for (auto & tuple : merges){
204
+ for (auto & tuple : merges) {
211
205
std::string w1 = tuple.first ;
212
206
std::string w2 = tuple.second ;
213
207
int token_length = ort_extensions::narrow<int >(w1.length () + w2.length ());
@@ -269,11 +263,10 @@ class BpeModel {
269
263
return {};
270
264
}
271
265
272
- std::vector<std::string> BuildDecoder () const {
273
- return id2token_map_;
274
- }
266
+ std::vector<std::string> BuildDecoder () const { return id2token_map_; }
275
267
276
- // REF: https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
268
+ // REF:
269
+ // https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
277
270
bpe::TokenPairs SplitByAddedAndSpecial (const ustring& input) const {
278
271
// split by added tokens
279
272
bpe::TokenPairs added_result;
@@ -343,9 +336,7 @@ class BpeModel {
343
336
}
344
337
}
345
338
346
- const auto & ByteEncoder () const {
347
- return byte_encoder_;
348
- }
339
+ const auto & ByteEncoder () const { return byte_encoder_; }
349
340
350
341
uint32_t GetTokenId (const std::string& key) const {
351
342
auto it = vocab_map_.find (key);
@@ -356,10 +347,18 @@ class BpeModel {
356
347
}
357
348
}
358
349
359
- const std::string& GetEndOfWordSuffix () const {
360
- return end_of_word_suffix_;
350
+ uint32_t GetAddedTokenId (const std::string& key) const {
351
+ size_t idx = 0 ;
352
+ int id = added_tokens_.FindLongest (ustring (key), idx);
353
+ if (idx == 0 ) {
354
+ return bpe::kInvalidTokenId ;
355
+ }
356
+
357
+ return static_cast <uint32_t >(id);
361
358
}
362
359
360
+ const std::string& GetEndOfWordSuffix () const { return end_of_word_suffix_; }
361
+
363
362
private:
364
363
struct BpeNode {
365
364
uint32_t id;
0 commit comments