@@ -42,10 +42,10 @@ struct KernelBpeDecoder {
42
42
return status;
43
43
} else {
44
44
auto um = ParseId2String (byte_decoder);
45
- std::transform (um.begin (), um.end (),
46
- std::inserter (byte_decoder_, byte_decoder_. end () ),
47
- []( const auto & p) { return std::make_pair ( static_cast < char32_t >(p. first ),
48
- ort_extensions::narrow< unsigned char >( std::stoul (p. second ))); });
45
+ std::transform (um.begin (), um.end (), std::inserter (byte_decoder_, byte_decoder_. end ()), []( const auto & p) {
46
+ return std::make_pair ( static_cast < char32_t >(p. first ),
47
+ ort_extensions::narrow< unsigned char >( std::stoul (p. second )));
48
+ });
49
49
}
50
50
51
51
std::string added_tokens;
@@ -59,8 +59,7 @@ struct KernelBpeDecoder {
59
59
ORTX_RETURN_IF_ERROR (OrtW::GetOpAttribute (info, " all_special_ids" , all_special_ids));
60
60
if (!all_special_ids.empty ()) {
61
61
auto um = ParseId2String (all_special_ids);
62
- std::transform (um.begin (), um.end (),
63
- std::inserter (all_special_ids_, all_special_ids_.end ()),
62
+ std::transform (um.begin (), um.end (), std::inserter (all_special_ids_, all_special_ids_.end ()),
64
63
[](const auto & p) { return p.first ; });
65
64
}
66
65
@@ -116,8 +115,29 @@ struct KernelBpeDecoder {
116
115
arr_vocab_.shrink_to_fit ();
117
116
}
118
117
119
- OrtxStatus Compute (const ortc::Tensor<int64_t >& ids,
120
- ortc::Tensor<std::string>& output) const {
118
+ const std::string spm_underscore{" \xe2\x96\x81 " };
119
+
120
+ static bool IsSpmByteWord (std::string_view word) {
121
+ return word.size () == 6 && word[0 ] == ' <' && word[1 ] == ' 0' && word[2 ] == ' x' && word[5 ] == ' >' ;
122
+ }
123
+
124
+ static std::string ReplaceAll (std::string_view s, const std::string& search, const std::string& replace) {
125
+ std::string result;
126
+ for (size_t pos = 0 ;; pos += search.length ()) {
127
+ auto new_pos = s.find (search, pos);
128
+ if (new_pos == std::string::npos) {
129
+ result += s.substr (pos, s.size () - pos);
130
+ break ;
131
+ }
132
+ result += s.substr (pos, new_pos - pos);
133
+ result += replace;
134
+ pos = new_pos;
135
+ }
136
+
137
+ return result;
138
+ }
139
+
140
+ OrtxStatus Compute (const ortc::Tensor<int64_t >& ids, ortc::Tensor<std::string>& output) const {
121
141
const int64_t * p_ids = ids.Data ();
122
142
const auto & ids_dim = ids.Shape ();
123
143
std::vector<int64_t > output_dim = {1 };
@@ -126,6 +146,8 @@ struct KernelBpeDecoder {
126
146
std::copy (ids_dim.begin (), ids_dim.begin () + ids_dim.size () - 1 , output_dim.begin ());
127
147
}
128
148
149
+ bool spm_mode = byte_decoder_.count (ustring (spm_underscore)[0 ]) > 0 ;
150
+
129
151
size_t seq_len = ids_dim.back ();
130
152
size_t string_batch = ids.NumberOfElement () / seq_len;
131
153
std::vector<std::string> decoded_strings;
@@ -148,24 +170,37 @@ struct KernelBpeDecoder {
148
170
149
171
if (added_tokens_.count (token)) {
150
172
const std::string& ws = added_tokens_.at (token);
151
- decoded_token = (std::string)ws ;
173
+ decoded_token. assign (ws) ;
152
174
} else if (static_cast <size_t >(token) < arr_vocab_.size ()) {
153
- const auto str = ustring (arr_vocab_[token]);
154
- for (auto wchr : str) {
155
- if (byte_decoder_.count (wchr) == 0 ) {
156
- if (wchr <= char32_t (0xFF )) {
157
- decoded_token.push_back (static_cast <char >(wchr));
158
- continue ;
159
- }
160
- if (skip_special_tokens_) {
161
- continue ;
162
- } else {
163
- decoded_token = unk_token_;
164
- break ;
175
+ const auto piece = arr_vocab_[token];
176
+ if (spm_mode) {
177
+ // sentencepiece case, which doesn't really have a byte decoder
178
+ if ((IsSpmByteWord (piece))) {
179
+ char buf[3 ] = {piece[3 ], piece[4 ], 0 }; // something like <0x20>
180
+ char token = {static_cast <char >(strtol (buf, NULL , 16 ))};
181
+ decoded_token.push_back (token);
182
+ } else {
183
+ decoded_token.append (ReplaceAll (piece, spm_underscore, " " ));
184
+ }
185
+ } else {
186
+ // the common bpe case
187
+ const auto str = ustring (piece);
188
+ for (auto wchr : str) {
189
+ if (byte_decoder_.count (wchr) == 0 ) {
190
+ if (wchr <= char32_t (0xFF )) {
191
+ decoded_token.push_back (static_cast <char >(wchr));
192
+ continue ;
193
+ }
194
+ if (skip_special_tokens_) {
195
+ continue ;
196
+ } else {
197
+ decoded_token = unk_token_;
198
+ break ;
199
+ }
165
200
}
201
+ char uchr = byte_decoder_.at (wchr);
202
+ decoded_token.push_back (uchr);
166
203
}
167
- char uchr = byte_decoder_.at (wchr);
168
- decoded_token.push_back (uchr);
169
204
}
170
205
} else {
171
206
if (skip_special_tokens_) {
@@ -183,15 +218,13 @@ struct KernelBpeDecoder {
183
218
}
184
219
}
185
220
186
- if (whitespace_token_ &&
187
- f_special && (tok_idx > 0 && !f_special_last)) {
221
+ if (whitespace_token_ && f_special && (tok_idx > 0 && !f_special_last)) {
188
222
text.push_back (' ' );
189
223
}
190
224
191
225
text.append (decoded_token);
192
226
193
- if (whitespace_token_ &&
194
- f_special && tok_idx != count - 1 ) {
227
+ if (whitespace_token_ && f_special && tok_idx != count - 1 ) {
195
228
text.push_back (' ' );
196
229
}
197
230
0 commit comments