Skip to content

Commit b4ebfc9

Browse files
authored
Fix spm converted FastTokenizer issue on non-ascii char (microsoft#778)
* Fix spm converted tokenizer issue on non-ascii char * remove pkg_resource in python
1 parent e113ed3 commit b4ebfc9

File tree

4 files changed

+72
-52
lines changed

4 files changed

+72
-52
lines changed

operators/tokenizer/bpe_decoder.hpp

+60-27
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ struct KernelBpeDecoder {
4242
return status;
4343
} else {
4444
auto um = ParseId2String(byte_decoder);
45-
std::transform(um.begin(), um.end(),
46-
std::inserter(byte_decoder_, byte_decoder_.end()),
47-
[](const auto& p) { return std::make_pair(static_cast<char32_t>(p.first),
48-
ort_extensions::narrow<unsigned char>(std::stoul(p.second))); });
45+
std::transform(um.begin(), um.end(), std::inserter(byte_decoder_, byte_decoder_.end()), [](const auto& p) {
46+
return std::make_pair(static_cast<char32_t>(p.first),
47+
ort_extensions::narrow<unsigned char>(std::stoul(p.second)));
48+
});
4949
}
5050

5151
std::string added_tokens;
@@ -59,8 +59,7 @@ struct KernelBpeDecoder {
5959
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids));
6060
if (!all_special_ids.empty()) {
6161
auto um = ParseId2String(all_special_ids);
62-
std::transform(um.begin(), um.end(),
63-
std::inserter(all_special_ids_, all_special_ids_.end()),
62+
std::transform(um.begin(), um.end(), std::inserter(all_special_ids_, all_special_ids_.end()),
6463
[](const auto& p) { return p.first; });
6564
}
6665

@@ -116,8 +115,29 @@ struct KernelBpeDecoder {
116115
arr_vocab_.shrink_to_fit();
117116
}
118117

119-
OrtxStatus Compute(const ortc::Tensor<int64_t>& ids,
120-
ortc::Tensor<std::string>& output) const {
118+
const std::string spm_underscore{"\xe2\x96\x81"};
119+
120+
static bool IsSpmByteWord(std::string_view word) {
121+
return word.size() == 6 && word[0] == '<' && word[1] == '0' && word[2] == 'x' && word[5] == '>';
122+
}
123+
124+
static std::string ReplaceAll(std::string_view s, const std::string& search, const std::string& replace) {
125+
std::string result;
126+
for (size_t pos = 0;; pos += search.length()) {
127+
auto new_pos = s.find(search, pos);
128+
if (new_pos == std::string::npos) {
129+
result += s.substr(pos, s.size() - pos);
130+
break;
131+
}
132+
result += s.substr(pos, new_pos - pos);
133+
result += replace;
134+
pos = new_pos;
135+
}
136+
137+
return result;
138+
}
139+
140+
OrtxStatus Compute(const ortc::Tensor<int64_t>& ids, ortc::Tensor<std::string>& output) const {
121141
const int64_t* p_ids = ids.Data();
122142
const auto& ids_dim = ids.Shape();
123143
std::vector<int64_t> output_dim = {1};
@@ -126,6 +146,8 @@ struct KernelBpeDecoder {
126146
std::copy(ids_dim.begin(), ids_dim.begin() + ids_dim.size() - 1, output_dim.begin());
127147
}
128148

149+
bool spm_mode = byte_decoder_.count(ustring(spm_underscore)[0]) > 0;
150+
129151
size_t seq_len = ids_dim.back();
130152
size_t string_batch = ids.NumberOfElement() / seq_len;
131153
std::vector<std::string> decoded_strings;
@@ -148,24 +170,37 @@ struct KernelBpeDecoder {
148170

149171
if (added_tokens_.count(token)) {
150172
const std::string& ws = added_tokens_.at(token);
151-
decoded_token = (std::string)ws;
173+
decoded_token.assign(ws);
152174
} else if (static_cast<size_t>(token) < arr_vocab_.size()) {
153-
const auto str = ustring(arr_vocab_[token]);
154-
for (auto wchr : str) {
155-
if (byte_decoder_.count(wchr) == 0) {
156-
if (wchr <= char32_t(0xFF)) {
157-
decoded_token.push_back(static_cast<char>(wchr));
158-
continue;
159-
}
160-
if (skip_special_tokens_) {
161-
continue;
162-
} else {
163-
decoded_token = unk_token_;
164-
break;
175+
const auto piece = arr_vocab_[token];
176+
if (spm_mode) {
177+
// sentencepiece case, which doesn't really have a byte decoder
178+
if ((IsSpmByteWord(piece))) {
179+
char buf[3] = {piece[3], piece[4], 0}; // something like <0x20>
180+
char token = {static_cast<char>(strtol(buf, NULL, 16))};
181+
decoded_token.push_back(token);
182+
} else {
183+
decoded_token.append(ReplaceAll(piece, spm_underscore, " "));
184+
}
185+
} else {
186+
// the common bpe case
187+
const auto str = ustring(piece);
188+
for (auto wchr : str) {
189+
if (byte_decoder_.count(wchr) == 0) {
190+
if (wchr <= char32_t(0xFF)) {
191+
decoded_token.push_back(static_cast<char>(wchr));
192+
continue;
193+
}
194+
if (skip_special_tokens_) {
195+
continue;
196+
} else {
197+
decoded_token = unk_token_;
198+
break;
199+
}
165200
}
201+
char uchr = byte_decoder_.at(wchr);
202+
decoded_token.push_back(uchr);
166203
}
167-
char uchr = byte_decoder_.at(wchr);
168-
decoded_token.push_back(uchr);
169204
}
170205
} else {
171206
if (skip_special_tokens_) {
@@ -183,15 +218,13 @@ struct KernelBpeDecoder {
183218
}
184219
}
185220

186-
if (whitespace_token_ &&
187-
f_special && (tok_idx > 0 && !f_special_last)) {
221+
if (whitespace_token_ && f_special && (tok_idx > 0 && !f_special_last)) {
188222
text.push_back(' ');
189223
}
190224

191225
text.append(decoded_token);
192226

193-
if (whitespace_token_ &&
194-
f_special && tok_idx != count - 1) {
227+
if (whitespace_token_ && f_special && tok_idx != count - 1) {
195228
text.push_back(' ');
196229
}
197230

operators/tokenizer/bpe_streaming.hpp

-19
Original file line numberDiff line numberDiff line change
@@ -53,25 +53,7 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
5353
return {};
5454
}
5555

56-
static std::string ReplaceAll(std::string_view s, const std::string& search, const std::string& replace) {
57-
std::string result;
58-
for (size_t pos = 0;; pos += search.length()) {
59-
auto new_pos = s.find(search, pos);
60-
if (new_pos == std::string::npos) {
61-
result += s.substr(pos, s.size() - pos);
62-
break;
63-
}
64-
result += s.substr(pos, new_pos - pos);
65-
result += replace;
66-
pos = new_pos;
67-
}
6856

69-
return result;
70-
}
71-
72-
static bool IsSpmByteWord(std::string_view word) {
73-
return word.size() == 6 && word[0] == '<' && word[1] == '0' && word[2] == 'x' && word[5] == '>';
74-
}
7557

7658
OrtxStatus Id2Token(extTokenId_t id,
7759
std::string& token,
@@ -119,7 +101,6 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
119101
}
120102

121103
OrtxStatus SpmId2Token(extTokenId_t id, std::string& token, bool& f_special_last) const {
122-
const char spm_underscore[] = "\xe2\x96\x81";
123104

124105
std::string piece = id < arr_vocab_.size() ? arr_vocab_[id] : "";
125106
bool f_special = false;

test/test_cliptok.py

-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import unittest
22
import numpy as np
33
import onnxruntime as _ort
4-
import pkg_resources
54

65
from pathlib import Path
76
from onnx import helper, onnx_pb as onnx_proto
@@ -150,8 +149,4 @@ def test_optional_outputs(self):
150149

151150

152151
if __name__ == "__main__":
153-
try:
154-
dist = pkg_resources.get_distribution('ftfy')
155-
except pkg_resources.DistributionNotFound:
156-
raise Exception("WARNING: ftfy is not installed - it is required for parity between CLIPTokenizer and CLIPTokenizerFast.")
157152
unittest.main()

test/test_fast_tokenizer.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,25 @@ def test_llama_tokenizer(self):
2121
np.testing.assert_array_equal(ids[0], actual_ids[0])
2222

2323
def test_mistral(self):
24-
tokenizer = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-v0.2", use_fast=True)
24+
tokenizer = AutoTokenizer.from_pretrained(
25+
"mistral-community/Mistral-7B-v0.2", use_fast=True)
2526
text = "\nOnce upon a time, I was really into monochromatic makeup looks. I have a lot of coppery and bronze "
2627
ids = tokenizer.encode(text, return_tensors="np")
2728

2829
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
2930
actual_ids, *_ = ort_inference(ort_tok, [text])
3031
np.testing.assert_array_equal(ids[0], actual_ids[0])
3132

33+
def test_phi_3_mini(self):
34+
tokenizer = AutoTokenizer.from_pretrained(
35+
"microsoft/Phi-3-mini-128k-instruct", use_fast=True)
36+
text = "what are you? \n 给 weiss ich, über was los ist \n"
37+
ids = tokenizer.encode(text, return_tensors="np")
38+
39+
ort_tok, _ = gen_processing_models(tokenizer, pre_kwargs={})
40+
actual_ids, *_ = ort_inference(ort_tok, [text])
41+
np.testing.assert_array_equal(ids[0], actual_ids[0][1:])
42+
3243

3344
if __name__ == '__main__':
3445
unittest.main()

0 commit comments

Comments
 (0)